├── README.md ├── featureExtractor ├── bcat_client │ ├── pom.xml │ └── src │ │ └── main │ │ └── java │ │ └── thusca │ │ └── bcat │ │ └── client │ │ ├── ClientApplication.java │ │ ├── consumer │ │ ├── BinFileFeatureExtractTest.java │ │ ├── Task2ExtractCoreFedora.java │ │ ├── TaskExtractFeatureLibs13.java │ │ └── TaskProcessTargets.java │ │ ├── entity │ │ ├── BaseFile.java │ │ ├── BinFileFeature.java │ │ ├── BinaryFile.java │ │ ├── FeatureExtractStatus.java │ │ └── FunctionFeature.java │ │ ├── model │ │ └── CIdModel.java │ │ ├── service │ │ ├── ExtractService.java │ │ ├── GetBinFeatureService.java │ │ ├── GetBinFileService.java │ │ └── SaveToJsonService.java │ │ └── utils │ │ ├── BinaryAnalyzer.java │ │ ├── FileUtil.java │ │ ├── LibmagicJnaWrapper.java │ │ ├── LibmagicJnaWrapperBean.java │ │ ├── StatusMsg.java │ │ └── libghidra │ │ ├── LibGhidra.java │ │ ├── LibHeadlessAnalyzer.java │ │ ├── LibHeadlessErrorLogger.java │ │ ├── LibHeadlessOptions.java │ │ ├── LibHeadlessScript.java │ │ ├── LibHeadlessTimedTaskMonitor.java │ │ └── LibProgramHandler.java └── pom.xml ├── main ├── __init__.py └── torch │ ├── __init__.py │ ├── analyze_results.py │ ├── b2sfinder_afcg.py │ ├── base_afcg.py │ ├── build_milvus_database.py │ ├── core_fedora_embeddings.py │ ├── dataset.py │ ├── eval.py │ ├── eval_re_large.py │ ├── func2vec.py │ ├── function_vector_channel.py │ ├── generate_vec_index.py │ ├── get_data_gemini_format.py │ ├── get_threshold.py │ ├── get_validation_pairs.py │ ├── milvus_mod.py │ ├── run.sh │ ├── torch_main.py │ ├── torch_model.py │ ├── utils.py │ └── utils_loss.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # binary_tpl_detection 2 | 3 | Dataset url: https://figshare.com/s/4a007e78f29243531b8c 4 | 5 | ## Feature Extractor 6 | - The extractor extracts features from all binary files under a given directory and save features to a json file. 7 | - Input: directory 8 | - Output: two files, stored in a given target directory. 9 | - Information such as running time is stored in the `status` file. 10 | - Extracted features are stored in the features file, such as `9760608.json`. The format of this json is a list of BinaryFile entity. 11 | - It is recommended to put your task code under `consumer` directory (in `featureExtractor/bcat_client/src/main/java/thusca/bcat/client/consumer`). See the example in `consumer/BinFileFeatureExtractTest.java` 12 | 13 | ### Pre-requisites 14 | Basic knowledge about Java Development, Springboot and Annotation Development.
15 | For example, if you use IDE like VScode or Idea, basic java development environment need to be installed such as `Java Extension Pack`, `MAVEN for JAVA`. It should be noted that we use Lombok Annotation and Springboot in code that may depend on extensions `Lombok Annotations Support` and `Spring Boot Tools` for IDE to debug or run. Besides, LibmagicJnaWrapper depends on libmagic to get file type, please install this library and modify the paths in LibmagicJnaWrapper.java. It can be easily installed using apt/brew command on Linux/MacOS. 16 | 17 | ### Build Artifact 18 | Env: 19 | - Java: Java 11. 20 | - IntelliJ Idea. (We have found that the extractor artifact works well only under IntelliJ Idea to build the artifact. Tested successful under Windows IntelliJ Idea 2021.2) 21 | 22 | Steps: 23 | 1. Ghidra: 9.1.2. The file `ghidra.jar` is stored under `/user/lib/ghidra.jar` you should put it under `/featureExtractor/bcat_client/lib` first. 24 | 2. Open Idea, open project "binary_lib_detection-main\featureExtractor". Wait until indexing finish, if error occurs, try reopen/clean the project. 25 | 3. File -> Project Structure -> Project SDK, select Java SDK 11. 26 | 4. File -> Project Structure -> Artifacts -> "+" -> jar -> from modules with dependencies -> Module ("bcat_client") -> Main Class ("ClientApplication") -> JAR files from libraries (select `copy to the output directory and link via manifest`) 27 | 5. The jars will be generated at path: featureExtractor\out\artifacts\bcat_client_jar, with `bcat_client.jar` inside. 28 | 29 | ### Task 30 | Methods for all tasks are stored under the directory `/consumer`. 31 | Building database: Code:`Task2ExtractCoreFedora.java`, Data: `FedoraLib_Dataset`. Set tha save path and get all features to build TPL feature database. We use the directory `../data/CoreFedoraFeatureJson0505` to represent the save path. 32 | 33 | ### Run 34 | Zip the bcat_client_jar folder and upload to a Linux server, unzip, and run: 35 | ```shell 36 | java -jar bcat_client.jar 37 | ``` 38 | 39 | Note: Java 11 required. 40 | 41 | ## Func similarity Model 42 | This model is used to determine if two functions are similar based on [Gemini](https://github.com/xiaojunxu/dnn-binary-code-similarity) Network. 43 | 44 | Prepration and Data 45 | Data is stored in `../data/vector_deduplicate_gemini_format_less_compilation_cases`.
46 | or Cross-5C_Dataset.7z on figshare. 47 | 48 | By default, we use the path `../data` under `main/torch` to store the data. Please copy them under it. 49 | 50 | ### Environment Step 51 | The network is written using Torch 1.8 in Python 3.8. Torch installation is based on cuda 11. 52 | 53 | ``` 54 | conda create -n tpldetection python=3.8 ipykernel 55 | bash 56 | conda activate tpldetection 57 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html 58 | pip install -r requirements.txt 59 | ``` 60 | 61 | Milvus v1.1.1(vector search engine) is necessary for function retrival. It requires docker 19.03 or higher 62 | ref: https://milvus.io/docs/v1.1.1/milvus_docker-gpu.md 63 | ```shell 64 | sudo docker pull milvusdb/milvus:1.1.1-gpu-d061621-330cc6 65 | mkdir -p /home/$USER/milvus/conf 66 | cd /home/$USER/milvus/conf 67 | wget https://raw.githubusercontent.com/milvus-io/milvus/v1.1.1/core/conf/demo/server_config.yaml 68 | 69 | sudo docker run -d --name milvus_gpu_1.1.1 --gpus all \ 70 | -p 19530:19530 \ 71 | -p 19121:19121 \ 72 | -v /home/$USER/milvus/db:/var/lib/milvus/db \ 73 | -v /home/$USER/milvus/conf:/var/lib/milvus/conf \ 74 | -v /home/$USER/milvus/logs:/var/lib/milvus/logs \ 75 | -v /home/$USER/milvus/wal:/var/lib/milvus/wal \ 76 | milvusdb/milvus:1.1.1-gpu-d061621-330cc6 77 | ``` 78 | 79 | ## Run 80 | Run the following command to train the model: 81 | ```shell 82 | # train/validation dataset: /data/func_comparison/vector_deduplicate_our_format_less_compilation_cases/train_test 83 | # test dataset: /data/func_comparison/vector_deduplicate_our_format_less_compilation_cases/valid 84 | cd main/torch 85 | bash run.sh 86 | ``` 87 | A trained model is saved under `../data/7fea_contra_torch_b128/saved_model/` 88 | 89 | ## Library detection 90 | 91 | ### Database 92 | #### Embedding 93 | raw feature database: `../data/CoreFedoraFeatureJson0505` 94 | 95 | Embeddings: 96 | set the path `../data/CoreFedoraFeatureJson0505` as `args.fedora_js`. 97 | You can use mutilprocess to speed up and the code is writen in `core_fedora_embeddings.py` as follows: 98 | ```python 99 | with Pool(10) as p: 100 | p.starmap(core_fedora_embedding, [(i, True) for i in range(10)]) 101 | ``` 102 | all embeddings are saved under the `args.save_path`. 103 | We use the path `../data/7fea_contra_torch_b128/core_funcs` to represent it. 104 | 105 | #### Indexing and Building Milvus dataset 106 | run `build_milvus_database.py` to build function vector database using Mulvis. 107 | 108 | the function `get_bin_fcg` is used to generate an indexing file containing binary to functions to accelarate. 109 | 110 | `get_bin2func_num` generates an indexing from binary to the number of funtions in it. 111 | 112 | 113 | #### Detection 114 | Data: detection_targets. Firstly, extract features from APKs. See the method `localExtractOSSPoliceApks` in `TaskProcessTargets.java` under the directory `consumer`. We use the directory`../data/detection_targets/feature_json` to save all extracted features. 115 | 116 | see the function `detect_v2` in function_vector_channel. 117 | Other methods + FCG Filter can be seen in files `xxx_afcg.py`. 118 | Baselines are under the directory `/related_work`. 119 | 120 | We combine basic feature channel (B2SFinder(basic features) + FCG Filter) and function vector channel together to report the final results. 121 | 122 | All files named `analyze_results.py` are used to calculate precision and recall. 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | bcat 5 | thusca 6 | 0.0.1-SNAPSHOT 7 | 8 | 4.0.0 9 | 10 | bcat_client 11 | 12 | 13 | 14 | com.google.guava 15 | guava 16 | 20.0 17 | 18 | 19 | 20 | ghidra 21 | ghidra 22 | 1.0 23 | system 24 | ${project.basedir}/lib/ghidra.jar 25 | 26 | 27 | com.sun.jna 28 | jna 29 | 3.0.9 30 | compile 31 | 32 | 33 | com.github.junrar 34 | junrar 35 | 3.0.0 36 | 37 | 38 | 39 | org.springframework.boot 40 | spring-boot-configuration-processor 41 | true 42 | 43 | 44 | org.springframework.boot 45 | spring-boot-test 46 | 47 | 48 | org.springframework.boot 49 | spring-boot-starter-test 50 | test 51 | 52 | 53 | org.junit.jupiter 54 | junit-jupiter-api 55 | 5.6.2 56 | 57 | 58 | net.sf.sevenzipjbinding 59 | sevenzipjbinding 60 | 9.20-2.00beta 61 | 62 | 63 | org.apache.logging.log4j 64 | log4j-api 65 | 2.13.3 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | src/main/java 74 | 75 | **/*.properties 76 | **/*.xml 77 | 78 | false 79 | 80 | 81 | src/main/resources 82 | 83 | 84 | 85 | 86 | 87 | 88 | org.springframework.boot 89 | spring-boot-maven-plugin 90 | 2.3.4.RELEASE 91 | 92 | thusca.bcat.client.ClientApplication 93 | 94 | 95 | 96 | 97 | repackage 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/ClientApplication.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client; 2 | 3 | import org.springframework.boot.SpringApplication; 4 | import org.springframework.boot.autoconfigure.SpringBootApplication; 5 | 6 | @SpringBootApplication 7 | public class ClientApplication { 8 | public static void main(String[] args){ 9 | SpringApplication.run(ClientApplication.class, args); 10 | } 11 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/consumer/BinFileFeatureExtractTest.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.consumer; 2 | 3 | import org.slf4j.Logger; 4 | import org.slf4j.LoggerFactory; 5 | import org.springframework.beans.factory.ObjectFactory; 6 | import org.springframework.beans.factory.annotation.Autowired; 7 | import org.springframework.beans.factory.annotation.Qualifier; 8 | import org.springframework.boot.ApplicationArguments; 9 | import org.springframework.boot.ApplicationRunner; 10 | import org.springframework.stereotype.Component; 11 | import thusca.bcat.client.service.ExtractService; 12 | 13 | //@Component 14 | public class BinFileFeatureExtractTest implements ApplicationRunner { 15 | @Autowired 16 | @Qualifier("ExtractService") 17 | ObjectFactory extractServiceObjectFactory; 18 | 19 | @Override 20 | public void run(ApplicationArguments args) throws Exception { 21 | Logger logger = LoggerFactory.getLogger(this.getClass()); 22 | String ghidraTmp = "/mnt/c/Users/user/Desktop/tmp/ghidraTmp"; 23 | String unzippedPackagePath = "/mnt/c/Users/user/Desktop/tmp/binaryTarget/test"; 24 | String jsonFileRootPath = "/mnt/c/Users/user/Desktop/tmp/saveJson"; 25 | int packageId = 12345678; 26 | long startTime = System.currentTimeMillis(); 27 | try { 28 | ExtractService extractService = extractServiceObjectFactory.getObject(); 29 | extractService.init(unzippedPackagePath, jsonFileRootPath, ghidraTmp, packageId); 30 | extractService.executable(); 31 | logger.info(Thread.currentThread().getName() + " [Done]: "+ packageId); 32 | } catch (Exception e) { 33 | e.printStackTrace(); 34 | } 35 | long endTime = System.currentTimeMillis(); 36 | logger.info("running time: " + (endTime - startTime)/1000 + "s"); 37 | } 38 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/consumer/Task2ExtractCoreFedora.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.consumer; 2 | 3 | import org.slf4j.Logger; 4 | import org.slf4j.LoggerFactory; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.beans.factory.annotation.Value; 7 | import org.springframework.boot.ApplicationArguments; 8 | import org.springframework.boot.ApplicationRunner; 9 | import org.springframework.stereotype.Component; 10 | import org.springframework.beans.factory.ObjectFactory; 11 | import thusca.bcat.client.service.ExtractService; 12 | 13 | import java.io.File; 14 | import java.util.concurrent.*; 15 | 16 | // @Component 17 | public class Task2ExtractCoreFedora implements ApplicationRunner { 18 | 19 | private final Logger logger = LoggerFactory.getLogger(this.getClass()); 20 | 21 | // tmp path 22 | @Value("${ghidra.tmp.path}") 23 | private String ghidraTmp; 24 | 25 | // save path 26 | @Value("${json.file.path}") 27 | private String jsonFilePath; 28 | 29 | // set pool size 30 | @Value("${core.pool.size}") 31 | private int CORE_POOL_SIZE; 32 | @Value("${core.pool.size}") 33 | private int MAX_POOL_SIZE; 34 | private static final int QUEUE_CAPACITY = 150; 35 | private static final Long KEEP_ALIVE_TIME = 1L; 36 | 37 | @Autowired 38 | ObjectFactory extractServiceObjectFactory; 39 | 40 | private String rootPath = "../data/FedoraLib_Dataset"; 41 | 42 | @Override 43 | public void run(ApplicationArguments args) throws Exception { 44 | logger.info("Client start......"); 45 | long startTime = System.currentTimeMillis(); 46 | 47 | extractPackage(); 48 | 49 | long endTime = System.currentTimeMillis(); 50 | logger.info("run time:" + (endTime - startTime) + "ms"); 51 | System.exit(0); 52 | } 53 | 54 | public void extractPackage() { 55 | ThreadPoolExecutor cachedThreadPool = new ThreadPoolExecutor(CORE_POOL_SIZE, MAX_POOL_SIZE, KEEP_ALIVE_TIME, 56 | TimeUnit.SECONDS, new ArrayBlockingQueue<>(QUEUE_CAPACITY), new ThreadPoolExecutor.CallerRunsPolicy()); 57 | 58 | File rootDir = new File(rootPath); 59 | 60 | for (File firstLevel : rootDir.listFiles()) { 61 | if (!firstLevel.isDirectory()) { 62 | continue; 63 | } 64 | String[] firstLevelStrings = firstLevel.toString().split("/", -1); 65 | String firstLevelId = firstLevelStrings[firstLevelStrings.length - 1]; 66 | for (File secondLevel : firstLevel.listFiles()) { 67 | if (!secondLevel.isDirectory()) { 68 | continue; 69 | } 70 | String[] secondLevelStrings = secondLevel.toString().split("/", -1); 71 | String secondLevelId = secondLevelStrings[secondLevelStrings.length - 1]; 72 | for (File packageDir : secondLevel.listFiles()) { 73 | if (!packageDir.isDirectory()) { 74 | continue; 75 | } 76 | String[] packageStrings = packageDir.toString().split("/", -1); 77 | String packageId = packageStrings[packageStrings.length - 1]; 78 | String jsonFileName = packageId+ ".json"; 79 | String savePath = jsonFilePath + "/" + firstLevelId + "/" + secondLevelId + "/" + packageId; 80 | File targetJsonFile = new File(savePath, jsonFileName); 81 | if (targetJsonFile.exists()) { 82 | logger.info("package has been processed: " + packageId); 83 | continue; 84 | } 85 | logger.info("package to be processed: " + packageId); 86 | process(packageDir.toString(), savePath, ghidraTmp, Integer.parseInt(packageId)); 87 | // CountDownLatch threadSignal = new CountDownLatch(1); 88 | // cachedThreadPool.submit(new Runnable() { 89 | // @Override 90 | // public void run() { 91 | // try { 92 | // if (!packageDir.exists()) { 93 | // System.out.println("no exist"); 94 | // } 95 | 96 | // process(packageDir.toString(), savePath, ghidraTmp, Integer.parseInt(packageId)); 97 | 98 | // } catch (Exception e) { 99 | // logger.info("error: " + e + packageDir.toString()); 100 | // } finally { 101 | // threadSignal.countDown(); 102 | // } 103 | // } 104 | // }); 105 | } 106 | } 107 | } 108 | 109 | // cachedThreadPool.shutdown(); 110 | // try { 111 | // cachedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.MINUTES); 112 | // } catch (InterruptedException e) { 113 | // e.printStackTrace(); 114 | // } 115 | } 116 | 117 | public void process(String packageDir, String savePath, String ghidraTmp, int packageId) { 118 | long startTime = System.currentTimeMillis(); 119 | try { 120 | ExtractService extractService = extractServiceObjectFactory.getObject(); 121 | extractService.init(packageDir.toString(), savePath, ghidraTmp, packageId); 122 | extractService.executable(); 123 | logger.info(Thread.currentThread().getName() + " extracted:" + packageDir.toString()); 124 | } catch (Exception e) { 125 | logger.info("exception in processing:" + e); 126 | } 127 | logger.info("run time:" + (System.currentTimeMillis() - startTime) / 1000 + "s"); 128 | } 129 | } 130 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/consumer/TaskExtractFeatureLibs13.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.consumer; 2 | 3 | import org.slf4j.Logger; 4 | import org.slf4j.LoggerFactory; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.boot.ApplicationArguments; 7 | import org.springframework.boot.ApplicationRunner; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.stereotype.Component; 10 | import thusca.bcat.client.entity.BinaryFile; 11 | import thusca.bcat.client.entity.FeatureExtractStatus; 12 | import thusca.bcat.client.service.GetBinFileService; 13 | import thusca.bcat.client.utils.FileUtil; 14 | import org.springframework.beans.factory.ObjectFactory; 15 | import thusca.bcat.client.service.ExtractService; 16 | 17 | import java.io.File; 18 | import java.io.IOException; 19 | import java.util.List; 20 | 21 | / @Component 22 | public class TaskExtractFeatureLibs13 implements ApplicationRunner { 23 | private final Logger logger = LoggerFactory.getLogger(this.getClass()); 24 | @Autowired 25 | ObjectFactory extractServiceObjectFactory; 26 | 27 | @Override 28 | public void run(ApplicationArguments args) throws Exception { 29 | logger.info("Client start......"); 30 | long startTime = System.currentTimeMillis(); 31 | 32 | localExtract(); 33 | 34 | long endTime = System.currentTimeMillis(); 35 | logger.info(" " + (endTime - startTime) + "ms"); 36 | System.exit(0); 37 | } 38 | 39 | public void localExtract() { 40 | String libsPath = "/mnt/c/Users/user/Desktop/data/binaryfiles13repos"; 41 | String ghidraTmp = "/mnt/c/Users/user/Desktop/tmp/ghidraTmp"; 42 | String jsonFileRootPath = "/mnt/c/Users/user/Desktop/data/featureJson"; 43 | File prefixFile = new File(libsPath); 44 | 45 | for (File lib : prefixFile.listFiles()) { 46 | if (!lib.isDirectory()) { 47 | continue; 48 | } 49 | String[] sufNames = lib.toString().split("/", -1); 50 | String libName = sufNames[sufNames.length - 1]; 51 | System.out.println(libName); 52 | 53 | for (File compilationCase : lib.listFiles()){ 54 | if (!compilationCase.isDirectory()) { 55 | continue; 56 | } 57 | sufNames = compilationCase.toString().split("/", -1); 58 | String caseName = sufNames[sufNames.length - 1]; 59 | System.out.println(caseName); 60 | long startTime = System.currentTimeMillis(); 61 | String savePath = jsonFileRootPath + "/" + libName + "/" + caseName; 62 | try{ 63 | ExtractService extractService = extractServiceObjectFactory.getObject(); 64 | extractService.init(compilationCase.toString(), savePath, ghidraTmp, 0); 65 | extractService.executable(); 66 | logger.info(Thread.currentThread().getName() + " 提取完成: " + (System.currentTimeMillis()-startTime) / 1000 + "s"); 67 | } catch (Exception e) { 68 | e.printStackTrace(); 69 | } 70 | } 71 | } 72 | } 73 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/consumer/TaskProcessTargets.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.consumer; 2 | 3 | import org.slf4j.Logger; 4 | import org.slf4j.LoggerFactory; 5 | import org.springframework.beans.factory.annotation.Autowired; 6 | import org.springframework.boot.ApplicationArguments; 7 | import org.springframework.boot.ApplicationRunner; 8 | import org.springframework.beans.factory.annotation.Value; 9 | import org.springframework.stereotype.Component; 10 | import thusca.bcat.client.entity.BinaryFile; 11 | import thusca.bcat.client.entity.FeatureExtractStatus; 12 | import thusca.bcat.client.service.GetBinFileService; 13 | import thusca.bcat.client.utils.FileUtil; 14 | import org.springframework.beans.factory.ObjectFactory; 15 | import thusca.bcat.client.service.ExtractService; 16 | 17 | import java.io.File; 18 | import java.io.IOException; 19 | import java.util.List; 20 | 21 | // @Component 22 | public class TaskProcessTargets implements ApplicationRunner { 23 | private final Logger logger = LoggerFactory.getLogger(this.getClass()); 24 | @Autowired 25 | ObjectFactory extractServiceObjectFactory; 26 | 27 | @Override 28 | public void run(ApplicationArguments args) throws Exception { 29 | logger.info("Client start......"); 30 | long startTime = System.currentTimeMillis(); 31 | 32 | localExtractOSSPoliceApks(); 33 | // localExtractLibDXApks(); 34 | 35 | long endTime = System.currentTimeMillis(); 36 | logger.info("time cost: " + (endTime - startTime) + "ms"); 37 | System.exit(0); 38 | } 39 | 40 | public void localExtractOSSPoliceApks() { 41 | String libsPath = "/mnt/c/Users/user/Desktop/detection_targets"; 42 | String ghidraTmp = "/mnt/c/Users/user/Desktop/tmp/ghidraTmp"; 43 | String jsonFileRootPath = "/mnt/c/Users/user/Desktop/data/featureJson"; 44 | File prefixFile = new File(libsPath); 45 | 46 | for (File lib : prefixFile.listFiles()) { 47 | String[] sufNames = lib.toString().split("/", -1); 48 | String libName = sufNames[sufNames.length - 1]; 49 | System.out.println(libName); 50 | 51 | long startTime = System.currentTimeMillis(); 52 | String savePath = jsonFileRootPath + "/" + libName + "/"; 53 | try{ 54 | ExtractService extractService = extractServiceObjectFactory.getObject(); 55 | extractService.init(lib.toString(), savePath, ghidraTmp, 0); 56 | extractService.executable(); 57 | logger.info(Thread.currentThread().getName() + " done: " + (System.currentTimeMillis()-startTime) / 1000 + "s"); 58 | } catch (Exception e) { 59 | e.printStackTrace(); 60 | } 61 | } 62 | } 63 | 64 | public void localExtractLibDXApks(){ 65 | String libsPath = "/mnt/c/Users/user/Desktop/detection_targets/unzipped_packages/DesktopApps"; 66 | String ghidraTmp = "/mnt/c/Users/user/Desktop/tmp/ghidraTmp"; 67 | String jsonFileRootPath = "/mnt/c/Users/user/Desktop/detection_targets/features/libdx_desktop"; 68 | File prefixFile = new File(libsPath); 69 | for (File app : prefixFile.listFiles()) { 70 | String[] sufNames = app.toString().split("/", -1); 71 | String appName = sufNames[sufNames.length - 1]; 72 | System.out.println(appName); 73 | for (File target: app.listFiles()) { 74 | sufNames = target.toString().split("/", -1); 75 | String targetName = sufNames[sufNames.length - 1]; 76 | System.out.println(targetName); 77 | long startTime = System.currentTimeMillis(); 78 | String savePath = jsonFileRootPath + "/" + appName + "/" + targetName + "/"; 79 | File savePathFile = new File(savePath); 80 | if (savePathFile.exists()) { 81 | continue; 82 | } 83 | 84 | try{ 85 | ExtractService extractService = extractServiceObjectFactory.getObject(); 86 | extractService.init(target.toString(), savePath, ghidraTmp, 0); 87 | extractService.executable(); 88 | logger.info(Thread.currentThread().getName() + " 提取完成: " + (System.currentTimeMillis()-startTime) / 1000 + "s"); 89 | } catch (Exception e) { 90 | e.printStackTrace(); 91 | } 92 | } 93 | } 94 | } 95 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/entity/BaseFile.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.entity; 2 | 3 | import java.io.File; 4 | 5 | import lombok.Data; 6 | 7 | @Data 8 | public class BaseFile { 9 | protected String filePath; 10 | protected String fileName; 11 | protected Boolean isProcessed = false; 12 | protected long byteSize; 13 | 14 | public BaseFile(String filePath) { 15 | File tempFile = new File(filePath); 16 | this.filePath = filePath; 17 | this.fileName = tempFile.getName(); 18 | this.byteSize = tempFile.length(); 19 | } 20 | 21 | public BaseFile(String filePath, String fileName) { 22 | this.filePath = filePath; 23 | this.fileName = fileName; 24 | this.byteSize = new File(filePath).length(); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/entity/BinFileFeature.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.entity; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import lombok.Data; 7 | 8 | @Data 9 | public class BinFileFeature { 10 | private String fileName; 11 | private String fileType; 12 | private List importFunctionNames = new ArrayList<>(); 13 | private List exportFunctionNames = new ArrayList<>(); 14 | private List stringConstants = new ArrayList<>(); 15 | private List functions = new ArrayList<>(); 16 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/entity/BinaryFile.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.entity; 2 | 3 | import lombok.Data; 4 | @Data 5 | public class BinaryFile extends BaseFile { 6 | protected BinFileFeature binFileFeature; 7 | private String formattedFileName; 8 | private String fileType; 9 | 10 | public BinaryFile(String filePath) { 11 | super(filePath); 12 | } 13 | 14 | public BinaryFile(String filePath, String fileName) { 15 | super(filePath, fileName); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/entity/FeatureExtractStatus.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.entity; 2 | 3 | import lombok.Data; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | @Data 9 | public class FeatureExtractStatus { 10 | private boolean getBinFiles = false; 11 | private List binFileNameList = new ArrayList<>(); 12 | private long getBinFileTime = 0; 13 | 14 | private List successfullyExtractedBinFeatureList = new ArrayList<>(); 15 | private List failedExtractedBinFeatureList = new ArrayList<>(); 16 | private long getBinFeatureTime = 0; 17 | 18 | private List successfullySavedJsonList = new ArrayList<>(); 19 | private List failedSavedJsonList = new ArrayList<>(); 20 | private long saveJsonTime = 0; 21 | 22 | public int extracted = 0; 23 | private int extractedStatus = 0; 24 | private long extractedTime = 0; 25 | 26 | private List errorMessages = new ArrayList<>(); 27 | 28 | public void addSuccessfullyExtractedBinFeature(String binFileName, long time, long byteSize) { 29 | successfullyExtractedBinFeatureList.add(new successfullyExtractedBinFeature(binFileName, time, byteSize)); 30 | } 31 | 32 | public void addFailedExtractedBinFeature(String binFileName, String errorMessage) { 33 | failedExtractedBinFeatureList.add(new failedExtractedBinFeature(binFileName, errorMessage)); 34 | } 35 | 36 | public void addSuccessfullySavedJson(String binFileName, long time, long byteSize) { 37 | successfullySavedJsonList.add(new successfullySavedJson(binFileName, time, byteSize)); 38 | } 39 | 40 | public void addfailedSavedJson(String binFileName, String errorMessage) { 41 | failedSavedJsonList.add(new failedSavedJson(binFileName, errorMessage)); 42 | } 43 | 44 | 45 | } 46 | 47 | @Data 48 | class successfullyExtractedBinFeature { 49 | private String binFileName; 50 | private long byteSize; 51 | private long time; 52 | successfullyExtractedBinFeature(String binFileName, long time, long byteSize) { 53 | this.binFileName = binFileName; 54 | this.time = time; 55 | this.byteSize = byteSize; 56 | } 57 | } 58 | 59 | @Data 60 | class failedExtractedBinFeature { 61 | private String binFileName; 62 | private String errorMessage; 63 | failedExtractedBinFeature(String binFileName, String errorMessage) { 64 | this.binFileName = binFileName; 65 | this.errorMessage = errorMessage; 66 | } 67 | } 68 | 69 | @Data 70 | class successfullySavedJson { 71 | private String binFileName; 72 | private long time; 73 | private long byteSize; 74 | successfullySavedJson(String binFileName, long time, long byteSize) { 75 | this.binFileName = binFileName; 76 | this.time = time; 77 | this.byteSize = byteSize; 78 | } 79 | } 80 | 81 | @Data 82 | class failedSavedJson { 83 | private String binFileName; 84 | private String errorMessage; 85 | failedSavedJson(String binFileName, String errorMessage) { 86 | this.binFileName = binFileName; 87 | this.errorMessage = errorMessage; 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/entity/FunctionFeature.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.entity; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | import java.util.Map; 7 | 8 | import lombok.Data; 9 | 10 | @Data 11 | public class FunctionFeature { 12 | private String functionName = ""; 13 | private String functionType = ""; 14 | private List args = new ArrayList<>(); 15 | private String functionSignature = ""; 16 | private String entryPoint = ""; 17 | private Boolean isExportFunction; 18 | private Boolean isImportFunction; 19 | private Boolean isThunkFunction; 20 | private Boolean isInline; 21 | private String memoryBlock = ""; 22 | private int edges; 23 | private int nodes; 24 | private int exits; 25 | private int complexity; 26 | private String cfSignature = ""; 27 | private String cfBody = ""; 28 | private int variables; 29 | 30 | private List instructionBytes = new ArrayList<>(); 31 | private List instructions = new ArrayList<>(); 32 | private List opcodes = new ArrayList<>(); 33 | private List pcodeInstr = new ArrayList<>(); 34 | private List callingFunctionAddresses = new ArrayList<>(); 35 | private List callingFunctionsByPointer = new ArrayList<>(); 36 | private List calledFunctionAddresses = new ArrayList<>(); 37 | private List calledFunctionsByPointer = new ArrayList<>(); 38 | private List calledStrings = new ArrayList<>(); 39 | private List calledImports = new ArrayList<>(); 40 | private List calledData = new ArrayList<>(); 41 | private int[] pcodes = new int[]{}; 42 | 43 | private List> nodesCFG = new ArrayList<>(); 44 | private List> edgesCFG = new ArrayList<>(); 45 | 46 | private List> edgePairs = new ArrayList<>(); 47 | private List> nodeGeminiVectors = new ArrayList<>(); 48 | private List> nodeGhidraVectors = new ArrayList<>(); 49 | private List> nodesAsm = new ArrayList<>(); 50 | private List> nodesPcode = new ArrayList<>(); 51 | private List> intConstants = new ArrayList<>(); 52 | private List> stringConstants = new ArrayList<>(); 53 | 54 | 55 | // private List pcodes = new ArrayList(Collections.nCopies(73, 0)); 56 | 57 | // static Map pcodeIndex = Map.ofEntries( 58 | // Map.entry("COPY", 0), 59 | // Map.entry("INT_ADD", 1), 60 | // Map.entry("BOOL_OR", 2), 61 | // Map.entry("LOAD", 3), 62 | // Map.entry("INT_SUB", 4), 63 | // Map.entry("FLOAT_EQUAL", 5), 64 | // Map.entry("STORE", 6), 65 | // Map.entry("INT_CARRY", 7), 66 | // Map.entry("FLOAT_NOTEQUAL", 8), 67 | // Map.entry("BRANCH", 9), 68 | // Map.entry("INT_SCARRY", 10), 69 | // Map.entry("FLOAT_LESS", 11), 70 | // Map.entry("CBRANCH", 12), 71 | // Map.entry("INT_SBORROW", 13), 72 | // Map.entry("FLOAT_LESSEQUAL", 14), 73 | // Map.entry("BRANCHIND", 15), 74 | // Map.entry("INT_2COMP", 16), 75 | // Map.entry("FLOAT_ADD", 17), 76 | // Map.entry("CALL", 18), 77 | // Map.entry("INT_NEGATE", 19), 78 | // Map.entry("FLOAT_SUB", 20), 79 | // Map.entry("CALLIND", 21), 80 | // Map.entry("INT_XOR", 22), 81 | // Map.entry("FLOAT_MULT", 23), 82 | // Map.entry("USERDEFINED", 24), 83 | // Map.entry("INT_AND", 25), 84 | // Map.entry("FLOAT_DIV", 26), 85 | // Map.entry("RETURN", 27), 86 | // Map.entry("INT_OR", 28), 87 | // Map.entry("FLOAT_NEG", 29), 88 | // Map.entry("PIECE", 30), 89 | // Map.entry("INT_LEFT", 31), 90 | // Map.entry("FLOAT_ABS", 32), 91 | // Map.entry("SUBPIECE", 33), 92 | // Map.entry("INT_RIGHT", 34), 93 | // Map.entry("FLOAT_SQRT", 35), 94 | // Map.entry("INT_EQUAL", 36), 95 | // Map.entry("INT_SRIGHT", 37), 96 | // Map.entry("FLOAT_CEIL", 38), 97 | // Map.entry("INT_NOTEQUAL", 39), 98 | // Map.entry("INT_MULT", 40), 99 | // Map.entry("FLOAT_FLOOR", 41), 100 | // Map.entry("INT_LESS", 42), 101 | // Map.entry("INT_DIV", 43), 102 | // Map.entry("FLOAT_ROUND", 44), 103 | // Map.entry("INT_SLESS", 45), 104 | // Map.entry("INT_REM", 46), 105 | // Map.entry("FLOAT_NAN", 47), 106 | // Map.entry("INT_LESSEQUAL", 48), 107 | // Map.entry("INT_SDIV", 49), 108 | // Map.entry("INT2FLOAT", 50), 109 | // Map.entry("INT_SLESSEQUAL", 51), 110 | // Map.entry("INT_SREM", 52), 111 | // Map.entry("FLOAT2FLOAT", 53), 112 | // Map.entry("INT_ZEXT", 54), 113 | // Map.entry("BOOL_NEGATE", 55), 114 | // Map.entry("TRUNC", 56), 115 | // Map.entry("INT_SEXT", 57), 116 | // Map.entry("BOOL_XOR", 58), 117 | // Map.entry("CPOOLREF", 59), 118 | // Map.entry("BOOL_AND", 60), 119 | // Map.entry("NEW", 61) 120 | // ); 121 | // static List PCODES = Arrays.asList("COPY", "INT_ADD", "BOOL_OR", "LOAD","INT_SUB", "FLOAT_EQUAL", "STORE", "INT_CARRY", "FLOAT_NOTEQUAL", "BRANCH", "INT_SCARRY", "FLOAT_LESS", "CBRANCH", "INT_SBORROW", "FLOAT_LESSEQUAL", "BRANCHIND", "INT_2COMP", "FLOAT_ADD", "CALL", "INT_NEGATE", "FLOAT_SUB", "CALLIND", "INT_XOR", "FLOAT_MULT", "USERDEFINED", "INT_AND", "FLOAT_DIV", "RETURN", "INT_OR", "FLOAT_NEG", "PIECE", "INT_LEFT", "FLOAT_ABS", "SUBPIECE", "INT_RIGHT", "FLOAT_SQRT", "INT_EQUAL", "INT_SRIGHT", "FLOAT_CEIL", "INT_NOTEQUAL", "INT_MULT", "FLOAT_FLOOR", "INT_LESS", "INT_DIV", "FLOAT_ROUND", "INT_SLESS", "INT_REM", "FLOAT_NAN", "INT_LESSEQUAL", "INT_SDIV", "INT2FLOAT", "INT_SLESSEQUAL", "INT_SREM", "FLOAT2FLOAT", "INT_ZEXT", "BOOL_NEGATE", "TRUNC", "INT_SEXT", "BOOL_XOR", "CPOOLREF", "BOOL_AND", "NEW"); 122 | 123 | @Override 124 | public boolean equals(Object obj) { 125 | if (this == obj) 126 | return true; 127 | if (obj == null) 128 | return false; 129 | if (getClass() != obj.getClass()) 130 | return false; 131 | FunctionFeature other = (FunctionFeature) obj; 132 | if (args == null) { 133 | if (other.args != null) 134 | return false; 135 | } else if (!args.equals(other.args)) 136 | return false; 137 | if (calledFunctionAddresses == null) { 138 | if (other.calledFunctionAddresses != null) 139 | return false; 140 | // } else if (!calledFunctionAddresses.equals(other.calledFunctionAddresses)) 141 | } else if (!isListEqual(calledFunctionAddresses, other.calledFunctionAddresses)) 142 | return false; 143 | if (calledImports == null) { 144 | if (other.calledImports != null) 145 | return false; 146 | // } else if (!calledImports.equals(other.calledImports)) 147 | } else if (!isListEqual(calledImports, other.calledImports)) 148 | return false; 149 | if (calledStrings == null) { 150 | if (other.calledStrings != null) 151 | return false; 152 | // } else if (!calledStrings.equals(other.calledStrings)) 153 | } else if (!isListEqual(calledStrings, other.calledStrings)) 154 | return false; 155 | if (callingFunctionAddresses == null) { 156 | if (other.callingFunctionAddresses != null) 157 | return false; 158 | // } else if (!callingFunctionAddresses.equals(other.callingFunctionAddresses)) 159 | } else if (!isListEqual(callingFunctionAddresses, other.callingFunctionAddresses)) 160 | return false; 161 | if (entryPoint == null) { 162 | if (other.entryPoint != null) 163 | return false; 164 | } else if (!entryPoint.equals(other.entryPoint)) 165 | return false; 166 | if (functionName == null) { 167 | if (other.functionName != null) 168 | return false; 169 | } else if (!functionName.equals(other.functionName)) 170 | return false; 171 | if (functionSignature == null) { 172 | if (other.functionSignature != null) 173 | return false; 174 | } else if (!functionSignature.equals(other.functionSignature)) 175 | return false; 176 | if (functionType == null) { 177 | if (other.functionType != null) 178 | return false; 179 | } else if (!functionType.equals(other.functionType)) 180 | return false; 181 | if (instructionBytes == null) { 182 | if (other.instructionBytes != null) 183 | return false; 184 | } else if (!instructionBytes.equals(other.instructionBytes)) 185 | return false; 186 | if (instructions == null) { 187 | if (other.instructions != null) 188 | return false; 189 | } else if (!instructions.equals(other.instructions)) 190 | return false; 191 | if (isExportFunction == null) { 192 | if (other.isExportFunction != null) 193 | return false; 194 | } else if (!isExportFunction.equals(other.isExportFunction)) 195 | return false; 196 | if (isImportFunction == null) { 197 | if (other.isImportFunction != null) 198 | return false; 199 | } else if (!isImportFunction.equals(other.isImportFunction)) 200 | return false; 201 | if (isThunkFunction == null) { 202 | if (other.isThunkFunction != null) 203 | return false; 204 | } else if (!isThunkFunction.equals(other.isThunkFunction)) 205 | return false; 206 | if (memoryBlock == null) { 207 | if (other.memoryBlock != null) 208 | return false; 209 | } else if (!memoryBlock.equals(other.memoryBlock)) 210 | return false; 211 | if (opcodes == null) { 212 | if (other.opcodes != null) 213 | return false; 214 | } else if (!opcodes.equals(other.opcodes)) 215 | return false; 216 | return true; 217 | } 218 | 219 | public static boolean isListEqual(List l0, List l1) { 220 | if (l0 == l1) 221 | return true; 222 | if (l0 == null && l1 == null) 223 | return true; 224 | if (l0 == null || l1 == null) 225 | return false; 226 | if (l0.size() != l1.size()) 227 | return false; 228 | for (Object o : l0) { 229 | if (!l1.contains(o)) 230 | return false; 231 | } 232 | for (Object o : l1) { 233 | if (!l0.contains(o)) 234 | return false; 235 | } 236 | return true; 237 | } 238 | 239 | @Override 240 | public int hashCode() { 241 | final int prime = 31; 242 | int result = 1; 243 | result = prime * result + ((args == null) ? 0 : args.hashCode()); 244 | result = prime * result + ((calledFunctionAddresses == null) ? 0 : calledFunctionAddresses.hashCode()); 245 | result = prime * result + ((calledImports == null) ? 0 : calledImports.hashCode()); 246 | result = prime * result + ((calledStrings == null) ? 0 : calledStrings.hashCode()); 247 | result = prime * result + ((callingFunctionAddresses == null) ? 0 : callingFunctionAddresses.hashCode()); 248 | result = prime * result + ((entryPoint == null) ? 0 : entryPoint.hashCode()); 249 | result = prime * result + ((functionName == null) ? 0 : functionName.hashCode()); 250 | result = prime * result + ((functionSignature == null) ? 0 : functionSignature.hashCode()); 251 | result = prime * result + ((functionType == null) ? 0 : functionType.hashCode()); 252 | result = prime * result + ((instructionBytes == null) ? 0 : instructionBytes.hashCode()); 253 | result = prime * result + ((instructions == null) ? 0 : instructions.hashCode()); 254 | result = prime * result + ((isExportFunction == null) ? 0 : isExportFunction.hashCode()); 255 | result = prime * result + ((isImportFunction == null) ? 0 : isImportFunction.hashCode()); 256 | result = prime * result + ((isThunkFunction == null) ? 0 : isThunkFunction.hashCode()); 257 | result = prime * result + ((memoryBlock == null) ? 0 : memoryBlock.hashCode()); 258 | result = prime * result + ((opcodes == null) ? 0 : opcodes.hashCode()); 259 | return result; 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/model/CIdModel.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.model; 2 | 3 | public class CIdModel { 4 | int cid; 5 | 6 | public int getCid() { 7 | return cid; 8 | } 9 | 10 | public void setCid(int cid) { 11 | this.cid = cid; 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/service/ExtractService.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.service; 2 | 3 | import org.apache.log4j.Logger; 4 | import org.springframework.beans.factory.annotation.Autowired; 5 | import org.springframework.context.annotation.Scope; 6 | import org.springframework.stereotype.Service; 7 | import thusca.bcat.client.entity.BinaryFile; 8 | import thusca.bcat.client.entity.FeatureExtractStatus; 9 | import thusca.bcat.client.utils.FileUtil; 10 | import thusca.bcat.client.utils.LibmagicJnaWrapper; 11 | import thusca.bcat.client.utils.StatusMsg; 12 | 13 | import java.io.File; 14 | 15 | import java.io.IOException; 16 | import java.nio.file.Files; 17 | import java.nio.file.Path; 18 | import java.nio.file.Paths; 19 | import java.util.ArrayList; 20 | import java.util.Iterator; 21 | import java.util.List; 22 | import java.util.concurrent.*; 23 | 24 | @Service("ExtractService") 25 | @Scope("prototype") 26 | public class ExtractService { 27 | private List binaryFileList = new ArrayList<>(); 28 | protected String COMPONENT_PATH; 29 | protected String JSON_SAVE_ROOT_PATH; 30 | protected String GHIDRA_TMP_PATH; 31 | protected Integer PACKAGE_ID; 32 | public FeatureExtractStatus STATUS = new FeatureExtractStatus(); 33 | 34 | private final Logger logger = Logger.getLogger(this.getClass()); 35 | 36 | // private ExecutorService cachedThreadPool; 37 | 38 | @Autowired 39 | GetBinFileService getBinFileService; 40 | 41 | @Autowired 42 | GetBinFeatureService getBinFeatureService; 43 | 44 | @Autowired 45 | SaveToJsonService saveToJsonService; 46 | 47 | public void init(String componentPath, String jsonSaveRootPath, String ghidraTmp, int packageId) { 48 | STATUS = new FeatureExtractStatus(); 49 | COMPONENT_PATH = componentPath; 50 | PACKAGE_ID = packageId; 51 | JSON_SAVE_ROOT_PATH = jsonSaveRootPath; 52 | GHIDRA_TMP_PATH = ghidraTmp; 53 | } 54 | 55 | 56 | public void executable() { 57 | long startTime = System.currentTimeMillis(); 58 | try { 59 | STATUS.setExtractedStatus(1); 60 | executableDetail(); 61 | } catch (Exception e) { 62 | logger.info(PACKAGE_ID + " : executable Error: " + e); 63 | STATUS.setExtractedStatus(2); 64 | e.printStackTrace(); 65 | } finally { 66 | STATUS.setExtractedTime(System.currentTimeMillis() - startTime); 67 | logger.info("write status to json..." + JSON_SAVE_ROOT_PATH); 68 | saveToJsonService.saveStatusToJson(STATUS, JSON_SAVE_ROOT_PATH); 69 | } 70 | } 71 | 72 | protected void executableDetail() { 73 | try { 74 | binaryFileList = getBinFileService.getBinaryFiles(COMPONENT_PATH, STATUS); 75 | } catch (Exception e) { 76 | System.out.println("exception: " + binaryFileList.size()); 77 | e.printStackTrace(); 78 | } 79 | 80 | System.out.println("start extract file:" + COMPONENT_PATH); 81 | // 获取特征 82 | long startTime = System.currentTimeMillis(); 83 | File ghidraProjectTmpDir = null; 84 | String fileName = ""; 85 | String fileType = ""; 86 | List binaryFiles = new ArrayList<>(); 87 | 88 | try { 89 | for (BinaryFile binaryFile : binaryFileList) { 90 | ghidraProjectTmpDir = FileUtil.createTempFile(GHIDRA_TMP_PATH); 91 | fileName = binaryFile.getFileName(); 92 | fileType = binaryFile.getFileType(); 93 | StatusMsg[] statusMsgs = {new StatusMsg()}; 94 | try { 95 | File subFile = new File(binaryFile.getFilePath()); 96 | if (!subFile.isFile()) { 97 | logger.info("ERROR: " + PACKAGE_ID + " _ " + binaryFile.getFilePath() + " - - no such file..."); 98 | } 99 | ExecutorService executor = Executors.newSingleThreadExecutor(); 100 | TimerTask task = new TimerTask(binaryFile, ghidraProjectTmpDir, statusMsgs, fileType); 101 | Future f1 = executor.submit(task); 102 | if (f1.get(30, TimeUnit.MINUTES)) { 103 | logger.info(PACKAGE_ID + " _ " + binaryFile.getFilePath() + " done within 30 minutes..."); 104 | } else { 105 | statusMsgs[0].setErrMsg("over time"); 106 | statusMsgs[0].setOK(false); 107 | logger.info(PACKAGE_ID + " _ " + binaryFile.getFilePath() + " over time: more than 30 minutes..."); 108 | STATUS.addFailedExtractedBinFeature(binaryFile.getFileName(), "over time"); 109 | } 110 | } catch (Exception e) { 111 | e.printStackTrace(); 112 | statusMsgs[0].setOK(false); 113 | logger.info(PACKAGE_ID + " _ " + binaryFile.getFileName() + " ERROR... maybe, time over: " + e); 114 | STATUS.getErrorMessages().add(fileName + " : " + e.getMessage()); 115 | STATUS.setExtractedStatus(2); 116 | } finally { 117 | if (!statusMsgs[0].isOK()) { 118 | STATUS.setExtractedStatus(2); 119 | } else { 120 | binaryFiles.add(binaryFile); 121 | } 122 | if (ghidraProjectTmpDir != null && ghidraProjectTmpDir.isDirectory()) { 123 | ghidraProjectTmpDir.delete(); 124 | } 125 | } 126 | } 127 | } catch (Exception e) { 128 | e.printStackTrace(); 129 | logger.info(PACKAGE_ID + " ERROR...excutableDetail to extract feature: " + " : " + e); 130 | STATUS.setExtractedStatus(2); 131 | } finally { 132 | STATUS.setGetBinFeatureTime(System.currentTimeMillis() - startTime); 133 | postExtract(binaryFiles, STATUS); 134 | } 135 | } 136 | 137 | private StatusMsg getFeatureNoTimer(BinaryFile binaryFile, File ghidraProjectTmpDir) { 138 | StatusMsg statusMsg = new StatusMsg(); 139 | statusMsg = getBinFeatureService.getBinFileFeature(binaryFile, ghidraProjectTmpDir, JSON_SAVE_ROOT_PATH, binaryFile.getFileType()); 140 | return statusMsg; 141 | } 142 | 143 | protected void postExtract(List binaryFiles, FeatureExtractStatus status) { 144 | try { 145 | File featureJson = saveToJsonService.saveBinaryFileListToJson(binaryFiles, JSON_SAVE_ROOT_PATH, PACKAGE_ID.toString()); 146 | } catch (Exception e) { 147 | STATUS.addfailedSavedJson(PACKAGE_ID.toString(), e.getMessage()); 148 | } 149 | } 150 | 151 | 152 | class TimerTask implements Callable { 153 | BinaryFile binaryFile; 154 | File ghidraProjectTmpDir; 155 | StatusMsg[] statusMsg; 156 | String fileType; 157 | 158 | public TimerTask(BinaryFile binaryFile, File ghidraProjectTmpDir, StatusMsg[] statusMsg, String fileType) { 159 | this.binaryFile = binaryFile; 160 | this.ghidraProjectTmpDir = ghidraProjectTmpDir; 161 | this.statusMsg = statusMsg; 162 | this.fileType = fileType; 163 | } 164 | 165 | @Override 166 | public Boolean call() throws Exception { 167 | statusMsg[0] = getBinFeatureService.getBinFileFeature(binaryFile, ghidraProjectTmpDir, JSON_SAVE_ROOT_PATH, fileType); 168 | return true; 169 | } 170 | } 171 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/service/GetBinFeatureService.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.service; 2 | 3 | import org.springframework.stereotype.Component; 4 | import org.springframework.stereotype.Service; 5 | import thusca.bcat.client.entity.BinFileFeature; 6 | import thusca.bcat.client.entity.BinaryFile; 7 | import thusca.bcat.client.utils.BinaryAnalyzer; 8 | import thusca.bcat.client.utils.StatusMsg; 9 | 10 | import java.io.File; 11 | 12 | @Component 13 | public class GetBinFeatureService { 14 | public StatusMsg getBinFileFeature(BinaryFile binaryFile, File ghidraProjectTmpDir, String jsonPath, String fileType) { 15 | BinaryAnalyzer binaryAnalyzer = new BinaryAnalyzer(binaryFile.getFilePath(), ghidraProjectTmpDir.getAbsolutePath(), jsonPath, fileType); 16 | StatusMsg statusMsg = binaryAnalyzer.extractFeatures(); 17 | BinFileFeature binFileFeature = binaryAnalyzer.getBinFileFeature(); 18 | binaryFile.setBinFileFeature(binFileFeature); 19 | binaryFile.setIsProcessed(true); 20 | statusMsg.setFilePath(binaryFile.getFilePath()); 21 | return statusMsg; 22 | } 23 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/service/GetBinFileService.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.service; 2 | 3 | import org.springframework.beans.factory.annotation.Autowired; 4 | import org.springframework.stereotype.Service; 5 | import thusca.bcat.client.entity.BinaryFile; 6 | import thusca.bcat.client.entity.FeatureExtractStatus; 7 | import thusca.bcat.client.utils.LibmagicJnaWrapperBean; 8 | 9 | import java.io.File; 10 | import java.io.IOException; 11 | import java.nio.file.Files; 12 | import java.util.ArrayList; 13 | import java.util.Arrays; 14 | import java.util.Iterator; 15 | import java.util.List; 16 | 17 | @Service 18 | public class GetBinFileService { 19 | 20 | @Autowired 21 | LibmagicJnaWrapperBean LIB_MAGIC_WRAPPER; 22 | 23 | public String fileType; 24 | 25 | private List fileTypes = new ArrayList<>(Arrays.asList("ELF", "Mach-O", "PE")); 26 | 27 | public void getFiles(File rootFile, List fileList) { 28 | File[] files = rootFile.listFiles(); 29 | for (File f : files) { 30 | if (f.isDirectory() && !Files.isSymbolicLink(f.toPath())) { 31 | getFiles(f, fileList); 32 | } else if (f.isFile() && !Files.isSymbolicLink(f.toPath())) { 33 | fileList.add(f); 34 | } 35 | } 36 | } 37 | 38 | public List getBinaryFiles(String componentPath, FeatureExtractStatus status) { 39 | List binaryFileList = new ArrayList<>(); 40 | long startTime = System.currentTimeMillis(); 41 | File rootFile = new File(componentPath); 42 | if (!rootFile.exists()) return binaryFileList; 43 | 44 | List fileList = new ArrayList<>(); 45 | getFiles(rootFile, fileList); 46 | 47 | Iterator fileIterator = fileList.iterator(); 48 | while (fileIterator.hasNext()) { 49 | File subFile = fileIterator.next(); 50 | if (subFile.isFile()) { 51 | fileType = LIB_MAGIC_WRAPPER.getMimeType(subFile.getAbsolutePath()); 52 | for (String fileTypePrefix : fileTypes) { 53 | if (fileType.startsWith(fileTypePrefix)) { 54 | status = getFileList(componentPath, status, subFile, fileTypePrefix, binaryFileList); 55 | } 56 | } 57 | } 58 | } 59 | status.setGetBinFiles(true); 60 | status.setGetBinFileTime(System.currentTimeMillis() - startTime); 61 | return binaryFileList; 62 | } 63 | 64 | public FeatureExtractStatus getFileList(String componentPath, FeatureExtractStatus status, File subFile, String fileType, List binaryFileList) { 65 | try { 66 | BinaryFile binFile = new BinaryFile(subFile.getCanonicalPath(), subFile.getName()); 67 | binFile.setFileType(fileType); 68 | String formattedFileName = binFile.getFilePath().replace(componentPath, "").replace("/", "____"); 69 | if (formattedFileName.startsWith("____")) { 70 | formattedFileName = formattedFileName.substring(4); 71 | } 72 | binFile.setFormattedFileName(formattedFileName); 73 | 74 | binaryFileList.add(binFile); 75 | status.getBinFileNameList().add(formattedFileName); 76 | } catch (IOException e) { 77 | status.getErrorMessages().add(e.getMessage()); 78 | } 79 | return status; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/service/SaveToJsonService.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.service; 2 | 3 | import com.alibaba.fastjson.JSON; 4 | import org.springframework.stereotype.Service; 5 | import thusca.bcat.client.entity.BinaryFile; 6 | import thusca.bcat.client.entity.FeatureExtractStatus; 7 | import thusca.bcat.client.utils.FileUtil; 8 | 9 | import java.io.File; 10 | import java.io.FileWriter; 11 | import java.io.IOException; 12 | import java.util.List; 13 | import java.util.UUID; 14 | 15 | import org.apache.log4j.Logger; 16 | 17 | @Service 18 | public class SaveToJsonService { 19 | private final Logger logger = Logger.getLogger(this.getClass()); 20 | 21 | public File saveBinaryFileToJson (BinaryFile binaryFile, String targetDirPath) throws IOException { 22 | File targetDir = new File(targetDirPath); 23 | if (!targetDir.exists()) { 24 | targetDir.mkdirs(); 25 | } 26 | String jsonFileName = binaryFile.getFormattedFileName() + ".json"; 27 | File targetJsonFile = new File(targetDir, jsonFileName); 28 | logger.info("write feature to json: " + targetJsonFile.getCanonicalPath()); 29 | FileUtil.saveStringToFile(targetJsonFile.getCanonicalPath(), JSON.toJSONString(binaryFile), false); 30 | return targetJsonFile; 31 | } 32 | 33 | public File saveBinaryFileListToJson(List binaryFiles, String targetDirPath, String id) throws IOException{ 34 | File targetDir = new File(targetDirPath); 35 | if (!targetDir.exists()) { 36 | targetDir.mkdirs(); 37 | } 38 | String jsonFileName = id+ ".json"; 39 | File targetJsonFile = new File(targetDir, jsonFileName); 40 | logger.info("write feature to json: " + targetJsonFile.getCanonicalPath()); 41 | FileUtil.saveStringToFile(targetJsonFile.getCanonicalPath(), JSON.toJSONString(binaryFiles), false); 42 | return targetJsonFile; 43 | } 44 | 45 | public void saveStatusToJson(FeatureExtractStatus status, String targetDirPath) { 46 | File targetDir = new File(targetDirPath); 47 | if (!targetDir.exists()) { 48 | targetDir.mkdirs(); 49 | } 50 | try { 51 | File targetJsonFile = new File(targetDir, "status"); 52 | FileUtil.saveStringToFile(targetJsonFile.getCanonicalPath(), JSON.toJSONString(status), false); 53 | } catch (Exception e) { 54 | e.printStackTrace(); 55 | } 56 | } 57 | 58 | 59 | } 60 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/FileUtil.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.utils; 2 | 3 | import org.apache.log4j.Logger; 4 | 5 | import java.io.*; 6 | import java.nio.file.Files; 7 | import java.nio.file.Path; 8 | import java.nio.file.Paths; 9 | import java.util.ArrayList; 10 | import java.util.Map; 11 | import java.util.UUID; 12 | 13 | import com.alibaba.fastjson.JSON; 14 | 15 | public class FileUtil { 16 | 17 | private static Logger logger = Logger.getLogger(FileUtil.class); 18 | 19 | /** 20 | * 获取文件夹下所有的文件路径 21 | * 22 | * @param dir_path 23 | * @return 24 | */ 25 | public static ArrayList getAllDir(String dir_path) { 26 | ArrayList all_dir = new ArrayList(); 27 | File file = new File(dir_path); 28 | if (file.exists()) { 29 | for (File tempfile : file.listFiles()) { 30 | if (tempfile.isDirectory()) { 31 | all_dir.add(tempfile.getAbsolutePath()); 32 | } 33 | } 34 | } 35 | return all_dir; 36 | } 37 | 38 | /** 39 | * 获取所有文件路径 40 | * 41 | * @param path 42 | * @return 43 | */ 44 | public static ArrayList getAllPath(String path) { 45 | ArrayList paths = new ArrayList(); 46 | File fpath = new File(path); 47 | getAllPath(fpath, paths); 48 | return paths; 49 | } 50 | 51 | /** 52 | * 递归获取一个目录下的所有文件 53 | * 54 | * @param path 55 | * @param paths 56 | */ 57 | public static void getAllPath(File path, ArrayList paths) { 58 | File fs[] = path.listFiles(); 59 | if (fs != null) { 60 | for (int i = 0; i < fs.length; i++) { 61 | if (fs[i].isDirectory()) { 62 | getAllPath(fs[i], paths); 63 | } 64 | if (fs[i].isFile()) { 65 | paths.add(fs[i].toString()); 66 | } 67 | } 68 | } 69 | } 70 | 71 | /** 72 | * 将内容写入指定路径 73 | * 74 | * @param filepath 75 | * @param content 76 | */ 77 | public static void saveStringToFile(String filepath, String content, Boolean append) { 78 | FileWriter fw = null; 79 | PrintWriter out = null; 80 | 81 | try { 82 | File file = new File(filepath); 83 | if (!file.exists()) { 84 | File fileParent = file.getParentFile(); 85 | if (!fileParent.exists()) { 86 | fileParent.mkdirs(); 87 | } 88 | } 89 | file.createNewFile(); 90 | fw = new FileWriter(file, append); 91 | out = new PrintWriter(fw); 92 | out.write(content); 93 | out.println(); 94 | } catch (Exception ex) { 95 | logger.error("save string error: " + filepath, ex); 96 | } finally { 97 | try { 98 | if (out != null) { 99 | out.close(); 100 | } 101 | if (fw != null) { 102 | fw.close(); 103 | } 104 | } catch (Exception ex) { 105 | logger.error("resource close error: ", ex); 106 | } 107 | 108 | } 109 | } 110 | 111 | public static void saveJsonToFile(String filepath, Map content) { 112 | String json = JSON.toJSONString(content); 113 | saveStringToFile(filepath, json, false); 114 | } 115 | 116 | public static String readFileToString(String filePath) throws IOException { 117 | String fileString = Files.readString(Paths.get(filePath)); 118 | return fileString; 119 | } 120 | 121 | public static File createTempFile(String parentFolder) throws IOException { 122 | String zipfileStr = UUID.randomUUID().toString().replace("-", ""); 123 | zipfileStr += UUID.randomUUID().toString().replace("-", ""); 124 | String folder = parentFolder; 125 | File file = new File(folder + File.separator + zipfileStr); 126 | file.mkdir(); 127 | return file; 128 | } 129 | 130 | public static String getPathUnderIndexedFile(String rootDir, int packageId) { 131 | String packageIdString = ""+packageId; 132 | int idLength = packageIdString.length(); 133 | String firstLevel = packageIdString.substring(idLength - 2); 134 | String secondLevel = packageIdString.substring(idLength - 4, idLength - 2); 135 | Path path; 136 | path = Paths.get(rootDir, firstLevel, secondLevel, packageIdString); 137 | return path.toString(); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/LibmagicJnaWrapper.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | package thusca.bcat.client.utils; 5 | 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | import java.nio.Buffer; 9 | import java.nio.ByteBuffer; 10 | 11 | import com.sun.jna.Library; 12 | import com.sun.jna.Native; 13 | import com.sun.jna.NativeLong; 14 | import com.sun.jna.Platform; 15 | import com.sun.jna.Pointer; 16 | 17 | /** 18 | * A wrapper for the libmagic library that relies on JNA. This is shamelessly 19 | * borrowed from the JHOVE2 code base, with reliance on JHOVE exceptions and 20 | * configuration removed so that it is easily called from other Maven projects. 21 | * 22 | * The original file can be found here https://bitbucket.org/jhove2/main/src/96b706cbc3e1bd727239d4bfd378690545b37264/src/main/java/org/jhove2/module/identify/file/LibmagicJnaWrapper.java?at=default 23 | * 24 | * @author hbian 25 | * @author carl@openplanetsfoundation.org 26 | */ 27 | 28 | 29 | public class LibmagicJnaWrapper { 30 | /** The default path for the magic file, taken from an Ubuntu installation */ 31 | //TODO: put magic.mgc and libmagic.so libmagic.dylib under resources. 32 | public static final String DEFAULT_MAGIC_PATH = (Platform.isLinux())? "/usr/share/misc/magic.mgc" : "/usr/local/Cellar/libmagic/5.39/share/misc/magic.mgc"; 33 | /** The default buffer size, the number of bytes to pass to file */ 34 | public static final int DEFAULT_BUFFER_SIZE = 8192; 35 | 36 | 37 | /** 38 | * Load the source library and 39 | */ 40 | public interface LibmagicDll extends Library { 41 | //TODO: the library name is "libmagic.so.1". However the name generated here is "libmagic.so" 42 | 43 | // String LIBRARY_NAME = (Platform.isWindows()) ? "magic1" : "magic"; 44 | String LIBRARY_NAME = (Platform.isLinux()) ? "/usr/lib/x86_64-linux-gnu/libmagic.so.1" : "/usr/local/lib/libmagic.dylib"; 45 | 46 | LibmagicDll BASE = (LibmagicDll) Native.loadLibrary(LIBRARY_NAME, LibmagicDll.class); 47 | 48 | //LibmagicDll BASE = (LibmagicDll) Native.loadLibrary(LIBRARY_NAME, LibmagicDll.class); 49 | // LibmagicDll BASE = (LibmagicDll) Native.loadLibrary("/usr/lib/x86_64-linux-gnu/libmagic.so.1", LibmagicDll.class); 50 | 51 | LibmagicDll INSTANCE = (LibmagicDll) Native.synchronizedLibrary(BASE); 52 | 53 | Pointer magic_open(int flags); 54 | 55 | void magic_close(Pointer cookie); 56 | 57 | int magic_setflags(Pointer cookie, int flags); 58 | 59 | String magic_file(Pointer cookie, String fileName); 60 | 61 | String magic_buffer(Pointer cookie, Buffer buffer, NativeLong length); 62 | 63 | int magic_compile(Pointer cookie, String magicFileName); 64 | 65 | int magic_check(Pointer cookie, String magicFileName); 66 | 67 | int magic_load(Pointer cookie, String magicFileName); 68 | 69 | int magic_errno(Pointer cookie); 70 | 71 | String magic_error(Pointer cookie); 72 | } 73 | 74 | /** Libmagic flag: No flags. */ 75 | public final static int MAGIC_NONE = 0x000000; 76 | /** Libmagic flag: Turn on debugging. */ 77 | public final static int MAGIC_DEBUG = 0x000001; 78 | /** Libmagic flag: Follow symlinks. */ 79 | public final static int MAGIC_SYMLINK = 0x000002; 80 | /** Libmagic flag: Check inside compressed files. */ 81 | public final static int MAGIC_COMPRESS = 0x000004; 82 | /** Libmagic flag: Look at the contents of devices. */ 83 | public final static int MAGIC_DEVICES = 0x000008; 84 | /** Libmagic flag: Return the MIME type. */ 85 | public final static int MAGIC_MIME_TYPE = 0x000010; 86 | /** Libmagic flag: Return all matches. */ 87 | public final static int MAGIC_CONTINUE = 0x000020; 88 | /** Libmagic flag: Print warnings to stderr. */ 89 | public final static int MAGIC_CHECK = 0x000040; 90 | /** Libmagic flag: Restore access time on exit. */ 91 | public final static int MAGIC_PRESERVE_ATIME = 0x000080; 92 | /** Libmagic flag: Don't translate unprintable chars. */ 93 | public final static int MAGIC_RAW = 0x000100; 94 | /** Libmagic flag: Handle ENOENT etc as real errors. */ 95 | public final static int MAGIC_ERROR = 0x000200; 96 | /** Libmagic flag: Return the MIME encoding. */ 97 | public final static int MAGIC_MIME_ENCODING = 0x000400; 98 | /** Libmagic flag: Return both MIME type and encoding. */ 99 | public final static int MAGIC_MIME = (MAGIC_MIME_TYPE | MAGIC_MIME_ENCODING); 100 | /** Libmagic flag: Return the Apple creator and type. */ 101 | public final static int MAGIC_APPLE = 0x000800; 102 | /** Libmagic flag: Don't check for compressed files. */ 103 | public final static int MAGIC_NO_CHECK_COMPRESS = 0x001000; 104 | /** Libmagic flag: Don't check for tar files. */ 105 | public final static int MAGIC_NO_CHECK_TAR = 0x002000; 106 | /** Libmagic flag: Don't check magic entries. */ 107 | public final static int MAGIC_NO_CHECK_SOFT = 0x004000; 108 | /** Libmagic flag: Don't check application type. */ 109 | public final static int MAGIC_NO_CHECK_APPTYPE = 0x008000; 110 | /** Libmagic flag: Don't check for elf details. */ 111 | public final static int MAGIC_NO_CHECK_ELF = 0x010000; 112 | /** Libmagic flag: Don't check for text files. */ 113 | public final static int MAGIC_NO_CHECK_TEXT = 0x020000; 114 | /** Libmagic flag: Don't check for cdf files. */ 115 | public final static int MAGIC_NO_CHECK_CDF = 0x040000; 116 | /** Libmagic flag: Don't check tokens. */ 117 | public final static int MAGIC_NO_CHECK_TOKENS = 0x100000; 118 | /** Libmagic flag: Don't check text encodings. */ 119 | public final static int MAGIC_NO_CHECK_ENCODING = 0x200000; 120 | 121 | /** Magic cookie pointer. */ 122 | private final Pointer cookie; 123 | 124 | /** 125 | * Creates a new instance returning the default information: MIME type and 126 | * character encoding. 127 | * 128 | * @throws 129 | * if any error occurred while initializing the libmagic. 130 | * 131 | * @see #LibmagicJnaWrapper(int) 132 | * @see #MAGIC_MIME 133 | */ 134 | public LibmagicJnaWrapper() { 135 | this(MAGIC_MIME | MAGIC_SYMLINK); 136 | } 137 | 138 | /** 139 | * Creates a new instance returning the information specified in the 140 | * flag argument 141 | * 142 | */ 143 | public LibmagicJnaWrapper(int flag) { 144 | this.cookie = LibmagicDll.INSTANCE.magic_open(flag); 145 | if (this.cookie == null) { 146 | throw new IllegalStateException("Libmagic initialization failed"); 147 | } 148 | } 149 | 150 | /** 151 | * Closes the magic database and deallocates any resources used. 152 | */ 153 | public void close() { 154 | LibmagicDll.INSTANCE.magic_close(cookie); 155 | } 156 | 157 | /** 158 | * Returns a textual explanation of the last error. 159 | * 160 | * @return the textual description of the last error, or null 161 | * if there was no error. 162 | */ 163 | public String getError() { 164 | return LibmagicDll.INSTANCE.magic_error(cookie); 165 | } 166 | 167 | /** 168 | * Returns the textual description of the contents of the specified file. 169 | * 170 | * @param filePath 171 | * the path of the file to be identified. 172 | * 173 | * @return the textual description of the file, or null if an 174 | * error occurred. 175 | */ 176 | public String getMimeType(String filePath) { 177 | if ((filePath == null) || (filePath.length() == 0)) { 178 | throw new IllegalArgumentException("filePath"); 179 | } 180 | return LibmagicDll.INSTANCE.magic_file(cookie, filePath); 181 | } 182 | 183 | /** 184 | * Returns textual description of the contents of the buffer 185 | * argument. 186 | * 187 | * @param buffer 188 | * the data to analyze. 189 | * @param length 190 | * the length, in bytes, of the buffer. 191 | * 192 | * @return the textual description of the buffer data, or null 193 | * if an error occurred. 194 | */ 195 | public String getMimeType(Buffer buffer, long length) { 196 | return LibmagicDll.INSTANCE.magic_buffer(cookie, buffer, 197 | new NativeLong(length)); 198 | } 199 | 200 | /** 201 | * Identify the MIME type of an input stream, using the default buffer size 202 | * {@link #DEFAULT_BUFFER_SIZE}. 203 | * 204 | * @param stream 205 | * a java.io.InputStream to be identified 206 | * @return the textual description of the buffer data, or null 207 | * if an error occurred. 208 | * @throws IOException 209 | * if there is a problem identifying a stream 210 | */ 211 | public String getMimeType(InputStream stream) throws IOException { 212 | return this.getMimeType(stream, DEFAULT_BUFFER_SIZE); 213 | } 214 | 215 | /** 216 | * Identify the MIME type of an input stream, using the passed buffer size. 217 | * 218 | * @param stream 219 | * a java.io.InputStream to be identified 220 | * @param bufferSize 221 | * effectively the number of bytes to pass to file, or the length 222 | * of the file if shorter 223 | * @return the textual description of the buffer data, or null 224 | * if an error occurred. 225 | * @throws IOException 226 | * if there is a problem identifying a stream 227 | */ 228 | public String getMimeType(InputStream stream, int bufferSize) 229 | throws IOException { 230 | // create buffer with capacity of bufferSize 231 | byte[] buffer = new byte[bufferSize]; 232 | int len = stream.read(buffer); 233 | ByteBuffer byteBuf = ByteBuffer.wrap(buffer); 234 | return this.getMimeType(byteBuf, len); 235 | } 236 | 237 | /** 238 | * Compiles the colon-separated list of database text files passed in as 239 | * magicFiles. 240 | * 241 | * @param magicFiles 242 | * the magic database file(s), or null to use the 243 | * default database. 244 | * @return 0 on success and -1 on failure. 245 | */ 246 | public int compile(String magicFiles) { 247 | return LibmagicDll.INSTANCE.magic_compile(cookie, magicFiles); 248 | } 249 | 250 | /** 251 | * Loads the colon-separated list of database files passed in as 252 | * magicFiles. This method must be used before any magic 253 | * queries be performed. 254 | * 255 | * @param magicFiles 256 | * the magic database file(s), or null to use the 257 | * default database. 258 | * @return 0 on success and -1 on failure. 259 | */ 260 | public int load(String magicFiles) { 261 | return LibmagicDll.INSTANCE.magic_load(cookie, magicFiles); 262 | } 263 | 264 | /** 265 | * Loads the magic file located at the path held in 266 | * {@link #DEFAULT_MAGIC_PATH} 267 | * 268 | * @return 0 on success and -1 on failure. 269 | */ 270 | public int loadCompiledMagic() { 271 | return this.load(DEFAULT_MAGIC_PATH); 272 | } 273 | } -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/LibmagicJnaWrapperBean.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.utils; 2 | 3 | import org.springframework.beans.factory.InitializingBean; 4 | import org.springframework.stereotype.Component; 5 | 6 | @Component 7 | public class LibmagicJnaWrapperBean implements InitializingBean { 8 | 9 | private LibmagicJnaWrapper libmagicJnaWrapper; 10 | 11 | /* User bean 初始化操作 */ 12 | @Override 13 | public void afterPropertiesSet() throws Exception { 14 | libmagicJnaWrapper = new LibmagicJnaWrapper( 15 | LibmagicJnaWrapper.MAGIC_NO_CHECK_ENCODING | LibmagicJnaWrapper.MAGIC_NO_CHECK_APPTYPE 16 | | LibmagicJnaWrapper.MAGIC_NO_CHECK_TOKENS); 17 | 18 | libmagicJnaWrapper.loadCompiledMagic(); 19 | } 20 | 21 | public String getMimeType(String filePath) { 22 | return libmagicJnaWrapper.getMimeType(filePath); 23 | } 24 | 25 | } 26 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/StatusMsg.java: -------------------------------------------------------------------------------- 1 | package thusca.bcat.client.utils; 2 | 3 | import lombok.Data; 4 | 5 | @Data 6 | public class StatusMsg { 7 | 8 | boolean isOK; 9 | String errMsg; 10 | String filePath; 11 | 12 | 13 | public StatusMsg(String msg, String path) { 14 | isOK = true; 15 | errMsg = msg; 16 | filePath = path; 17 | } 18 | 19 | public StatusMsg(boolean isOK, String msg, String path) { 20 | this.isOK = isOK; 21 | errMsg = msg; 22 | filePath = path; 23 | } 24 | 25 | public StatusMsg() { 26 | isOK = true; 27 | } 28 | 29 | public void setOKMsg(String msg, String path) { 30 | isOK = true; 31 | errMsg = msg; 32 | filePath = path; 33 | } 34 | 35 | public void setErrorMsg(String msg, String path) { 36 | isOK = false; 37 | errMsg = msg; 38 | filePath = path; 39 | } 40 | 41 | public String getMsg() { 42 | return filePath + " : " + errMsg; 43 | } 44 | 45 | public boolean isOK() { 46 | return isOK; 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/libghidra/LibHeadlessErrorLogger.java: -------------------------------------------------------------------------------- 1 | /* ### 2 | * IP: GHIDRA 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | /* 18 | Modified version of the Headless analyzer code, for any feedback, contact NADER SHALLABI at nader@nosecurecode.com 19 | */ 20 | 21 | package thusca.bcat.client.utils.libghidra; 22 | 23 | import java.io.*; 24 | 25 | import ghidra.util.ErrorLogger; 26 | 27 | /** 28 | * Custom headless error logger which is used when log4j is disabled. 29 | */ 30 | class LibHeadlessErrorLogger implements ErrorLogger { 31 | 32 | private PrintWriter logWriter; 33 | 34 | LibHeadlessErrorLogger(File logFile) { 35 | if (logFile != null) { 36 | setLogFile(logFile); 37 | } 38 | } 39 | 40 | synchronized void setLogFile(File logFile) { 41 | try { 42 | if (logFile == null) { 43 | if (logWriter != null) { 44 | writeLog("INFO", "File logging disabled"); 45 | logWriter.close(); 46 | logWriter = null; 47 | } 48 | return; 49 | } 50 | PrintWriter w = new PrintWriter(new FileWriter(logFile)); 51 | if (logWriter != null) { 52 | writeLog("INFO ", "Switching log file to: " + logFile); 53 | logWriter.close(); 54 | } 55 | logWriter = w; 56 | } 57 | catch (IOException e) { 58 | System.err.println("Failed to open log file " + logFile + ": " + e.getMessage()); 59 | } 60 | } 61 | 62 | private synchronized void writeLog(String line) { 63 | if (logWriter == null) { 64 | return; 65 | } 66 | logWriter.println(line); 67 | } 68 | 69 | private synchronized void writeLog(String level, String[] lines) { 70 | if (logWriter == null) { 71 | return; 72 | } 73 | for (String line : lines) { 74 | writeLog(level + " " + line); 75 | } 76 | logWriter.flush(); 77 | } 78 | 79 | private synchronized void writeLog(String level, String text) { 80 | if (logWriter == null) { 81 | return; 82 | } 83 | writeLog(level, chopLines(text)); 84 | } 85 | 86 | private synchronized void writeLog(String level, String text, Throwable throwable) { 87 | if (logWriter == null) { 88 | return; 89 | } 90 | writeLog(level, chopLines(text)); 91 | for (StackTraceElement element : throwable.getStackTrace()) { 92 | writeLog(level + " " + element.toString()); 93 | } 94 | logWriter.flush(); 95 | } 96 | 97 | private String[] chopLines(String text) { 98 | text = text.replace("\r", ""); 99 | return text.split("\n"); 100 | } 101 | 102 | @Override 103 | public void debug(Object originator, Object message) { 104 | // TODO for some reason debug is off 105 | // writeLog("DEBUG", message.toString()); 106 | } 107 | 108 | @Override 109 | public void debug(Object originator, Object message, Throwable throwable) { 110 | // TODO for some reason debug is off 111 | // writeLog("DEBUG", message.toString(), throwable); 112 | } 113 | 114 | @Override 115 | public void error(Object originator, Object message) { 116 | writeLog("ERROR", message.toString()); 117 | } 118 | 119 | @Override 120 | public void error(Object originator, Object message, Throwable throwable) { 121 | writeLog("ERROR", message.toString(), throwable); 122 | } 123 | 124 | @Override 125 | public void info(Object originator, Object message) { 126 | writeLog("INFO ", message.toString()); 127 | } 128 | 129 | @Override 130 | public void info(Object originator, Object message, Throwable throwable) { 131 | // TODO for some reason tracing is off 132 | // writeLog("INFO ", message.toString(), throwable); 133 | } 134 | 135 | @Override 136 | public void trace(Object originator, Object message) { 137 | // TODO for some reason tracing i soff 138 | // writeLog("TRACE", message.toString()); 139 | } 140 | 141 | @Override 142 | public void trace(Object originator, Object message, Throwable throwable) { 143 | // TODO for some reason tracing is off 144 | // writeLog("TRACE", message.toString(), throwable); 145 | } 146 | 147 | @Override 148 | public void warn(Object originator, Object message) { 149 | writeLog("WARN ", message.toString()); 150 | } 151 | 152 | @Override 153 | public void warn(Object originator, Object message, Throwable throwable) { 154 | writeLog("WARN ", message.toString(), throwable); 155 | } 156 | 157 | } 158 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/libghidra/LibHeadlessTimedTaskMonitor.java: -------------------------------------------------------------------------------- 1 | /* ### 2 | * IP: GHIDRA 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | /* 18 | Modified version of the Headless analyzer code, for any feedback, contact NADER SHALLABI at nader@nosecurecode.com 19 | */ 20 | 21 | 22 | package thusca.bcat.client.utils.libghidra; 23 | 24 | import java.util.Timer; 25 | import java.util.TimerTask; 26 | 27 | import ghidra.util.exception.CancelledException; 28 | import ghidra.util.task.CancelledListener; 29 | import ghidra.util.task.TaskMonitor; 30 | 31 | /** 32 | * Monitor used by Headless Analyzer for "timeout" functionality 33 | */ 34 | public class LibHeadlessTimedTaskMonitor implements TaskMonitor { 35 | 36 | private Timer timer = new Timer(); 37 | private volatile boolean isCancelled; 38 | 39 | LibHeadlessTimedTaskMonitor(int timeoutSecs) { 40 | isCancelled = false; 41 | timer.schedule(new TimeOutTask(), timeoutSecs * 1000); 42 | } 43 | 44 | 45 | private class TimeOutTask extends TimerTask { 46 | @Override 47 | public void run() { 48 | LibHeadlessTimedTaskMonitor.this.cancel(); 49 | } 50 | } 51 | 52 | @Override 53 | public boolean isCancelled() { 54 | return isCancelled; 55 | } 56 | 57 | @Override 58 | public void setShowProgressValue(boolean showProgressValue) { 59 | // stub 60 | } 61 | 62 | @Override 63 | public void setMessage(String message) { 64 | // stub 65 | } 66 | 67 | @Override 68 | public String getMessage() { 69 | return null; 70 | } 71 | 72 | @Override 73 | public void setProgress(long value) { 74 | // stub 75 | } 76 | 77 | @Override 78 | public void initialize(long max) { 79 | // stub 80 | } 81 | 82 | @Override 83 | public void setMaximum(long max) { 84 | // stub 85 | } 86 | 87 | @Override 88 | public long getMaximum() { 89 | return 0; 90 | } 91 | 92 | @Override 93 | public void setIndeterminate(boolean indeterminate) { 94 | // stub 95 | } 96 | 97 | @Override 98 | public boolean isIndeterminate() { 99 | return false; 100 | } 101 | 102 | @Override 103 | public void checkCanceled() throws CancelledException { 104 | if (isCancelled()) { 105 | throw new CancelledException(); 106 | } 107 | } 108 | 109 | @Override 110 | public void incrementProgress(long incrementAmount) { 111 | // stub 112 | } 113 | 114 | @Override 115 | public long getProgress() { 116 | return 0; 117 | } 118 | 119 | @Override 120 | public void cancel() { 121 | timer.cancel(); // Terminate the timer thread 122 | isCancelled = true; 123 | } 124 | 125 | @Override 126 | public void addCancelledListener(CancelledListener listener) { 127 | // stub 128 | } 129 | 130 | @Override 131 | public void removeCancelledListener(CancelledListener listener) { 132 | // stub 133 | } 134 | 135 | @Override 136 | public void setCancelEnabled(boolean enable) { 137 | // stub 138 | } 139 | 140 | @Override 141 | public boolean isCancelEnabled() { 142 | return true; 143 | } 144 | 145 | @Override 146 | public void clearCanceled() { 147 | isCancelled = false; 148 | } 149 | } 150 | 151 | -------------------------------------------------------------------------------- /featureExtractor/bcat_client/src/main/java/thusca/bcat/client/utils/libghidra/LibProgramHandler.java: -------------------------------------------------------------------------------- 1 | /* ### 2 | * IP: GHIDRA 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | /* 18 | Modified version of the Headless analyzer code, for any feedback, contact NADER SHALLABI at nader@nosecurecode.com 19 | */ 20 | 21 | package thusca.bcat.client.utils.libghidra; 22 | 23 | import ghidra.program.model.listing.Program; 24 | 25 | public interface LibProgramHandler { 26 | public void PostProcessHandler(Program program); 27 | } 28 | -------------------------------------------------------------------------------- /featureExtractor/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | pom 5 | 6 | bcat_client 7 | 8 | 9 | org.springframework.boot 10 | spring-boot-starter-parent 11 | 2.3.4.RELEASE 12 | 13 | 14 | thusca 15 | bcat 16 | 0.0.1-SNAPSHOT 17 | bcat 18 | thusca.bcat 19 | 20 | 21 | 11 22 | 23 | 24 | 25 | 26 | com.sun.jna 27 | jna 28 | 3.0.9 29 | compile 30 | 31 | 32 | 33 | org.springframework 34 | spring-web 35 | 36 | 37 | org.apache.logging.log4j 38 | log4j-to-slf4j 39 | 40 | 41 | 42 | 43 | log4j 44 | log4j 45 | 1.2.17 46 | 47 | 48 | 49 | com.alibaba 50 | fastjson 51 | 1.2.73 52 | 53 | 54 | 55 | 56 | org.apache.logging.log4j 57 | log4j-api 58 | 2.13.3 59 | 60 | 61 | 62 | 63 | org.springframework 64 | spring-beans 65 | 5.2.9.RELEASE 66 | 67 | 68 | 69 | org.springframework.boot 70 | spring-boot-starter-test 71 | test 72 | 73 | 74 | org.apache.logging.log4j 75 | log4j-to-slf4j 76 | 77 | 78 | 79 | 80 | 81 | 82 | org.springframework.boot 83 | spring-boot-starter-web 84 | 85 | 86 | org.apache.logging.log4j 87 | log4j-to-slf4j 88 | 89 | 90 | 91 | 92 | 93 | 94 | org.projectlombok 95 | lombok 96 | 1.18.12 97 | 98 | 99 | 100 | 101 | 102 | 103 | org.springframework.boot 104 | spring-boot-maven-plugin 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /main/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSoftwareAnalytics/LibDB/fed22dc4cdcd78d250526cbcf390e23e213c2383/main/__init__.py -------------------------------------------------------------------------------- /main/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DeepSoftwareAnalytics/LibDB/fed22dc4cdcd78d250526cbcf390e23e213c2383/main/torch/__init__.py -------------------------------------------------------------------------------- /main/torch/b2sfinder_afcg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle as pkl 3 | from utils import * 4 | from milvus_mod import mil 5 | from milvus import MetricType 6 | import os 7 | from func2vec import func2vec 8 | from function_vector_channel import com2bins, func_tar 9 | import copy 10 | import time 11 | import uuid 12 | from datetime import datetime 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--model', type=str, default='7fea_contra_torch_b128', 16 | help='network model') 17 | parser.add_argument('--fea_dim', type=int, default=7, 18 | help='feature dimension') 19 | parser.add_argument('--load_path', type=str, default='../data/{0}/saved_model/model-inter-best.pt', 20 | help='path for model loading, "#LATEST#" for the latest checkpoint') 21 | parser.add_argument('--base_result', type=str, default='../data/detection/b2sfinder/b2sfinder_results.json', 22 | help='result of base matching method') 23 | parser.add_argument('--id2key', type=str, default='../data/{0}/core_funcs/id2key.pkl', 24 | help='id to key') 25 | parser.add_argument('--id2vec', type=str, default='../data/{0}/core_funcs/id2vec.pkl', 26 | help='id to vec') 27 | parser.add_argument('--pid2bin2vecid', type=str, default='../data/{0}/pid2bin2vecid.pkl', 28 | help='dict:pid to bin to vecid') 29 | parser.add_argument('--test_app_dir', type=str, 30 | default='../data/detection_targets/osspolice_testdata/oss_featureJson', help='test data path') 31 | parser.add_argument('--bin2func_num', type=str, 32 | default='../data/bin2func_num.pkl', help='function numbers of binaries') 33 | parser.add_argument('--bin2fcgs', type=str, 34 | default='../data/funcs_fcg.pkl', help='fcgs of binaries') 35 | parser.add_argument('--id2package', type=str, 36 | default='../data/pid2package.json', help='') 37 | parser.add_argument('--k', type=int, 38 | default=1, help='topk') 39 | parser.add_argument('--fs_thres', type=float, 40 | default=0.8, help='threshold to decide similar functions') 41 | parser.add_argument('--allfea', type=bool, 42 | default=False, help='use all features in b2sfinder') 43 | parser.add_argument('--save_path', type=str, 44 | default='./b2sfinder_afcg_7fea_contra_torch_b128_k1_common5_fsthres0.8.json', help='the path to save results') 45 | 46 | ARGS = parser.parse_args() 47 | 48 | 49 | def b2sfinder_afcg(): 50 | MODEL = ARGS.model 51 | print('loading data...') 52 | id2keys = read_pkl(ARGS.id2key.format(MODEL)) 53 | base_results = read_json(ARGS.base_result) 54 | all_id2vecs = read_pkl(ARGS.id2vec.format(MODEL)) 55 | bin2func_num = read_pkl(ARGS.bin2func_num) 56 | bin2fcgs = read_pkl(ARGS.bin2fcgs) 57 | id2package = read_json(ARGS.id2package) 58 | pid2bin2vecid = read_pkl(ARGS.pid2bin2vecid.format(MODEL)) 59 | TEST_APP_DIR = ARGS.test_app_dir 60 | k = ARGS.k 61 | fs_thres = ARGS.fs_thres 62 | save_path = ARGS.save_path 63 | allfea = ARGS.allfea 64 | print('data loading finished!') 65 | 66 | m = mil() 67 | net = func2vec(ARGS.load_path.format(MODEL), gpu=True, fea_dim=ARGS.fea_dim) 68 | base_afcg_result = {} 69 | 70 | load_time = 0 71 | com_fcg_time = 0 72 | 73 | for app in base_results: 74 | print(datetime.now()) 75 | start_time = time.time() 76 | if app == 'net.avs234_16': 77 | continue 78 | print(app) 79 | base_afcg_result[app] = {} 80 | test_file_path = os.path.join(TEST_APP_DIR, app, '0.json') 81 | test_file_list = read_json(test_file_path) 82 | for bin_tar in base_results[app]: 83 | print(datetime.now()) 84 | start = time.time() 85 | print(' ', bin_tar) 86 | matched_libs = {} 87 | base_afcg_result[app][bin_tar] = [] 88 | for test_file in test_file_list: 89 | if test_file['formattedFileName'] == bin_tar: 90 | break 91 | 92 | funcs = {} 93 | entries = [] 94 | tar_vecs = [] 95 | for f in test_file['binFileFeature']['functions']: 96 | if f['nodes'] < 5: 97 | continue 98 | if f['isThunkFunction'] is True or 'text' not in f['memoryBlock']: 99 | continue 100 | cur_f = func_tar(f) 101 | funcs[f['entryPoint']] = cur_f 102 | vec = net.get_embedding_from_func_fea(cur_f.fea, True) 103 | vec = vec.cpu().detach().numpy() 104 | vec = norm_vec(vec) 105 | tar_vecs.append(vec) 106 | entries.append(cur_f.entry) 107 | 108 | match_result = base_results[app][bin_tar] 109 | print('b2sfinder matched results: ', len(match_result)) 110 | for item in match_result: 111 | if 'string' not in match_result[item]['match']: 112 | match_string = False 113 | else: 114 | match_string = match_result[item]['match']['string'] 115 | if 'export' not in match_result[item]['match']: 116 | match_export = False 117 | else: 118 | match_export = match_result[item]['match']['export'] 119 | if not allfea: 120 | if not match_string and not match_export: 121 | continue 122 | pid = tuple(eval(item))[0] 123 | lib_name = id2package[pid] 124 | if lib_name not in matched_libs: 125 | matched_libs[lib_name] = [] 126 | if 'string' not in match_result[item]['similarity']: 127 | sim_string = 0 128 | else: 129 | sim_string = match_result[item]['similarity']['string'] 130 | if 'export' not in match_result[item]['similarity']: 131 | sim_export = 0 132 | else: 133 | sim_export = match_result[item]['similarity']['export'] 134 | matched_libs[lib_name].append([item, sim_string / 0.8 + sim_export / 0.2]) 135 | for lib_name in matched_libs: 136 | matched_libs[lib_name] = sorted(matched_libs[lib_name], key=lambda x:x[1], reverse=True) 137 | 138 | for lib_name in matched_libs: 139 | for item_score in matched_libs[lib_name]: 140 | item = item_score[0] 141 | # test_file_tmp = copy.deepcopy(test_file) 142 | pid = tuple(eval(item))[0] 143 | fname = tuple(eval(item))[1] 144 | 145 | if pid not in pid2bin2vecid or fname not in pid2bin2vecid[pid]: 146 | continue 147 | vecids = pid2bin2vecid[pid][fname] 148 | vecs = {'vecs': [], 'ids': []} 149 | for id in vecids: 150 | vecs['ids'].append(id) 151 | vecs['vecs'].append(all_id2vecs[id]) 152 | try: 153 | print(datetime.now()) 154 | tmp_collection = 'tmp'+str(uuid.uuid4()).replace("-", '') 155 | print('milvus loads data...'+fname) 156 | start = time.time() 157 | m.load_data(list(vecs['vecs']), list(vecs['ids']), 158 | tmp_collection, clear_old=True, metrictype=MetricType.IP) 159 | load_time += time.time() - start 160 | start = time.time() 161 | correct_edges = False 162 | common, afcg_rate = com2bins( 163 | m, bin2fcgs, test_file, funcs, entries, tar_vecs, tmp_collection, pid, fname, net, id2keys, id2package, bin2func_num, k, fs_thres, correct_edges) 164 | 165 | com_fcg_time += time.time() - start 166 | print(common, ' ', afcg_rate, ' ', fname.split('____')[-1]) 167 | base_afcg_result[app][bin_tar].append( 168 | [id2package[pid], pid, fname, common, afcg_rate]) 169 | m.delete_collection(tmp_collection) 170 | print('load_time/com: ', load_time/com_fcg_time) 171 | except Exception as e: 172 | print(e) 173 | continue 174 | if common > 10 or (common > 5 and afcg_rate > 0.2): 175 | break 176 | save_json(base_afcg_result, save_path) 177 | print(load_time) 178 | print(com_fcg_time) 179 | print(load_time/com_fcg_time) 180 | 181 | if __name__ == '__main__': 182 | allfea=False 183 | fs_thres = 0.8 184 | 185 | # ARGS.base_results = '/home/user/binary_lib_detection/related_work/b2sfinder/FeatureMatch/b2sfinder_s_subset0.5.json' 186 | ARGS.save_path = './b2sfinder_afcg_allversions_7fea_contra_torch_b128_k1_com10_com5rate0.2_fsthres0.8.json' 187 | b2sfinder_afcg() 188 | -------------------------------------------------------------------------------- /main/torch/base_afcg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle as pkl 3 | from utils import * 4 | from milvus_mod import mil 5 | from milvus import MetricType 6 | import os 7 | from func2vec import func2vec 8 | from function_vector_channel import com2bins 9 | import copy 10 | import time 11 | import uuid 12 | from datetime import datetime 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--model', type=str, default='7fea_contra_torch_b128', 16 | help='network model') 17 | parser.add_argument('--fea_dim', type=int, default=7, 18 | help='feature dimension') 19 | parser.add_argument('--load_path', type=str, default='../data/{0}/saved_model/model-inter-best.pt', 20 | help='path for model loading, "#LATEST#" for the latest checkpoint') 21 | parser.add_argument('--base_result', type=str, default='../data/detection/base/base_result_b2sfinder_rule.json', 22 | help='result of base matching method') 23 | parser.add_argument('--id2key', type=str, default='../data/{0}/core_funcs/id2key.pkl', 24 | help='id to key') 25 | parser.add_argument('--id2vec', type=str, default='../data/{0}/core_funcs/id2vec.pkl', 26 | help='id to vec') 27 | parser.add_argument('--pid2bin2vecid', type=str, default='../data/{0}/pid2bin2vecid.pkl', 28 | help='dict:pid to bin to vecid') 29 | parser.add_argument('--test_app_dir', type=str, 30 | default='../data/detection_targets/features/osspolice/oss_featureJson', help='test data path') 31 | parser.add_argument('--bin2func_num', type=str, 32 | default='../data/bin2func_num.pkl', help='function numbers of binaries') 33 | parser.add_argument('--bin2fcgs', type=str, 34 | default='../data/funcs_fcg.pkl', help='fcgs of binaries') 35 | parser.add_argument('--id2package', type=str, 36 | default='../data/pid2package.json', help='') 37 | parser.add_argument('--k', type=int, 38 | default=1, help='topk') 39 | parser.add_argument('--fs_thres', type=float, 40 | default=0.8, help='threshold to decide similar functions') 41 | 42 | 43 | ARGS = parser.parse_args() 44 | 45 | print(datetime.now()) 46 | 47 | MODEL = ARGS.model 48 | id2keys = read_pkl(ARGS.id2key.format(MODEL)) 49 | base_results = read_json(ARGS.base_result) 50 | all_id2vecs = read_pkl(ARGS.id2vec.format(MODEL)) 51 | bin2func_num = read_pkl(ARGS.bin2func_num) 52 | bin2fcgs = read_pkl(ARGS.bin2fcgs) 53 | id2package = read_json(ARGS.id2package) 54 | pid2bin2vecid = read_pkl(ARGS.pid2bin2vecid.format(MODEL)) 55 | TEST_APP_DIR = ARGS.test_app_dir 56 | k = ARGS.k 57 | fs_thres = ARGS.fs_thres 58 | 59 | m = mil() 60 | net = func2vec(ARGS.load_path.format(MODEL), gpu=True, fea_dim=ARGS.fea_dim) 61 | base_afcg_result = {} 62 | 63 | for app in base_results: 64 | print(datetime.now()) 65 | if app == 'net.avs234_16': 66 | continue 67 | print(app) 68 | base_afcg_result[app] = {} 69 | test_file_path = os.path.join(TEST_APP_DIR, app, '0.json') 70 | test_file_list = read_json(test_file_path) 71 | for bin_tar in base_results[app]: 72 | print(datetime.now()) 73 | print(' ', bin_tar) 74 | matched_libs = {} 75 | base_afcg_result[app][bin_tar] = [] 76 | for test_file in test_file_list: 77 | if test_file['formattedFileName'] == bin_tar: 78 | break 79 | match_result = base_results[app][bin_tar]['match_result'] 80 | for item in match_result: 81 | if match_result[item]['strs'] < 0.8 and match_result[item]['exps'] < 0.2: 82 | continue 83 | pid = tuple(eval(item))[0] 84 | lib_name = id2package[pid] 85 | if lib_name not in matched_libs: 86 | matched_libs[lib_name] = [] 87 | matched_libs[lib_name].append([item, match_result[item]['strs'] / 0.8 + match_result[item]['exps'] / 0.2]) 88 | for lib_name in matched_libs: 89 | matched_libs[lib_name] = sorted(matched_libs[lib_name], key=lambda x:x[1], reverse=True) 90 | 91 | filtered_matched_libs = [] 92 | for lib_name in matched_libs: 93 | for item_score in matched_libs[lib_name]: 94 | item = item_score[0] 95 | test_file_tmp = copy.deepcopy(test_file) 96 | pid = tuple(eval(item))[0] 97 | fname = tuple(eval(item))[1] 98 | vecids = pid2bin2vecid[pid][fname] 99 | vecs = {'vecs': [], 'ids': []} 100 | for id in vecids: 101 | vecs['ids'].append(id) 102 | vecs['vecs'].append(all_id2vecs[id]) 103 | try: 104 | print(datetime.now()) 105 | tmp_collection = 'tmp'+str(uuid.uuid4()).replace("-", '') 106 | print('milvus loads data...'+fname) 107 | m.load_data(list(vecs['vecs']), list(vecs['ids']), 108 | tmp_collection, clear_old=True, metrictype=MetricType.IP) 109 | common, afcg_rate = com2bins( 110 | m, bin2fcgs, test_file_tmp, tmp_collection, pid, fname, net, id2keys, id2package, bin2func_num, k, fs_thres) 111 | print(common, ' ', afcg_rate, ' ', fname.split('____')[-1]) 112 | base_afcg_result[app][bin_tar].append( 113 | [id2package[pid], pid, fname, common, afcg_rate]) 114 | m.delete_collection(tmp_collection) 115 | print('delete collection...') 116 | except Exception as e: 117 | print(e) 118 | continue 119 | if common > 5 or (common > 2 and afcg_rate > 0.1): 120 | filtered_matched_libs.append(lib_name) 121 | break 122 | save_json(base_afcg_result, './base_afcg_7fea_contra_torch_b128_k1_common5_nofsthres.json') 123 | -------------------------------------------------------------------------------- /main/torch/core_fedora_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from datetime import datetime 4 | import os 5 | import argparse 6 | import json 7 | import pickle as pkl 8 | from multiprocessing import Pool 9 | 10 | from torch_model import graphnn 11 | from utils_loss import graph 12 | from func2vec import func2vec 13 | from utils import * 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--device', type=str, default='0', 18 | help='visible gpu device') 19 | parser.add_argument('--gap', type=str, default='0.5', 20 | help='triple loss gap') 21 | parser.add_argument('--fea_dim', type=int, default=7, 22 | help='feature dimension') 23 | parser.add_argument('--fedora_js', type=str, default='../data/CoreFedoraFeatureJson0505', 24 | help='feature json of fedora packages') 25 | parser.add_argument('--valid_pairs', type=str, default='../data/validation_pairs/valid_pairs_v1/76_fea_dim', 26 | help='valid pair json') 27 | 28 | parser.add_argument('--load_path', type=str, default='../data/7fea_contra_torch_b128/saved_model/model-inter-best.pt', 29 | help='path for model loading, "#LATEST#" for the latest checkpoint') 30 | parser.add_argument('--save_path', type=str, 31 | default='../data/core_funcs', help='path for pkl saving') 32 | parser.add_argument('--valid_save_path', type=str, 33 | default='../data/76fea_triple_torch_epo150_gap0.5/valid_pairs', help='path for pkl saving') 34 | 35 | 36 | def core_fedora_embedding(prefix, norm): 37 | print("start process with prefix: ", prefix) 38 | args = parser.parse_args() 39 | print("=================================") 40 | print(args) 41 | print("=================================") 42 | 43 | # all features extracted from core fedora packages. this dir is generated by unzipping coreFedoraFeatureZip 44 | FEATURE_JSON_CORE_FEDORA = args.fedora_js 45 | if not os.path.exists(args.save_path): 46 | os.makedirs(args.save_path) 47 | core_fedora_fedora_model_7fea = args.save_path + \ 48 | '/core_funcs_{0}.pkl'.format(str(prefix)) 49 | # Model 50 | # os.environ["CUDA_VISIBLE_DEVICES"] = args.device 51 | gpu = prefix % 4 52 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) 53 | LOAD_PATH = args.load_path 54 | level_ids = os.listdir(FEATURE_JSON_CORE_FEDORA) 55 | 56 | model = func2vec(LOAD_PATH, True, args.fea_dim) 57 | 58 | func_vectors = {} 59 | count = 0 60 | package_count = 0 61 | for first_level_id in level_ids: 62 | if not first_level_id.startswith(str(prefix)): 63 | continue 64 | for second_level_id in level_ids: 65 | second_level_path = os.path.join( 66 | FEATURE_JSON_CORE_FEDORA, first_level_id, second_level_id) 67 | if not os.path.exists(second_level_path): 68 | continue 69 | packages = os.listdir(second_level_path) 70 | for package in packages: 71 | package_count += 1 72 | if package_count % 100 == 0: 73 | print("package: ", package_count, "----", prefix) 74 | package_path = os.path.join(second_level_path, package) 75 | if not os.path.isdir(package_path): 76 | continue 77 | package_feature_json = os.path.join( 78 | package_path, package+'.json') 79 | if not os.path.exists(package_feature_json): 80 | continue 81 | try: 82 | with open(package_feature_json, 'r') as load_f: 83 | content = json.load(load_f) 84 | except: 85 | print("error package", package) 86 | continue 87 | 88 | for binary_file in content: 89 | if 'binFileFeature' not in binary_file: 90 | continue 91 | for func in binary_file['binFileFeature']['functions']: 92 | if func['nodes'] < 5: 93 | continue 94 | if func['isThunkFunction'] is True or 'text' not in func['memoryBlock']: 95 | continue 96 | key = ( 97 | package, binary_file['formattedFileName'], func['entryPoint']) 98 | # key2func[key] = func 99 | vec = model.get_embedding_from_func_fea( 100 | func, correct_edges=True) 101 | vec = vec.cpu().detach().numpy() 102 | if norm: 103 | vec = norm_vec(vec) 104 | func_vectors[key] = vec 105 | count += 1 106 | if count % 10000 == 0: 107 | print(count, "----", prefix) 108 | 109 | # key2func = {} 110 | # keys = list(key2func.keys()) 111 | # st = 0 112 | # batch = 128 113 | # while(st < len(keys)): 114 | # if st+batch > len(keys): 115 | # end = len(keys) 116 | # else: 117 | # end = st + batch 118 | # func_fea_l = [] 119 | # for i in range(st, end): 120 | # func_fea_l.append(key2func[keys[i]]) 121 | 122 | # vecs = model.get_embeddings_from_func_fea_l(func_fea_l, correct_edges=True, func_sig=None) 123 | # vecs = vecs.cpu().detach().numpy() 124 | # # vecs = vecs.detach().numpy() 125 | # for i in range(st, end): 126 | # if norm: 127 | # vec = norm_vec(vecs[i-st]) 128 | # else: 129 | # vec = vecs[i-st] 130 | # func_vectors[keys[i]] = vec 131 | # count += end-st+1 132 | # st = end 133 | # print("count: ", count, "----", prefix) 134 | 135 | print("func nums : ", len(func_vectors), "----", prefix) 136 | 137 | with open(core_fedora_fedora_model_7fea, 'wb') as fo: 138 | pkl.dump(func_vectors, fo) 139 | 140 | 141 | def get_true_pairs(js_path): 142 | true_pairs = [] 143 | with open(js_path) as load_f: 144 | for line in load_f: 145 | pair = json.loads(line.strip()) 146 | true_pairs.append(pair) 147 | return true_pairs 148 | 149 | 150 | def save_data(data, save_path): 151 | for item in data: 152 | with open(save_path, 'a+') as f: 153 | line = json.dumps(item) 154 | f.write(line+'\n') 155 | 156 | 157 | def valid_embedding_pairs(norm): 158 | args = parser.parse_args() 159 | print("=================================") 160 | print(args) 161 | print("=================================") 162 | 163 | valid_pairs_7fea = args.valid_pairs 164 | valid_vec_pairs_dir = args.valid_save_path 165 | if not os.path.exists(args.valid_save_path): 166 | os.makedirs(args.valid_save_path) 167 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 168 | LOAD_PATH = args.load_path 169 | 170 | # Model 171 | model = func2vec(LOAD_PATH, True, args.fea_dim) 172 | 173 | count = 0 174 | js_files_7fea = os.listdir(valid_pairs_7fea) 175 | for js_file in js_files_7fea: 176 | print('start to process:', js_file) 177 | valid_vec_pairs = [] 178 | pairs_7_fea_js = os.path.join(valid_pairs_7fea, js_file) 179 | true_pairs_7_fea = get_true_pairs(pairs_7_fea_js) 180 | for pair in true_pairs_7_fea: 181 | count += 1 182 | if count % 10000 == 0: 183 | print(count) 184 | key = pair[0]['fname'] 185 | l_graph = model.get_graph(pair[0]) 186 | l_vec = model.get_embedding_from_func_graph(l_graph) 187 | l_vec = l_vec.cpu().detach().numpy() 188 | if norm: 189 | l_vec = norm_vec(l_vec) 190 | r_graph = model.get_graph(pair[1]) 191 | r_vec = model.get_embedding_from_func_graph(r_graph) 192 | r_vec = r_vec.cpu().detach().numpy() 193 | if norm: 194 | r_vec = norm_vec(r_vec) 195 | valid_vec_pairs.append({key: [l_vec, r_vec]}) 196 | true_vec_pairs_7_fea_js = os.path.join( 197 | valid_vec_pairs_dir, js_file.replace('.json', '.pkl')) 198 | with open(true_vec_pairs_7_fea_js, 'wb') as fo: 199 | pkl.dump(valid_vec_pairs, fo) 200 | 201 | 202 | def generate_vec_index(data_path, id2key_save_path, id2vec_save_path, norm): 203 | raw_data_files = os.listdir(data_path) 204 | raw_data = {} 205 | for f in raw_data_files: 206 | with open(os.path.join(data_path, f), 'rb') as load_f: 207 | raw_data.update(pkl.load(load_f)) 208 | id = 0 209 | id2key = {} 210 | id2vec = {} 211 | for key in raw_data: 212 | id += 1 213 | if id % 100000 == 0: 214 | print(id) 215 | id2key[id] = key 216 | if norm: 217 | id2vec[id] = norm_vec(raw_data[key]) 218 | else: 219 | id2vec[id] = raw_data[key] 220 | print(id) 221 | save_pkl(id2key, id2key_save_path) 222 | save_pkl(id2vec, id2vec_save_path) 223 | 224 | 225 | def vec_index_all_fedora(data_path, prefix, index_save_path, norm): 226 | file_path = os.path.join(data_path, 'core_funcs_{0}.pkl'.format(prefix)) 227 | with open(file_path, 'rb') as load_f: 228 | content = pkl.load(load_f) 229 | if len(content) > 100000000: 230 | print('error, more than 100000000 functions') 231 | id = 100000000 * prefix 232 | id2key = {} 233 | id2vec = {} 234 | for key in content: 235 | id+=1 236 | if id % 1000000 == 0: 237 | print(id, '----', prefix) 238 | id2key[id] = key 239 | if norm: 240 | id2vec[id] = norm_vec(content[key]) 241 | else: 242 | id2vec[id] = content[key] 243 | 244 | id2key_save_path = os.path.join(index_save_path, 'id2key_{0}.pkl'.format(prefix)) 245 | id2vec_save_path = os.path.join(index_save_path, 'id2vec_{0}.pkl'.format(prefix)) 246 | save_pkl(id2key, id2key_save_path) 247 | save_pkl(id2vec, id2vec_save_path) 248 | 249 | 250 | 251 | 252 | 253 | 254 | if __name__ == "__main__": 255 | with Pool(10) as p: 256 | p.starmap(core_fedora_embedding, [(i, True) for i in range(10)]) 257 | print("core func embedding done") 258 | 259 | # valid_embedding_pairs(True) 260 | -------------------------------------------------------------------------------- /main/torch/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from numpy.random import choice as npc 4 | import numpy as np 5 | 6 | 7 | class dataset(Dataset): 8 | 9 | def __init__(self, Gs, classes, batch_size, neg_batch, neg_batch_flag, train): 10 | super(dataset, self).__init__() 11 | np.random.seed(0) 12 | self.Gs = Gs 13 | self.classes = classes 14 | self.batch_size = batch_size 15 | self.neg_batch = neg_batch 16 | if train: 17 | self.perm = np.random.permutation(len(self.Gs)) 18 | else: 19 | self.perm = range(len(Gs)) 20 | self.neg_batch_flag = neg_batch_flag 21 | 22 | def __len__(self): 23 | return len(self.Gs) / self.batch_size 24 | 25 | def shuffle(self): 26 | self.perm = np.random.permutation(len(self.Gs)) 27 | 28 | def get_pair(self, st, neg_batch_flag, output_id=False, load_id=None): 29 | if load_id is None: 30 | C = len(self.classes) 31 | if (st + self.batch_size > len(self.perm)): 32 | M = len(self.perm) - st 33 | else: 34 | M = self.batch_size 35 | ed = st + M 36 | triple_ids = [] # [(G_0, G_p, G_n)] 37 | p_funcs = [] 38 | true_pairs = [] 39 | n_ids = [] 40 | 41 | for g_id in self.perm[st:ed]: 42 | g0 = self.Gs[g_id] 43 | cls = g0.label 44 | p_funcs.append(cls) 45 | tot_g = len(self.classes[cls]) 46 | if (len(self.classes[cls]) >= 2): 47 | p_id = self.classes[cls][np.random.randint(tot_g)] 48 | while g_id == p_id: 49 | p_id = self.classes[cls][np.random.randint(tot_g)] 50 | true_pairs.append((g_id, p_id)) 51 | else: 52 | triple_ids = load_id[0] 53 | if not neg_batch_flag: 54 | M = len(true_pairs) 55 | self.neg_batch = M 56 | for i in range(self.neg_batch): 57 | n_cls = np.random.randint(C) 58 | while (len(self.classes[n_cls]) == 0) or (n_cls in p_funcs): 59 | n_cls = np.random.randint(C) 60 | tot_g2 = len(self.classes[n_cls]) 61 | n_id = self.classes[n_cls][np.random.randint(tot_g2)] 62 | n_ids.append(n_id) 63 | maxN1 = 0 64 | maxN2 = 0 65 | maxN3 = 0 66 | for pair in true_pairs: 67 | maxN1 = max(maxN1, self.Gs[pair[0]].node_num) 68 | maxN2 = max(maxN2, self.Gs[pair[1]].node_num) 69 | for id in n_ids: 70 | maxN3 = max(maxN3, self.Gs[id].node_num) 71 | feature_dim = len(self.Gs[0].features[0]) 72 | X1_input = np.zeros((M, maxN1, feature_dim)) 73 | X2_input = np.zeros((M, maxN2, feature_dim)) 74 | X3_input = np.zeros((self.neg_batch, maxN3, feature_dim)) 75 | node1_mask = np.zeros((M, maxN1, maxN1)) 76 | node2_mask = np.zeros((M, maxN2, maxN2)) 77 | node3_mask = np.zeros((self.neg_batch, maxN3, maxN3)) 78 | 79 | for i in range(len(true_pairs)): 80 | g1 = self.Gs[true_pairs[i][0]] 81 | g2 = self.Gs[true_pairs[i][1]] 82 | 83 | for u in range(g1.node_num): 84 | X1_input[i, u, :] = np.array(g1.features[u]) 85 | for v in g1.succss[u]: 86 | node1_mask[i, u, v] = 1 87 | for u in range(g2.node_num): 88 | X2_input[i, u, :] = np.array(g2.features[u]) 89 | for v in g2.succss[u]: 90 | node2_mask[i, u, v] = 1 91 | 92 | for i in range(len(n_ids)): 93 | g3 = self.Gs[n_ids[i]] 94 | for u in range(g3.node_num): 95 | X3_input[i, u, :] = np.array(g3.features[u]) 96 | for v in g3.succss[u]: 97 | node3_mask[i, u, v] = 1 98 | if output_id: 99 | return X1_input, X2_input, X3_input, node1_mask, node2_mask, node3_mask, triple_ids 100 | else: 101 | return X1_input, X2_input, X3_input, node1_mask, node2_mask, node3_mask 102 | 103 | def __getitem__(self, index): 104 | X1, X2, X3, m1, m2, m3 = self.get_pair(index * self.batch_size, neg_batch_flag=self.neg_batch_flag) 105 | return torch.from_numpy(X1).float(), torch.from_numpy(X2).float(), torch.from_numpy(X3).float(), torch.from_numpy(m1).float(), torch.from_numpy(m2).float(), torch.from_numpy(m3).float() -------------------------------------------------------------------------------- /main/torch/eval.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | print(tf.__version__) 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from datetime import datetime 6 | from graphnnSiamese import graphnn 7 | from utils_valid import * 8 | import os 9 | import argparse 10 | import json 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--device', type=str, default='0,1,2,3', 14 | help='visible gpu device') 15 | parser.add_argument('--use_device', type=str, default='/gpu:1', 16 | help='used gpu device') 17 | parser.add_argument('--fea_dim', type=int, default=7, 18 | help='feature dimension') 19 | parser.add_argument('--embed_dim', type=int, default=64, 20 | help='embedding dimension') 21 | parser.add_argument('--embed_depth', type=int, default=5, 22 | help='embedding network depth') 23 | parser.add_argument('--output_dim', type=int, default=64, 24 | help='output layer dimension') 25 | parser.add_argument('--iter_level', type=int, default=5, 26 | help='iteration times') 27 | parser.add_argument('--lr', type=float, default=1e-4, 28 | help='learning rate') 29 | parser.add_argument('--epoch', type=int, default=100, 30 | help='epoch number') 31 | parser.add_argument('--batch_size', type=int, default=128, 32 | help='batch size') 33 | parser.add_argument('--load_path', type=str, 34 | default='./saved_model/graphnn-model_best', 35 | help='path for model loading, "#LATEST#" for the latest checkpoint') 36 | parser.add_argument('--log_path', type=str, default=None, 37 | help='path for training log') 38 | 39 | def get_true_pairs(js_path): 40 | true_pairs = [] 41 | with open(js_path) as load_f: 42 | for line in load_f: 43 | pair = json.loads(line.strip()) 44 | true_pairs.append(pair) 45 | return true_pairs 46 | 47 | 48 | def eval(fea_dim, model_load_path, t_pairs_path, use_device, thres): 49 | args = parser.parse_args() 50 | args.dtype = tf.float32 51 | print("=================================") 52 | print(args) 53 | print("=================================") 54 | Dtype = args.dtype 55 | NODE_FEATURE_DIM = args.fea_dim 56 | EMBED_DIM = args.embed_dim 57 | EMBED_DEPTH = args.embed_depth 58 | OUTPUT_DIM = args.output_dim 59 | ITERATION_LEVEL = args.iter_level 60 | LEARNING_RATE = args.lr 61 | MAX_EPOCH = args.epoch 62 | BATCH_SIZE = args.batch_size 63 | LOAD_PATH = args.load_path 64 | LOG_PATH = args.log_path 65 | DEVICE = args.use_device 66 | 67 | NODE_FEATURE_DIM = fea_dim 68 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device 69 | LOAD_PATH = model_load_path 70 | DEVICE = use_device 71 | 72 | t_pairs = get_true_pairs(t_pairs_path) 73 | print("true pairs: ", len(t_pairs)) 74 | 75 | 76 | # Model 77 | gnn = graphnn( 78 | N_x = NODE_FEATURE_DIM, 79 | Dtype = Dtype, 80 | N_embed = EMBED_DIM, 81 | depth_embed = EMBED_DEPTH, 82 | N_o = OUTPUT_DIM, 83 | ITER_LEVEL = ITERATION_LEVEL, 84 | lr = LEARNING_RATE, 85 | device = DEVICE 86 | ) 87 | gnn.init(LOAD_PATH, LOG_PATH) 88 | 89 | recall = get_recall_epoch_batch(gnn, t_pairs, BATCH_SIZE, thres) 90 | gnn.say("recall rate = {0} @ {1}".format(recall, datetime.now())) 91 | print(recall) 92 | return recall 93 | # print(max((1-fpr+tpr)/2)) 94 | # index = np.argmax((1-fpr+tpr)/2) 95 | # print("index:", index) 96 | # print("fpr", fpr[index]) 97 | # print("tpr", tpr[index]) 98 | # print(thres[index]) 99 | 100 | 101 | if __name__ == '__main__': 102 | thres7 = 0.7367 103 | # thres76 = 0.7482 104 | thres76 = 0.7532 105 | 106 | model_7_fea_dim = '../data/saved_model/graphnn_model_gemini/graphnn_model_gemini_best' 107 | # model_76_fea_dim = '../data/saved_model/graphnn_model_ghidra/saved_ghidra_model_best' 108 | model_76_fea_dim = '../data/saved_model/graphnn_model_ghidra_depth5/graphnn_model_ghidra_best' 109 | 110 | gpu_device = '/gpu:3' 111 | 112 | pairs_7_fea_dim_dir = '../data/validation_pairs/valid_pairs_v1/7_fea_dim' 113 | pair_fs = os.listdir(pairs_7_fea_dim_dir) 114 | res7 = {} 115 | # for f in pair_fs: 116 | # recall7 = eval(7, model_7_fea_dim, os.path.join(pairs_7_fea_dim_dir, f), gpu_device, thres7) 117 | # res7[f] = recall7 118 | 119 | pairs_76_fea_dim_dir = '../data/validation_pairs/valid_pairs_v1/76_fea_dim' 120 | pair_fs = os.listdir(pairs_76_fea_dim_dir) 121 | res76 = {} 122 | for f in pair_fs: 123 | if f != 'cc_version_diff_76_fea_dim.json': 124 | continue 125 | recall76 = eval(76, model_76_fea_dim, os.path.join(pairs_76_fea_dim_dir, f), gpu_device, thres76) 126 | res76[f] = recall76 127 | 128 | # print("recall:") 129 | # for i in res7: 130 | # print(i, " ", res7[i]) 131 | 132 | for i in res76: 133 | print(i, " ", res76[i]) 134 | 135 | # plt.figure() 136 | # plt.title('ROC CURVE') 137 | # plt.xlabel('False Positive Rate') 138 | # plt.ylabel('True Positive Rate') 139 | # plt.plot(fpr7,tpr7,color='b') 140 | # plt.plot(fpr76, tpr76,color='r') 141 | # # plt.plot([0, 1], [0, 1], color='m', linestyle='--') 142 | # plt.savefig('auc.png') 143 | 144 | 145 | -------------------------------------------------------------------------------- /main/torch/eval_re_large.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import pickle as pkl 3 | import os 4 | import json 5 | import numpy as np 6 | from milvus import MetricType 7 | import time 8 | from milvus_mod import mil 9 | 10 | 11 | def read_pkl(pkl_path): 12 | with open(pkl_path, 'rb') as f: 13 | content = pkl.load(f) 14 | return content 15 | 16 | 17 | def save_pkl(content, save_path): 18 | with open(save_path, 'wb') as f: 19 | pkl.dump(content, f) 20 | 21 | 22 | def read_json(js_path): 23 | with open(js_path, 'r') as f: 24 | content = json.load(f) 25 | return content 26 | 27 | 28 | def generate_vec_index(data_path, id2key_save_path, id2vec_save_path, norm): 29 | raw_data_files = os.listdir(data_path) 30 | raw_data = {} 31 | for f in raw_data_files: 32 | with open(os.path.join(data_path, f), 'rb') as load_f: 33 | raw_data.update(pkl.load(load_f)) 34 | id = 0 35 | id2key = {} 36 | id2vec = {} 37 | for key in raw_data: 38 | id += 1 39 | if id % 100000 == 0: 40 | print(id) 41 | id2key[id] = key 42 | if norm: 43 | id2vec[id] = norm_vec(raw_data[key]) 44 | else: 45 | id2vec[id] = raw_data[key] 46 | print(id) 47 | save_pkl(id2key, id2key_save_path) 48 | save_pkl(id2vec, id2vec_save_path) 49 | 50 | 51 | def get_filtered_data(id2key, id2vec, pid2pname, save_path, norm): 52 | test_p = ['sqlite', 'stunnel', 'yasm', 'ytasm'] 53 | id2keys = read_pkl(id2key) 54 | id2vecs = read_pkl(id2vec) 55 | pid2pnames = read_json(pid2pname) 56 | count = 0 57 | print(len(id2vecs)) 58 | deleted_p = [] 59 | for id in list(id2vecs.keys()): 60 | count += 1 61 | if norm: 62 | id2vecs[id] = norm_vec(id2vecs[id]) 63 | pid = id2keys[id][0] 64 | pname = pid2pnames[pid] 65 | fname = id2keys[id][1] 66 | for i in test_p: 67 | if i in pname or i in fname: 68 | deleted_p.append(pname) 69 | del id2vecs[id] 70 | print(len(id2vecs)) 71 | save_pkl(id2vecs, save_path) 72 | return id2vecs 73 | 74 | 75 | def norm_vec(vec): 76 | return vec/np.sqrt(sum(vec**2)) 77 | 78 | 79 | def generate_test_cases(test_embeddings_path, test_cases_dir, norm): 80 | test_cases = os.listdir(test_embeddings_path) 81 | id = 50000000 82 | count = 0 83 | if not os.path.exists(test_cases_dir): 84 | os.makedirs(test_cases_dir) 85 | for test_case in test_cases: 86 | id2key = {} 87 | id2lvec = {} 88 | vecs = {} 89 | print(test_case) 90 | with open(os.path.join(test_embeddings_path, test_case), 'rb') as load_f: 91 | true_pairs = pkl.load(load_f) 92 | test_id_vec_pairs = [] 93 | for pair in true_pairs: 94 | count += 1 95 | if count % 10000 == 0: 96 | print(count) 97 | key = list(pair.keys())[0] 98 | l_vec = list(pair.values())[0][0].tolist() 99 | r_vec = list(pair.values())[0][1].tolist() 100 | if key not in vecs: 101 | id += 1 102 | vecs[key] = [[l_vec], [id]] 103 | index = id 104 | else: 105 | if l_vec not in vecs[key][0]: 106 | id += 1 107 | vecs[key][0].append(l_vec) 108 | vecs[key][1].append(id) 109 | index = id 110 | else: 111 | index = vecs[key][0].index(l_vec) 112 | index = vecs[key][1][index] 113 | id2key[index] = key 114 | if norm: 115 | id2lvec[index] = norm_vec(np.array(l_vec)) 116 | test_id_vec_pairs.append([index, norm_vec(np.array(r_vec))]) 117 | else: 118 | id2lvec[index] = np.array(l_vec) 119 | test_id_vec_pairs.append([index, np.array(r_vec)]) 120 | with open(os.path.join(test_cases_dir, test_case), 'wb') as fo: 121 | pkl.dump(test_id_vec_pairs, fo) 122 | with open(os.path.join(test_cases_dir, 'id2key_'+test_case), 'wb') as fo: 123 | pkl.dump(id2key, fo) 124 | with open(os.path.join(test_cases_dir, 'id2vec_'+test_case), 'wb') as fo: 125 | pkl.dump(id2lvec, fo) 126 | 127 | 128 | def build_database(collection_name, id2vecs, clear_old, metrictype): 129 | m = mil() 130 | m.load_data(list(id2vecs.values()), list(id2vecs.keys()), 131 | collection_name, clear_old=clear_old, metrictype=metrictype) 132 | 133 | 134 | def recall_rate(labels, query_results): 135 | total_num = len(labels) 136 | recall_num = 0 137 | for i in range(len(labels)): 138 | recall_ids = [] 139 | for cand in query_results[i]: 140 | recall_ids.append(cand.id) 141 | if labels[i] in recall_ids: 142 | recall_num += 1 143 | return recall_num / total_num 144 | 145 | 146 | def get_4_recalls(labels, query_results): 147 | total_num = len(labels) 148 | recall_num = [0, 0, 0, 0] 149 | for i in range(len(labels)): 150 | recall_ids = [] 151 | for cand in query_results[i]: 152 | recall_ids.append(cand.id) 153 | if labels[i] in recall_ids[:10]: 154 | recall_num[0] += 1 155 | if labels[i] in recall_ids[:20]: 156 | recall_num[1] += 1 157 | if labels[i] in recall_ids[:50]: 158 | recall_num[2] += 1 159 | if labels[i] in recall_ids: 160 | recall_num[3] += 1 161 | return (100 * np.array(recall_num) / total_num).tolist() 162 | 163 | 164 | def test_recall(test_id_vec, m, collection_name, k): 165 | l_ids = [] 166 | r_vecs = [] 167 | for pair in test_id_vec: 168 | l_ids.append(pair[0]) 169 | r_vecs.append(pair[1]) 170 | print("query length: ", len(r_vecs)) 171 | start = time.time() 172 | results = m.query(np.array(r_vecs), collection_name, k) 173 | if results: 174 | recalls = get_4_recalls(l_ids, results) 175 | # return [round(recall*100, 2), round(time.time()-start, 2), len(r_vecs)] 176 | return recalls 177 | else: 178 | return None 179 | 180 | 181 | def test_one_case(test_id_vec_pair, collection_name, id2vec_pkl, clear_old, metrictype): 182 | m = mil() 183 | print("data num:", m.get_count(collection_name)) 184 | id2vecs = read_pkl(id2vec_pkl) 185 | m.load_data(list(id2vecs.values()), list(id2vecs.keys()), 186 | collection_name, clear_old=clear_old, metrictype=metrictype) 187 | print("data num after added:", m.get_count(collection_name)) 188 | test_id_vec = read_pkl(test_id_vec_pair) 189 | print("test: ", test_id_vec_pair) 190 | # res_10 = test_recall(test_id_vec, m, collection_name, 10) 191 | # res_20 = test_recall(test_id_vec, m, collection_name, 20) 192 | # res_50 = test_recall(test_id_vec, m, collection_name, 50) 193 | res = test_recall(test_id_vec, m, collection_name, 100) 194 | 195 | m.delete_entities_by_ids(list(id2vecs.keys()), collection_name) 196 | print("data num after tested and deleted:", m.get_count(collection_name)) 197 | # return [res_10, res_20, res_50, res_100] 198 | return res 199 | 200 | 201 | def test_all(test_cases_dir, collection_name, clear_old, metrictype): 202 | res = {} 203 | fs = os.listdir(test_cases_dir) 204 | for f in fs: 205 | if f.startswith('id2'): 206 | continue 207 | test_case = os.path.join(test_cases_dir, f) 208 | # test_id_vec_pair = read_pkl(test_case) 209 | id2vec_pkl = os.path.join(test_cases_dir, 'id2vec_'+f) 210 | res[test_case] = test_one_case( 211 | test_case, collection_name, id2vec_pkl, clear_old, metrictype) 212 | print(res) 213 | 214 | for test_case in res: 215 | print('{0} & {1}& {2}& {3}& {4}'.format(test_case, str(round(res[test_case][0], 2)), str( 216 | round(res[test_case][1], 2)), str(round(res[test_case][2], 2)), str(round(res[test_case][3], 2)))) 217 | 218 | 219 | 220 | if __name__ == '__main__': 221 | model = '7fea_contra_torch_b128' 222 | # model = '76fea_triple_torch_negb128_epo150' 223 | # model = '76fea_triple_torch_epo150_gap0.5' 224 | # collection_name = '_76fea_triple_torch_epo150_gap05_filtered' 225 | 226 | PINFO = '../data/core_funcs_vectors/pid2package.json' 227 | ID2KEY = '../data/{0}/core_funcs/id2key.pkl'.format( 228 | model) 229 | ID2VEC = '../data/{0}/core_funcs/id2vec.pkl'.format( 230 | model) 231 | 232 | data_path = '../data/{0}/core_funcs'.format( 233 | model) 234 | # generate_vec_index(data_path, ID2KEY, ID2VEC, norm=False) 235 | 236 | # filtered_id2vecs_path = '../data/{0}/core_funcs/filtered_id2vecs.pkl'.format( 237 | # model) 238 | # filtered_id2vecs = get_filtered_data( 239 | # ID2KEY, ID2VEC, PINFO, filtered_id2vecs_path, True) 240 | 241 | # test_embeddings_path = '../data/{0}/valid_pairs'.format( 242 | # model) 243 | # test_cases_dir = '../data/{0}/test_cases'.format( 244 | # model) 245 | # generate_test_cases(test_embeddings_path, test_cases_dir, norm=True) 246 | 247 | # collection_name = '_76fea_triple_torch_negb128_epo150_filtered' 248 | # collection_name = '_7fea_contra_torch_b128_filtered' 249 | 250 | # filtered_id2vecs = read_pkl(filtered_id2vecs_path) 251 | # build_database(collection_name, filtered_id2vecs, 252 | # clear_old=True, metrictype=MetricType.IP) 253 | 254 | # test_all(test_cases_dir, collection_name, 255 | # clear_old=False, metrictype=MetricType.IP) 256 | 257 | collection_name = '_7fea_contra_torch_b128_all' 258 | all_id2vecs = read_pkl(ID2VEC) 259 | build_database(collection_name, all_id2vecs, 260 | clear_old=True, metrictype=MetricType.IP) 261 | -------------------------------------------------------------------------------- /main/torch/func2vec.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import argparse 5 | import os 6 | from torch.autograd import Variable 7 | 8 | from utils_loss import graph 9 | 10 | 11 | class func2vec(object): 12 | def __init__(self, load_path, gpu, fea_dim) -> None: 13 | self.gpu = gpu 14 | self.fea_dim = fea_dim 15 | if self.gpu: 16 | self.model = torch.load(load_path).cuda() 17 | else: 18 | self.model = torch.load(load_path, map_location='cpu') 19 | self.model.eval() 20 | 21 | def get_X_mask(self, cur_graph): 22 | X1_input = np.zeros( 23 | (1, cur_graph.node_num, len(cur_graph.features[0]))) 24 | mask1 = np.zeros((1, cur_graph.node_num, cur_graph.node_num)) 25 | for u in range(cur_graph.node_num): 26 | X1_input[0, u, :] = np.array(cur_graph.features[u]) 27 | for v in cur_graph.succss[u]: 28 | mask1[0, u, v] = 1 29 | return X1_input, mask1 30 | 31 | def get_item(self, func, func_signature): 32 | item = {} 33 | item["src"] = func_signature 34 | item["n_num"] = func['nodes'] 35 | item['succss'] = func['edgePairs'] 36 | if self.fea_dim == 7: 37 | item['features'] = func['nodeGeminiVectors'] 38 | else: 39 | item['features'] = func['nodeGhidraVectors'] 40 | item['fname'] = func_signature 41 | return item 42 | 43 | def get_graph(self, item): 44 | cur_graph = graph(item['n_num'], item['fname'], item['src']) 45 | for u in range(item['n_num']): 46 | cur_graph.features[u] = np.array(item['features'][u]) 47 | for v in item['succss'][u]: 48 | cur_graph.add_edge(u, v) 49 | return cur_graph 50 | 51 | def get_embedding_from_func_graph(self, func_graph): 52 | X1_input, mask1 = self.get_X_mask(func_graph) 53 | if self.gpu: 54 | X1_input, mask1 = torch.from_numpy(X1_input).float().cuda(), torch.from_numpy(mask1).float().cuda() 55 | else: 56 | X1_input, mask1 = torch.from_numpy(X1_input).float(), torch.from_numpy(mask1).float() 57 | vec = self.model.predict(X1_input, mask1)[0] 58 | return vec 59 | 60 | def get_embedding_from_func_fea(self, func_fea, correct_edges, func_sig=None): 61 | if correct_edges: 62 | new_edgePairs = [] 63 | for i in range(func_fea['nodes']): 64 | new_edgePairs.append([]) 65 | for i in func_fea['edgePairs']: 66 | if i[1]: 67 | new_edgePairs[i[0]].append(i[1]) 68 | func_fea['edgePairs'] = new_edgePairs 69 | item = self.get_item(func_fea, func_sig) 70 | func_graph = self.get_graph(item) 71 | vec = self.get_embedding_from_func_graph(func_graph) 72 | return vec 73 | 74 | def get_embeddings_from_func_fea_l(self, func_fea_l, correct_edges, func_sig=None): 75 | func_gs = [] 76 | for func_fea in func_fea_l: 77 | if correct_edges: 78 | new_edgePairs = [] 79 | for i in range(len(func_fea['nodesAsm'])): 80 | new_edgePairs.append([]) 81 | for i in func_fea['edgePairs']: 82 | if i[1]: 83 | new_edgePairs[i[0]].append(i[1]) 84 | func_fea['edgePairs'] = new_edgePairs 85 | item = self.get_item(func_fea, func_sig) 86 | func_graph = self.get_graph(item) 87 | func_gs.append(func_graph) 88 | maxN = 0 89 | for g in func_gs: 90 | maxN = max(maxN, g.node_num) 91 | fea_dim = len(g.features[0]) 92 | X_input = np.zeros((len(func_gs), maxN, fea_dim)) 93 | node_mask = np.zeros((len(func_gs), maxN, maxN)) 94 | for i in range(len(func_gs)): 95 | g = func_gs[i] 96 | for u in range(g.node_num): 97 | X_input[i, u, :] = np.array(g.features[u]) 98 | for v in g.succss[u]: 99 | node_mask[i, u, v] = 1 100 | if self.gpu: 101 | X_input, mask = torch.from_numpy(X_input).float().cuda(), torch.from_numpy(node_mask).float().cuda() 102 | else: 103 | X_input, mask = torch.from_numpy(X_input).float(), torch.from_numpy(node_mask).float() 104 | with torch.no_grad(): 105 | vecs = self.model.predict(X_input, mask) 106 | return vecs 107 | 108 | 109 | def get_vecs_from_bin_fea(self, bin_fea, correct_edges): 110 | res = {} 111 | for func in bin_fea['binFileFeature']['functions']: 112 | if func['nodes'] < 5: 113 | continue 114 | if func['isThunkFunction'] is True or 'text' not in func['memoryBlock']: 115 | continue 116 | key = func['entryPoint'] 117 | vec = self.get_embedding_from_func_fea( 118 | func, correct_edges, func_sig=key) 119 | res[key] = vec 120 | return res 121 | 122 | def get_vecs_from_package_fea(self, package_fea, correct_edges): 123 | res = {} 124 | for bin_fea in package_fea: 125 | if 'binFileFeature' not in bin_fea: 126 | res[bin_fea['formattedFileName']] = {} 127 | res[bin_fea['formattedFileName']] = self.get_vecs_from_bin_fea( 128 | bin_fea, correct_edges) 129 | return res 130 | 131 | def get_vecs_from_package_fea_json(self, package_fea_json, correct_edges): 132 | with open(package_fea_json, 'r') as load_f: 133 | package_fea = json.load(load_f) 134 | return self.get_vecs_from_package_fea(package_fea, correct_edges) 135 | 136 | 137 | if __name__ == '__main__': 138 | ## test 139 | 140 | args = parser.parse_args() 141 | NODE_FEATURE_DIM = args.fea_dim 142 | EMBED_DIM = args.embed_dim 143 | EMBED_DEPTH = args.embed_depth 144 | OUTPUT_DIM = args.output_dim 145 | ITERATION_LEVEL = args.iter_level 146 | LEARNING_RATE = args.lr 147 | MAX_EPOCH = args.epoch 148 | BATCH_SIZE = args.batch_size 149 | NEG_BATCH_SIZE = args.neg_batch_size 150 | LOAD_PATH = args.load_path 151 | SAVE_PATH = args.save_path 152 | LOG_PATH = args.log_path 153 | DEVICE = args.use_device 154 | WORKERS = args.workers 155 | 156 | func2vec_mod = func2vec('../data/saved_model/graphnn_model_7fea_contra_cos_torch/model-inter-best.pt') 157 | package_fea_json = '/raid/data/CoreFedoraFeatureJson0505/00/00/9210000/9210000.json' 158 | res = func2vec_mod.get_vecs_from_package_fea_json(package_fea_json, True, False) 159 | -------------------------------------------------------------------------------- /main/torch/generate_vec_index.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import os 3 | import json 4 | import numpy as np 5 | 6 | # vec -> numpy array 7 | def norm_vec(vec): 8 | return vec/np.sqrt(sum(vec**2)) 9 | 10 | 11 | def generate_vec_index(data_path, id2key_sava_path, id2vec_save_path, norm): 12 | raw_data_files = os.listdir(data_path) 13 | raw_data = {} 14 | for f in raw_data_files: 15 | with open(os.path.join(data_path, f), 'rb') as load_f: 16 | raw_data.update(pkl.load(load_f)) 17 | id = 0 18 | id2key = {} 19 | id2vec = {} 20 | for key in raw_data: 21 | id += 1 22 | if id % 100000 == 0: 23 | print(id) 24 | id2key[id] = key 25 | if norm: 26 | id2vec[id] = norm_vec(raw_data[key]) 27 | else: 28 | id2vec[id] = raw_data[key] 29 | print(id) 30 | with open(id2key_sava_path, 'wb') as fo: 31 | pkl.dump(id2key, fo) 32 | with open(id2vec_save_path, 'wb') as fo: 33 | pkl.dump(id2vec, fo) 34 | 35 | 36 | 37 | def get_true_pairs(js_path): 38 | true_pairs = [] 39 | with open(js_path) as load_f: 40 | for line in load_f: 41 | pair = json.loads(line.strip()) 42 | true_pairs.append(pair) 43 | return true_pairs 44 | 45 | def save_data(data, save_path): 46 | for item in data: 47 | with open(save_path, 'a+') as f: 48 | line = json.dumps(item) 49 | f.write(line+'\n') 50 | 51 | def generate_valid_vec_index(js_dir, id_vec_dir, norm): 52 | js_fs = os.listdir(js_dir) 53 | id2key = {} 54 | id2vec = {} 55 | vecs = {} 56 | id = 50000000 57 | count =0 58 | if not os.path.exists(id_vec_dir): 59 | os.makedirs(id_vec_dir) 60 | for f in js_fs: 61 | print(f) 62 | # true_pairs = get_true_pairs(os.path.join(js_dir, f)) 63 | with open(os.path.join(js_dir, f), 'rb') as load_f: 64 | true_pairs = pkl.load(load_f) 65 | id_vec_pairs = [] 66 | for pair in true_pairs: 67 | count += 1 68 | if count % 10000 ==0: 69 | print(count) 70 | key = list(pair.keys())[0] 71 | l_vec_pair = list(pair.values())[0][0].tolist() 72 | r_vec_pair = list(pair.values())[0][1].tolist() 73 | if key not in vecs: 74 | id += 1 75 | vecs[key] = [[l_vec_pair],[id]] 76 | index = id 77 | else: 78 | if l_vec_pair not in vecs[key][0]: 79 | id += 1 80 | vecs[key][0].append(l_vec_pair) 81 | vecs[key][1].append(id) 82 | index = id 83 | else: 84 | index = vecs[key][0].index(l_vec_pair) 85 | index = vecs[key][1][index] 86 | id2key[index] = key 87 | if norm: 88 | id2vec[index] = norm_vec(np.array(l_vec_pair)) 89 | id_vec_pairs.append([index, norm_vec(np.array(r_vec_pair)).tolist()]) 90 | else: 91 | id2vec[index] = np.array(l_vec_pair) 92 | id_vec_pairs.append([index, np.array(r_vec_pair).tolist()]) 93 | # save_data(id_vec_pairs, os.path.join(id_vec_dir, f)) 94 | with open(os.path.join(id_vec_dir, f), 'wb') as fo: 95 | pkl.dump(id_vec_pairs, fo) 96 | with open(os.path.join(id_vec_dir, 'id2key.pkl'), 'wb') as fo: 97 | pkl.dump(id2key, fo) 98 | with open(os.path.join(id_vec_dir, 'id2vec.pkl'), 'wb') as fo: 99 | pkl.dump(id2vec, fo) 100 | 101 | def check(): 102 | id2vec = '../data/validation_pairs/valid_pairs_v1/id_vec/id_vec_7fea/id2vec.pkl' 103 | with open(id2vec, 'rb') as fo: 104 | a=pkl.load(fo) 105 | pairs = get_true_pairs('../data/validation_pairs/valid_pairs_v1/id_vec/id_vec_7fea/cc_version_diff_7_fea_dim.json') 106 | count = 0 107 | for pair in pairs: 108 | target_vec = np.array(pair[1]) 109 | base_vec = np.array(a[pair[0]]) 110 | cos = sum(base_vec * target_vec) 111 | if 0.5 + cos/2 > 0.7367: 112 | count += 1 113 | print(count/len(pairs)) 114 | 115 | 116 | if __name__ == '__main__': 117 | data_7fea = '../data/7fea_contra_tf/core_funcs' 118 | id2key_7fea = '../data/7fea_contra_tf/core_funcs/id2key.pkl' 119 | id2vec_7fea = '../data/7fea_contra_tf/core_funcs/id2vec.pkl' 120 | 121 | generate_vec_index(data_7fea, id2key_7fea, id2vec_7fea, norm=False) 122 | 123 | 124 | 125 | valid_vec_pairs_dir = '../data/7fea_contra_tf/valid_pairs' 126 | id_vec_valid_dir = '../data/7fea_contra_tf/id_vec_valid' 127 | 128 | generate_valid_vec_index(valid_vec_pairs_dir, id_vec_valid_dir, norm = True) 129 | 130 | # check() -------------------------------------------------------------------------------- /main/torch/get_data_gemini_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import numpy as np 5 | 6 | def save_data(data, save_path): 7 | for item in data: 8 | with open(save_path, 'a+') as f: 9 | line = json.dumps(item) 10 | f.write(line+'\n') 11 | 12 | 13 | def get_all_compilation_cases(data_path): 14 | libraries = os.listdir(data_path) 15 | all_compilation_cases = [] 16 | for library in libraries: 17 | print(library) 18 | if library.startswith('.'): 19 | continue 20 | compilation_cases = os.listdir(os.path.join(data_path, library)) 21 | all_compilation_cases += compilation_cases 22 | return set(all_compilation_cases) 23 | 24 | 25 | def get_filtered_compilation_cases(data_path): 26 | libraries = os.listdir(data_path) 27 | all_compilation_cases = [] 28 | names = [] 29 | for library in libraries: 30 | print(library) 31 | if library.startswith('.'): 32 | continue 33 | f_names = os.listdir(os.path.join(data_path, library)) 34 | 35 | for f_name in f_names: 36 | if f_name.startswith('arm_x86') or f_name.startswith('linux_gcc_5') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_7') or f_name.startswith('linux_gcc_8') or f_name.startswith('mac_gcc_8') or f_name.startswith("linux_clang_3.8") or f_name.startswith("linux_clang_4.0") or f_name.startswith("linux_clang_5.0") or f_name.startswith("mac_gcc_7"): 37 | continue 38 | all_compilation_cases.append(f_name) 39 | return set(all_compilation_cases) 40 | 41 | def get_filtered_compilation_cases_without_oplevel(data_path): 42 | compilation_cases = get_filtered_compilation_cases(data_path) 43 | res = [] 44 | for case in compilation_cases: 45 | res.append(case[:-3]) 46 | return set(res) 47 | 48 | 49 | def get_deduplicate_data_vectors(data_path, save_path, train_test): 50 | all_compilation_cases = get_filtered_compilation_cases(data_path) 51 | print(len(all_compilation_cases)) 52 | print(all_compilation_cases) 53 | libraries = os.listdir(data_path) 54 | replace_words = ['.exe', '.so', '.dylib', '.dll', '.json'] 55 | rex_num = re.compile('\_\d*$') 56 | all_function_features = {} 57 | valid_filenames = ['libsqlite3.0', 'sqlite3', 'ndisasm', 'nasm', 'libassuan.0', 'm4', 'libksba.8', 'yat2m', 'libgpg-error.0', 'gpg-error', 'gpg', 'gpgkeys_ldap', 'gpgsplit', 'gpgkeys_curl', 'gpgkeys_finger', 'gpgkeys_hkp', 'gpgv', 'libgcrypt.20','dumpsexp','mpicalc', 'hmac256','libnpth.0','libpng','libstunnel', 'stunnel', 'vsyasm'] 58 | valid_libs = ['sqlite-autoconf-3330000', 'stunnel-5.56', 'yasm-1.3.0'] 59 | 60 | duplicate_num = 0 61 | count = 0 62 | 63 | for compilation_case in all_compilation_cases: 64 | # print(compilation_case) 65 | if compilation_case.startswith('.'): 66 | continue 67 | data = [] 68 | for library in libraries: 69 | if library.startswith('.'): 70 | continue 71 | if train_test: 72 | if library in valid_libs: 73 | continue 74 | else: 75 | if not library in valid_libs: 76 | continue 77 | compilation_cases = os.listdir(os.path.join(data_path, library)) 78 | if compilation_case not in compilation_cases: 79 | continue 80 | 81 | subdir_path = os.path.join( 82 | os.path.join(data_path, library, compilation_case)) 83 | feature_files = os.listdir(subdir_path) 84 | for feature_file in feature_files: 85 | if feature_file.startswith('.') or feature_file == 'status': 86 | continue 87 | filename = feature_file.replace('.json', '') 88 | feature_file_path = os.path.join(subdir_path, feature_file) 89 | for word in replace_words: 90 | filename = filename.replace(word, '') 91 | if filename == 'libnpth-0': 92 | filename = 'libnpth.0' 93 | if filename not in valid_filenames and library != 'gnupg-2.2.23': 94 | continue 95 | all_non_thunk_funcs = get_all_non_thunk_funcs( 96 | feature_file_path) 97 | selected_ones = {} 98 | for func in all_non_thunk_funcs: 99 | if func['nodes'] < 5: 100 | continue 101 | func_name = func['functionName'] 102 | if 'mac' in compilation_case or 'win' in compilation_case: 103 | if func_name.startswith('_'): 104 | func_name = func_name[1:] 105 | func_name = func_name.replace('.', '_') 106 | func_name = rex_num.sub('', func_name) 107 | func_signature = library + '##' + filename + '##' + func_name 108 | 109 | item = get_data_gemini_item(func, func_signature) 110 | if func_signature not in selected_ones: 111 | selected_ones[func_signature] = [] 112 | selected_ones[func_signature].append(item) 113 | for func_signature in selected_ones: 114 | if len(selected_ones[func_signature]) > 1: 115 | continue 116 | added_vector = selected_ones[func_signature][0]['features'] 117 | if func_signature in all_function_features: 118 | if added_vector in all_function_features[func_signature]['f']: 119 | duplicate_num += 1 120 | continue 121 | else: 122 | all_function_features[func_signature]['f'].append(added_vector) 123 | all_function_features[func_signature]['c'].append(compilation_case) 124 | else: 125 | all_function_features[func_signature] = {"f":[], 'c':[]} 126 | all_function_features[func_signature]['f'].append(added_vector) 127 | all_function_features[func_signature]['c'].append(compilation_case) 128 | data.append(selected_ones[func_signature][0]) 129 | save_data(data, os.path.join(save_path, compilation_case+'.json')) 130 | count += len(data) 131 | print(duplicate_num) 132 | print(count) 133 | 134 | 135 | def get_all_non_thunk_funcs(feature_file_path): 136 | with open(feature_file_path, 'r') as feature_file: 137 | feature = json.load(feature_file) 138 | res = [] 139 | for func in feature['binFileFeature']['functions']: 140 | if func['isThunkFunction'] is False and 'text' in func['memoryBlock']: 141 | res.append(func) 142 | return res 143 | 144 | 145 | 146 | # {"src": "openssl-1.0.1f-armeb-linux-O0v54/ectest.o.txt", "n_num": 33, "succs": [[], [0, 5], [28, 22], [24, 2], [], [4, 15], [16, 23], [30, 6], [18, 3], [26, 19], [32, 9], [8, 27], [11, 20], [29, 7], [17, 12], [21, 14], [10, 31], [], [], [1, 25], [], [], [], [], [], [], [], [], [], [], [], [], []], "features": [[0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 1.0, 18.0, 0.0, 2.0, 7.0, 1.0], [0.0, 1.0, 2.0, 0.0, 2.0, 8.0, 1.0], [0.0, 1.0, 4.0, 0.0, 2.0, 12.0, 2.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 1.0, 16.0, 0.0, 2.0, 12.0, 2.0], [0.0, 1.0, 28.0, 0.0, 2.0, 6.0, 1.0], [0.0, 1.0, 30.0, 0.0, 2.0, 12.0, 3.0], [0.0, 1.0, 6.0, 0.0, 2.0, 8.0, 1.0], [0.0, 1.0, 22.0, 0.0, 2.0, 6.0, 1.0], [0.0, 1.0, 24.0, 0.0, 2.0, 12.0, 3.0], [0.0, 1.0, 8.0, 0.0, 2.0, 8.0, 1.0], [0.0, 1.0, 10.0, 0.0, 2.0, 12.0, 2.0], [0.0, 3.0, 32.0, 0.0, 10.0, 39.0, 4.0], [0.0, 1.0, 12.0, 0.0, 3.0, 9.0, 3.0], [0.0, 1.0, 14.0, 0.0, 2.0, 8.0, 1.0], [0.0, 2.0, 26.0, 0.0, 4.0, 17.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 4.0, 20.0, 0.0, 4.0, 22.0, 6.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 7.0, 23.0, 2.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0], [0.0, 2.0, 0.0, 0.0, 4.0, 21.0, 4.0]], "fname": "group_order_tests"} 147 | 148 | def get_data_gemini_item(func, func_signature, compilation_case): 149 | item = {} 150 | item["src"] = func_signature 151 | item["n_num"] = len(func['nodesAsm']) 152 | item['succss'] = func['edgePairs'] 153 | item['features'] = func['nodeGeminiVectors'] 154 | item['fname'] = func_signature 155 | item['compilation'] = compilation_case 156 | return item 157 | 158 | 159 | def get_data_ghidra_item(func, func_signature): 160 | item = {} 161 | item["src"] = func_signature 162 | item["n_num"] = len(func['nodesAsm']) 163 | item['succss'] = func['edgePairs'] 164 | item['features'] = func['nodeGhidraVectors'] 165 | item['fname'] = func_signature 166 | return item 167 | 168 | if __name__ == '__main__': 169 | 170 | feature_data_path = '/mnt/c/Users/user/Desktop/data/featureJson0417' 171 | gemini_data_train_test_path = '/mnt/c/Users/user/Desktop/data/vector_deduplicate_gemini_format_less_compilation_cases/train_test' 172 | gemini_data_valid_path = '/mnt/c/Users/user/Desktop/data/vector_deduplicate_gemini_format_less_compilation_cases/valid' 173 | ghidra_data_train_test_path = '/mnt/c/Users/user/Desktop/data/vector_deduplicate_ghidra_format_less_compilation_cases/train_test' 174 | ghidra_data_valid_path = '/mnt/c/Users/user/Desktop/data/vector_deduplicate_ghidra_format_less_compilation_cases/valid' 175 | 176 | data_path = '/mnt/c/Users/user/Desktop/data/vector_gemini_format/train_test' 177 | 178 | # get_deduplicate_data_vectors(feature_data_path, ghidra_data_train_test_path, True) 179 | # get_deduplicate_data_vectors(feature_data_path, ghidra_data_valid_path, False) 180 | 181 | valid_arm2non_arm_feature_json = '/mnt/c/Users/user/Desktop/data/validation_arm2non_arm/feature_json' 182 | valid_arm2non_arm_save_path = '/mnt/c/Users/user/Desktop/data/validation_arm2non_arm/gemini/validation_arm2non_arm_data' 183 | get_deduplicate_data_vectors(valid_arm2non_arm_feature_json, valid_arm2non_arm_save_path, False) 184 | 185 | # print(get_filtered_compilation_cases_without_oplevel(feature_data_path)) -------------------------------------------------------------------------------- /main/torch/get_threshold.py: -------------------------------------------------------------------------------- 1 | import tensorflow.compat.v1 as tf 2 | print(tf.__version__) 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from datetime import datetime 6 | from graphnnSiamese import graphnn 7 | from utils import * 8 | import os 9 | import argparse 10 | import json 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--device', type=str, default='0,1,2,3', 14 | help='visible gpu device') 15 | parser.add_argument('--use_device', type=str, default='/gpu:1', 16 | help='used gpu device') 17 | parser.add_argument('--fea_dim', type=int, default=76, 18 | help='feature dimension') 19 | parser.add_argument('--embed_dim', type=int, default=64, 20 | help='embedding dimension') 21 | parser.add_argument('--embed_depth', type=int, default=5, 22 | help='embedding network depth') 23 | parser.add_argument('--output_dim', type=int, default=64, 24 | help='output layer dimension') 25 | parser.add_argument('--iter_level', type=int, default=5, 26 | help='iteration times') 27 | parser.add_argument('--lr', type=float, default=1e-4, 28 | help='learning rate') 29 | parser.add_argument('--epoch', type=int, default=100, 30 | help='epoch number') 31 | parser.add_argument('--batch_size', type=int, default=128, 32 | help='batch size') 33 | # parser.add_argument('--load_path', type=str, 34 | # default='../data/saved_model/graphnn_model_ghidra/saved_ghidra_model_best', 35 | # help='path for model loading, "#LATEST#" for the latest checkpoint') 36 | parser.add_argument('--load_path', type=str, 37 | default='../data/saved_model/graphnn_model_ghidra_depth5/graphnn_model_ghidra_best', 38 | help='path for model loading, "#LATEST#" for the latest checkpoint') 39 | parser.add_argument('--log_path', type=str, default=None, 40 | help='path for training log') 41 | 42 | 43 | 44 | 45 | if __name__ == '__main__': 46 | args = parser.parse_args() 47 | args.dtype = tf.float32 48 | print("=================================") 49 | print(args) 50 | print("=================================") 51 | 52 | os.environ["CUDA_VISIBLE_DEVICES"]=args.device 53 | Dtype = args.dtype 54 | NODE_FEATURE_DIM = args.fea_dim 55 | EMBED_DIM = args.embed_dim 56 | EMBED_DEPTH = args.embed_depth 57 | OUTPUT_DIM = args.output_dim 58 | ITERATION_LEVEL = args.iter_level 59 | LEARNING_RATE = args.lr 60 | MAX_EPOCH = args.epoch 61 | BATCH_SIZE = args.batch_size 62 | LOAD_PATH = args.load_path 63 | LOG_PATH = args.log_path 64 | DEVICE = args.use_device 65 | 66 | SHOW_FREQ = 1 67 | TEST_FREQ = 1 68 | SAVE_FREQ = 5 69 | # DATA_FILE_NAME_VALID = '../data/validation_arm2non_arm_gemini_data/' 70 | DATA_FILE_NAME_TRAIN_TEST = '../data/vector_deduplicate_ghidra_format_less_compilation_cases/train_test' 71 | F_PATH_TRAIN_TEST = get_f_name(DATA_FILE_NAME_TRAIN_TEST) 72 | FUNC_NAME_DICT_TRAIN_TEST = get_f_dict(F_PATH_TRAIN_TEST) 73 | 74 | print("start reading data") 75 | Gs_train_test, classes_train_test = read_graph(F_PATH_TRAIN_TEST, FUNC_NAME_DICT_TRAIN_TEST, NODE_FEATURE_DIM) 76 | print("train and test ---- 8:2") 77 | print("{} graphs, {} functions".format(len(Gs_train_test), len(classes_train_test))) 78 | 79 | perm = np.random.permutation(len(classes_train_test)) 80 | Gs_train, classes_train, Gs_test, classes_test =\ 81 | partition_data(Gs_train_test, classes_train_test, [0.8, 0.2], perm) 82 | print("Train: {} graphs, {} functions".format( 83 | len(Gs_train), len(classes_train))) 84 | print("Test: {} graphs, {} functions".format( 85 | len(Gs_test), len(classes_test))) 86 | 87 | 88 | print("valid") 89 | DATA_FILE_NAME_VALID = '../data/vector_deduplicate_ghidra_format_less_compilation_cases/valid' 90 | F_PATH_VALID = get_f_name(DATA_FILE_NAME_VALID) 91 | FUNC_NAME_DICT_VALID = get_f_dict(F_PATH_VALID) 92 | Gs_valid, classes_valid = read_graph(F_PATH_VALID, FUNC_NAME_DICT_VALID, NODE_FEATURE_DIM) 93 | print("{} graphs, {} functions".format(len(Gs_valid), len(classes_valid))) 94 | Gs_valid, classes_valid = partition_data(Gs_valid, classes_valid, [1], list(range(len(classes_valid)))) 95 | 96 | 97 | 98 | 99 | 100 | # Model 101 | gnn = graphnn( 102 | N_x = NODE_FEATURE_DIM, 103 | Dtype = Dtype, 104 | N_embed = EMBED_DIM, 105 | depth_embed = EMBED_DEPTH, 106 | N_o = OUTPUT_DIM, 107 | ITER_LEVEL = ITERATION_LEVEL, 108 | lr = LEARNING_RATE, 109 | device = DEVICE 110 | ) 111 | gnn.init(LOAD_PATH, LOG_PATH) 112 | 113 | auc0, fpr0, tpr0, thres0 = get_auc_epoch_batch(gnn, Gs_train, classes_train, 114 | BATCH_SIZE) 115 | gnn.say("Initial training auc = {0} @ {1}".format(auc0, datetime.now())) 116 | 117 | print(auc0) 118 | print(max((1-fpr0+tpr0)/2)) 119 | index = np.argmax((1-fpr0+tpr0)/2) 120 | print("index:", index) 121 | print("fpr", fpr0[index]) 122 | print("tpr", tpr0[index]) 123 | print(thres0[index]) 124 | 125 | 126 | auc1, fpr1, tpr1, thres1 = get_auc_epoch_batch(gnn, Gs_test, classes_test, 127 | BATCH_SIZE) 128 | gnn.say("Initial testing auc = {0} @ {1}".format(auc1, datetime.now())) 129 | 130 | print(auc1) 131 | print(max((1-fpr1+tpr1)/2)) 132 | index = np.argmax(1-fpr1+tpr1) 133 | print("index:", index) 134 | print("fpr", fpr1[index]) 135 | print("tpr", tpr1[index]) 136 | print(thres1[index]) 137 | 138 | 139 | auc2, fpr2, tpr2, thres2 = get_auc_epoch_batch(gnn, Gs_valid, classes_valid, 140 | BATCH_SIZE) 141 | gnn.say("Initial validation auc = {0} @ {1}".format(auc2, datetime.now())) 142 | 143 | print(auc2) 144 | print(max((1-fpr2+tpr2)/2)) 145 | index = np.argmax((1-fpr2+tpr2)/2) 146 | print("index:", index) 147 | print("fpr", fpr2[index]) 148 | print("tpr", tpr2[index]) 149 | print(thres2[index]) 150 | 151 | plt.figure() 152 | plt.title('ROC CURVE') 153 | plt.xlabel('False Positive Rate') 154 | plt.ylabel('True Positive Rate') 155 | plt.plot(fpr1,tpr1,color='b') 156 | plt.plot(fpr1, 1-fpr1+tpr1, color='b') 157 | plt.plot(fpr2, tpr2,color='r') 158 | plt.plot(fpr2, 1-fpr2+tpr2, color='r') 159 | # plt.plot([0, 1], [0, 1], color='m', linestyle='--') 160 | plt.savefig('auc_depth5.png') -------------------------------------------------------------------------------- /main/torch/get_validation_pairs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | def save_data(data, save_path): 6 | for item in data: 7 | with open(save_path, 'a+') as f: 8 | line = json.dumps(item) 9 | f.write(line+'\n') 10 | 11 | 12 | def get_pairs(l_fea_js, r_fea_js): 13 | pairs = [] 14 | l_fea_d = {} 15 | r_fea_d = {} 16 | f_sigs = [] 17 | with open(l_fea_js) as load_f: 18 | for line in load_f: 19 | l_fea = json.loads(line.strip()) 20 | l_fea_d[l_fea['fname']] = l_fea 21 | 22 | with open(r_fea_js) as load_f: 23 | for line in load_f: 24 | r_fea = json.loads(line.strip()) 25 | r_fea_d[r_fea['fname']] = r_fea 26 | 27 | for f_sig in l_fea_d: 28 | if f_sig in r_fea_d: 29 | pairs.append([l_fea_d[f_sig], r_fea_d[f_sig]]) 30 | f_sigs.append(f_sig) 31 | return pairs, f_sigs 32 | 33 | 34 | def get_max_oplevel_pairs(fea_js_path, pairs_save_path): 35 | f_names = os.listdir(fea_js_path) 36 | valid_cases = [] 37 | for f_name in f_names: 38 | if f_name.startswith('arm_x86') or f_name.startswith('linux_gcc_5') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_7') or f_name.startswith('linux_gcc_8') or f_name.startswith('mac_gcc_8') or f_name.startswith("linux_clang_3.8") or f_name.startswith("linux_clang_4.0") or f_name.startswith("linux_clang_5.0") or f_name.startswith("mac_gcc_7"): 39 | continue 40 | valid_cases.append(f_name[:-6]) 41 | 42 | all_pairs = [] 43 | all_f_sigs = [] 44 | for comp_case in valid_cases: 45 | l_fea_js = os.path.join(fea_js_path, comp_case+'0.json') 46 | r_fea_js = os.path.join(fea_js_path, comp_case+'3.json') 47 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 48 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 49 | all_pairs += pairs 50 | all_f_sigs += f_sigs 51 | print('func num: ', len(set(all_f_sigs))) 52 | print('graph num:', len(all_pairs)*2) 53 | print('pair num: ', len(all_pairs)) 54 | save_data(all_pairs, pairs_save_path) 55 | 56 | 57 | def get_os_diff_linux_win_pairs(fea_js_path, pairs_save_path): 58 | compilation_case_pairs = [['linux_gcc_8_O', 'win_gcc_8.1_O'], ['linux_gcc_7_O', 'win_gcc_7.1_O'], [ 59 | 'linux_gcc_6_O', 'win_gcc_6.2_O'], ['linux_gcc_5_O', 'win_gcc_5.2_O'], ['linux_gcc_4.8_O', 'win_gcc_4.9_O']] 60 | 61 | all_pairs = [] 62 | all_f_sigs = [] 63 | for case_pair in compilation_case_pairs: 64 | for i in ['0', '1', '2', '3']: 65 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+i+'.json') 66 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+i+'.json') 67 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 68 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 69 | all_pairs += pairs 70 | all_f_sigs += f_sigs 71 | print('func num: ', len(set(all_f_sigs))) 72 | print('graph num:', len(all_pairs)*2) 73 | print('pair num: ', len(all_pairs)) 74 | save_data(all_pairs, pairs_save_path) 75 | 76 | 77 | def get_os_diff_linux_mac_pairs(fea_js_path, pairs_save_path): 78 | compilation_case_pairs = [['linux_gcc_9_O', 'mac_gcc_9_O'], ['linux_gcc_8_O', 'mac_gcc_8_O'], [ 79 | 'linux_gcc_7_O', 'mac_gcc_7_O'], ['linux_gcc_6_O', 'mac_gcc_6_O']] 80 | 81 | all_pairs = [] 82 | all_f_sigs = [] 83 | for case_pair in compilation_case_pairs: 84 | for i in ['0', '1', '2', '3']: 85 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+i+'.json') 86 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+i+'.json') 87 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 88 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 89 | all_pairs += pairs 90 | all_f_sigs += f_sigs 91 | print('func num: ', len(set(all_f_sigs))) 92 | print('graph num:', len(all_pairs)*2) 93 | print('pair num: ', len(all_pairs)) 94 | save_data(all_pairs, pairs_save_path) 95 | 96 | 97 | def get_os_diff_win_mac_pairs(fea_js_path, pairs_save_path): 98 | compilation_case_pairs = [['win_gcc_8.1_O', 'mac_gcc_8_O'], [ 99 | 'win_gcc_7.1_O', 'mac_gcc_7_O'], ['win_gcc_6.2_O', 'mac_gcc_6_O']] 100 | 101 | all_pairs = [] 102 | all_f_sigs = [] 103 | for case_pair in compilation_case_pairs: 104 | for i in ['0', '1', '2', '3']: 105 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+i+'.json') 106 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+i+'.json') 107 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 108 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 109 | all_pairs += pairs 110 | all_f_sigs += f_sigs 111 | print('func num: ', len(set(all_f_sigs))) 112 | print('graph num:', len(all_pairs)*2) 113 | print('pair num: ', len(all_pairs)) 114 | save_data(all_pairs, pairs_save_path) 115 | 116 | 117 | def get_cc_diff_pairs(fea_js_path, pairs_save_path): 118 | compilation_case_pairs = [['linux_clang_3.5_O', 'linux_gcc_4.8_O'], ['linux_clang_3.8_O', 'linux_gcc_5_O'], [ 119 | 'linux_clang_4.0_O', 'linux_gcc_6_O'], ['linux_clang_5.0_O', 'linux_gcc_7_O'], ['linux_clang_6.0_O', 'linux_gcc_8_O']] 120 | 121 | all_pairs = [] 122 | all_f_sigs = [] 123 | for case_pair in compilation_case_pairs: 124 | for i in ['0', '1', '2', '3']: 125 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+i+'.json') 126 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+i+'.json') 127 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 128 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 129 | all_pairs += pairs 130 | all_f_sigs += f_sigs 131 | print('func num: ', len(set(all_f_sigs))) 132 | print('graph num:', len(all_pairs)*2) 133 | print('pair num: ', len(all_pairs)) 134 | save_data(all_pairs, pairs_save_path) 135 | 136 | 137 | def get_arch_diff_pairs(fea_js_path, pairs_save_path): 138 | compilation_case_pairs = [['arm_arm-linux-gnueabi-gcc_5.4_O', 'linux_gcc_5_O'], ['arm_arm-linux-gnueabihf-gcc_5.4_O', 'linux_gcc_5_O']] 139 | 140 | all_pairs = [] 141 | all_f_sigs = [] 142 | for case_pair in compilation_case_pairs: 143 | for i in ['0', '1', '2', '3']: 144 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+i+'.json') 145 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+i+'.json') 146 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 147 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 148 | all_pairs += pairs 149 | all_f_sigs += f_sigs 150 | print('func num: ', len(set(all_f_sigs))) 151 | print('graph num:', len(all_pairs)*2) 152 | print('pair num: ', len(all_pairs)) 153 | save_data(all_pairs, pairs_save_path) 154 | 155 | 156 | def get_max_diff_pairs(fea_js_path, pairs_save_path): 157 | compilation_case_pairs = [['arm_arm-linux-gnueabihf-gcc_5.4_O3', 'mac_clang_12_O0'], ['arm_arm-linux-gnueabihf-gcc_5.4_O0', 158 | 'mac_clang_12_O3'], ['arm_arm-linux-gnueabi-gcc_5.4_O0', 'win_gcc_8.1_O3'], ['arm_arm-linux-gnueabi-gcc_5.4_O3', 'win_gcc_8.1_O0']] 159 | 160 | all_pairs = [] 161 | all_f_sigs = [] 162 | for case_pair in compilation_case_pairs: 163 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+'.json') 164 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+'.json') 165 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 166 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 167 | all_pairs += pairs 168 | all_f_sigs += f_sigs 169 | 170 | print('func num: ', len(set(all_f_sigs))) 171 | print('graph num:', len(all_pairs)*2) 172 | print('pair num: ', len(all_pairs)) 173 | save_data(all_pairs, pairs_save_path) 174 | 175 | 176 | def get_cc_version_diff_pairs(fea_js_path, pairs_save_path): 177 | compilation_case_pairs = [['linux_clang_3.5_O', 'linux_clang_5.0_O'], ['linux_clang_3.8_O', 'linux_clang_6.0_O'], ['linux_gcc_4.8_O', 'linux_gcc_8_O'], ['linux_gcc_6_O', 'linux_gcc_9_O']] 178 | 179 | all_pairs = [] 180 | all_f_sigs = [] 181 | for case_pair in compilation_case_pairs: 182 | for i in ['0', '1', '2', '3']: 183 | l_fea_js = os.path.join(fea_js_path, case_pair[0]+i+'.json') 184 | r_fea_js = os.path.join(fea_js_path, case_pair[1]+i+'.json') 185 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 186 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 187 | all_pairs += pairs 188 | all_f_sigs += f_sigs 189 | print('func num: ', len(set(all_f_sigs))) 190 | print('graph num:', len(all_pairs)*2) 191 | print('pair num: ', len(all_pairs)) 192 | save_data(all_pairs, pairs_save_path) 193 | 194 | 195 | 196 | def get_oplevel_pairs(fea_js_path, pairs_save_path, l_level, r_level): 197 | f_names = os.listdir(fea_js_path) 198 | valid_cases = [] 199 | for f_name in f_names: 200 | if f_name.startswith('arm_x86') or f_name.startswith('linux_gcc_5') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_7') or f_name.startswith('linux_gcc_8') or f_name.startswith('mac_gcc_8') or f_name.startswith("linux_clang_3.8") or f_name.startswith("linux_clang_4.0") or f_name.startswith("linux_clang_5.0") or f_name.startswith("mac_gcc_7"): 201 | continue 202 | valid_cases.append(f_name[:-6]) 203 | 204 | all_pairs = [] 205 | all_f_sigs = [] 206 | for comp_case in valid_cases: 207 | l_fea_js = os.path.join(fea_js_path, comp_case+'{0}.json'.format(l_level)) 208 | r_fea_js = os.path.join(fea_js_path, comp_case+'{0}.json'.format(r_level)) 209 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 210 | pairs, f_sigs = get_pairs(l_fea_js, r_fea_js) 211 | all_pairs += pairs 212 | all_f_sigs += f_sigs 213 | print('func num: ', len(set(all_f_sigs))) 214 | print('graph num:', len(all_pairs)*2) 215 | print('pair num: ', len(all_pairs)) 216 | save_data(all_pairs, pairs_save_path) 217 | 218 | 219 | def unique_l3(fea_js_path, pairs_save_path, l_level=0, r_level=3): 220 | f_names = os.listdir(fea_js_path) 221 | valid_cases = [] 222 | for f_name in f_names: 223 | if f_name.startswith('arm_x86') or f_name.startswith('linux_gcc_5') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_7') or f_name.startswith('linux_gcc_8') or f_name.startswith('mac_gcc_8') or f_name.startswith("linux_clang_3.8") or f_name.startswith("linux_clang_4.0") or f_name.startswith("linux_clang_5.0") or f_name.startswith("mac_gcc_7"): 224 | continue 225 | valid_cases.append(f_name[:-6]) 226 | 227 | all_pairs = [] 228 | all_f_sigs = [] 229 | for comp_case in valid_cases: 230 | l_fea_js = os.path.join(fea_js_path, comp_case+'{0}.json'.format(l_level)) 231 | r_fea_js = os.path.join(fea_js_path, comp_case+'{0}.json'.format(r_level)) 232 | if os.path.exists(l_fea_js) and os.path.exists(r_fea_js): 233 | pairs = [] 234 | l_fea_d = {} 235 | r_fea_d = {} 236 | f_sigs = [] 237 | with open(l_fea_js) as load_f: 238 | for line in load_f: 239 | l_fea = json.loads(line.strip()) 240 | l_fea_d[l_fea['fname']] = l_fea 241 | 242 | with open(r_fea_js) as load_f: 243 | for line in load_f: 244 | r_fea = json.loads(line.strip()) 245 | r_fea_d[r_fea['fname']] = r_fea 246 | 247 | for f_sig in r_fea_d: 248 | if f_sig not in l_fea_d: 249 | print(f_sig) 250 | 251 | 252 | if __name__ == '__main__': 253 | fea_js_gemini = '/mnt/d/data/vector_gemini_format/valid' 254 | fea_js_ghidra = '../data/func_comparison/vector_ghidra_format/valid' 255 | 256 | # subset1: os diff subset, linux - win 257 | os_diff_linux_win_7_fea_dim = '/mnt/d/data/valid_pairs_v1/os_diff_linux_win_7_fea_dim.json' 258 | os_diff_linux_win_76_fea_dim = '/mnt/d/data/valid_pairs_v1/os_diff_linux_win_76_fea_dim.json' 259 | # get_os_diff_linux_win_pairs(fea_js_gemini, os_diff_linux_win_7_fea_dim) 260 | # get_os_diff_linux_win_pairs(fea_js_ghidra, os_diff_linux_win_76_fea_dim) 261 | 262 | # subset1: os diff subset, linux - mac 263 | os_diff_linux_mac_7_fea_dim = '/mnt/d/data/valid_pairs_v1/os_diff_linux_mac_7_fea_dim.json' 264 | os_diff_linux_mac_76_fea_dim = '/mnt/d/data/valid_pairs_v1/os_diff_linux_mac_76_fea_dim.json' 265 | # get_os_diff_linux_mac_pairs(fea_js_gemini, os_diff_linux_mac_7_fea_dim) 266 | # get_os_diff_linux_mac_pairs(fea_js_ghidra, os_diff_linux_mac_76_fea_dim) 267 | 268 | # subset1: os diff subset, mac win 269 | os_diff_win_mac_7_fea_dim = '/mnt/d/data/valid_pairs_v1/os_diff_win_mac_7_fea_dim.json' 270 | os_diff_win_mac_76_fea_dim = '/mnt/d/data/valid_pairs_v1/os_diff_win_mac_76_fea_dim.json' 271 | # get_os_diff_win_mac_pairs(fea_js_gemini, os_diff_win_mac_7_fea_dim) 272 | # get_os_diff_win_mac_pairs(fea_js_ghidra, os_diff_win_mac_76_fea_dim) 273 | 274 | # subset2: arch diff subset, arm, x86 275 | arch_diff_7_fea_dim = '/mnt/d/data/valid_pairs_v1/arch_diff_7_fea_dim.json' 276 | arch_diff_76_fea_dim = '/mnt/d/data/valid_pairs_v1/arch_diff_76_fea_dim.json' 277 | # get_arch_diff_pairs(fea_js_gemini, arch_diff_7_fea_dim) 278 | # get_arch_diff_pairs(fea_js_ghidra, arch_diff_76_fea_dim) 279 | 280 | # subset3: cc diff subset, linux clang , linux gcc 281 | cc_diff_7_fea_dim = '/mnt/d/data/valid_pairs_v1/cc_diff_7_fea_dim.json' 282 | cc_diff_76_fea_dim = '/mnt/d/data/valid_pairs_v1/cc_diff_76_fea_dim.json' 283 | # get_cc_diff_pairs(fea_js_gemini, cc_diff_7_fea_dim) 284 | # get_cc_diff_pairs(fea_js_ghidra, cc_diff_76_fea_dim) 285 | 286 | # subset4: optimization level diff 287 | max_oplevel_pairs_7_fea_dim = '/mnt/d/data/valid_pairs_v1/max_oplevel_pairs_7_fea_dim.json' 288 | max_oplevel_pairs_76_fea_dim = '/mnt/d/data/valid_pairs_v1/max_oplevel_pairs_76_fea_dim.json' 289 | # get_max_oplevel_pairs(fea_js_gemini, max_oplevel_pairs_7_fea_dim) 290 | # get_max_oplevel_pairs(fea_js_ghidra, max_oplevel_pairs_76_fea_dim) 291 | 292 | # subset5: max diff 293 | max_diff_pairs_7_fea_dim = '/mnt/d/data/valid_pairs_v1/max_diff_pairs_7_fea_dim.json' 294 | max_diff_pairs_76_fea_dim = '/mnt/d/data/valid_pairs_v1/max_diff_pairs_76_fea_dim.json' 295 | # get_max_diff_pairs(fea_js_gemini, max_diff_pairs_7_fea_dim) 296 | # get_max_diff_pairs(fea_js_ghidra, max_diff_pairs_76_fea_dim) 297 | 298 | 299 | ## cc version 300 | cc_version_diff_7_fea_dim = '/mnt/d/data/valid_pairs_v1/cc_version_diff_7_fea_dim.json' 301 | cc_version_diff_76_fea_dim = '/mnt/d/data/valid_pairs_v1/cc_version_diff_76_fea_dim.json' 302 | # get_cc_version_diff_pairs(fea_js_gemini, cc_version_diff_7_fea_dim) 303 | # get_cc_version_diff_pairs(fea_js_ghidra, cc_version_diff_76_fea_dim) 304 | 305 | 306 | oplevel_pairs_2_0_7_fea_dim = '/mnt/d/data/valid_pairs_v1/oplevel_pairs_2_0_7_fea_dim.json' 307 | 308 | a='../data/validation_pairs/valid_pairs_v1/76_fea_dim/' 309 | oplevel_pairs_2_0_76_fea_dim = './oplevel_pairs_2_0_76_fea_dim.json' 310 | # get_oplevel_pairs(fea_js_gemini, oplevel_pairs_2_0_7_fea_dim, 2, 0) 311 | get_oplevel_pairs(fea_js_ghidra, oplevel_pairs_2_0_76_fea_dim, 2, 0) 312 | 313 | oplevel_pairs_2_1_7_fea_dim = '/mnt/d/data/valid_pairs_v1/oplevel_pairs_2_1_7_fea_dim.json' 314 | oplevel_pairs_2_1_76_fea_dim = './oplevel_pairs_2_1_76_fea_dim.json' 315 | # get_oplevel_pairs(fea_js_gemini, oplevel_pairs_2_1_7_fea_dim, 2, 1) 316 | get_oplevel_pairs(fea_js_ghidra, oplevel_pairs_2_1_76_fea_dim, 2, 1) 317 | 318 | oplevel_pairs_2_3_7_fea_dim = '/mnt/d/data/valid_pairs_v1/oplevel_pairs_2_3_7_fea_dim.json' 319 | oplevel_pairs_2_3_76_fea_dim = './oplevel_pairs_2_3_76_fea_dim.json' 320 | # get_oplevel_pairs(fea_js_gemini, oplevel_pairs_2_3_7_fea_dim, 2, 3) 321 | get_oplevel_pairs(fea_js_ghidra, oplevel_pairs_2_3_76_fea_dim, 2, 3) -------------------------------------------------------------------------------- /main/torch/milvus_mod.py: -------------------------------------------------------------------------------- 1 | # This program demos how to connect to Milvus vector database, 2 | # create a vector collection, 3 | # insert 10 vectors, 4 | # and execute a vector similarity search. 5 | 6 | import random 7 | import numpy as np 8 | import pickle as pkl 9 | from milvus import Milvus, IndexType, MetricType, Status 10 | import os 11 | import json 12 | import time 13 | from tqdm import tqdm 14 | import math 15 | 16 | 17 | 18 | class mil(object): 19 | def __init__(self, dim=64) -> None: 20 | 21 | # Milvus server IP address and port. 22 | # You may need to change _HOST and _PORT accordingly. 23 | self._HOST = '127.0.0.1' 24 | self._PORT = '19530' # default value 25 | # _Ppip ORT = '19121' # default http value 26 | 27 | # Vector parameters 28 | self._DIM = dim # dimension of vector 29 | 30 | self._INDEX_FILE_SIZE = 4096 # max file size of stored index 31 | 32 | self.BATCH =10000 33 | self.milvus = Milvus(self._HOST, self._PORT) 34 | 35 | def get_count(self, collection_name): 36 | status, result = self.milvus.count_entities(collection_name) 37 | return result 38 | 39 | def delete_collection(self, collection_name): 40 | status, ok = self.milvus.has_collection(collection_name) 41 | if ok: 42 | self.milvus.drop_collection(collection_name) 43 | self.milvus.flush([collection_name]) 44 | 45 | 46 | def load_data(self, vec_l, id_l, collection_name, clear_old, metrictype): 47 | # Specify server addr when create milvus client instance 48 | # milvus client instance maintain a connection pool, param 49 | # `pool_size` specify the max connection num. 50 | 51 | status, ok = self.milvus.has_collection(collection_name) 52 | if ok and clear_old: 53 | self.milvus.drop_collection(collection_name) 54 | # print('dropout old collection') 55 | if not ok or (ok and clear_old): 56 | param = { 57 | 'collection_name': collection_name, 58 | 'dimension': self._DIM, 59 | 'index_file_size': self._INDEX_FILE_SIZE, # optional 60 | 'metric_type': metrictype # optional 61 | } 62 | self.milvus.create_collection(param) 63 | # print('collection created:', collection_name) 64 | 65 | status, result = self.milvus.count_entities(collection_name) 66 | old_count = result 67 | vectors = np.array(vec_l).astype(np.float32) 68 | length = len(vectors) 69 | begin = 0 70 | end = self.BATCH 71 | # print('begin to insert') 72 | # print('data length:', length) 73 | # print('start with old count:', old_count) 74 | loops = math.ceil(length / self.BATCH) 75 | for i in tqdm(range(loops)): 76 | begin = self.BATCH * i 77 | end = self.BATCH * (i+1) 78 | if begin >= length: 79 | break 80 | if end>=length: 81 | end = length 82 | status, ids = self.milvus.insert(collection_name=collection_name, records=vectors[begin:end], ids=id_l[begin:end]) 83 | if not status.OK(): 84 | print("Insert failed: {}".format(status)) 85 | break 86 | begin += self.BATCH 87 | end += self.BATCH 88 | 89 | # Flush collection inserted data to disk. 90 | self.milvus.flush([collection_name]) 91 | 92 | def load_partition_data(self, bin2vecs, collection_name, clear_old, metrictype): 93 | status, ok = self.milvus.has_collection(collection_name) 94 | if ok and clear_old: 95 | self.milvus.drop_collection(collection_name) 96 | print('dropout old collection') 97 | if not ok or (ok and clear_old): 98 | param = { 99 | 'collection_name': collection_name, 100 | 'dimension': self._DIM, 101 | 'index_file_size': self._INDEX_FILE_SIZE, # optional 102 | 'metric_type': metrictype # optional 103 | } 104 | self.milvus.create_collection(param) 105 | print('collection created:', collection_name) 106 | count = 0 107 | for bin_f in bin2vecs: 108 | count += 1 109 | if count % 10000 == 0: 110 | print(count) 111 | vecs = bin2vecs[bin_f]['vecs'] 112 | ids = bin2vecs[bin_f]['ids'] 113 | vectors = np.array(vecs).astype(np.float32) 114 | self.milvus.create_partition(collection_name, bin_f) 115 | status, ids = self.milvus.insert(collection_name=collection_name,records=vectors, ids=ids, partition_tag=bin_f) 116 | self.milvus.flush([collection_name]) 117 | 118 | def creat_index(self, collection_name): 119 | # Get demo_collection row count 120 | status, result = self.milvus.count_entities(collection_name) 121 | print(result) 122 | 123 | # present collection statistics info 124 | _, info = self.milvus.get_collection_stats(collection_name) 125 | print(info) 126 | 127 | # Obtain raw vectors by providing vector ids 128 | # status, result_vectors = milvus.get_entity_by_id(collection_name, ids[:10]) 129 | 130 | # create index of vectors, search more rapidly 131 | index_param = { 132 | 'nlist': 8192 133 | } 134 | 135 | # Create ivflat index in demo_collection 136 | # You can search vectors without creating index. however, Creating index help to 137 | # search faster 138 | print("Creating index: {}".format(index_param)) 139 | status = self.milvus.create_index(collection_name, IndexType.IVF_FLAT, index_param) 140 | 141 | # describe index, get information of index 142 | status, index = self.milvus.get_index_info(collection_name) 143 | print(index) 144 | print('creating index. Done.') 145 | 146 | 147 | def query(self, query_vectors, collection_name, k, partition_tags=None, nprobe=100): 148 | # Show collections in Milvus server 149 | _, collections = self.milvus.list_collections() 150 | 151 | # Describe demo_collection 152 | _, collection = self.milvus.get_collection_info(collection_name) 153 | 154 | 155 | # query_vectors = np.load(query_vectors_path) 156 | 157 | # execute vector similarity search 158 | search_param = { 159 | "nprobe": nprobe 160 | } 161 | 162 | param = { 163 | 'collection_name': collection_name, 164 | 'query_records': query_vectors, 165 | 'top_k': k, 166 | 'params': search_param, 167 | 'partition_tags': partition_tags 168 | } 169 | 170 | status, results = self.milvus.search(**param) 171 | if status.OK(): 172 | return results 173 | else: 174 | print("Search failed. ", status) 175 | return None 176 | 177 | 178 | 179 | def get_vector_by_id(self, collection_name, ids): 180 | status, vectors = self.milvus.get_entity_by_id(collection_name, ids) 181 | print(status) 182 | for vector in vectors: 183 | print(vector) 184 | 185 | def delete_entities_by_ids(self, ids, collection_name): 186 | print("to delete: ", len(ids)) 187 | begin = 0 188 | end = self.BATCH 189 | length = len(ids) 190 | while(True): 191 | if begin >= length: 192 | break 193 | if end>=length: 194 | end = length 195 | print(begin) 196 | self.milvus.delete_entity_by_id(collection_name, id_array=ids[begin:end]) 197 | begin += self.BATCH 198 | end += self.BATCH 199 | self.milvus.flush([collection_name]) 200 | # status, result = self.milvus.count_entities(collection_name) 201 | # print("left entities:", result) 202 | 203 | 204 | def recall_rate(labels, query_results): 205 | total_num = len(labels) 206 | recall_num = 0 207 | for i in range(len(labels)): 208 | recall_ids = [] 209 | for cand in query_results[i]: 210 | recall_ids.append(cand.id) 211 | if labels[i] in recall_ids: 212 | recall_num += 1 213 | return recall_num / total_num 214 | 215 | 216 | 217 | def get_true_pairs(js_path): 218 | true_pairs = [] 219 | with open(js_path) as load_f: 220 | for line in load_f: 221 | pair = json.loads(line.strip()) 222 | true_pairs.append(pair) 223 | return true_pairs 224 | 225 | 226 | 227 | 228 | def query_all_valid_data(m, collection_name, target_dir, k): 229 | target_vec_pkls = os.listdir(target_dir) 230 | recalls = {} 231 | # for target in target_vec_pkls: 232 | for target_js in target_vec_pkls: 233 | if not target_js.endswith('.json'): 234 | continue 235 | # if target.startswith('id2'): 236 | # continue 237 | id_l = [] 238 | target_vec_r = [] 239 | # with open(os.path.join(target_dir, target), 'rb') as load_f: 240 | # true_pairs = pkl.load(load_f) 241 | true_pairs = get_true_pairs(os.path.join(target_dir, target_js)) 242 | for pair in true_pairs: 243 | id_l.append(pair[0]) 244 | target_vec_r.append(pair[1]) 245 | print(target_js) 246 | print("query length: ", len(target_vec_r)) 247 | start = time.time() 248 | results = m.query(np.array(target_vec_r), collection_name, k) 249 | if results: 250 | recall = recall_rate(id_l, results) 251 | recalls[target_js] = [str(round(recall*100, 2)), str(round(time.time()-start, 2)), str(len(target_vec_r))] 252 | return recalls 253 | 254 | 255 | def query_tasks_topk(m, collection_name, k, target_dir, dim76 = False): 256 | count = m.get_count(collection_name) 257 | print("count in {0}: {1}".format(collection_name, count)) 258 | 259 | recalls_7fea = query_all_valid_data(m, collection_name, target_dir, k) 260 | return recalls_7fea 261 | 262 | 263 | 264 | 265 | def print_table(collection_name, target_dir, m): 266 | # recalls_7fea_10, recalls_76fea_10, recalls_76fea_depth5_10 = query_tasks_topk(10, True) 267 | recalls_7fea_10 = query_tasks_topk(m, collection_name, 10, target_dir, False) 268 | recalls_7fea_20 = query_tasks_topk(m, collection_name, 20, target_dir, False) 269 | recalls_7fea_50 = query_tasks_topk(m, collection_name, 50, target_dir, False) 270 | recalls_7fea_100 = query_tasks_topk(m, collection_name, 100, target_dir, False) 271 | res_topk = {} 272 | print(recalls_7fea_10) 273 | print(recalls_7fea_20) 274 | print(recalls_7fea_50) 275 | print(recalls_7fea_100) 276 | 277 | print("top k table") 278 | for i in recalls_7fea_10: 279 | res_topk[i] = '{0} & {1}& {2}& {3}& {4}'.format(i, recalls_7fea_10[i][0], recalls_7fea_20[i][0], recalls_7fea_50[i][0], recalls_7fea_100[i][0]) 280 | print(res_topk[i]) 281 | 282 | # print("top 10 table") 283 | # for i in recalls_7fea_10: 284 | # if i.startswith('os_diff_linux'): 285 | # prefix = i[:15] 286 | # elif i.startswith('oplevel_pairs_2'): 287 | # prefix = i[:17] 288 | # else: 289 | # prefix = i[:10] 290 | # r = '& ' + recalls_7fea_10[i][0] +'& ' 291 | # for j in recalls_76fea_10: 292 | # if j.startswith(prefix): 293 | # r+= recalls_76fea_10[j][0]+'& ' 294 | # right_j = j 295 | # for z in recalls_76fea_depth5_10: 296 | # if z.startswith(prefix): 297 | # right_z = z 298 | # r+= recalls_76fea_depth5_10[z][0]+'& ' 299 | # r += recalls_7fea_10[i][1] +'& ' + recalls_76fea_10[right_j][1]+'& ' + recalls_76fea_depth5_10[right_z][1] 300 | # print(prefix, ' ', r) 301 | print('---------------------------------') 302 | 303 | 304 | # def get_lib_cand(target, collection_name, k): 305 | 306 | 307 | 308 | if __name__ == '__main__': 309 | collection_name = 'filtered_database' 310 | # collection_name = '_7fea_contra_torch_init' 311 | # id2vec_7fea = '../data/7fea_contra_tf/core_funcs/id2vec.pkl' 312 | # with open(id2vec_7fea, 'rb') as load_f: 313 | # id2vecs = pkl.load(load_f) 314 | # vecs = [] 315 | # ids = [] 316 | # for id in id2vecs: 317 | # ids.append(id) 318 | # vecs.append(id2vecs[id]) 319 | m = mil() 320 | m.milvus.list_collections() 321 | # m.load_data(np.array(list(id2vecs.values())), list(id2vecs.keys()), collection_name, clear_old=True, metrictype=MetricType.IP) 322 | 323 | id2vec_7fea_valid = '../data/validation_pairs/valid_pairs_v1/id_vec/id_vec_7fea/id2vec.pkl' 324 | with open(id2vec_7fea_valid, 'rb') as load_f: 325 | id2vecs = pkl.load(load_f) 326 | # m.load_data(np.array(list(id2vecs.values())), list(id2vecs.keys()), collection_name, clear_old = False, metrictype=MetricType.IP) 327 | 328 | # delete 329 | # m.delete_entities_by_ids(list(id2vecs.keys()), collection_name) 330 | 331 | # id2vec_l2tol013_7fea_valid = '../data/validation_pairs/valid_pairs_v1/id_vec/id_vec_7fea/id2vec_l2tol013.pkl' 332 | # with open(id2vec_l2tol013_7fea_valid, 'rb') as load_f: 333 | # id2vecs = pkl.load(load_f) 334 | # m.load_data(list(id2vecs.values()), list(id2vecs.keys()), collection_name, clear_old = False, metrictype=MetricType.IP) 335 | 336 | 337 | # query 338 | collection_name = 'filtered_database' 339 | target_dir = '../data/validation_pairs/valid_pairs_v1/id_vec/id_vec_7fea' 340 | # m = mil() 341 | # print_table(collection_name, target_dir, m) 342 | 343 | 344 | -------------------------------------------------------------------------------- /main/torch/run.sh: -------------------------------------------------------------------------------- 1 | out_dir=saved_model 2 | mkdir -p $out_dir 3 | 4 | train_valid_dir=~/sci2/user/data/func_comparison/vector_deduplicate_our_format_less_compilation_cases/train_test 5 | 6 | test_data_dir=~/sci2/user/data/func_comparison/vector_deduplicate_our_format_less_compilation_cases/valid 7 | 8 | python torch_main.py --fea_dim 7 --save_path $out_dir --train_valid $train_valid_dir --test_data $test_data_dir -------------------------------------------------------------------------------- /main/torch/torch_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torch.utils.data import DataLoader 4 | from torch.autograd import Variable 5 | #import matplotlib.pyplot as plt 6 | import numpy as np 7 | import os 8 | import argparse 9 | import time 10 | import pickle 11 | from datetime import datetime 12 | import random 13 | from sklearn.metrics import auc, roc_curve 14 | from multiprocessing import Pool 15 | 16 | from torch_model import graphnn 17 | from dataset import dataset 18 | from utils_loss import get_f_name, get_f_dict, read_graph, partition_data 19 | 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('--device', type=str, default='0', 23 | help='visible gpu device') 24 | parser.add_argument('--workers', type=int, default=12, 25 | help='workers') 26 | parser.add_argument('--use_device', type=str, default='/gpu:0', 27 | help='used gpu device') 28 | parser.add_argument('--fea_dim', type=int, default=7, 29 | help='feature dimension') 30 | parser.add_argument('--embed_dim', type=int, default=64, 31 | help='embedding dimension') 32 | parser.add_argument('--embed_depth', type=int, default=2, 33 | help='embedding network depth') 34 | parser.add_argument('--output_dim', type=int, default=64, 35 | help='output layer dimension') 36 | parser.add_argument('--iter_level', type=int, default=5, 37 | help='iteration times') 38 | parser.add_argument('--lr', type=float, default=1e-4, 39 | help='learning rate') 40 | parser.add_argument('--epoch', type=int, default=150, 41 | help='epoch number') 42 | parser.add_argument('--batch_size', type=int, default=128, 43 | help='batch size') 44 | parser.add_argument('--neg_batch_size', type=int, default=128, 45 | help='negative batch size') 46 | parser.add_argument('--load_path', type=str, default=None, 47 | help='path for model loading, "#LATEST#" for the latest checkpoint') 48 | parser.add_argument('--save_path', type=str, 49 | default='../data/7fea_contra_torch_epo150_gap0.5/saved_model', help='path for model saving') 50 | parser.add_argument('--log_path', type=str, default=None, 51 | help='path for training log') 52 | parser.add_argument('--seed', type=int, default=1234) 53 | parser.add_argument('--train_valid', type=str, 54 | default='../data/func_comparison/vector_deduplicate_gemini_format_less_compilation_cases/train_test', help='path for train_valid data') 55 | parser.add_argument('--test_data', type=str, 56 | default='../data/func_comparison/vector_deduplicate_gemini_format_less_compilation_cases/valid', help='path for test data') 57 | 58 | 59 | def contra_loss_show(net, dataLoader, DEVICE): 60 | loss_val = [] 61 | tot_cos = [] 62 | tot_truth = [] 63 | for batch_id, (X1, X2, X3, m1, m2, m3) in enumerate(dataLoader, 1): 64 | X1, X2, X3, m1, m2, m3 = X1[0], X2[0], X3[0], m1[0], m2[0], m3[0] 65 | if 'gpu' in DEVICE: 66 | X1, X2, X3, m1, m2, m3 = X1.cuda(non_blocking=True), X2.cuda(non_blocking=True), X3.cuda( 67 | non_blocking=True), m1.cuda(non_blocking=True), m2.cuda(non_blocking=True), m3.cuda(non_blocking=True) 68 | # else: 69 | # X1, X2, X3, m1, m2, m3 = Variable(X1), Variable(X2), Variable( 70 | # X3), Variable(m1), Variable(m2), Variable(m3) 71 | loss, cos_p, cos_n = net.forward(X1, X2, X3, m1, m2, m3) 72 | cos_p = list(cos_p.cpu().detach().numpy()) 73 | cos_n = list(cos_n.cpu().detach().numpy()) 74 | tot_cos += cos_p 75 | tot_truth += [1]*len(cos_p) 76 | tot_cos += cos_n 77 | tot_truth += [-1]*len(cos_n) 78 | loss_val.append(loss.item()) 79 | cos = np.array(tot_cos) 80 | truth = np.array(tot_truth) 81 | 82 | fpr, tpr, thres = roc_curve(truth, (1+cos)/2) 83 | model_auc = auc(fpr, tpr) 84 | return loss_val, model_auc, tpr 85 | 86 | 87 | if __name__ == '__main__': 88 | args = parser.parse_args() 89 | print("=================================") 90 | print(args) 91 | print("=================================") 92 | 93 | np.random.seed(args.seed) 94 | torch.manual_seed(args.seed) 95 | torch.cuda.manual_seed(args.seed) 96 | random.seed(args.seed) 97 | 98 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 99 | NODE_FEATURE_DIM = args.fea_dim 100 | EMBED_DIM = args.embed_dim 101 | EMBED_DEPTH = args.embed_depth 102 | OUTPUT_DIM = args.output_dim 103 | ITERATION_LEVEL = args.iter_level 104 | LEARNING_RATE = args.lr 105 | MAX_EPOCH = args.epoch 106 | BATCH_SIZE = args.batch_size 107 | NEG_BATCH_SIZE = args.neg_batch_size 108 | LOAD_PATH = args.load_path 109 | SAVE_PATH = args.save_path 110 | LOG_PATH = args.log_path 111 | DEVICE = args.use_device 112 | WORKERS = args.workers 113 | DDATA_FILE_NAME_TRAIN_VALID = args.train_valid 114 | DATA_FILE_NAME_TEST = args.test_data 115 | 116 | SHOW_FREQ = 1 117 | TEST_FREQ = 1 118 | SAVE_FREQ = 5 119 | 120 | if not os.path.exists(SAVE_PATH): 121 | os.makedirs(SAVE_PATH) 122 | 123 | F_PATH_TRAIN_VALID = get_f_name(DDATA_FILE_NAME_TRAIN_VALID) 124 | FUNC_NAME_DICT_TRAIN_VALID = get_f_dict(F_PATH_TRAIN_VALID) 125 | F_PATH_TEST = get_f_name(DATA_FILE_NAME_TEST) 126 | FUNC_NAME_DICT_TEST = get_f_dict(F_PATH_TEST) 127 | print("start reading data") 128 | Gs_train_valid, classes_train_valid = read_graph( 129 | F_PATH_TRAIN_VALID, FUNC_NAME_DICT_TRAIN_VALID, NODE_FEATURE_DIM) 130 | print("train and test ---- 8:2") 131 | print("{} graphs, {} functions".format( 132 | len(Gs_train_valid), len(classes_train_valid))) 133 | 134 | perm = np.random.permutation(len(classes_train_valid)) 135 | Gs_train, classes_train, Gs_valid, classes_valid =\ 136 | partition_data(Gs_train_valid, classes_train_valid, [0.8, 0.2], perm) 137 | print("Train: {} graphs, {} functions".format( 138 | len(Gs_train), len(classes_train))) 139 | print("Valid: {} graphs, {} functions".format( 140 | len(Gs_valid), len(classes_valid))) 141 | 142 | print("Test") 143 | Gs_test, classes_test = read_graph( 144 | F_PATH_TEST, FUNC_NAME_DICT_TEST, NODE_FEATURE_DIM) 145 | print("{} graphs, {} functions".format(len(Gs_test), len(classes_test))) 146 | Gs_test, classes_test = partition_data( 147 | Gs_test, classes_test, [1], list(range(len(classes_test)))) 148 | 149 | trainSet = dataset(Gs_train, classes_train, BATCH_SIZE, 150 | NEG_BATCH_SIZE, neg_batch_flag=False, train=True) 151 | validSet = dataset(Gs_valid, classes_valid, BATCH_SIZE, 152 | NEG_BATCH_SIZE, neg_batch_flag=False, train=True) 153 | testSet = dataset(Gs_test, classes_test, BATCH_SIZE, 154 | NEG_BATCH_SIZE, neg_batch_flag=False, train=True) 155 | 156 | trainLoader = DataLoader( 157 | trainSet, batch_size=1, shuffle=False, num_workers=WORKERS, pin_memory=True) 158 | validLoader = DataLoader( 159 | validSet, batch_size=1, shuffle=False, num_workers=WORKERS, pin_memory=True) 160 | 161 | testLoader = DataLoader(testSet, batch_size=1, 162 | shuffle=False, num_workers=WORKERS, pin_memory=True) 163 | 164 | net = graphnn(NODE_FEATURE_DIM, EMBED_DIM, OUTPUT_DIM, 165 | EMBED_DEPTH, ITERATION_LEVEL) 166 | 167 | net = net.cuda() 168 | 169 | optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE) 170 | optimizer.zero_grad() 171 | 172 | train_loss = [] 173 | time_start = time.time() 174 | 175 | best_loss = 99999999 176 | best_auc = 0 177 | 178 | for i in range(1, MAX_EPOCH+1): 179 | trainSet.shuffle() 180 | trainLoader = DataLoader( 181 | trainSet, batch_size=1, shuffle=False, num_workers=WORKERS) 182 | loss_val = [] 183 | tot_cos = [] 184 | tot_truth = [] 185 | time_start = time.time() 186 | net.train() 187 | p_n_gap = [] 188 | for batch_id, (X1, X2, X3, m1, m2, m3) in enumerate(trainLoader, 1): 189 | X1, X2, X3, m1, m2, m3 = X1[0], X2[0], X3[0], m1[0], m2[0], m3[0] 190 | if 'gpu' in DEVICE: 191 | X1, X2, X3, m1, m2, m3 = X1.cuda(non_blocking=True), X2.cuda(non_blocking=True), X3.cuda( 192 | non_blocking=True), m1.cuda(non_blocking=True), m2.cuda(non_blocking=True), m3.cuda(non_blocking=True) 193 | # else: 194 | # X1, X2, X3, m1, m2, m3 = X1, Variable(X2), Variable( 195 | # X3), Variable(m1), Variable(m2), Variable(m3) 196 | loss, cos_p, cos_n = net.forward(X1, X2, X3, m1, m2, m3) 197 | cos_p = cos_p.cpu().detach().numpy() 198 | cos_n = cos_n.cpu().detach().numpy() 199 | p_n_gap.append(np.mean(cos_p - cos_n)) 200 | cos_p = list(cos_p) 201 | cos_n = list(cos_n) 202 | tot_cos += cos_p 203 | tot_truth += [1]*len(cos_p) 204 | tot_cos += cos_n 205 | tot_truth += [-1]*len(cos_n) 206 | loss_val.append(loss.item()) 207 | optimizer.zero_grad() 208 | loss.backward() 209 | optimizer.step() 210 | cos = np.array(tot_cos) 211 | truth = np.array(tot_truth) 212 | 213 | fpr, tpr, thres = roc_curve(truth, (1+cos)/2) 214 | model_auc = auc(fpr, tpr) 215 | print('Epoch: [%d]\tloss:%.4f\tp_n_gap:%.4f\tauc:%.4f\t@%s\ttime lapsed:\t%.2f s' % 216 | (i, np.mean(loss_val), np.mean(p_n_gap), model_auc, datetime.now(), time.time() - time_start)) 217 | if i % SHOW_FREQ == 0: 218 | net.eval() 219 | with torch.no_grad(): 220 | time_start = time.time() 221 | loss_val, model_auc, tpr = contra_loss_show( 222 | net, validLoader, DEVICE) 223 | print('Valid: [%d]\tloss:%.4f\tauc:%.4f\t@%s\ttime lapsed:\t%.2f s' % 224 | (i, np.mean(loss_val), model_auc, datetime.now(), time.time() - time_start)) 225 | 226 | time_start = time.time() 227 | loss_test, test_model_auc, tpr = contra_loss_show( 228 | net, testLoader, DEVICE) 229 | print("#"*70) 230 | print('Test: [%d]\tloss:%.4f\tauc:%.4f\t@%s\ttime lapsed:\t%.2f s' % 231 | (i, np.mean(loss_test), test_model_auc, datetime.now(), time.time() - time_start)) 232 | print("#"*70) 233 | time_start = time.time() 234 | train_loss.append(np.mean(loss_test)) 235 | 236 | # if model_auc > best_auc: 237 | # torch.save(net, SAVE_PATH + "/model-inter-best.pt") 238 | 239 | if np.mean(loss_val) < best_loss: 240 | torch.save(net, SAVE_PATH + "/model-inter-best.pt") 241 | 242 | if i % SAVE_FREQ == 0: 243 | torch.save(net, SAVE_PATH + 244 | '/model-inter-' + str(i+1) + ".pt") 245 | 246 | # learning_rate = learning_rate * 0.95 247 | 248 | with open('train_loss', 'wb') as f: 249 | pickle.dump(train_loss, f) 250 | 251 | from core_fedora_embeddings import * 252 | with Pool(10) as p: 253 | p.starmap(core_fedora_embedding, [(i, True) for i in range(10)]) 254 | valid_embedding_pairs(True) 255 | 256 | -------------------------------------------------------------------------------- /main/torch/torch_model.py: -------------------------------------------------------------------------------- 1 | from numpy.core.fromnumeric import repeat 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | import numpy as np 8 | 9 | def truncated_normal_(tensor,mean=0,std=0.1): 10 | with torch.no_grad(): 11 | size = tensor.shape 12 | tmp = tensor.new_empty(size+(4,)).normal_() 13 | valid = (tmp < 2) & (tmp > -2) 14 | ind = valid.max(-1, keepdim=True)[1] 15 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 16 | tensor.data.mul_(std).add_(mean) 17 | return tensor 18 | 19 | class graphnn(nn.Module): 20 | 21 | def __init__(self, N_x_feadim7, N_embed_outdim64, N_o, Wembed_depth, iter_level): 22 | super(graphnn, self).__init__() 23 | self.zero = torch.tensor(0.0).cuda() 24 | self.N_x_feadim7 = N_x_feadim7 25 | self.N_embed_outdim64 = N_embed_outdim64 26 | self.N_o = N_o 27 | self.Wembed_depth = Wembed_depth 28 | self.iter_level = iter_level 29 | 30 | self.node_val = nn.Linear(self.N_x_feadim7, self.N_embed_outdim64, False) 31 | truncated_normal_(self.node_val.weight) 32 | self.Wembed0 = nn.Linear(self.N_embed_outdim64, self.N_embed_outdim64, False) 33 | truncated_normal_(self.Wembed0.weight) 34 | self.Wembed1 = nn.Linear(self.N_embed_outdim64, self.N_embed_outdim64, False) 35 | truncated_normal_(self.Wembed1.weight) 36 | # for i in range(self.Wembed_depth): 37 | 38 | # self.Wembed.append() 39 | self.re = nn.ReLU(inplace=True) 40 | self.ta = nn.Tanh() 41 | 42 | self.cos1 = nn.CosineSimilarity(dim=1, eps=1e-10) 43 | self.cos2 = nn.CosineSimilarity(dim=2, eps=1e-10) 44 | 45 | self.out = nn.Linear(self.N_embed_outdim64, self.N_o) 46 | truncated_normal_(self.out.weight) 47 | torch.nn.init.constant_(self.out.bias, 0) 48 | 49 | def predict(self, x, msg_mask): 50 | return self.forward_once(x, msg_mask) 51 | 52 | def forward_once(self, x, msg_mask): 53 | node_embed = torch.reshape(self.node_val(torch.reshape(x, [-1, self.N_x_feadim7])), [x.shape[0], -1, self.N_embed_outdim64]) 54 | cur_msg = self.re(node_embed) 55 | for t in range(self.iter_level): 56 | Li_t = torch.matmul(msg_mask, cur_msg) 57 | cur_info = torch.reshape(Li_t, [-1, self.N_embed_outdim64]) 58 | # for i in range(self.Wembed_depth-1): 59 | 60 | cur_info = self.re(self.Wembed0(cur_info)) 61 | cur_info = self.Wembed1(cur_info) 62 | 63 | neigh_val_t = torch.reshape(cur_info, Li_t.shape) 64 | tot_val_t = node_embed + neigh_val_t 65 | tot_msg_t = self.ta(tot_val_t) 66 | cur_msg = tot_msg_t 67 | g_embed = torch.sum(cur_msg, 1) 68 | output = self.out(g_embed) 69 | return output 70 | 71 | def forward(self, X1, X2, X3, m1, m2, m3): 72 | embed1 = self.forward_once(X1, m1) 73 | embed2 = self.forward_once(X2, m2) 74 | embed3 = self.forward_once(X3, m3) 75 | 76 | 77 | # triple l2 distance, neg batch N, N>batch size 78 | # dist_p = torch.sum((embed1-embed2) ** 2, 1) 79 | # dist_n = torch.sum((embed1.reshape(embed1.shape[0], 1, embed1.shape[1]) - embed3).reshape(embed1.shape[0]*embed3.shape[0], embed1.shape[1]) ** 2, 1) 80 | # all_loss = torch.maximum(dist_p - torch.min(dist_n.reshape(embed1.shape[0], embed3.shape[0]), 1).values + 0.5, torch.tensor(0.0).cuda()) 81 | 82 | # old method to delete, neg batch N, N>batch size 83 | # dist_n = torch.sum((torch.repeat_interleave(embed1, embed3.shape[0], 0)- embed3.repeat(embed1.shape[0], 1)) ** 2, 1) 84 | # all_loss = torch.max(torch.reshape(torch.maximum(torch.repeat_interleave(dist_p, embed3.shape[0], axis=0) - dist_n + 0.5, torch.tensor(0.0).cuda()), [embed1.shape[0], -1]), 1).values 85 | 86 | # triple cos similarity, neg batch N, N>batch size 87 | # cos_dist_p = self.cos1(embed1, embed2) 88 | # cos_dist_n = self.cos2(embed1.reshape(embed1.shape[0], 1, embed1.shape[1]), embed3.reshape(1, embed3.shape[0], embed3.shape[1])) 89 | # all_loss = torch.maximum(torch.max(cos_dist_n, 1).values - cos_dist_p + 0.1, self.zero) 90 | # loss = torch.mean(all_loss) 91 | # return loss, cos_dist_p, torch.max(cos_dist_n, 1).values 92 | 93 | # triple cos similarity, neg batch N, N>batch size, mean cos_dist_n 94 | # cos_dist_p = self.cos1(embed1, embed2) 95 | # cos_dist_n = self.cos2(embed1.reshape(embed1.shape[0], 1, embed1.shape[1]), embed3.reshape(1, embed3.shape[0], embed3.shape[1])) 96 | # all_loss = torch.maximum((cos_dist_n - cos_dist_p.reshape(cos_dist_p.shape[0], 1)).reshape(-1) + 0.5, self.zero) 97 | # loss = torch.mean(all_loss) 98 | # return loss, cos_dist_p, self.cos1(embed1, embed3) 99 | 100 | # triple cos sim, non-neg_batch 101 | # cos_dist_p = self.cos1(embed1, embed2) 102 | # cos_dist_n = self.cos1(embed1, embed3) 103 | # all_loss = torch.maximum(cos_dist_n - cos_dist_p + 0.5, self.zero) 104 | # loss = torch.mean(all_loss) 105 | 106 | # contrastive loss 107 | cos_dist_p = self.cos1(embed1, embed2) 108 | cos_dist_n = self.cos1(embed1, embed3) 109 | loss = (torch.mean((1-cos_dist_p)**2) + torch.mean((cos_dist_n+1)**2)) / 2 110 | return loss, cos_dist_p, cos_dist_n 111 | -------------------------------------------------------------------------------- /main/torch/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import pickle as pkl 4 | 5 | 6 | def norm_vec(vec): 7 | return vec/np.sqrt(sum(vec**2)) 8 | 9 | def read_pkl(pkl_path): 10 | with open(pkl_path, 'rb') as f: 11 | content = pkl.load(f) 12 | return content 13 | 14 | 15 | def save_pkl(content, save_path): 16 | with open(save_path, 'wb') as f: 17 | pkl.dump(content, f) 18 | 19 | 20 | def read_json(js_path): 21 | with open(js_path, 'r') as f: 22 | content = json.load(f) 23 | return content 24 | 25 | 26 | def save_json(content, js_path): 27 | with open(js_path, 'w') as f: 28 | json.dump(content, f) -------------------------------------------------------------------------------- /main/torch/utils_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import auc, roc_curve 3 | import json 4 | import os 5 | from collections import deque 6 | import time 7 | 8 | from multiprocessing import Process,Queue 9 | 10 | 11 | def get_f_name(DATA, SF, CM, OP, VS): 12 | F_NAME = [] 13 | for sf in SF: 14 | for cm in CM: 15 | for op in OP: 16 | for vs in VS: 17 | F_NAME.append(DATA+sf+cm+op+vs+".json") 18 | return F_NAME 19 | 20 | 21 | def get_f_name(DATA): 22 | F_PATH = [] 23 | for f_name in os.listdir(DATA): 24 | if f_name.startswith('arm_x86') or f_name.startswith('linux_gcc_5') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_6') or f_name.startswith('linux_gcc_7') or f_name.startswith('linux_gcc_8') or f_name.startswith('mac_gcc_8'): 25 | continue 26 | # if not f_name.startswith('mac_gcc_8_O'): 27 | # continue 28 | F_PATH.append(os.path.join(DATA, f_name)) 29 | return F_PATH 30 | 31 | 32 | def get_f_dict(F_NAME): 33 | name_num = 0 34 | name_dict = {} 35 | for f_name in F_NAME: 36 | with open(f_name) as inf: 37 | for line in inf: 38 | g_info = json.loads(line.strip()) 39 | if (g_info['fname'] not in name_dict): 40 | name_dict[g_info['fname']] = name_num 41 | name_num += 1 42 | return name_dict 43 | 44 | 45 | class graph(object): 46 | def __init__(self, node_num=0, label=None, name=None): 47 | self.node_num = node_num 48 | self.label = label 49 | self.name = name 50 | self.features = [] 51 | self.succss = [] 52 | self.preds = [] 53 | if (node_num > 0): 54 | for i in range(node_num): 55 | self.features.append([]) 56 | self.succss.append([]) 57 | self.preds.append([]) 58 | 59 | def add_node(self, feature=[]): 60 | self.node_num += 1 61 | self.features.append(feature) 62 | self.succss.append([]) 63 | self.preds.append([]) 64 | 65 | def add_edge(self, u, v): 66 | self.succss[u].append(v) 67 | self.preds[v].append(u) 68 | 69 | def toString(self): 70 | ret = '{} {}\n'.format(self.node_num, self.label) 71 | for u in range(self.node_num): 72 | for fea in self.features[u]: 73 | ret += '{} '.format(fea) 74 | ret += str(len(self.succss[u])) 75 | for succ in self.succss[u]: 76 | ret += ' {}'.format(succ) 77 | ret += '\n' 78 | return ret 79 | 80 | 81 | def read_graph(F_NAME, FUNC_NAME_DICT, FEATURE_DIM): 82 | graphs = [] 83 | classes = [] 84 | if FUNC_NAME_DICT != None: 85 | for f in range(len(FUNC_NAME_DICT)): 86 | classes.append([]) 87 | 88 | for f_name in F_NAME: 89 | with open(f_name) as inf: 90 | for line in inf: 91 | g_info = json.loads(line.strip()) 92 | label = FUNC_NAME_DICT[g_info['fname']] 93 | classes[label].append(len(graphs)) 94 | cur_graph = graph(g_info['n_num'], label, g_info['src']) 95 | for u in range(g_info['n_num']): 96 | cur_graph.features[u] = np.array(g_info['features'][u]) 97 | for v in g_info['succss'][u]: 98 | cur_graph.add_edge(u, v) 99 | graphs.append(cur_graph) 100 | return graphs, classes 101 | 102 | 103 | def partition_data(Gs, classes, partitions, perm): 104 | C = len(classes) 105 | st = 0.0 106 | ret = [] 107 | for part in partitions: 108 | cur_g = [] 109 | cur_c = [] 110 | ed = st + part * C 111 | for cls in range(int(st), int(ed)): 112 | prev_class = classes[perm[cls]] 113 | cur_c.append([]) 114 | for i in range(len(prev_class)): 115 | cur_g.append(Gs[prev_class[i]]) 116 | cur_g[-1].label = len(cur_c)-1 117 | cur_c[-1].append(len(cur_g)-1) 118 | 119 | ret.append(cur_g) 120 | ret.append(cur_c) 121 | st = ed 122 | 123 | return ret 124 | 125 | 126 | def generate_epoch_pair(Gs, classes, M, train, output_id=False, load_id=None): 127 | epoch_data = [] 128 | if train: 129 | perm = np.random.permutation(len(Gs)) 130 | else: 131 | perm = range(len(Gs)) 132 | st = 0 133 | while st < len(Gs): 134 | X1, X2, X3, m1, m2, m3 = generate_batch_pairs(Gs, classes, M, st, perm) 135 | epoch_data.append((X1, X2, X3, m1, m2, m3)) 136 | st += M 137 | 138 | return epoch_data 139 | 140 | 141 | def generate_batch_pairs(Gs, classes, M, st, perm): 142 | X1, X2, X3, m1, m2, m3 = get_pair(Gs, classes, M, st=st, perm=perm) 143 | return X1, X2, X3, m1, m2, m3 144 | 145 | 146 | def get_pair(Gs, classes, M, st, perm, output_id=False, load_id=None): 147 | if load_id is None: 148 | C = len(classes) 149 | if (st + M > len(perm)): 150 | M = len(perm) - st 151 | ed = st + M 152 | triple_ids = [] # [(G_0, G_p, G_n)] 153 | p_funcs = [] 154 | true_pairs = [] 155 | n_ids = [] 156 | 157 | for g_id in perm[st:ed]: 158 | g0 = Gs[g_id] 159 | cls = g0.label 160 | p_funcs.append(cls) 161 | tot_g = len(classes[cls]) 162 | if (len(classes[cls]) >= 2): 163 | p_id = classes[cls][np.random.randint(tot_g)] 164 | while g_id == p_id: 165 | p_id = classes[cls][np.random.randint(tot_g)] 166 | true_pairs.append((g_id, p_id)) 167 | else: 168 | triple_ids = load_id[0] 169 | 170 | M = len(true_pairs) 171 | neg_batch = M 172 | for i in range(neg_batch): 173 | n_cls = np.random.randint(C) 174 | while (len(classes[n_cls]) == 0) or (n_cls in p_funcs): 175 | n_cls = np.random.randint(C) 176 | tot_g2 = len(classes[n_cls]) 177 | n_id = classes[n_cls][np.random.randint(tot_g2)] 178 | n_ids.append(n_id) 179 | maxN1 = 0 180 | maxN2 = 0 181 | maxN3 = 0 182 | for pair in true_pairs: 183 | maxN1 = max(maxN1, Gs[pair[0]].node_num) 184 | maxN2 = max(maxN2, Gs[pair[1]].node_num) 185 | for id in n_ids: 186 | maxN3 = max(maxN3, Gs[id].node_num) 187 | feature_dim = len(Gs[0].features[0]) 188 | X1_input = np.zeros((M, maxN1, feature_dim)) 189 | X2_input = np.zeros((M, maxN2, feature_dim)) 190 | X3_input = np.zeros((neg_batch, maxN3, feature_dim)) 191 | node1_mask = np.zeros((M, maxN1, maxN1)) 192 | node2_mask = np.zeros((M, maxN2, maxN2)) 193 | node3_mask = np.zeros((neg_batch, maxN3, maxN3)) 194 | 195 | for i in range(len(true_pairs)): 196 | g1 = Gs[true_pairs[i][0]] 197 | g2 = Gs[true_pairs[i][1]] 198 | 199 | for u in range(g1.node_num): 200 | X1_input[i, u, :] = np.array(g1.features[u]) 201 | for v in g1.succss[u]: 202 | node1_mask[i, u, v] = 1 203 | for u in range(g2.node_num): 204 | X2_input[i, u, :] = np.array(g2.features[u]) 205 | for v in g2.succss[u]: 206 | node2_mask[i, u, v] = 1 207 | 208 | for i in range(len(n_ids)): 209 | g3 = Gs[n_ids[i]] 210 | for u in range(g3.node_num): 211 | X3_input[i, u, :] = np.array(g3.features[u]) 212 | for v in g3.succss[u]: 213 | node3_mask[i, u, v] = 1 214 | if output_id: 215 | return X1_input, X2_input, X3_input, node1_mask, node2_mask, node3_mask, triple_ids 216 | else: 217 | return X1_input, X2_input, X3_input, node1_mask, node2_mask, node3_mask 218 | 219 | def f(queue,i_l,i_h,graphs, classes, batch_size, perm): 220 | l = i_h - i_l 221 | for i in range(int(l/batch_size)): 222 | t = get_pair(graphs, classes, batch_size, st=i_l+i*batch_size, perm=perm) 223 | queue.put(t) 224 | 225 | def train_epoch(model, graphs, classes, batch_size, load_data=None): 226 | count = 0 227 | cum_loss = 0.0 228 | 229 | perm = np.random.permutation(len(graphs)) 230 | st = 0 231 | while(st + batch_size < len(graphs)): 232 | X1, X2, X3, m1, m2, m3 = generate_batch_pairs( 233 | graphs, classes, batch_size, st, perm) 234 | st += batch_size 235 | if len(X1) == 0: 236 | continue 237 | loss = model.train(X1, X2, X3, m1, m2, m3) 238 | cum_loss += loss 239 | count += 1 240 | 241 | # queue = Queue(maxsize=10) 242 | # Process_num=6 243 | # for i in range(Process_num): 244 | # print(i,'start') 245 | # ii = int((len(graphs) - batch_size)/Process_num) 246 | # t = Process(target=f,args=(queue, i*ii,(i+1)*ii, graphs, classes, batch_size, perm)) 247 | # t.start() 248 | # print(int(len(graphs) /batch_size)) 249 | # for j in range(int(len(graphs) /batch_size)): 250 | # print(j) 251 | # t=queue.get() 252 | # print("q size:", queue.qsize()) 253 | # if len(t[0]) == 0: 254 | # continue 255 | # loss = model.train(t[0], t[1], t[2], t[3], t[4], t[5]) 256 | # cum_loss += loss 257 | # count += 1 258 | return cum_loss / count 259 | 260 | 261 | def get_loss(model, graphs, classes, batch_size, load_data=None): 262 | count = 0 263 | cum_loss = 0.0 264 | perm = range(len(graphs)) 265 | 266 | # queue = Queue(maxsize=5) 267 | # Process_num=3 268 | # for i in range(Process_num): 269 | # print(i,'start') 270 | # ii = int((len(graphs) - batch_size)/Process_num) 271 | # t = Process(target=f,args=(queue, i*ii,(i+1)*ii, graphs, classes, batch_size, perm)) 272 | # t.start() 273 | # print(int(len(graphs) /batch_size)) 274 | # for j in range(int(len(graphs) /batch_size)): 275 | # print(j) 276 | # t=queue.get() 277 | # if len(t[0]) == 0: 278 | # continue 279 | # loss = model.calc_loss(t[0], t[1], t[2], t[3], t[4], t[5]) 280 | # cum_loss += loss 281 | # count += 1 282 | 283 | 284 | st = 0 285 | while(st + batch_size < len(graphs)): 286 | X1, X2, X3, m1, m2, m3 = generate_batch_pairs( 287 | graphs, classes, batch_size, st, perm) 288 | st += batch_size 289 | if len(X1) == 0: 290 | continue 291 | loss = model.calc_loss(X1, X2, X3, m1, m2, m3) 292 | cum_loss += loss 293 | count += 1 294 | return cum_loss / count 295 | 296 | # def get_auc_epoch_batch(model, graphs, classes, batch_size): 297 | # tot_diff = [] 298 | # tot_truth = [] 299 | # st = 0 300 | # perm = range(len(graphs)) 301 | # while(st < len(graphs)): 302 | # X1, X2, X3, m1, m2, m3 = generate_batch_pairs( 303 | # graphs, classes, batch_size, st, perm) 304 | # st += batch_size 305 | # if len(X1) == 0: 306 | # continue 307 | # diff_p = model.calc_diff(X1, X2, m1, m2) 308 | # diff_n = model.calc_diff(X1, X3, m1, m3) 309 | # tot_diff += list(diff_p) + list(diff_n) 310 | # y_p = np.ones(len(diff_p)) 311 | # y_n = np.zeros(len(diff_n)) 312 | # tot_truth += list(y_n > 0) + list(y_p > 0) 313 | 314 | # diff = np.array(tot_diff) 315 | # truth = np.array(tot_truth) 316 | 317 | # fpr, tpr, thres = roc_curve(truth, diff) 318 | # model_auc = auc(fpr, tpr) 319 | # return model_auc, fpr, tpr, thres 320 | 321 | 322 | 323 | class SequenceData(): 324 | def __init__(self, graphs, classes, batch_size, perm): 325 | self.graphs = graphs 326 | self.classes = classes 327 | self.batch_size = batch_size 328 | self.perm = perm 329 | self.L = len(self.graphs) 330 | self.queue = Queue(maxsize=30) 331 | 332 | self.Process_num=3 333 | for i in range(self.Process_num): 334 | print(i,'start') 335 | ii = int(self.__len__()/self.Process_num) 336 | t = Process(target=self.f,args=(i*ii,(i+1)*ii)) 337 | t.start() 338 | def __len__(self): 339 | return self.L - self.batch_size 340 | def __getitem__(self, st): 341 | X1, X2, X3, m1, m2, m3 = get_pair(self.graphs, self.classes, self.batch_size, st=st, perm=self.perm) 342 | return X1, X2, X3, m1, m2, m3 343 | 344 | def f(self,i_l,i_h): 345 | l = i_h - i_l 346 | for i in range(int(l/self.batch_size)): 347 | t = self.__getitem__(i_l+i*self.batch_size) 348 | self.queue.put(t) 349 | 350 | # def gen(self): 351 | # while 1: 352 | # t = self.queue.get() 353 | # yield t[0],t[1],t[2],t[3] 354 | 355 | 356 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | autopep8==1.5.7 2 | certifi==2021.5.30 3 | charset-normalizer==2.0.3 4 | cycler==0.10.0 5 | decorator==4.4.2 6 | grpcio==1.37.1 7 | grpcio-tools==1.37.1 8 | idna==3.2 9 | joblib==1.0.1 10 | kiwisolver==1.3.1 11 | matplotlib==3.4.2 12 | networkx==2.5.1 13 | numpy==1.21.1 14 | Pillow==8.3.1 15 | protobuf==3.17.3 16 | pycodestyle==2.7.0 17 | pymilvus==1.1.2 18 | pyparsing==2.4.7 19 | python-dateutil==2.8.2 20 | requests==2.26.0 21 | scikit-learn==0.24.2 22 | scipy==1.7.0 23 | six==1.16.0 24 | sklearn==0.0 25 | threadpoolctl==2.2.0 26 | toml==0.10.2 27 | tqdm==4.61.2 28 | typing-extensions==3.10.0.0 29 | ujson==4.0.2 30 | urllib3==1.26.6 31 | --------------------------------------------------------------------------------