├── .dockerignore ├── .gitignore ├── Dockerfile ├── INSTALL ├── LICENSE ├── README-docker.md ├── README.md ├── REQUIREMENTS ├── STATUS ├── app.py ├── codeart ├── .gitignore ├── README.md ├── code │ ├── arguments.py │ ├── codeart_tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer.json │ │ └── tokenizer_config.json │ ├── data_utils.py │ ├── modeling_utils.py │ ├── models │ │ ├── __init__.py │ │ ├── configuration_codeart.py │ │ ├── configuration_rabert.py │ │ ├── modeling_codeart.py │ │ ├── modeling_jtrans.py │ │ ├── modeling_rabert.py │ │ ├── tokenization_codeart.py │ │ └── tokenization_rabert.py │ ├── run.py │ └── trainer.py ├── evaluation-jtrans │ ├── malware-family-classification │ │ ├── config │ │ │ ├── eval.json │ │ │ └── train.json │ │ ├── eval_config.sh │ │ ├── evaluate_multilabel.py │ │ ├── run.py │ │ ├── run_config.sh │ │ ├── run_multilabel.py │ │ ├── temp_evaluate.py │ │ └── utils.py │ └── provenance-attribution │ │ ├── run.py │ │ └── run.sh ├── evaluation │ ├── binary-similarity │ │ ├── .gitignore │ │ ├── README.md │ │ ├── binsim_dataset.py │ │ ├── binsim_trainer.py │ │ ├── config │ │ │ └── train.json │ │ ├── dump_files.py │ │ ├── encode.sh │ │ ├── eval.py │ │ ├── inference.py │ │ ├── model_utils.py │ │ ├── pretty_print.sh │ │ ├── pretty_print_all.py │ │ ├── run.py │ │ ├── run_config.sh │ │ ├── sample_and_report.py │ │ ├── sample_and_report.sh │ │ └── utils.py │ ├── malware-family-classification │ │ ├── config │ │ │ ├── eval-2f-100c.json │ │ │ └── train-2f-100c.json │ │ ├── eval_config.sh │ │ ├── evaluate_multilabel.py │ │ ├── run_config.sh │ │ ├── run_multilabel.py │ │ └── utils.py │ └── type-inference │ │ ├── config │ │ ├── eval-O0.json │ │ └── train-O0.json │ │ ├── eval_config.sh │ │ ├── labels.json │ │ ├── run.py │ │ ├── run_config.sh │ │ └── utils.py ├── preprocess │ ├── README.md │ ├── analysis │ │ ├── expr_lang_analyzer.py │ │ └── prog_model.py │ ├── analyze.py │ ├── binary_base.py │ ├── collect.py │ ├── disassemble.py │ ├── type_inference │ │ ├── base_calculator.py │ │ ├── die_globals.py │ │ ├── gen_dataset.py │ │ ├── parse_dwarf.py │ │ ├── upload_dataset.py │ │ └── utils.py │ └── utils │ │ ├── asm_parser.py │ │ └── data_utils.py └── scripts │ ├── config │ └── default.json │ └── train_config.sh └── requirements.txt /.dockerignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .save/ 3 | convert.ipynb 4 | *.npy 5 | 6 | __pycache__ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | save/ 3 | convert.ipynb 4 | *.npy 5 | 6 | */__pycache__ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use a PyTorch official image with CUDA and cuDNN pre-installed 2 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 3 | 4 | # Set the working directory in the container to /workspace 5 | WORKDIR /workspace 6 | 7 | # Copy the current directory contents into the container at /workspace 8 | COPY . /workspace 9 | 10 | # Install any needed packages specified in requirements.txt 11 | RUN pip install --no-cache-dir -r requirements.txt 12 | 13 | # Make port 47907 available to the world outside this container 14 | EXPOSE 47907 15 | 16 | # Define environment variable 17 | ENV NAME World 18 | 19 | # Run app.py when the container launches 20 | CMD ["python", "app.py"] 21 | -------------------------------------------------------------------------------- /INSTALL: -------------------------------------------------------------------------------- 1 | Please refer to the section "Setup and Run a Quick Example" in `README-docker.md` for the instructions of running CodeArt with a minimal example. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zian Su 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README-docker.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | 3 | This file contains the instructions for running CodeArt and reproducing the results in the paper. 4 | Reviewers will need to install Docker and pull the Docker image from [Docker Hub](https://hub.docker.com/repository/docker/sheepy928/codeart). 5 | Then, they can run CodeArt in the Docker image. 6 | 7 | 8 | 9 | # Setup and Run a Quick Example 10 | 11 | ## Pull the Docker Image 12 | 13 | Please use the following command to pull the Docker image. 14 | The size of the image is about 3.8GB, which contains the CodeArt repository and a full PyTorch environment with CUDA installed. Please note that it may take a few minutes to download the image. 15 | 16 | ```bash 17 | docker pull sheepy928/codeart:v1-release 18 | ``` 19 | 20 | After pulling the image, please run the container with the following command. 21 | `-it` means the container is interactive. 22 | ```bash 23 | docker run -it -p 47906:47906 codeart 24 | ``` 25 | The following message will be shown if the container is successfully started. Note that the public URL will be different each time, and each is valid for 72 hours or until the container is stopped. You can access our demo through either the local URL or the public URL. 26 | ``` 27 | Running on local URL: http://127.0.0.1:47906 28 | Running on public URL: https://.gradio.live 29 | ``` 30 | 31 | The CodeArt repository is located at `/workspace/codeart` in the container. The model checkpoints and datasets are hosted on Hugging Face, but the links are already hardcoded in the demo. 32 | 33 | If you want to access the models and datasets, you can find them [here](https://huggingface.co/collections/PurCL/codeart-artifacts-662e73dcd9b837e4b970c9be). 34 | 35 | ## A Quick Example 36 | 37 | In this example, we will reproduce one experiment reported Table 1 in the paper. 38 | 39 | 1. Open a web browser and visit the URL `https://.gradio.live` as shown in the message or `http://127.0.0.1:47906` if you are running the container locally. 40 | 41 | 2. Click on the `Table 1` tab on the top of the page. 42 | 43 | 3. Select the model as `PurCL/codeart-26m` and the dataset as `binutilsh`, which corresponds to the `Binutils` experiment in Table 1. Then, select the pool size as `50`. You will notice that `Run Alias` is automatically populated, and you do not need to change it. 44 | 45 | 4. Click on the `Run` button. The model will be downloaded from Hugging Face, and it may take a few minutes to download depending on your network speed. After the model is downloaded, the model will automatically be evaluated on the dataset, and the results will be displayed on the right side of the page. Please note that on a single A6000 GPU, the evaluation may take around 10 minutes. 46 | 47 | 5. You may see the actual command executed in the terminal. 48 | 49 | If the reviewer notices “Error” in the output block, please refresh the webpage. If the error persists, please restart the server program. 50 | 51 | ## Figure 8 52 | 53 | Figure 8 evaluates the performance of CodeArt on the binary similarity downstream task. 54 | The testing process of a binary similarity model is as follows: The model takes as input two binary files that are compiled from the same source code but with different compiler flags. For each function in the binary files, the model encodes the function into a vector (i.e., the embedding of the function). 55 | 56 | Then for each function in the binary file compiled with the `-O0` flag, we compute the cosine similarity between the embedding of the function and the embeddings of candidate functions in the binary file compiled with the `-O3` flag. Then we rank the candidate functions according to the cosine similarity, and record the rank of the function that corresponds to the same source code function as the function being queried. 57 | 58 | In Fig 8, the x-axis denotes pool sizes and y-axis the performance. It is worth noting that a larger pool size implies a more challenging setup. Therefore, while our interface supports all the experiments, the performance of CodeArt can be validated by only running experiments with a pool size of 500. 59 | 60 | ### Key Experiments 61 | 62 | This section describes how to run CodeArt with a pool size of 500. It takes around 1 hour in total. For the baseline models, we use the numbers reported by DiEmph[1] (for JTrans) and PEM[2] (for GNNs). 63 | 64 | First, please click the tab “Figure8” in the UI. In the “Select Model” field, please select “PurCL/codeart-binsim”. In the “Select Dataset” field, please select a project corresponding to a subfigure in Figure 8. For the “Pool Size” field, please select 500. Then click the button “run”. The results will be available within around 10 minutes. (We conduct the test with a single Nvidia A6000(48GB) GPU. The time may vary depending on GPUs.) 65 | 66 | For example, for the following experiment: 67 | {Select Model: PurCL/codeart-binsim, Select Dataset: libcurlh, Pool Size: 500} 68 | the expected output looks like: 69 | ```shell 70 | Number of overlapped functions: 666 71 | …… 72 | Number of selected overlapped functions: 500 73 | source embedding shape: (500, 768), target embedding shape: (500, 768) 74 | {'recall': {1: 0.6948000000000001, 3: 0.8308, 5: 0.8704000000000001, 10: 0.9128000000000001}, 'mrr': 0.7722353968253969} 75 | Final-PR@1: 0.6948000000000001 76 | Final-MRR: 0.7722353968253969 77 | ``` 78 | The value in ‘recall’-1 (0.6948) is corresponding to the point (500, 0.694) in the subfigure for Curl. 79 | 80 | Note that due to the random essence of the sampling pool of functions, the results may have variances no more than 5%. 81 | 82 | ### Other Experiments 83 | 84 | Please change “Pool Size” to other values to obtain the results for the other pool sizes. 85 | 86 | ## Table 1 87 | 88 | Table 1 is similar to Figure 8. It evaluates the performance of CodeArt in a zero-shot setup. Please refer to the previous section of Figure 8 for backgrounds about evaluating models on the binary similarity task. 89 | 90 | It is sufficient to evaluate CodeArt on the most challenging setup (with a pool size of 500). Our interface supports all the experiments though. 91 | 92 | ### Key Experiments 93 | 94 | This section describes how to run CodeArt with a pool size of 500. It takes around 1 hour in total. 95 | 96 | Please set the “Select Model” to “PurCL/codeart-26m”. In the “Select Dataset” field, please select a project corresponding to a subfigure in Table 1. For the “Pool Size” field, please select 500. Then click the button “run”. The results will be available within around 10 minutes. 97 | 98 | For example, for the curl dataset, the expected output looks like 99 | ```shell 100 | Number of overlapped functions: 666 101 | Number of selected overlapped functions: 500 102 | source embedding shape: (500, 768), target embedding shape: (500, 768) 103 | … 104 | {'recall': {1: 0.47639999999999993, 3: 0.63, 5: 0.6799999999999999, 10: 0.7464}, 'mrr': 0.5643303174603174} 105 | Final-PR@1: 0.47639999999999993 106 | Final-MRR: 0.5643303174603174 107 | ``` 108 | The ‘recall’-1 is 0.476, corresponding to the column Pool-size 500-CodeArt (the last column) and row “Curl” in Table 1, which has the value 0.47. 109 | Note that the results may have variances no more than 5%. 110 | 111 | ### Other Experiments 112 | 113 | Please set “Pool Size” to other values for experiments with a different pool sizes. 114 | 115 | ## Table 2 116 | 117 | Table 2 evaluates the performance of CodeArt on the malware family classification downstream task. It takes as input N binary functions from a malware sample, and outputs a label denoting the family of the malware. The “N-Funcs” denotes how many functions are taken as input for the model to classify the whole binary to some malware families. Our interface supports results from baseline JTrans and CodeArt with different “N-Funcs” options. 118 | 119 | ### Key Experiments 120 | 121 | First, please click the tab “Table 2” in the UI. In the “Select Model” field, please select “2Funcs-CodeArt”. Then, click the button “Run”. The results will be available in 10 minutes. 122 | 123 | The expected output looks like: 124 | 125 | ```shell 126 | 127 | [some loggings] 128 | 129 | auc: 0.9248027229306747, lrap: 0.5966341229736315, lrl: 0.0864753141245597 130 | ``` 131 | 132 | 133 | ## Table 3 134 | Table 3 evaluates the zero-shot binary-similarity performance of CodeArt variants pretrained on BinCorp-3m for ablation study. The test set is coreutils and pool size is 100. Our interface supports all variants in the table. Specifically, the mapping between the name in the table and options in “Select Model” is: 135 | “w/o local mask” -> “PurCL/codeart-3m-wo_local_mask” 136 | “w/o trans-closure” -> “PurCL/codeart-3m-wo_trans_closure” 137 | “max-trans-closure 4” -> “PurCL/codeart-3m-max_trans_closure_4” 138 | “max-trans-closure 6” -> “PurCL/codeart-3m-max_trans_closure_6” 139 | “w/o rel-pos-bias” -> “PurCL/codeart-3m-wo_rel_pos_bias” 140 | 141 | ### Key Experiments 142 | 143 | First, please click the tab “Table 3” in the UI. In the “Select Model” field, please select “PurCL/codeart-3m_wo_local_mask”. In the “Select Dataset” field, please select “coreutilsh”. In the “Pool Size” field, please select 100. Then, click the button “Run”. The time cost and output format are similar to previous binary similarity experiments. 144 | 145 | Note that results in this table also have some randomness as other binary-similarity results, but the relative performance gap is consistent. 146 | 147 | 148 | ## Figure 9 149 | 150 | Figure 9 evaluates the performance of CodeArt on the type inference downstream task. Our interface supports CodeArt for type inference with different optimization levels (O0, O1, O2, and O3). 151 | 152 | ### Key Experiments 153 | 154 | First, please click the tab “Figure 9” in the UI. In the “Select Optimization Level” field, please select “O1”. Then, click the button “Run”. The results will be available in 10 minutes. 155 | 156 | The expected output looks like: 157 | 158 | ```shell 159 | 160 | [some loggings] 161 | 162 | ***** predict metrics ***** 163 | predict_f1 = 0.9447 164 | predict_loss = 0.0175 165 | predict_precision = 0.9447 166 | predict_recall = 0.9447 167 | predict_runtime = 0:00:48.47 168 | predict_samples = 4124 169 | predict_samples_per_second = 85.081 170 | predict_steps_per_second = 0.681 171 | ``` 172 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CodeArt: Better Code Models by Attention Regularization When Symbols Are Lacking 2 | 3 | This is the official implementation of 4 | 5 | Zian Su, Xiangzhe Xu, Ziyang Huang, Zhuo Zhang, Yapeng Ye, Jianjun Huang, Xiangyu Zhang, "[CodeArt: Better Code Models by Attention Regularization When Symbols Are Lacking](https://arxiv.org/abs/2402.11842)", FSE'24 6 | 7 | --- 8 | 9 | Check the `artifact` branch for reproducible artifact. 10 | -------------------------------------------------------------------------------- /REQUIREMENTS: -------------------------------------------------------------------------------- 1 | ===Artifact Evaluation=== 2 | 3 | For the artifact evaluation, CodeArt requires `docker` installation. 4 | We provide a runnable docker image for the artifact evaluation. 5 | 6 | We recommend using a machine with at least 64GB of RAM and 8 CPU cores. 7 | We also recommend using a machine with an Nvidia GPU having at least 24 GB VRAM. 8 | Our reproduction is tested on A6000 48GB. 9 | 10 | Most of the experiments should work fine on a machine with 32GB of RAM and 16GB VRAM. 11 | 12 | Running all experiments may require up to 100GB of disk space. 13 | This is because CodeArt checkpoints and datasetsare hosted on Hugging Face and requires downloading them. 14 | 15 | The space usage can be reduced with minor engineering efforts. 16 | We will improve it in the future. -------------------------------------------------------------------------------- /STATUS: -------------------------------------------------------------------------------- 1 | CodeArt is applying for the following badges: 2 | 3 | 1. Available. 4 | 5 | 2. Functional 6 | 7 | 3. Reusable 8 | 9 | Below are the reasons for applying for the badges: 10 | 11 | 1. Available: We provide a runnable docker image that contains the code and the links to the datasets. This docker image can be used to run CodeArt. 12 | 13 | 2. Functional: In addtion to the code and docker image, we provide a `README-docker.md` file that contains the instructions to run the docker image and reproduce the results of the paper. We also integrate the experiments into a frontend web interface, which allows users to interact run the experiments and see the results with a few clicks. 14 | 15 | 3. Reusable: The code is well-structured and documented with launching scripts and configuration files. The code can be easily reused to run on other datasets or to test other models. In addition, the frontend web interface by itself is a reusable tool that can be used to run experiments on other models and datasets, but for simplicity and to avoid confusion, we only allow running the experiments for the models and datasets used in the paper. 16 | 17 | -------------------------------------------------------------------------------- /codeart/.gitignore: -------------------------------------------------------------------------------- 1 | save/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | checkpoints/** 165 | checkpoint*/** 166 | MLM-w-regularization/save/** 167 | **/cache/** 168 | **/wandb/** 169 | .vscode/settings.json 170 | -------------------------------------------------------------------------------- /codeart/README.md: -------------------------------------------------------------------------------- 1 | # CodeArt: Better Code Models by Attention Regularization When Symbols Are Lacking 2 | 3 | This repo contains code for the paper *CodeArt: Better Code Models by Attention Regularization When Symbols Are Lacking*. 4 | 5 | ## Environment 6 | 7 | - torch==2.0.1 8 | - transformers==4.30.2 9 | - datasets==2.14.4 10 | - networkx==3.1 11 | - scikit-learn=1.3.0 12 | 13 | ## Quick Tour 14 | 15 | To use CodeArt, you can follow this general pipeline of encoding instructions and dependences: 16 | 17 | ```python 18 | import sys 19 | sys.path.append('path_to_/code/') 20 | 21 | from models import ( 22 | CodeArtConfig, 23 | CodeArtTokenizer, 24 | CodeArtModel 25 | ) 26 | from modeling_utils import MaskBuilder 27 | 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | 30 | tokenizer = CodeArtTokenizer.from_pretrained("path_to_pretrained_checkpoint") 31 | maskbuilder = MaskBuilder( 32 | preset=None, 33 | enable_global_memory_patterns=True, 34 | enable_bridge_patterns=False, 35 | enable_graph_patterns=True 36 | ) 37 | tokenizer.maskbuilder = maskbuilder 38 | 39 | model = CodeArtModel.from_pretrained("path_to_pretrained_checkpoint") 40 | model.to(device) 41 | 42 | # encode one example 43 | instructions = [[0, 'push r15'], [1, 'push r14'], [2, 'mov r14d, r9d'], ...] 44 | dependences = [[16, 12], [17, 16], [21, 10], [25, 13], ...] 45 | encoded = tokenizer.inst_encode(instructions, dependences) 46 | embeddings = model( 47 | input_ids=encoded['input_ids'], 48 | attention_mask=encoded['attention_mask'], 49 | relative_position_matrix=encoded['relative_position_matrix'] 50 | ) 51 | ``` 52 | 53 | ## Dependence Analysis and Preprocessing 54 | 55 | To convert a binary program to an input to CodeArt, we need to 56 | first use IDA Pro to disassemble the binary program and then 57 | perform a conservative dependence analysis to extract the program dependence. 58 | 59 | Please refer to `preprocess/README.md` for details. 60 | 61 | ## Datasets 62 | 63 | The preprocessed datasets we use for training and evaluation are on HuggingFace Hub and you can refer to the configuration files to check them. 64 | 65 | > Due to safety concerns, we will not directly release the raw binaries of the malware dataset. 66 | > Instead, after publication, we will provide the raw binaries upon request to insterested researchers. 67 | > For now, we release the sha256 hashes of the samples in the malware dataset in `evaluation/malware-family-classification/id2family.jsonl`. 68 | 69 | ## Pretraining 70 | 71 | To replicate the pretraining, you can navigate to `scripts/`, and run `train_config.sh config/default.json`. 72 | 73 | ## Evaluation 74 | 75 | The evaluation code of CodeArt is under the directory `evaluation/`. 76 | 77 | ### Binary Similarity Analysis 78 | 79 | Please refer to `evaluation/binary-similarity/README.md` for details. 80 | 81 | ### Malware Family Classification 82 | 83 | Navigate to `evaluation/malware-family-classification/`. To finetune CodeArt on this task, run `run_config.sh config/train-2f-100c.json`. To evaluate the finetuned model, run `eval_config.sh config/eval-2f-100c.json`. Note that you need to specify the correct `model_name_or_path` in the configurations. 84 | 85 | ### Type Inference 86 | 87 | Navigate to `evaluation/type-inference/`. To finetune CodeArt on this task, run `run_config.sh config/train-O0.json` (you can modify the `dataset_name` in the configuration to finetune for O1, O2, and O3). To evaluate the finetuned model, run `eval_config.sh config/eval-O0.json`. Note that you need to specify the correct `model_name_or_path` in the configurations. 88 | 89 | ## Checkpoints 90 | 91 | We release checkpoints of pretraining and downstream tasks in [this link](https://drive.google.com/drive/folders/1PwNLmWmjXYH8ZYYtD7HmOtMybvpTBMRp?usp=sharing). You can download these checkpoints and extract them to `checkpoints/`. -------------------------------------------------------------------------------- /codeart/code/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | from transformers.utils.versions import require_version 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | """ 9 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 10 | """ 11 | 12 | model_name_or_path: Optional[str] = field( 13 | default=None, 14 | metadata={ 15 | "help": ( 16 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 17 | ) 18 | }, 19 | ) 20 | # model_type: Optional[str] = field( 21 | # default=None, 22 | # metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, 23 | # ) 24 | config_overrides: Optional[str] = field( 25 | default=None, 26 | metadata={ 27 | "help": ( 28 | "Override some existing default config settings when a model is trained from scratch. Example: " 29 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 30 | ) 31 | }, 32 | ) 33 | config_name: Optional[str] = field( 34 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 35 | ) 36 | tokenizer_name: Optional[str] = field( 37 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 38 | ) 39 | cache_dir: Optional[str] = field( 40 | default=None, 41 | metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 42 | ) 43 | use_fast_tokenizer: bool = field( 44 | default=True, 45 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 46 | ) 47 | model_revision: str = field( 48 | default="main", 49 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 50 | ) 51 | use_auth_token: bool = field( 52 | default=False, 53 | metadata={ 54 | "help": ( 55 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 56 | "with private models)." 57 | ) 58 | }, 59 | ) 60 | low_cpu_mem_usage: bool = field( 61 | default=False, 62 | metadata={ 63 | "help": ( 64 | "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded." 65 | "set True will benefit LLM loading time and RAM consumption." 66 | ) 67 | }, 68 | ) 69 | 70 | # positional embedding arguments 71 | position_embedding_type: str = field( 72 | default='absolute', 73 | metadata={ 74 | "help": ( 75 | "Type of positional embedding to use, can be either 'absolute' or 'relative'" 76 | ), 77 | "choices": ['absolute', 'mixed'], 78 | } 79 | ) 80 | 81 | max_relative_position_embeddings: Optional[int] = field( 82 | default=8, 83 | metadata={ 84 | "help": ( 85 | "Max relative postion distance to consider in relative embeddings" 86 | ) 87 | } 88 | ) 89 | 90 | # edge prediction arguments 91 | ep_add_linear_projection: bool = field( 92 | default=False, 93 | metadata={ 94 | "help": ( 95 | "Whether to add a linear projection before computing dot product for edge prediction" 96 | ) 97 | } 98 | ) 99 | 100 | # masking arguments 101 | masking_preset: Optional[str] = field( 102 | default=None, 103 | metadata={ 104 | "help": ("Preset masking strategy"), 105 | "choices": [None, 'aggressive', 'conservative'], 106 | } 107 | ) 108 | masking_enable_global_memory_patterns: bool = field( 109 | default=True, 110 | metadata={ 111 | "help": ("enable global memory patterns") 112 | } 113 | ) 114 | masking_enable_bridge_patterns: bool = field( 115 | default=True, 116 | metadata={ 117 | "help": ("enable bridge patterns") 118 | } 119 | ) 120 | masking_enable_graph_patterns: bool = field( 121 | default=True, 122 | metadata={ 123 | "help": ("enable graph patterns") 124 | } 125 | ) 126 | masking_enable_local_patterns: bool = field( 127 | default=True, 128 | metadata={ 129 | "help": ("enable local patterns") 130 | } 131 | ) 132 | 133 | with_transitive_closure: bool = field( 134 | default=True, 135 | metadata={ 136 | "help": ("Whether to include transitive closure in the masking process") 137 | } 138 | ) 139 | 140 | max_transitions: Optional[int] = field( 141 | default=None, 142 | metadata={ 143 | "help": ("Maximum number of transitions to consider in masking") 144 | } 145 | ) 146 | 147 | def __post_init__(self): 148 | if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): 149 | raise ValueError( 150 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 151 | ) 152 | 153 | 154 | @dataclass 155 | class DataTrainingArguments: 156 | """ 157 | Arguments pertaining to what data we are going to input our model for training and eval. 158 | """ 159 | 160 | dataset_name: Optional[str] = field( 161 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 162 | ) 163 | dataset_config_name: Optional[str] = field( 164 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 165 | ) 166 | train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."}) 167 | validation_file: Optional[str] = field( 168 | default=None, 169 | metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."}, 170 | ) 171 | overwrite_cache: bool = field( 172 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 173 | ) 174 | validation_split_percentage: Optional[int] = field( 175 | default=5, 176 | metadata={ 177 | "help": "The percentage of the train set used as validation set in case there's no validation split" 178 | }, 179 | ) 180 | max_seq_length: Optional[int] = field( 181 | default=None, 182 | metadata={ 183 | "help": ( 184 | "The maximum total input sequence length after tokenization. Sequences longer " 185 | "than this will be truncated." 186 | ) 187 | }, 188 | ) 189 | preprocessing_num_workers: Optional[int] = field( 190 | default=None, 191 | metadata={"help": "The number of processes to use for the preprocessing."}, 192 | ) 193 | mlm_probability: float = field( 194 | default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"} 195 | ) 196 | ep_probability: float = field( 197 | default=0.15, metadata={"help": "Ratio of nodes to select for edge prediction loss"} 198 | ) 199 | line_by_line: bool = field( 200 | default=False, 201 | metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, 202 | ) 203 | pad_to_max_length: bool = field( 204 | default=False, 205 | metadata={ 206 | "help": ( 207 | "Whether to pad all samples to `max_seq_length`. " 208 | "If False, will pad the samples dynamically when batching to the maximum length in the batch." 209 | ) 210 | }, 211 | ) 212 | max_train_samples: Optional[int] = field( 213 | default=None, 214 | metadata={ 215 | "help": ( 216 | "For debugging purposes or quicker training, truncate the number of training examples to this " 217 | "value if set." 218 | ) 219 | }, 220 | ) 221 | max_eval_samples: Optional[int] = field( 222 | default=None, 223 | metadata={ 224 | "help": ( 225 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 226 | "value if set." 227 | ) 228 | }, 229 | ) 230 | streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"}) 231 | 232 | def __post_init__(self): 233 | if self.streaming: 234 | require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`") 235 | 236 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 237 | raise ValueError("Need either a dataset name or a training/validation file.") 238 | else: 239 | if self.train_file is not None: 240 | extension = self.train_file.split(".")[-1] 241 | if extension not in ["csv", "json", "txt"]: 242 | raise ValueError("`train_file` should be a csv, a json or a txt file.") 243 | if self.validation_file is not None: 244 | extension = self.validation_file.split(".")[-1] 245 | if extension not in ["csv", "json", "txt"]: 246 | raise ValueError("`validation_file` should be a csv, a json or a txt file.") -------------------------------------------------------------------------------- /codeart/code/codeart_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /codeart/code/codeart_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "clean_up_tokenization_spaces": true, 3 | "model_max_length": 512, 4 | "tokenizer_class": "PreTrainedTokenizerFast" 5 | } 6 | -------------------------------------------------------------------------------- /codeart/code/modeling_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | from networkx.algorithms.shortest_paths import floyd_warshall_numpy 4 | import torch 5 | from typing import Any, Dict, List, Tuple 6 | 7 | 8 | def create_attention_mask_aggressive( 9 | input_length: int, 10 | instruction_node_positions: List[int], 11 | data_dep: List[Tuple[int, int]]): 12 | 13 | # 0. [CLS] (attend to all) and [SEP] (attended by all) tokens 14 | # 1. regular token local pattern: letting local tokens only attend to bridge tokens 15 | # 2. bridge () token's local and global patterns 16 | # 3. bridge dependency pattern 17 | # NOTE: directed graph can be easier for later processing 18 | 19 | attention_mask = torch.zeros(size=(input_length, input_length), dtype=torch.bool) 20 | 21 | # [CLS] 22 | attention_mask[0, :] = 1 23 | attention_mask[:, 0] = 1 # all tokens can somehow get some global information from [CLS] token 24 | 25 | # [SEP]: currently this is not meaningful as rabert has no NSP 26 | 27 | for inst_id, position_id in enumerate(instruction_node_positions): 28 | next_position_id = instruction_node_positions[inst_id + 1] \ 29 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 30 | 31 | # attention_mask[position_id + 1: next_position_id, instruction_node_positions] = 1 32 | attention_mask[position_id, position_id: next_position_id] = 1 33 | for token_position in range(position_id + 1, next_position_id): 34 | # local attend to local and direct bridge 35 | attention_mask[token_position, position_id: next_position_id] = 1 36 | # local to attend to all other bridges 37 | attention_mask[token_position, instruction_node_positions] = 1 38 | 39 | for source_inst_id, target_inst_id in data_dep: 40 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 41 | continue 42 | source_position_id = instruction_node_positions[source_inst_id] 43 | target_position_id = instruction_node_positions[target_inst_id] 44 | attention_mask[source_position_id, target_position_id] = 1 45 | attention_mask[target_position_id, source_position_id] = 1 46 | 47 | return attention_mask 48 | 49 | 50 | def create_attention_mask_conservative( 51 | input_length, 52 | instruction_node_positions, 53 | data_dep 54 | ): 55 | """ 56 | The difference between `aggressive` and `conservative` mask creation is that, 57 | `aggressive` only allows tokens as information source between dependent 58 | instructions, whereas `conservative` allows all tokens in dependent instruction 59 | context to attend to each other. 60 | 61 | Only `aggressive` version can lead to sparse solutions. 62 | """ 63 | 64 | attention_mask = torch.zeros(size=(input_length, input_length), dtype=torch.bool) 65 | 66 | # [CLS] 67 | attention_mask[0, :] = 1 68 | attention_mask[:, 0] = 1 # all tokens can somehow get some global information from [CLS] token 69 | 70 | # [SEP]: currently this is not meaningful as rabert has no NSP 71 | 72 | for inst_id, position_id in enumerate(instruction_node_positions): 73 | next_position_id = instruction_node_positions[inst_id + 1] \ 74 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 75 | 76 | attention_mask[position_id, position_id: next_position_id] = 1 77 | for token_position in range(position_id + 1, next_position_id): 78 | # local attend to local and direct bridge 79 | attention_mask[token_position, position_id: next_position_id] = 1 80 | # local to attend to all other bridges 81 | attention_mask[token_position, instruction_node_positions] = 1 82 | 83 | for source_inst_id, target_inst_id in data_dep: 84 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 85 | continue 86 | source_position_id = instruction_node_positions[source_inst_id] 87 | target_position_id = instruction_node_positions[target_inst_id] 88 | attention_mask[source_position_id, target_position_id] = 1 89 | attention_mask[target_position_id, source_position_id] = 1 90 | 91 | return attention_mask 92 | 93 | 94 | def create_attention_mask_gcb( 95 | input_length, 96 | instruction_node_positions, 97 | data_dep 98 | ): 99 | attention_mask = torch.ones(size=(input_length, input_length), dtype=torch.bool) 100 | 101 | # local patterns 102 | for inst_id, position_id in enumerate(instruction_node_positions): 103 | next_position_id = instruction_node_positions[inst_id + 1] \ 104 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 105 | 106 | # remove all node related attention first 107 | attention_mask[position_id, :] = 0 108 | attention_mask[:, position_id] = 0 109 | 110 | if position_id + 2 <= next_position_id: 111 | attention_mask[position_id, position_id + 2] = 1 112 | attention_mask[position_id + 2, position_id] = 1 113 | 114 | # graph patterns 115 | for source_inst_id, target_inst_id in data_dep: 116 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 117 | continue 118 | source_position_id = instruction_node_positions[source_inst_id] 119 | target_position_id = instruction_node_positions[target_inst_id] 120 | attention_mask[source_position_id, target_position_id] = 1 121 | attention_mask[target_position_id, source_position_id] = 1 122 | 123 | return attention_mask 124 | 125 | 126 | 127 | class MaskBuilder(object): 128 | 129 | def __init__( 130 | self, 131 | preset=None, 132 | enable_global_memory_patterns=True, 133 | enable_bridge_patterns=True, 134 | enable_graph_patterns=True, 135 | device='cpu' 136 | ): 137 | self.preset = preset 138 | self.enable_global_memory_patterns = enable_global_memory_patterns 139 | self.enable_bridge_patterns = enable_bridge_patterns 140 | self.enable_graph_patterns = enable_graph_patterns 141 | self.device = device 142 | 143 | def create_attention_mask( 144 | self, 145 | input_length, 146 | instruction_node_positions, 147 | data_dep=None, 148 | ): 149 | if self.preset == 'graphcodebert': 150 | return create_attention_mask_gcb(input_length, instruction_node_positions, data_dep) 151 | elif self.preset is None: 152 | pass 153 | else: 154 | raise NotImplementedError 155 | 156 | attention_mask = torch.zeros(size=(input_length, input_length), dtype=torch.bool) 157 | 158 | # [CLS] token 159 | attention_mask[0, :] = 1 160 | 161 | # global memory 162 | if self.enable_global_memory_patterns: 163 | attention_mask[:, 0] = 1 # all tokens can somehow get some global information from [CLS] token 164 | 165 | # local patterns 166 | for inst_id, position_id in enumerate(instruction_node_positions): 167 | next_position_id = instruction_node_positions[inst_id + 1] \ 168 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 169 | 170 | attention_mask[position_id, position_id: next_position_id] = 1 171 | for token_position in range(position_id + 1, next_position_id): 172 | # local attend to local and direct bridge 173 | attention_mask[token_position, position_id: next_position_id] = 1 174 | # local to attend to all other bridges 175 | if self.enable_bridge_patterns: 176 | attention_mask[token_position, instruction_node_positions] = 1 177 | 178 | # graph patterns 179 | if self.enable_graph_patterns: 180 | for source_inst_id, target_inst_id in data_dep: 181 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 182 | continue 183 | source_position_id = instruction_node_positions[source_inst_id] 184 | target_position_id = instruction_node_positions[target_inst_id] 185 | attention_mask[source_position_id, target_position_id] = 1 186 | attention_mask[target_position_id, source_position_id] = 1 187 | 188 | return attention_mask 189 | 190 | def create_attention_mask_no_local( 191 | self, 192 | input_length, 193 | instruction_node_positions, 194 | data_dep=None, 195 | ): 196 | 197 | attention_mask = torch.ones(size=(input_length, input_length), dtype=torch.bool) 198 | 199 | # [CLS] token and global memory already there 200 | 201 | # local patterns 202 | for inst_id, position_id in enumerate(instruction_node_positions): 203 | next_position_id = instruction_node_positions[inst_id + 1] \ 204 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 205 | 206 | 207 | for i in instruction_node_positions: 208 | attention_mask[i, instruction_node_positions] = 0 209 | 210 | 211 | # graph patterns 212 | if self.enable_graph_patterns: 213 | for source_inst_id, target_inst_id in data_dep: 214 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 215 | continue 216 | source_position_id = instruction_node_positions[source_inst_id] 217 | target_position_id = instruction_node_positions[target_inst_id] 218 | attention_mask[source_position_id, target_position_id] = 1 219 | attention_mask[target_position_id, source_position_id] = 1 220 | 221 | 222 | return attention_mask 223 | 224 | def create_attention_mask_and_relative_position_matrix( 225 | self, 226 | input_length, 227 | instruction_node_positions, 228 | data_dep, 229 | max_transitions=None, 230 | ): 231 | # 1. data_dep mask option 232 | # 2. bridge token option 233 | # 3. global memory token option 234 | 235 | attention_mask = torch.zeros(size=(input_length, input_length), dtype=torch.bool, device=self.device) 236 | 237 | # [CLS] token 238 | attention_mask[0, :] = 1 239 | 240 | # global memory 241 | if self.enable_global_memory_patterns: # TODO: enable larger memory 242 | attention_mask[:, 0] = 1 # all tokens can somehow get some global information from [CLS] token 243 | # attention_mask[:, -1] = 1 # all tokens can somehow get some global information from [SEP] token 244 | 245 | # local patterns 246 | for inst_id, position_id in enumerate(instruction_node_positions): 247 | next_position_id = instruction_node_positions[inst_id + 1] \ 248 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 249 | 250 | attention_mask[position_id, position_id: next_position_id] = 1 251 | for token_position in range(position_id + 1, next_position_id): 252 | # local attend to local and direct bridge 253 | attention_mask[token_position, position_id: next_position_id] = 1 254 | # local to attend to all other bridges 255 | if self.enable_bridge_patterns: 256 | attention_mask[token_position, instruction_node_positions] = 1 257 | 258 | # filter out out-of-range dependencies 259 | remaining_data_dep = [] 260 | if self.enable_graph_patterns: 261 | for source_inst_id, target_inst_id in data_dep: 262 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 263 | continue 264 | else: 265 | remaining_data_dep.append((source_inst_id, target_inst_id)) 266 | 267 | # graph construction 268 | graph = nx.Graph() 269 | graph.add_nodes_from(range(len(instruction_node_positions))) 270 | graph.add_edges_from(remaining_data_dep) 271 | 272 | # dense graph all pairs shortest path length 273 | spl_matrix = floyd_warshall_numpy(graph) 274 | if max_transitions is not None: 275 | spl_matrix[spl_matrix > max_transitions] = -1 276 | spl_matrix[spl_matrix == np.inf] = -1 277 | spl_matrix = torch.tensor(spl_matrix, dtype=torch.long, device=self.device) 278 | 279 | # recover full matrix 280 | rel_pos_matrix = torch.full(size=(input_length, input_length), fill_value=-1, dtype=torch.long, device=self.device) 281 | for i, nid in enumerate(instruction_node_positions): 282 | rel_pos_matrix[nid, instruction_node_positions] = spl_matrix[i, :] 283 | 284 | 285 | # merge rel_pos_matrix into attention_mask 286 | attention_mask = torch.logical_or(attention_mask, rel_pos_matrix >= 0) 287 | 288 | return attention_mask, rel_pos_matrix 289 | 290 | def create_attention_mask_and_relative_position_matrix_no_local( 291 | self, 292 | input_length, 293 | instruction_node_positions, 294 | data_dep, 295 | max_transitions=None, 296 | ): 297 | # 1. data_dep mask option 298 | # 2. bridge token option 299 | # 3. global memory token option 300 | 301 | 302 | attention_mask = torch.ones(size=(input_length, input_length), dtype=torch.bool, device=self.device) 303 | 304 | 305 | # local patterns 306 | for inst_id, position_id in enumerate(instruction_node_positions): 307 | next_position_id = instruction_node_positions[inst_id + 1] \ 308 | if inst_id + 1 < len(instruction_node_positions) else input_length - 1 309 | 310 | 311 | # filter out out-of-range dependencies 312 | remaining_data_dep = [] 313 | if self.enable_graph_patterns: 314 | for source_inst_id, target_inst_id in data_dep: 315 | if source_inst_id >= inst_id or target_inst_id >= inst_id: # support truncation 316 | continue 317 | else: 318 | remaining_data_dep.append((source_inst_id, target_inst_id)) 319 | 320 | # graph construction 321 | graph = nx.Graph() 322 | graph.add_nodes_from(range(len(instruction_node_positions))) 323 | graph.add_edges_from(remaining_data_dep) 324 | 325 | 326 | # dense graph all pairs shortest path length 327 | spl_matrix = floyd_warshall_numpy(graph) 328 | if max_transitions is not None: 329 | spl_matrix[spl_matrix > max_transitions] = -1 330 | spl_matrix[spl_matrix == np.inf] = -1 331 | spl_matrix = torch.tensor(spl_matrix, dtype=torch.long, device=self.device) 332 | 333 | 334 | # recover full matrix 335 | rel_pos_matrix = torch.full(size=(input_length, input_length), fill_value=-1, dtype=torch.long, device=self.device) 336 | for i, nid in enumerate(instruction_node_positions): 337 | rel_pos_matrix[nid, instruction_node_positions] = spl_matrix[i, :] 338 | 339 | for i in instruction_node_positions: 340 | attention_mask[i, instruction_node_positions] = 0 341 | 342 | 343 | # merge rel_pos_matrix into attention_mask 344 | attention_mask = torch.logical_or(attention_mask, rel_pos_matrix >= 0) 345 | 346 | 347 | return attention_mask, rel_pos_matrix -------------------------------------------------------------------------------- /codeart/code/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_rabert import RabertConfig 2 | from .configuration_codeart import CodeArtConfig 3 | 4 | from .modeling_rabert import RabertModel, RabertForMaskedLMWithEdgePrediction, RabertForSequenceClassification, RabertForBinSim 5 | from .modeling_codeart import CodeArtModel, CodeArtForMaskedLMWithEdgePrediction, CodeArtForSequenceClassification, CodeArtForMultipleSequenceClassification, CodeArtForBinSim, CodeArtForTokenClassification 6 | 7 | from .tokenization_rabert import RabertTokenizer, GCBLikeTokenizer 8 | from .tokenization_codeart import CodeArtTokenizer 9 | 10 | from .modeling_jtrans import JTransModel, JTransForSequenceClassification, JTransForMultipleSequenceClassification, JTransForTokenClassification 11 | -------------------------------------------------------------------------------- /codeart/code/models/configuration_codeart.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | 3 | 4 | class CodeArtConfig(PretrainedConfig): 5 | r""" 6 | This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is 7 | used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. 8 | Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa 9 | [roberta-base](https://huggingface.co/roberta-base) architecture. 10 | 11 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 12 | documentation from [`PretrainedConfig`] for more information. 13 | 14 | 15 | Args: 16 | vocab_size (`int`, *optional*, defaults to 50265): 17 | Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the 18 | `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. 19 | hidden_size (`int`, *optional*, defaults to 768): 20 | Dimensionality of the encoder layers and the pooler layer. 21 | num_hidden_layers (`int`, *optional*, defaults to 12): 22 | Number of hidden layers in the Transformer encoder. 23 | num_attention_heads (`int`, *optional*, defaults to 12): 24 | Number of attention heads for each attention layer in the Transformer encoder. 25 | intermediate_size (`int`, *optional*, defaults to 3072): 26 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 27 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 28 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 29 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 30 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 31 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 32 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 33 | The dropout ratio for the attention probabilities. 34 | max_position_embeddings (`int`, *optional*, defaults to 512): 35 | The maximum sequence length that this model might ever be used with. Typically set this to something large 36 | just in case (e.g., 512 or 1024 or 2048). 37 | type_vocab_size (`int`, *optional*, defaults to 2): 38 | The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. 39 | initializer_range (`float`, *optional*, defaults to 0.02): 40 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 41 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 42 | The epsilon used by the layer normalization layers. 43 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 44 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 45 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 46 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 47 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 48 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 49 | is_decoder (`bool`, *optional*, defaults to `False`): 50 | Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. 51 | use_cache (`bool`, *optional*, defaults to `True`): 52 | Whether or not the model should return the last key/values attentions (not used by all models). Only 53 | relevant if `config.is_decoder=True`. 54 | classifier_dropout (`float`, *optional*): 55 | The dropout ratio for the classification head. 56 | 57 | Examples: 58 | 59 | ```python 60 | >>> from transformers import RobertaConfig, RobertaModel 61 | 62 | >>> # Initializing a RoBERTa configuration 63 | >>> configuration = RobertaConfig() 64 | 65 | >>> # Initializing a model (with random weights) from the configuration 66 | >>> model = RobertaModel(configuration) 67 | 68 | >>> # Accessing the model configuration 69 | >>> configuration = model.config 70 | ```""" 71 | model_type = "codeart" 72 | 73 | def __init__( 74 | self, 75 | vocab_size=50265, 76 | hidden_size=768, 77 | num_hidden_layers=12, 78 | num_attention_heads=12, 79 | intermediate_size=3072, 80 | hidden_act="gelu", 81 | hidden_dropout_prob=0.1, 82 | attention_probs_dropout_prob=0.1, 83 | max_position_embeddings=512, 84 | max_relative_position_embeddings=8, # NOTE: codeart unique 85 | type_vocab_size=2, 86 | initializer_range=0.02, 87 | layer_norm_eps=1e-12, 88 | pad_token_id=3, 89 | bos_token_id=1, 90 | eos_token_id=2, 91 | # position_embedding_type="absolute", 92 | position_embedding_type="mixed", 93 | use_cache=True, 94 | classifier_dropout=None, 95 | ep_add_linear_projection=False, # NOTE: codeart unique 96 | **kwargs, 97 | ): 98 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 99 | 100 | self.vocab_size = vocab_size 101 | self.hidden_size = hidden_size 102 | self.num_hidden_layers = num_hidden_layers 103 | self.num_attention_heads = num_attention_heads 104 | self.hidden_act = hidden_act 105 | self.intermediate_size = intermediate_size 106 | self.hidden_dropout_prob = hidden_dropout_prob 107 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 108 | self.max_position_embeddings = max_position_embeddings 109 | self.max_relative_position_embeddings = max_relative_position_embeddings # NOTE: codeart unique 110 | self.type_vocab_size = type_vocab_size 111 | self.initializer_range = initializer_range 112 | self.layer_norm_eps = layer_norm_eps 113 | self.position_embedding_type = position_embedding_type 114 | self.use_cache = use_cache 115 | self.classifier_dropout = classifier_dropout 116 | 117 | self.ep_add_linear_projection = ep_add_linear_projection # NOTE: codeart unique -------------------------------------------------------------------------------- /codeart/code/models/configuration_rabert.py: -------------------------------------------------------------------------------- 1 | from transformers.configuration_utils import PretrainedConfig 2 | 3 | 4 | class RabertConfig(PretrainedConfig): 5 | r""" 6 | This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is 7 | used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. 8 | Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa 9 | [roberta-base](https://huggingface.co/roberta-base) architecture. 10 | 11 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 12 | documentation from [`PretrainedConfig`] for more information. 13 | 14 | 15 | Args: 16 | vocab_size (`int`, *optional*, defaults to 50265): 17 | Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the 18 | `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. 19 | hidden_size (`int`, *optional*, defaults to 768): 20 | Dimensionality of the encoder layers and the pooler layer. 21 | num_hidden_layers (`int`, *optional*, defaults to 12): 22 | Number of hidden layers in the Transformer encoder. 23 | num_attention_heads (`int`, *optional*, defaults to 12): 24 | Number of attention heads for each attention layer in the Transformer encoder. 25 | intermediate_size (`int`, *optional*, defaults to 3072): 26 | Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. 27 | hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): 28 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 29 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 30 | hidden_dropout_prob (`float`, *optional*, defaults to 0.1): 31 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 32 | attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): 33 | The dropout ratio for the attention probabilities. 34 | max_position_embeddings (`int`, *optional*, defaults to 512): 35 | The maximum sequence length that this model might ever be used with. Typically set this to something large 36 | just in case (e.g., 512 or 1024 or 2048). 37 | type_vocab_size (`int`, *optional*, defaults to 2): 38 | The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. 39 | initializer_range (`float`, *optional*, defaults to 0.02): 40 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 41 | layer_norm_eps (`float`, *optional*, defaults to 1e-12): 42 | The epsilon used by the layer normalization layers. 43 | position_embedding_type (`str`, *optional*, defaults to `"absolute"`): 44 | Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For 45 | positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to 46 | [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). 47 | For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models 48 | with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). 49 | is_decoder (`bool`, *optional*, defaults to `False`): 50 | Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. 51 | use_cache (`bool`, *optional*, defaults to `True`): 52 | Whether or not the model should return the last key/values attentions (not used by all models). Only 53 | relevant if `config.is_decoder=True`. 54 | classifier_dropout (`float`, *optional*): 55 | The dropout ratio for the classification head. 56 | 57 | Examples: 58 | 59 | ```python 60 | >>> from transformers import RobertaConfig, RobertaModel 61 | 62 | >>> # Initializing a RoBERTa configuration 63 | >>> configuration = RobertaConfig() 64 | 65 | >>> # Initializing a model (with random weights) from the configuration 66 | >>> model = RobertaModel(configuration) 67 | 68 | >>> # Accessing the model configuration 69 | >>> configuration = model.config 70 | ```""" 71 | model_type = "rabert" 72 | 73 | def __init__( 74 | self, 75 | vocab_size=50265, 76 | hidden_size=768, 77 | num_hidden_layers=12, 78 | num_attention_heads=12, 79 | intermediate_size=3072, 80 | hidden_act="gelu", 81 | hidden_dropout_prob=0.1, 82 | attention_probs_dropout_prob=0.1, 83 | max_position_embeddings=512, 84 | type_vocab_size=2, 85 | initializer_range=0.02, 86 | layer_norm_eps=1e-12, 87 | pad_token_id=3, 88 | bos_token_id=1, 89 | eos_token_id=2, 90 | position_embedding_type="absolute", 91 | use_cache=True, 92 | classifier_dropout=None, 93 | ep_add_linear_projection=False, # NOTE: rabert unique 94 | **kwargs, 95 | ): 96 | super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) 97 | 98 | self.vocab_size = vocab_size 99 | self.hidden_size = hidden_size 100 | self.num_hidden_layers = num_hidden_layers 101 | self.num_attention_heads = num_attention_heads 102 | self.hidden_act = hidden_act 103 | self.intermediate_size = intermediate_size 104 | self.hidden_dropout_prob = hidden_dropout_prob 105 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 106 | self.max_position_embeddings = max_position_embeddings 107 | self.type_vocab_size = type_vocab_size 108 | self.initializer_range = initializer_range 109 | self.layer_norm_eps = layer_norm_eps 110 | self.position_embedding_type = position_embedding_type 111 | self.use_cache = use_cache 112 | self.classifier_dropout = classifier_dropout 113 | 114 | self.ep_add_linear_projection = ep_add_linear_projection # NOTE: rabert unique -------------------------------------------------------------------------------- /codeart/code/models/modeling_jtrans.py: -------------------------------------------------------------------------------- 1 | from transformers import ( 2 | BertTokenizer, 3 | BertForMaskedLM, 4 | BertModel, 5 | ) 6 | from typing import List, Optional, Tuple, Union 7 | import torch 8 | from transformers.modeling_outputs import SequenceClassifierOutput, TokenClassifierOutput 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from transformers.modeling_utils import ( 12 | PreTrainedModel 13 | ) 14 | 15 | from transformers.models.roberta.modeling_roberta import ( 16 | RobertaClassificationHead 17 | ) 18 | 19 | 20 | 21 | class JTransModel(BertModel): 22 | def __init__(self, config, add_pooling_layer=True): 23 | super().__init__(config, add_pooling_layer=add_pooling_layer) 24 | self.config = config 25 | self.embeddings.position_embeddings = self.embeddings.word_embeddings 26 | 27 | 28 | class JTransForSequenceClassification(PreTrainedModel): 29 | def __init__(self, config): 30 | super().__init__(config) 31 | self.num_labels = config.num_labels 32 | self.config = config 33 | self.bert = JTransModel(config) 34 | classifier_dropout = ( 35 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 36 | ) 37 | self.dropout = nn.Dropout(classifier_dropout) 38 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 39 | self.post_init() 40 | 41 | def forward( 42 | self, 43 | input_ids: Optional[torch.Tensor] = None, 44 | attention_mask: Optional[torch.Tensor] = None, 45 | token_type_ids: Optional[torch.Tensor] = None, 46 | position_ids: Optional[torch.Tensor] = None, 47 | head_mask: Optional[torch.Tensor] = None, 48 | inputs_embeds: Optional[torch.Tensor] = None, 49 | labels: Optional[torch.Tensor] = None, 50 | output_attentions: Optional[bool] = None, 51 | output_hidden_states: Optional[bool] = None, 52 | return_dict: Optional[bool] = None, 53 | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: 54 | return_dict = ( 55 | return_dict if return_dict is not None else self.config.use_return_dict 56 | ) 57 | outputs = self.bert( 58 | input_ids, 59 | attention_mask=attention_mask, 60 | token_type_ids=token_type_ids, 61 | position_ids=position_ids, 62 | head_mask=head_mask, 63 | inputs_embeds=inputs_embeds, 64 | output_attentions=output_attentions, 65 | output_hidden_states=output_hidden_states, 66 | return_dict=return_dict, 67 | ) 68 | pooled_output = outputs[1] 69 | 70 | pooled_output = self.dropout(pooled_output) 71 | logits = self.classifier(pooled_output) 72 | loss = None 73 | if labels is not None: 74 | labels = labels.to(logits.device) 75 | if self.config.problem_type is None: 76 | if self.num_labels == 1: 77 | self.config.problem_type = "regression" 78 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 79 | self.config.problem_type = "single_label_classification" 80 | else: 81 | self.config.problem_type = "multi_label_classification" 82 | 83 | if self.config.problem_type == "regression": 84 | loss_fct = nn.MSELoss() 85 | if self.num_labels == 1: 86 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 87 | else: 88 | loss = loss_fct(logits, labels) 89 | elif self.config.problem_type == "single_label_classification": 90 | loss_fct = nn.CrossEntropyLoss() 91 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 92 | elif self.config.problem_type == "multi_label_classification": 93 | loss_fct = nn.BCEWithLogitsLoss() 94 | loss = loss_fct(logits, labels) 95 | if not return_dict: 96 | output = (logits,) + outputs[2:] 97 | return ((loss,) + output) if loss is not None else output 98 | 99 | return SequenceClassifierOutput( 100 | loss=loss, 101 | logits=logits, 102 | hidden_states=outputs.hidden_states, 103 | attentions=outputs.attentions, 104 | ) 105 | 106 | 107 | class JTransForMultipleSequenceClassification(PreTrainedModel): 108 | def __init__(self, config, num_sequences): 109 | super().__init__(config) 110 | 111 | self.num_labels = config.num_labels 112 | self.config = config 113 | self.num_sequences = num_sequences 114 | 115 | self.bert = JTransModel(config) 116 | 117 | # attention pooling 118 | self.attn_vector = nn.Parameter(torch.normal(0.0, config.initializer_range, size=(1, config.hidden_size)).squeeze(), requires_grad=True) 119 | 120 | # classifier_dropout = ( 121 | # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 122 | # ) 123 | # self.dropout = nn.Dropout(classifier_dropout) 124 | # self.classifier = nn.Linear(config.hidden_size, config.num_labels) 125 | self.classifier = RobertaClassificationHead(config) 126 | self.post_init() 127 | 128 | def forward( 129 | self, 130 | input_ids: Optional[torch.Tensor] = None, 131 | attention_mask: Optional[torch.Tensor] = None, 132 | token_type_ids: Optional[torch.Tensor] = None, 133 | position_ids: Optional[torch.Tensor] = None, 134 | head_mask: Optional[torch.Tensor] = None, 135 | sequence_mask: Optional[torch.Tensor] = None, 136 | inputs_embeds: Optional[torch.Tensor] = None, 137 | labels: Optional[torch.Tensor] = None, 138 | output_attentions: Optional[bool] = None, 139 | output_hidden_states: Optional[bool] = None, 140 | return_dict: Optional[bool] = None, 141 | **kwargs 142 | ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: 143 | return_dict = ( 144 | return_dict if return_dict is not None else self.config.use_return_dict 145 | ) 146 | outputs = self.bert( 147 | input_ids, 148 | attention_mask=attention_mask, 149 | token_type_ids=token_type_ids, 150 | position_ids=position_ids, 151 | head_mask=head_mask, 152 | inputs_embeds=inputs_embeds, 153 | output_attentions=output_attentions, 154 | output_hidden_states=output_hidden_states, 155 | return_dict=return_dict, 156 | ) 157 | # pooled_output = outputs[1] 158 | 159 | # pooled_output = self.dropout(pooled_output) 160 | # logits = self.classifier(pooled_output) 161 | 162 | 163 | sequence_output = outputs[0] 164 | sequence_output = sequence_output.view(-1, self.num_sequences, sequence_output.shape[1], self.config.hidden_size) # (bs, num_seq, max_length, hidden_size) 165 | 166 | # mean pooling 167 | # sequence_output = sequence_output.mean(dim=1) # (bs, max_length, hidden_size) 168 | # logits = self.classifier(sequence_output) 169 | 170 | # attention pooling 171 | sequence_output = sequence_output[:, :, 0, :].squeeze() # get [CLS], (bs, num_seq, hidden_size) 172 | attention_scores = torch.matmul(sequence_output, self.attn_vector) 173 | extended_sequence_mask = (1.0 - sequence_mask) * torch.finfo(self.dtype).min 174 | attention_scores = attention_scores + extended_sequence_mask 175 | attention_probs = F.softmax(attention_scores, dim=-1) 176 | attention_output = torch.matmul(attention_probs.unsqueeze(1), sequence_output) # (bs, num_seq, 1, hidden_size) 177 | logits = self.classifier(attention_output) 178 | 179 | 180 | loss = None 181 | if labels is not None: 182 | labels = labels.to(logits.device) 183 | if self.config.problem_type is None: 184 | if self.num_labels == 1: 185 | self.config.problem_type = "regression" 186 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 187 | self.config.problem_type = "single_label_classification" 188 | else: 189 | self.config.problem_type = "multi_label_classification" 190 | 191 | if self.config.problem_type == "regression": 192 | loss_fct = nn.MSELoss() 193 | if self.num_labels == 1: 194 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 195 | else: 196 | loss = loss_fct(logits, labels) 197 | elif self.config.problem_type == "single_label_classification": 198 | loss_fct = nn.CrossEntropyLoss() 199 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 200 | elif self.config.problem_type == "multi_label_classification": 201 | loss_fct = nn.BCEWithLogitsLoss() 202 | loss = loss_fct(logits, labels) 203 | if not return_dict: 204 | output = (logits,) + outputs[2:] 205 | return ((loss,) + output) if loss is not None else output 206 | 207 | return SequenceClassifierOutput( 208 | loss=loss, 209 | logits=logits, 210 | hidden_states=outputs.hidden_states, 211 | attentions=outputs.attentions, 212 | ) 213 | 214 | 215 | class JTransForTokenClassification(PreTrainedModel): 216 | def __init__(self, config): 217 | super().__init__(config) 218 | self.num_labels = config.num_labels 219 | 220 | self.roberta = JTransModel(config, add_pooling_layer=False) 221 | classifier_dropout = ( 222 | config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob 223 | ) 224 | self.dropout = nn.Dropout(classifier_dropout) 225 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 226 | 227 | # Initialize weights and apply final processing 228 | self.post_init() 229 | 230 | def forward( 231 | self, 232 | input_ids: Optional[torch.LongTensor] = None, 233 | attention_mask: Optional[torch.FloatTensor] = None, 234 | token_type_ids: Optional[torch.LongTensor] = None, 235 | position_ids: Optional[torch.LongTensor] = None, 236 | head_mask: Optional[torch.FloatTensor] = None, 237 | inputs_embeds: Optional[torch.FloatTensor] = None, 238 | labels: Optional[torch.LongTensor] = None, 239 | output_attentions: Optional[bool] = None, 240 | output_hidden_states: Optional[bool] = None, 241 | return_dict: Optional[bool] = None, 242 | ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: 243 | r""" 244 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 245 | Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. 246 | """ 247 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 248 | 249 | outputs = self.roberta( 250 | input_ids, 251 | attention_mask=attention_mask, 252 | token_type_ids=token_type_ids, 253 | position_ids=position_ids, 254 | head_mask=head_mask, 255 | inputs_embeds=inputs_embeds, 256 | output_attentions=output_attentions, 257 | output_hidden_states=output_hidden_states, 258 | return_dict=return_dict, 259 | ) 260 | 261 | sequence_output = outputs[0] 262 | 263 | sequence_output = self.dropout(sequence_output) 264 | logits = self.classifier(sequence_output) 265 | 266 | loss = None 267 | if labels is not None: 268 | # move labels to correct device to enable model parallelism 269 | labels = labels.to(logits.device) 270 | loss_fct = nn.CrossEntropyLoss() 271 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 272 | 273 | if not return_dict: 274 | output = (logits,) + outputs[2:] 275 | return ((loss,) + output) if loss is not None else output 276 | 277 | return TokenClassifierOutput( 278 | loss=loss, 279 | logits=logits, 280 | hidden_states=outputs.hidden_states, 281 | attentions=outputs.attentions, 282 | ) 283 | -------------------------------------------------------------------------------- /codeart/code/models/tokenization_codeart.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from typing import List, Dict, Optional, Tuple 4 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 5 | 6 | 7 | class CodeArtTokenizer(PreTrainedTokenizerFast): 8 | 9 | inst_token = '' 10 | maskbuilder = None 11 | local_patterns=True 12 | 13 | def inst_encode( 14 | self, 15 | code: List[Tuple[int, str]], 16 | data_dep: List[Tuple[int, int]], 17 | max_transitions=None, 18 | return_extra_info=False, 19 | ): 20 | # tokenization & locate tokens 21 | tokens, instruction_node_positions = [], [] 22 | special_tokens_mask = [] # for MLM 23 | tokens.append(self.cls_token) 24 | special_tokens_mask.append(1) 25 | 26 | for inst_id, instruction in code: 27 | if len(tokens) >= self.model_max_length: 28 | break 29 | instruction_tokens = self.tokenize(instruction) 30 | instruction_node_positions.append(len(tokens)) 31 | tokens.append(self.inst_token) 32 | special_tokens_mask.append(1) 33 | tokens += instruction_tokens 34 | special_tokens_mask += [0] * len(instruction_tokens) 35 | assert(len(tokens) == len(special_tokens_mask)) # debug 36 | 37 | # truncation & padding & [SEP] 38 | tokens = tokens[:self.model_max_length - 1] 39 | special_tokens_mask = special_tokens_mask[:self.model_max_length - 1] 40 | assert(len(tokens) == len(special_tokens_mask)) # debug 41 | tokens.append(self.sep_token) 42 | special_tokens_mask.append(1) 43 | if len(tokens) < self.model_max_length: 44 | tokens += [self.pad_token] * (self.model_max_length - len(tokens)) 45 | special_tokens_mask += [1] * (self.model_max_length - len(special_tokens_mask)) 46 | 47 | # print(len(tokens), len(special_tokens_mask)) 48 | assert(len(tokens) == len(special_tokens_mask)) # debug 49 | 50 | # convert tokens to ids 51 | input_ids = self.convert_tokens_to_ids(tokens) 52 | 53 | assert self.maskbuilder 54 | 55 | if not self.local_patterns: 56 | attention_mask, relative_position_matrix = \ 57 | self.maskbuilder.create_attention_mask_and_relative_position_matrix_no_local( 58 | self.model_max_length, 59 | instruction_node_positions, 60 | data_dep, 61 | max_transitions=max_transitions 62 | ) 63 | else: 64 | attention_mask, relative_position_matrix = \ 65 | self.maskbuilder.create_attention_mask_and_relative_position_matrix( 66 | self.model_max_length, 67 | instruction_node_positions, 68 | data_dep, 69 | max_transitions=max_transitions 70 | ) 71 | 72 | if return_extra_info: 73 | return { # NOTE: `attention_mask` needs to be bool to support collator's edge sampling 74 | 'input_ids': torch.tensor(input_ids, dtype=torch.long), \ 75 | 'attention_mask': attention_mask, \ 76 | 'special_tokens_mask': torch.tensor(special_tokens_mask, dtype=torch.long), 77 | 'relative_position_matrix': relative_position_matrix, \ 78 | 'instruction_node_positions': instruction_node_positions 79 | } 80 | else: 81 | return { 82 | 'input_ids': torch.tensor(input_ids, dtype=torch.long), \ 83 | 'attention_mask': attention_mask, \ 84 | 'relative_position_matrix': relative_position_matrix 85 | } 86 | 87 | def batch_inst_encode( 88 | self, 89 | examples, 90 | max_transitions=None, 91 | ): 92 | batch = { 93 | 'input_ids': [], 94 | 'attention_mask': [], 95 | 'relative_position_matrix': [] 96 | } 97 | 98 | for example in examples: 99 | # encoded = self.inst_encode(eval(example['code']), eval(example['data_dep'])) 100 | encoded = self.inst_encode(example['code'], example['data_dep'], max_transitions=max_transitions) 101 | batch['input_ids'].append(encoded['input_ids']) 102 | batch['attention_mask'].append(encoded['attention_mask']) 103 | batch['relative_position_matrix'].append(encoded['relative_position_matrix']) 104 | return { 105 | 'input_ids': torch.stack(batch['input_ids']), 106 | 'attention_mask': torch.stack(batch['attention_mask']), 107 | 'relative_position_matrix': torch.stack(batch['relative_position_matrix']) 108 | } -------------------------------------------------------------------------------- /codeart/code/models/tokenization_rabert.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | from typing import List, Dict, Optional, Tuple 5 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 6 | 7 | 8 | class RabertTokenizer(PreTrainedTokenizerFast): 9 | 10 | inst_token = '' 11 | maskbuilder = None 12 | local_patterns=True 13 | 14 | def inst_encode( 15 | self, 16 | code: List[Tuple[int, str]], 17 | data_dep: List[Tuple[int, int]], 18 | return_extra_info=False, 19 | **kwargs 20 | ): 21 | # tokenization & locate tokens 22 | tokens, instruction_node_positions = [], [] 23 | special_tokens_mask = [] # for MLM 24 | tokens.append(self.cls_token) 25 | special_tokens_mask.append(1) 26 | 27 | for inst_id, instruction in code: 28 | if len(tokens) >= self.model_max_length: 29 | break 30 | instruction_tokens = self.tokenize(instruction) 31 | instruction_node_positions.append(len(tokens)) 32 | tokens.append(self.inst_token) 33 | special_tokens_mask.append(1) 34 | tokens += instruction_tokens 35 | special_tokens_mask += [0] * len(instruction_tokens) 36 | assert(len(tokens) == len(special_tokens_mask)) # debug 37 | 38 | # truncation & padding & [SEP] 39 | tokens = tokens[:self.model_max_length - 1] 40 | special_tokens_mask = special_tokens_mask[:self.model_max_length - 1] 41 | assert(len(tokens) == len(special_tokens_mask)) # debug 42 | tokens.append(self.sep_token) 43 | special_tokens_mask.append(1) 44 | if len(tokens) < self.model_max_length: 45 | tokens += [self.pad_token] * (self.model_max_length - len(tokens)) 46 | special_tokens_mask += [1] * (self.model_max_length - len(special_tokens_mask)) 47 | 48 | # print(len(tokens), len(special_tokens_mask)) 49 | assert(len(tokens) == len(special_tokens_mask)) # debug 50 | 51 | # convert tokens to ids 52 | input_ids = self.convert_tokens_to_ids(tokens) 53 | 54 | # rabert mask (NOTE: currently paddings will all be masked out) 55 | if self.maskbuilder is None: 56 | from modeling_utils import create_attention_mask_aggressive 57 | attention_mask = create_attention_mask_aggressive( 58 | self.model_max_length, 59 | instruction_node_positions, 60 | data_dep 61 | ) 62 | else: 63 | if not self.local_patterns: 64 | attention_mask = self.maskbuilder.create_attention_mask_no_local( 65 | self.model_max_length, 66 | instruction_node_positions, 67 | data_dep 68 | ) 69 | else: 70 | attention_mask = self.maskbuilder.create_attention_mask( 71 | self.model_max_length, 72 | instruction_node_positions, 73 | data_dep 74 | ) 75 | 76 | if return_extra_info: 77 | return { # NOTE: `attention_mask` needs to be bool to support collator's edge sampling 78 | 'input_ids': torch.tensor(input_ids, dtype=torch.long), \ 79 | 'attention_mask': attention_mask, \ 80 | 'special_tokens_mask': torch.tensor(special_tokens_mask, dtype=torch.long), 81 | 'instruction_node_positions': instruction_node_positions 82 | } 83 | else: 84 | return { 85 | 'input_ids': torch.tensor(input_ids, dtype=torch.long), \ 86 | 'attention_mask': torch.tensor(attention_mask, dtype=torch.long) 87 | } 88 | 89 | def batch_inst_encode( 90 | self, 91 | examples 92 | ): 93 | batch = { 94 | 'input_ids': [], 95 | 'attention_mask': [] 96 | } 97 | 98 | for example in examples: 99 | encoded = self.inst_encode(example['code'], example['data_dep']) 100 | batch['input_ids'].append(encoded['input_ids']) 101 | batch['attention_mask'].append(encoded['attention_mask']) 102 | 103 | return { 104 | 'input_ids': torch.stack(batch['input_ids']), 105 | 'attention_mask': torch.stack(batch['attention_mask']) 106 | } 107 | 108 | 109 | 110 | class GCBLikeTokenizer(PreTrainedTokenizerFast): 111 | 112 | inst_token = '' 113 | maskbuilder = None 114 | local_patterns=True 115 | 116 | def inst_encode( 117 | self, 118 | code: List[Tuple[int, str]], 119 | data_dep: List[Tuple[int, int]], 120 | return_extra_info=False, 121 | **kwargs 122 | ): 123 | # tokenization & locate tokens 124 | tokens, instruction_node_positions = [], [] 125 | special_tokens_mask = [] # for MLM 126 | tokens.append(self.cls_token) 127 | special_tokens_mask.append(1) 128 | 129 | for inst_id, instruction in code: 130 | if len(tokens) >= self.model_max_length: 131 | break 132 | instruction_tokens = self.tokenize(instruction) 133 | instruction_node_positions.append(len(tokens)) 134 | tokens.append(self.inst_token) 135 | special_tokens_mask.append(1) 136 | tokens += instruction_tokens 137 | if len(instruction_tokens) > 1: # maybe with operand 138 | tokens[instruction_node_positions[-1]] = instruction_tokens[1] # overwrite with the operand 139 | special_tokens_mask += [0] * len(instruction_tokens) 140 | assert(len(tokens) == len(special_tokens_mask)) # debug 141 | 142 | # truncation & padding & [SEP] 143 | tokens = tokens[:self.model_max_length - 1] 144 | special_tokens_mask = special_tokens_mask[:self.model_max_length - 1] 145 | assert(len(tokens) == len(special_tokens_mask)) # debug 146 | tokens.append(self.sep_token) 147 | special_tokens_mask.append(1) 148 | if len(tokens) < self.model_max_length: 149 | tokens += [self.pad_token] * (self.model_max_length - len(tokens)) 150 | special_tokens_mask += [1] * (self.model_max_length - len(special_tokens_mask)) 151 | 152 | # print(len(tokens), len(special_tokens_mask)) 153 | assert(len(tokens) == len(special_tokens_mask)) # debug 154 | # convert tokens to ids 155 | input_ids = self.convert_tokens_to_ids(tokens) 156 | 157 | # gcb-like mask 158 | assert self.maskbuilder 159 | attention_mask = self.maskbuilder.create_attention_mask( 160 | self.model_max_length, 161 | instruction_node_positions, 162 | data_dep 163 | ) 164 | 165 | if return_extra_info: 166 | return { # NOTE: `attention_mask` needs to be bool to support collator's edge sampling 167 | 'input_ids': torch.tensor(input_ids, dtype=torch.long), \ 168 | 'attention_mask': attention_mask, \ 169 | 'special_tokens_mask': torch.tensor(special_tokens_mask, dtype=torch.long), 170 | 'instruction_node_positions': instruction_node_positions 171 | } 172 | else: 173 | return { 174 | 'input_ids': torch.tensor(input_ids, dtype=torch.long), \ 175 | 'attention_mask': torch.tensor(attention_mask, dtype=torch.long) 176 | } 177 | 178 | def batch_inst_encode( 179 | self, 180 | examples 181 | ): 182 | batch = { 183 | 'input_ids': [], 184 | 'attention_mask': [] 185 | } 186 | 187 | for example in examples: 188 | encoded = self.inst_encode(example['code'], example['data_dep']) 189 | batch['input_ids'].append(encoded['input_ids']) 190 | batch['attention_mask'].append(encoded['attention_mask']) 191 | 192 | return { 193 | 'input_ids': torch.stack(batch['input_ids']), 194 | 'attention_mask': torch.stack(batch['attention_mask']) 195 | } -------------------------------------------------------------------------------- /codeart/evaluation-jtrans/malware-family-classification/config/eval.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "/data3/share/ziansu/jtrans-malware-4f-100c-1/checkpoint-4000/", 3 | "max_functions": 4, 4 | "dataset_name": "PurCL/malware-top-100-jtrans", 5 | "output_dir": "../save/jtrans-xxx", 6 | 7 | "max_seq_length": 512, 8 | 9 | "use_auth_token": true, 10 | "dataloader_num_workers": 2, 11 | "remove_unused_columns": false, 12 | 13 | "do_train": true, 14 | "do_eval": true, 15 | "do_predict": true, 16 | 17 | "per_device_train_batch_size": 4, 18 | "gradient_accumulation_steps": 1, 19 | "per_device_eval_batch_size": 8, 20 | 21 | "num_train_epochs": 5, 22 | "learning_rate": 5e-5, 23 | "evaluation_strategy": "steps", 24 | "eval_steps": 100, 25 | "save_steps": 1000, 26 | "logging_steps": 10, 27 | 28 | "report_to": "tensorboard", 29 | "cache_dir": "../save/.cache", 30 | 31 | "overwrite_output_dir": true 32 | } -------------------------------------------------------------------------------- /codeart/evaluation-jtrans/malware-family-classification/config/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "../../save/jtrans/models/jTrans-pretrain/", 3 | 4 | "dataset_name": "PurCL/malware-top-100-jtrans", 5 | "output_dir": "../save/jtrans-malware-2f-100c", 6 | "max_functions": 2, 7 | "overwrite_cache": true, 8 | "max_seq_length": 512, 9 | 10 | "use_auth_token": true, 11 | "dataloader_num_workers": 2, 12 | "remove_unused_columns": false, 13 | 14 | "do_train": true, 15 | "do_eval": true, 16 | "do_predict": true, 17 | 18 | "per_device_train_batch_size": 4, 19 | "gradient_accumulation_steps": 1, 20 | "per_device_eval_batch_size": 8, 21 | 22 | "num_train_epochs": 10, 23 | "learning_rate": 1e-4, 24 | "evaluation_strategy": "steps", 25 | "eval_steps": 100, 26 | "save_steps": 500, 27 | "logging_steps": 10, 28 | "load_best_model_at_end": true, 29 | "metric_for_best_model": "roc_auc_score", 30 | 31 | "report_to": "tensorboard", 32 | "cache_dir": "../save/.cache", 33 | 34 | "overwrite_output_dir": true, 35 | "no_cuda": false 36 | } -------------------------------------------------------------------------------- /codeart/evaluation-jtrans/malware-family-classification/eval_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=purcl 5 | WANDB_project=malware-family-classification 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | export CUDA_VISIBLE_DEVICES=2 22 | # python temp_evaluate.py $CURRENT_DIR/$CONFIG 23 | python evaluate_multilabel.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation-jtrans/malware-family-classification/run_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=purcl 5 | WANDB_project=malware-family-classification 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | export CUDA_VISIBLE_DEVICES=3,7 22 | # python run.py $CURRENT_DIR/$CONFIG 23 | # torchrun --nproc_per_node=2 run.py $CURRENT_DIR/$CONFIG 24 | torchrun --nproc_per_node=2 run_multilabel.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation-jtrans/malware-family-classification/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class DataCollatorForMFC(object): 6 | 7 | def __init__(self, tokenizer, label2id, max_functions): 8 | self.tokenizer = tokenizer 9 | self.label2id = label2id 10 | self.max_functions = max_functions 11 | 12 | def __call__( 13 | self, 14 | examples 15 | ): 16 | 17 | batch = { 18 | 'input_ids': [], 19 | 'token_type_ids': [], 20 | 'attention_mask': [], 21 | 'labels': [], 22 | 'sequence_mask': [], 23 | 'all_labels': [] 24 | } 25 | 26 | for example in examples: 27 | num_functions = 0 28 | sequence_mask = [] 29 | for function in eval(example['functions'])[:self.max_functions]: 30 | funcstr = function['jtrans_function_string'] 31 | encoded = self.tokenizer(funcstr, add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_tensors='pt') 32 | batch['input_ids'].append(encoded['input_ids'].squeeze(0)) 33 | batch['token_type_ids'].append(encoded['token_type_ids'].squeeze(0)) 34 | batch['attention_mask'].append(encoded['attention_mask'].squeeze(0)) 35 | num_functions += 1 36 | sequence_mask.append(1) 37 | for _ in range(num_functions, self.max_functions): 38 | funcstr = function['jtrans_function_string'] 39 | encoded = self.tokenizer(funcstr, add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_tensors='pt') 40 | batch['input_ids'].append(encoded['input_ids'].squeeze(0)) 41 | batch['token_type_ids'].append(encoded['token_type_ids'].squeeze(0)) 42 | batch['attention_mask'].append(encoded['attention_mask'].squeeze(0)) 43 | sequence_mask.append(0) 44 | batch['labels'].append(self.label2id[random.sample(example['labels'], 1)[0]]) 45 | batch["sequence_mask"].append(sequence_mask) 46 | batch["all_labels"].append(example['labels']) 47 | 48 | return { 49 | 'input_ids': torch.stack(batch['input_ids']), 50 | 'token_type_ids': torch.stack(batch['token_type_ids']), 51 | 'attention_mask': torch.stack(batch['attention_mask']), 52 | 'labels': torch.tensor(batch['labels'], dtype=torch.long).unsqueeze(1), 53 | 'sequence_mask': torch.tensor(batch['sequence_mask'], dtype=torch.long), 54 | 'all_labels': batch['all_labels'] 55 | } 56 | 57 | 58 | class DataCollatorForMFCMultilabel(object): 59 | 60 | def __init__(self, tokenizer, label2id, max_functions): 61 | self.tokenizer = tokenizer 62 | self.label2id = label2id 63 | self.max_functions = max_functions 64 | 65 | def __call__( 66 | self, 67 | examples 68 | ): 69 | 70 | batch = { 71 | 'input_ids': [], 72 | 'token_type_ids': [], 73 | 'attention_mask': [], 74 | 'labels': [], 75 | 'sequence_mask': [], 76 | } 77 | 78 | for example in examples: 79 | num_functions = 0 80 | sequence_mask = [] 81 | for function in eval(example['functions'])[:self.max_functions]: 82 | funcstr = function['jtrans_function_string'] 83 | encoded = self.tokenizer(funcstr, add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_tensors='pt') 84 | batch['input_ids'].append(encoded['input_ids'].squeeze(0)) 85 | batch['token_type_ids'].append(encoded['token_type_ids'].squeeze(0)) 86 | batch['attention_mask'].append(encoded['attention_mask'].squeeze(0)) 87 | num_functions += 1 88 | sequence_mask.append(1) 89 | for _ in range(num_functions, self.max_functions): 90 | funcstr = function['jtrans_function_string'] 91 | encoded = self.tokenizer(funcstr, add_special_tokens=True,max_length=512,padding='max_length',truncation=True,return_tensors='pt') 92 | batch['input_ids'].append(encoded['input_ids'].squeeze(0)) 93 | batch['token_type_ids'].append(encoded['token_type_ids'].squeeze(0)) 94 | batch['attention_mask'].append(encoded['attention_mask'].squeeze(0)) 95 | sequence_mask.append(0) 96 | labels = torch.zeros(len(self.label2id)) 97 | for l in example['labels']: 98 | labels[self.label2id[l]] = 1 99 | batch['labels'].append(labels) 100 | batch["sequence_mask"].append(sequence_mask) 101 | 102 | return { 103 | 'input_ids': torch.stack(batch['input_ids']), 104 | 'token_type_ids': torch.stack(batch['token_type_ids']), 105 | 'attention_mask': torch.stack(batch['attention_mask']), 106 | 'labels': torch.stack(batch['labels']), 107 | 'sequence_mask': torch.tensor(batch['sequence_mask'], dtype=torch.long), 108 | } -------------------------------------------------------------------------------- /codeart/evaluation-jtrans/provenance-attribution/run.sh: -------------------------------------------------------------------------------- 1 | TOKENIZERS_PARALLELISM=false 2 | export CUDA_VISIBLE_DEVICES=7 3 | 4 | MODEL_NAME_OR_PATH=../../save/jtrans/models/jTrans-pretrain/ 5 | # MODEL_NAME_OR_PATH=../../save/jtrans/models/jTrans-finetune/ 6 | 7 | python run.py \ 8 | --model_name_or_path $MODEL_NAME_OR_PATH \ 9 | --dataset_name PurCL/binkit-jtrans-all \ 10 | --use_auth_token \ 11 | --dataloader_num_workers 2 \ 12 | --max_train_samples 100 \ 13 | --max_eval_samples 10000 \ 14 | --max_predict_samples 10000 \ 15 | --max_seq_length 512 \ 16 | --labels O0,O1,O2,O3 \ 17 | --remove_unused_columns False \ 18 | --do_train \ 19 | --do_eval \ 20 | --do_predict \ 21 | --per_device_train_batch_size 8 \ 22 | --gradient_accumulation_steps 1 \ 23 | --per_device_eval_batch_size 32 \ 24 | --num_train_epochs 5 \ 25 | --learning_rate 5e-5 \ 26 | --evaluation_strategy steps \ 27 | --eval_steps 2 \ 28 | --save_steps 100 \ 29 | --logging_step 2 \ 30 | --report_to tensorboard \ 31 | --cache_dir ../save/.cache \ 32 | --output_dir ../save/jtrans-pa \ 33 | --overwrite_output_dir \ 34 | --load_best_model_at_end=True \ 35 | --metric_for_best_model eval_loss \ 36 | --load_best_model_at_end 37 | # --overwrite_cache \ 38 | # --warmup_steps 1000 \ -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/.gitignore: -------------------------------------------------------------------------------- 1 | report_*pool*.txt 2 | cache/** 3 | output/** -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/README.md: -------------------------------------------------------------------------------- 1 | # Binary Similarity 2 | 3 | This file illustrates how we evaluate CodeArt on the binary similarity task. 4 | Please first download the checkpoint for CodeArt as specified in the `README.md` in the root directory of this repository. 5 | 6 | ## Download Other Files 7 | 8 | Please download the preprocessed test data from [this google drive link](https://drive.google.com/file/d/1KSJMVtSgoI5bBMx9xhQtvBadUU3rehER/view?usp=share_link) to `evaluation/binary-similarity/cache` and unzip it. 9 | 10 | Your directory should look like this: 11 | 12 | ``` 13 | evaluation/binary-similarity 14 | |--cache 15 | |--binary_clone_detection 16 | |--binutilsh-pool.id 17 | |--... 18 | ``` 19 | 20 | ## Finetuned Model 21 | 22 | To provide a quick validation of CodeArt, we provide a finetuned model on the binary similarity task. 23 | Please download the finetuned model from [this google drive link](https://drive.google.com/file/d/1FF1BS4kXkkB6561CV63GwruumPsGvgF6/view?usp=share_link) to `evaluation/save/codeart-binsim`. 24 | 25 | Your directory should look like this: 26 | 27 | ``` 28 | evaluation/save 29 | |--codeart-binsim 30 | |--checkpoint-4000 31 | |--pytorch_model.bin 32 | |--... 33 | ``` 34 | 35 | ## Evaluation 36 | 37 | The evaluation has three steps. 38 | The first step is to encode the test data to their embeddings. 39 | In the second step, the script constructs candidate function pools with different sizes, 40 | and randomly picks functions to query the candidate pools. 41 | In the third step, the script reports the results averaged over multiple runs. 42 | 43 | For the encoding step, please run `encode.sh`. 44 | 45 | For the evaluation step, please run `sample_and_report.sh`. 46 | 47 | The raw results are stored in `report_ckpt-4k--.txt`. 48 | 49 | For the report step, please use `pretty_print_all.py` to generate a human-readable report. 50 | Specifically, please run `python3 pretty_print_all.py report_ckpt-4k`. 51 | 52 | ## Finetuning 53 | 54 | Interested readers can finetune CodeArt on the binary similarity task by running `./run_config.sh config/train.json`. 55 | Please fill in your wandb information in `run_config.sh` before running the script. 56 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/binsim_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | class BinSimDataset(torch.utils.data.Dataset): 5 | 6 | def __init__(self, args, raw_ds): 7 | self.args = args 8 | self.raw_ds = raw_ds 9 | 10 | def __len__(self): 11 | return len(self.raw_ds) 12 | 13 | def __getitem__(self, idx): 14 | # 1 positive sample, 1 negative samples 15 | selected = self.raw_ds[idx%len(self.raw_ds)] 16 | first_selected_function = selected['functions'][0] 17 | # randomly pick an int from 0 to len(selected['functions']) - 1 18 | random_idx = random.randint(1, len(selected['functions']) - 1) 19 | pos_selected_function = selected['functions'][random_idx] 20 | # randomly pick an int from 0 to len(self.raw_ds) - 1 21 | random_idx = random.randint(0, len(self.raw_ds) - 1) 22 | while random_idx == idx%len(self.raw_ds): 23 | random_idx = random.randint(0, len(self.raw_ds) - 1) 24 | neg_selected_function = random.choice(self.raw_ds[random_idx]['functions']) 25 | return { 26 | 'functions': [first_selected_function, pos_selected_function, neg_selected_function] 27 | } 28 | 29 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/binsim_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from transformers import Trainer, TrainingArguments 3 | from transformers.trainer_utils import ( 4 | EvalPrediction, 5 | has_length, 6 | speed_metrics 7 | ) 8 | from transformers.modeling_utils import ( 9 | PreTrainedModel, 10 | ) 11 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 12 | from transformers.trainer_callback import ( 13 | TrainerCallback, 14 | ) 15 | from transformers.trainer_pt_utils import ( 16 | find_batch_size 17 | ) 18 | from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint 19 | import torch 20 | from typing import Callable, Dict, List, Optional, Tuple 21 | from torch.utils.data.dataset import Dataset 22 | from transformers.utils import logging 23 | import numpy as np 24 | import torch.nn.functional as F 25 | from models.modeling_codeart import CodeArtForBinSim 26 | 27 | logger = logging.get_logger(__name__) 28 | 29 | 30 | class BinSimTrainer(Trainer): 31 | def __init__( 32 | self, 33 | model: CodeArtForBinSim = None, 34 | args: TrainingArguments = None, 35 | data_collator: Optional[Callable] = None, 36 | train_dataset: Optional[Dataset] = None, 37 | eval_dataset: Optional[Dataset] = None, 38 | tokenizer: Optional[PreTrainedTokenizerBase] = None, 39 | model_init: Optional[Callable[[], PreTrainedModel]] = None, 40 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 41 | callbacks: Optional[List[TrainerCallback]] = None, 42 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( 43 | None, 44 | None, 45 | ), 46 | preprocess_logits_for_metrics: Optional[ 47 | Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 48 | ] = None, 49 | ): 50 | super(BinSimTrainer, self).__init__( 51 | model=model, 52 | args=args, 53 | data_collator=data_collator, 54 | train_dataset=train_dataset, 55 | eval_dataset=eval_dataset, 56 | tokenizer=tokenizer, 57 | model_init=model_init, 58 | compute_metrics=compute_metrics, 59 | callbacks=callbacks, 60 | optimizers=optimizers, 61 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 62 | ) 63 | 64 | def evaluate( 65 | self, 66 | eval_dataset: Optional[Dataset] = None, 67 | ignore_keys: Optional[List[str]] = None, 68 | metric_key_prefix: str = "valid", 69 | ) -> Dict[str, float]: 70 | # # # # # # # # # # # # # # # # # # # # # # # # # 71 | # 72 | # BEGIN MAGIC 73 | # 74 | # # # # # # # # # # # # # # # # # # # # # # # # # 75 | self._memory_tracker.start() 76 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 77 | start_time = time.time() 78 | if self.is_deepspeed_enabled and self.deepspeed is None: 79 | _, _ = deepspeed_init(self, num_training_steps=0, inference=True) 80 | 81 | model = self._wrap_model(self.model, training=False, dataloader=eval_dataloader) 82 | 83 | if len(self.accelerator._models) == 0 and model is self.model: 84 | model = ( 85 | self.accelerator.prepare(model) 86 | if self.is_deepspeed_enabled 87 | else self.accelerator.prepare_model(model, evaluation_mode=True) 88 | ) 89 | 90 | if self.is_fsdp_enabled: 91 | self.model = model 92 | 93 | # for the rest of this function `model` is the outside model, whether it was wrapped or not 94 | if model is not self.model: 95 | self.model_wrapped = model 96 | 97 | # backward compatibility 98 | if self.is_deepspeed_enabled: 99 | self.deepspeed = self.model_wrapped 100 | if not self.is_in_train: 101 | if self.args.fp16_full_eval: 102 | model = model.to(dtype=torch.float16, device=self.args.device) 103 | elif self.args.bf16_full_eval: 104 | model = model.to(dtype=torch.bfloat16, device=self.args.device) 105 | 106 | batch_size = self.args.eval_batch_size 107 | 108 | logger.info(f"***** Running Validation *****") 109 | if has_length(eval_dataloader): 110 | logger.info(f" Num examples = {self.num_examples(eval_dataloader)}") 111 | else: 112 | logger.info(" Num examples: Unknown") 113 | logger.info(f" Batch size = {batch_size}") 114 | model.eval() 115 | self.callback_handler.eval_dataloader = eval_dataloader 116 | # Do this before wrapping. 117 | eval_dataset = getattr(eval_dataloader, "dataset", None) 118 | 119 | if self.args.past_index >= 0: 120 | self._past = None 121 | 122 | # # # # # # # # # # # # # # # # # # # # # # # # # 123 | # 124 | # END MAGIC 125 | # 126 | # # # # # # # # # # # # # # # # # # # # # # # # # 127 | # torch.cuda.empty_cache() 128 | pr1_all=[] 129 | mrr_all=[] 130 | pr1_wo_pooler_all=[] 131 | mrr_wo_pooler_all=[] 132 | observed_batch_size = 0 133 | for step, inputs in enumerate(eval_dataloader): 134 | if observed_batch_size == 0: 135 | observed_batch_size = find_batch_size(inputs) 136 | pr1, mrr, pr1_wo_pooler, mrr_wo_pooler = self.validate(model, inputs) 137 | pr1_all.append(pr1) 138 | mrr_all.append(mrr) 139 | pr1_wo_pooler_all.append(pr1_wo_pooler) 140 | mrr_wo_pooler_all.append(mrr_wo_pooler) 141 | self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) 142 | 143 | pr1_all = np.array(pr1_all) 144 | mrr_all = np.array(mrr_all) 145 | pr1_wo_pooler_all = np.array(pr1_wo_pooler_all) 146 | mrr_wo_pooler_all = np.array(mrr_wo_pooler_all) 147 | pr1 = np.mean(pr1_all) 148 | mrr = np.mean(mrr_all) 149 | pr1_wo_pooler = np.mean(pr1_wo_pooler_all) 150 | mrr_wo_pooler = np.mean(mrr_wo_pooler_all) 151 | 152 | ret_metrics = {"valid_pr1": pr1, "valid_mrr": mrr, 153 | "valid_pr1_wo_pooler": pr1_wo_pooler, 154 | "valid_mrr_wo_pooler": mrr_wo_pooler, 155 | 'batch_size': observed_batch_size} 156 | ret_metrics.update( 157 | speed_metrics( 158 | metric_key_prefix, 159 | start_time, 160 | num_samples=self.num_examples(eval_dataloader), 161 | num_steps=step 162 | ) 163 | ) 164 | self.log(ret_metrics) 165 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, ret_metrics) 166 | self._memory_tracker.stop_and_update_metrics(ret_metrics) 167 | return ret_metrics 168 | 169 | def validate(self, model, valid_data): 170 | valid_data = self._prepare_inputs(valid_data) 171 | batch_size = valid_data["input_ids"].shape[0] 172 | view_size = valid_data["input_ids"].shape[1] 173 | 174 | with torch.no_grad(): 175 | outputs = model(**valid_data) 176 | embs = outputs.embeddings 177 | embs_wo_pooler = outputs.embs_wo_pooler 178 | # reshape back to batch_size * view_size 179 | embs = embs.view(batch_size, view_size, -1) 180 | # take the first emb in each batch 181 | anchor_embs = embs[:, 0, :] 182 | # emb_pool 183 | emb_pool = embs[:, 1, :] 184 | # normalize both 185 | anchor_embs_normalized = F.normalize(anchor_embs, dim=-1) 186 | emb_pool_normalized = F.normalize(emb_pool, dim=-1) 187 | sim = torch.matmul(anchor_embs_normalized, emb_pool_normalized.transpose(0, 1)) 188 | rank = [] 189 | for i in range(batch_size): 190 | rank.append(torch.sum(sim[i, :] > sim[i, i]).item()) 191 | rank_arr = np.array(rank) 192 | mrr = np.mean(1 / (rank_arr + 1)) 193 | pr1 = np.mean(rank_arr == 0) 194 | 195 | embs_wo_pooler = embs_wo_pooler.view(batch_size, view_size, -1) 196 | anchor_embs_wo_pooler = embs_wo_pooler[:, 0, :] 197 | emb_pool_wo_pooler = embs_wo_pooler[:, 1, :] 198 | anchor_embs_wo_pooler_normalized = F.normalize(anchor_embs_wo_pooler, dim=-1) 199 | emb_pool_wo_pooler_normalized = F.normalize(emb_pool_wo_pooler, dim=-1) 200 | sim_wo_pooler = torch.matmul(anchor_embs_wo_pooler_normalized, emb_pool_wo_pooler_normalized.transpose(0, 1)) 201 | rank_wo_pooler = [] 202 | for i in range(batch_size): 203 | rank_wo_pooler.append(torch.sum(sim_wo_pooler[i, :] > sim_wo_pooler[i, i]).item()) 204 | rank_wo_pooler_arr = np.array(rank_wo_pooler) 205 | mrr_wo_pooler = np.mean(1 / (rank_wo_pooler_arr + 1)) 206 | pr1_wo_pooler = np.mean(rank_wo_pooler_arr == 0) 207 | return pr1, mrr, pr1_wo_pooler, mrr_wo_pooler 208 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/config/train.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "PurCL/bincorp-3m-binsim", 3 | "model_name_or_path": "../../checkpoints/codeart-26m", 4 | "output_dir": "../save/codeart-binsim", 5 | "remove_unused_columns": false, 6 | "overwrite_output_dir": true, 7 | "do_train": true, 8 | "do_eval": true, 9 | "evaluation_strategy": "steps", 10 | "eval_steps": 1000, 11 | "num_train_epochs": 10, 12 | "save_steps": 2000, 13 | "margin": 0.5, 14 | 15 | "dataloader_num_workers": 2, 16 | "per_device_train_batch_size": 16, 17 | "per_device_eval_batch_size": 128, 18 | "gradient_accumulation_steps": 2, 19 | "learning_rate": 1e-5, 20 | 21 | "logging_steps": 10, 22 | "report_to": "wandb", 23 | 24 | "masking_enable_global_memory_patterns": true, 25 | "masking_enable_bridge_patterns": false, 26 | "masking_enable_graph_patterns": true, 27 | "masking_enable_local_patterns": true, 28 | "with_transitive_closure": true, 29 | 30 | "position_embedding_type": "mixed", 31 | "max_relative_position_embeddings": 8, 32 | "max_seq_length": 512 33 | } -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/dump_files.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import json 3 | import numpy as np 4 | import os 5 | import torch 6 | from tqdm import tqdm 7 | from typing import List, Union, Tuple, Optional, Dict 8 | 9 | from datasets import load_dataset 10 | 11 | def dump_files_jsonl( 12 | source_dataset_path, 13 | target_dataset_path, 14 | output_prefix, 15 | swap=False, 16 | ): 17 | func2id = {} 18 | 19 | with open(source_dataset_path, 'r') as f: 20 | source_dataset = [json.loads(line.strip()) for line in f.readlines()] 21 | for i, js in enumerate(source_dataset): 22 | func2id[(js['metadata']['project_name'], js['metadata']['function_name'])] = i 23 | 24 | with open(target_dataset_path, 'r') as f: 25 | target_dataset = [json.loads(line.strip()) for line in f.readlines()] 26 | 27 | query_ids, queries = [], [] 28 | pool_ids, pool = [], [] 29 | success = 0 30 | for example in target_dataset: 31 | func_id = (example['metadata']['project_name'], example['metadata']['function_name']) 32 | try: 33 | source_id = func2id[func_id] 34 | query_ids.append(func_id) 35 | queries.append(source_dataset[source_id]) 36 | pool_ids.append(func_id) 37 | pool.append(example) 38 | success += 1 39 | except: 40 | pass 41 | print(f"success/source/target: {success}/{len(source_dataset)}/{len(target_dataset)}") 42 | 43 | if swap: 44 | query_ids, pool_ids = pool_ids, query_ids 45 | queries, pool = pool, queries 46 | 47 | # dump queries and pool 48 | oracle_file = {} # NOTE: not necessay right now, modify under strict evaluation 49 | with open("%s-query.id"%output_prefix, 'w') as f: 50 | for query_id in query_ids: 51 | f.write(json.dumps(query_id) + '\n') 52 | with open("%s-query.jsonl"%output_prefix, 'w') as f: 53 | for query in queries: 54 | f.write(json.dumps(query) + '\n') 55 | with open("%s-pool.id"%output_prefix, 'w') as f: 56 | for pool_id in pool_ids: 57 | f.write(json.dumps(pool_id) + '\n') 58 | with open("%s-pool.jsonl"%output_prefix, 'w') as f: 59 | for pool_binary in pool: 60 | f.write(json.dumps(pool_binary) + '\n') 61 | 62 | 63 | if __name__ == '__main__': 64 | import sys 65 | sys.path.append('../../code/') 66 | import argparse 67 | 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument( 70 | '--src_dataset_path', type=str, default='binsim-dataset/coreutilsh-O0.jsonl', 71 | help='source dataset path' 72 | ) 73 | parser.add_argument( 74 | '--tgt_dataset_path', type=str, default='binsim-dataset/coreutilsh-O3.jsonl', 75 | help='target dataset path' 76 | ) 77 | parser.add_argument( 78 | '--output_prefix', type=str, default='cache/binary_clone_detection/coreutilsh', 79 | help='output prefix' 80 | ) 81 | 82 | args = parser.parse_args() 83 | 84 | dump_files_jsonl( 85 | source_dataset_path=args.src_dataset_path, 86 | target_dataset_path=args.tgt_dataset_path, 87 | output_prefix=args.output_prefix 88 | ) -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/encode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | 3 | function encode(){ 4 | model="--model_name_or_path $1" 5 | python3 inference.py $model \ 6 | --masking_enable_global_memory_patterns true \ 7 | --masking_enable_bridge_patterns false \ 8 | --masking_enable_graph_patterns true \ 9 | --masking_enable_local_patterns true \ 10 | --with_transitive_closure true \ 11 | --position_embedding_type mixed \ 12 | --max_relative_position_embeddings 8 \ 13 | --normalize_embed true \ 14 | --batch_size 48 \ 15 | --source_file cache/binary_clone_detection/$3-query.jsonl \ 16 | --target_file cache/binary_clone_detection/$3-pool.jsonl \ 17 | --source_embed_save_file output/$3-src_$2.npy \ 18 | --target_embed_save_file output/$3-tgt_$2.npy \ 19 | --zero_shot false \ 20 | --top_k 1 21 | } 22 | 23 | 24 | function encode_benchmarks(){ 25 | encode $1 $2 coreutilsh 26 | encode $1 $2 binutilsh 27 | encode $1 $2 libcurlh 28 | encode $1 $2 libmagickh 29 | encode $1 $2 opensslh 30 | encode $1 $2 libsqlh 31 | encode $1 $2 puttyh 32 | } 33 | 34 | 35 | mkdir -p output 36 | 37 | encode_benchmarks ../save/codeart-binsim/checkpoint-4000 ckpt-4k 38 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | from collections import OrderedDict 4 | import json 5 | import numpy as np 6 | from typing import Dict 7 | 8 | 9 | TOP_K = [1, 3, 5, 10] 10 | 11 | 12 | def mrr(gold, pred): 13 | ranks = [] 14 | for g, p in zip(gold, pred): 15 | try: 16 | r = p['retrieved'].index(g) 17 | ranks.append(1 / (r + 1)) 18 | except ValueError: 19 | ranks.append(0) 20 | 21 | return np.mean(ranks) 22 | 23 | 24 | def recall(gold, pred, top_k=None): 25 | top_k = TOP_K if top_k is None else top_k 26 | recall_n = {x: 0 for x in top_k} 27 | 28 | for g, p in zip(gold, pred): 29 | 30 | for k in top_k: 31 | candidates = p['retrieved'][: k] 32 | recall_n[k] += 1 if g in candidates else 0 33 | 34 | recall_n = {k: v / len(pred) for k, v in recall_n.items()} 35 | 36 | return recall_n 37 | 38 | 39 | def eval_from_dict(results: Dict): 40 | 41 | gold, pred = [], [] 42 | 43 | for k, v in results.items(): 44 | gold.append(k) 45 | pred.append(v) 46 | 47 | metrics = { 48 | 'recall': recall(gold, pred), 49 | 'mrr': mrr(gold, pred) 50 | } 51 | 52 | return metrics 53 | 54 | 55 | def eval_from_file(result_file): 56 | gold, pred = [], [] 57 | 58 | with open(result_file, 'r') as f: 59 | results = json.load(f) 60 | 61 | for k, v in results.items(): 62 | gold.append(k) 63 | pred.append(v) 64 | 65 | metrics = { 66 | 'recall': recall(gold, pred), 67 | 'mrr': mrr(gold, pred) 68 | } 69 | 70 | print(metrics) 71 | 72 | 73 | if __name__ == '__main__': 74 | eval_from_file('../../save/.cache/binary_clone_detection/retrieval_results.json') -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('../../code') 3 | 4 | from transformers import ( 5 | HfArgumentParser, 6 | set_seed, 7 | ) 8 | from typing import Optional 9 | import shlex 10 | import os 11 | import numpy as np 12 | import json 13 | from dataclasses import dataclass, field 14 | import argparse 15 | import torch 16 | from utils import AverageMeter, BinaryRetriever, BinaryRetrieverForGCBLike 17 | from eval import eval_from_dict, eval_from_file, TOP_K 18 | from modeling_utils import MaskBuilder 19 | from models import ( 20 | RabertConfig, 21 | CodeArtConfig, 22 | RabertModel, 23 | CodeArtModel, 24 | RabertTokenizer, 25 | CodeArtTokenizer, 26 | CodeArtForBinSim, 27 | RabertForBinSim, 28 | GCBLikeTokenizer, 29 | ) 30 | 31 | 32 | 33 | @dataclass 34 | class ModelArguments: 35 | """ 36 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 37 | """ 38 | 39 | model_name_or_path: str = field( 40 | metadata={ 41 | "help": "Path to pretrained model or model identifier from huggingface.co/models"} 42 | ) 43 | config_name: Optional[str] = field( 44 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 45 | ) 46 | tokenizer_name: Optional[str] = field( 47 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 48 | ) 49 | cache_dir: Optional[str] = field( 50 | default=None, 51 | metadata={ 52 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, 53 | ) 54 | use_fast_tokenizer: bool = field( 55 | default=True, 56 | metadata={ 57 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 58 | ) 59 | model_revision: str = field( 60 | default="main", 61 | metadata={ 62 | "help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 63 | ) 64 | use_auth_token: bool = field( 65 | default=False, 66 | metadata={ 67 | "help": ( 68 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 69 | "with private models)." 70 | ) 71 | }, 72 | ) 73 | ignore_mismatched_sizes: bool = field( 74 | default=False, 75 | metadata={ 76 | "help": "Will enable to load a pretrained model whose head dimensions are different."}, 77 | ) 78 | 79 | # position arguments 80 | position_embedding_type: str = field( 81 | default='absolute', 82 | metadata={ 83 | "help": ( 84 | "Type of positional embedding to use, can be either 'absolute' or 'relative'" 85 | ), 86 | "choices": ['absolute', 'mixed'], 87 | } 88 | ) 89 | max_relative_position_embeddings: Optional[int] = field( 90 | default=5, 91 | metadata={ 92 | "help": ( 93 | "Max relative postion distance to consider in relative embeddings" 94 | ) 95 | } 96 | ) 97 | 98 | # masking arguments 99 | masking_preset: Optional[str] = field( 100 | default=None, 101 | metadata={ 102 | "help": ("Preset masking strategy"), 103 | "choices": [None, 'aggressive', 'conservative'], 104 | } 105 | ) 106 | masking_enable_global_memory_patterns: bool = field( 107 | default=True, 108 | metadata={ 109 | "help": ("enable global memory patterns") 110 | } 111 | ) 112 | masking_enable_bridge_patterns: bool = field( 113 | default=True, 114 | metadata={ 115 | "help": ("enable bridge patterns") 116 | } 117 | ) 118 | masking_enable_graph_patterns: bool = field( 119 | default=True, 120 | metadata={ 121 | "help": ("enable graph patterns") 122 | } 123 | ) 124 | masking_enable_local_patterns: bool = field( 125 | default=True, 126 | metadata={ 127 | "help": ("enable local patterns") 128 | } 129 | ) 130 | 131 | with_transitive_closure: bool = field( 132 | default=True, 133 | metadata={ 134 | "help": ("Whether to include transitive closure in the masking process") 135 | } 136 | ) 137 | 138 | max_transitions: Optional[int] = field( 139 | default=None, 140 | ) 141 | 142 | normalize_embed: bool = field( 143 | default=True 144 | ) 145 | 146 | zero_shot: bool = field( 147 | default=False 148 | ) 149 | 150 | gcb_like: bool = field( 151 | default=False 152 | ) 153 | 154 | 155 | @dataclass 156 | class DataArguments: 157 | 158 | batch_size: int = field( 159 | default=16 160 | ) 161 | source_file: str = field( 162 | default='coreutils.clang.O0.jsonl' 163 | ) 164 | target_file: str = field( 165 | default='coreutils.gcc.O3.jsonl' 166 | ) 167 | source_embed_save_file: str = field( 168 | default='src_embedding.npy' 169 | ) 170 | target_embed_save_file: str = field( 171 | default='tgt_embedding.npy' 172 | ) 173 | save_file: str = field( 174 | default='retrieval_results.json' 175 | ) 176 | top_k: int = field( 177 | default=200 178 | ) 179 | cpu: Optional[bool] = field( 180 | default=False 181 | ) 182 | pool_size: int = field( 183 | default=100 184 | ) 185 | 186 | def __post_init__(self): 187 | self.source_idx_file = self.source_file.replace(".jsonl", ".id") 188 | self.target_idx_file = self.target_file.replace(".jsonl", ".id") 189 | 190 | 191 | def main(): 192 | set_seed(42) 193 | parser = HfArgumentParser((ModelArguments, DataArguments)) 194 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 195 | # If we pass only one argument to the script and it's the path to a json file, 196 | # let's parse it to get our arguments. 197 | model_args, data_args = parser.parse_json_file( 198 | json_file=os.path.abspath(sys.argv[1])) 199 | else: 200 | model_args, data_args = parser.parse_args_into_dataclasses() 201 | 202 | if not model_args.with_transitive_closure: 203 | if model_args.gcb_like: 204 | tokenizer = GCBLikeTokenizer.from_pretrained( 205 | model_args.model_name_or_path 206 | ) 207 | 208 | binsim_model = RabertForBinSim.from_pretrained( 209 | model_args.model_name_or_path 210 | ) 211 | model = binsim_model.rabert 212 | 213 | else: 214 | 215 | tokenizer = RabertTokenizer.from_pretrained( 216 | model_args.model_name_or_path 217 | ) 218 | model = RabertModel.from_pretrained( 219 | model_args.model_name_or_path, 220 | ) 221 | else: 222 | tokenizer = CodeArtTokenizer.from_pretrained( 223 | model_args.model_name_or_path 224 | ) 225 | # model = CodeArtModel.from_pretrained( 226 | # model_args.model_name_or_path, 227 | # ) 228 | # loaded_weights = torch.load( 229 | # os.path.join(model_args.model_name_or_path, 'pytorch_model.bin'), 230 | # map_location='cpu' 231 | # ) 232 | binsim_model = CodeArtForBinSim.from_pretrained( 233 | model_args.model_name_or_path 234 | ) 235 | model = binsim_model.codeart 236 | # mask builder 237 | maskbuilder = MaskBuilder( 238 | preset=model_args.masking_preset, 239 | enable_global_memory_patterns=model_args.masking_enable_global_memory_patterns, 240 | enable_bridge_patterns=model_args.masking_enable_bridge_patterns, 241 | enable_graph_patterns=model_args.masking_enable_graph_patterns, 242 | device='cpu' if data_args.cpu else 'cuda' 243 | ) 244 | 245 | tokenizer.add_tokens('') 246 | tokenizer.maskbuilder = maskbuilder 247 | 248 | pooler_method = 'cls' 249 | if not model_args.zero_shot: 250 | pooler_method = 'pooler' 251 | 252 | if not model_args.gcb_like: 253 | searcher = BinaryRetriever 254 | print("\033[93m searcher = BinaryRetriever \033[0m") 255 | else: 256 | searcher = BinaryRetrieverForGCBLike 257 | print("\033[93m searcher = BinaryRetrieverForGCBLike \033[0m") 258 | 259 | searcher = searcher( 260 | tokenizer=tokenizer, 261 | encoder=model, 262 | pooler=pooler_method 263 | ) 264 | 265 | searcher.encode_file( 266 | data_args.source_file, 267 | data_args.source_embed_save_file, 268 | normalize_embed=model_args.normalize_embed, 269 | max_transitions=model_args.max_transitions 270 | ) 271 | searcher.encode_file( 272 | data_args.target_file, 273 | data_args.target_embed_save_file, 274 | normalize_embed=model_args.normalize_embed, 275 | max_transitions=model_args.max_transitions 276 | ) 277 | 278 | with open(data_args.source_idx_file, 'r') as f: 279 | source_id_map = {} 280 | for idx, line in enumerate(f): 281 | source_id_map[idx] = line.strip() 282 | 283 | with open(data_args.target_idx_file, 'r') as f: 284 | target_id_map = {} 285 | for idx, line in enumerate(f): 286 | target_id_map[idx] = line.strip() 287 | 288 | source_embed = np.load(data_args.source_embed_save_file) 289 | target_embed = np.load(data_args.target_embed_save_file) 290 | assert (len(source_id_map) == source_embed.shape[0]) 291 | assert (len(target_id_map) == target_embed.shape[0]) 292 | 293 | results = { 294 | 'recall': {k: AverageMeter() for k in TOP_K}, 295 | 'mrr': AverageMeter() 296 | } 297 | total_pools = len(source_id_map) // data_args.pool_size 298 | for i in range(total_pools): 299 | pool_source_embed = source_embed[i * 300 | data_args.pool_size: (i + 1) * data_args.pool_size] 301 | pool_target_embed = target_embed[i * 302 | data_args.pool_size: (i + 1) * data_args.pool_size] 303 | pool_source_id_map, pool_target_id_map = {}, {} 304 | for j in range(data_args.pool_size): 305 | pool_source_id_map[j] = source_id_map[i * data_args.pool_size + j] 306 | pool_target_id_map[j] = target_id_map[i * data_args.pool_size + j] 307 | pool_results = searcher.retrieve( 308 | pool_source_embed, 309 | pool_target_embed, 310 | pool_source_id_map, 311 | pool_target_id_map, 312 | data_args.top_k 313 | ) 314 | pool_results = eval_from_dict(pool_results) 315 | results['mrr'].update(pool_results['mrr']) 316 | for k in TOP_K: 317 | results['recall'][k].update(pool_results['recall'][k]) 318 | 319 | print(results) 320 | 321 | 322 | if __name__ == '__main__': 323 | main() 324 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/model_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class DataCollatorForRabert(object): 6 | 7 | def __init__(self, tokenizer) -> None: 8 | self.tokenizer = tokenizer 9 | 10 | def __call__( 11 | self, 12 | examples 13 | ): 14 | batch = { 15 | 'input_ids': [], 16 | 'attention_mask': [], 17 | 'labels': [] 18 | } 19 | 20 | for example in examples: 21 | encoded = self.tokenizer.inst_encode(eval(example['code']), eval(example['data_dep'])) 22 | batch['input_ids'].append(encoded['input_ids']) 23 | batch['attention_mask'].append(encoded['attention_mask']) 24 | return { 25 | 'input_ids': torch.stack(batch['input_ids']), 26 | 'attention_mask': torch.stack(batch['attention_mask']) 27 | } 28 | 29 | 30 | class DataCollatorForCodeArt(object): 31 | 32 | def __init__(self, tokenizer) -> None: 33 | self.tokenizer = tokenizer 34 | 35 | def __call__( 36 | self, 37 | examples 38 | ): 39 | batch = { 40 | 'input_ids': [], 41 | 'attention_mask': [], 42 | 'relative_position_matrix': [] 43 | } 44 | 45 | for example in examples: 46 | num_functions = 0 47 | sequence_mask = [] 48 | current_ids = [] 49 | current_attention_mask = [] 50 | current_relative_position_matrix = [] 51 | for function in example['functions']: 52 | encoded = self.tokenizer.inst_encode(eval(function['code']), eval(function['data_dep'])) 53 | current_ids.append(encoded['input_ids']) 54 | current_attention_mask.append(encoded['attention_mask']) 55 | current_relative_position_matrix.append(encoded['relative_position_matrix']) 56 | num_functions += 1 57 | sequence_mask.append(1) 58 | batch['input_ids'].append(torch.stack(current_ids)) 59 | batch['attention_mask'].append(torch.stack(current_attention_mask)) 60 | batch['relative_position_matrix'].append(torch.stack(current_relative_position_matrix)) 61 | return { 62 | 'input_ids': torch.stack(batch['input_ids']), 63 | 'attention_mask': torch.stack(batch['attention_mask']), 64 | 'relative_position_matrix': torch.stack(batch['relative_position_matrix']), 65 | } -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/pretty_print.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | if [ $# -ne 1 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | FIN_PATTERN_PREFIX=$1 10 | 11 | echo "PR@1 for pool size 32, 50, 100, 200, 300, 500" 12 | cat $FIN_PATTERN_PREFIX*32.txt|egrep "Final-PR@1"|sed 's/.*:\s*\(\S*\)/\1/g' 13 | cat $FIN_PATTERN_PREFIX*50.txt|egrep "Final-PR@1"|sed 's/.*:\s*\(\S*\)/\1/g' 14 | cat $FIN_PATTERN_PREFIX*100.txt|egrep "Final-PR@1"|sed 's/.*:\s*\(\S*\)/\1/g' 15 | cat $FIN_PATTERN_PREFIX*200.txt|egrep "Final-PR@1"|sed 's/.*:\s*\(\S*\)/\1/g' 16 | cat $FIN_PATTERN_PREFIX*300.txt|egrep "Final-PR@1"|sed 's/.*:\s*\(\S*\)/\1/g' 17 | cat $FIN_PATTERN_PREFIX*500.txt|egrep "Final-PR@1"|sed 's/.*:\s*\(\S*\)/\1/g' 18 | 19 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/pretty_print_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | 5 | benchmarks = [ 6 | "coreutilsh", 7 | "binutilsh", 8 | "libcurlh", 9 | "libmagickh", 10 | "opensslh", 11 | "libsqlh", 12 | "puttyh", 13 | ] 14 | 15 | pool_sizes = [ 16 | 32, 50, 100, 200, 300, 500 17 | ] 18 | 19 | 20 | print("Pretty print for ", sys.argv[1]) 21 | 22 | prefix = sys.argv[1] 23 | if prefix.endswith("-"): 24 | prefix = prefix[:-1] 25 | 26 | file_map = {} 27 | for benchmark in benchmarks: 28 | for pool_size in pool_sizes: 29 | fname = "%s-%s-pool%d.txt" % (prefix, benchmark, pool_size) 30 | file_map[(benchmark, pool_size)] = fname 31 | 32 | 33 | result = {} 34 | for k,v in file_map.items(): 35 | fin = open(v, "r") 36 | # find the line starts with Final-PR@1: ... 37 | for line in fin.readlines(): 38 | if line.startswith("Final-PR@1:"): 39 | result[k] = float(line.split(":")[1].strip()) 40 | break 41 | fin.close() 42 | 43 | print("Pool size,", end="") 44 | for benchmark in benchmarks: 45 | print("%s," % benchmark, end="") 46 | print() 47 | 48 | for pool_size in pool_sizes: 49 | print("%d" % pool_size, end="") 50 | for benchmark in benchmarks: 51 | print(",%.3f" % result[(benchmark, pool_size)], end="") 52 | print() 53 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/run_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org= 5 | WANDB_project=codeart-binsim 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | WANDB_ENTITY=$WANDB_org \ 22 | WANDB_PROJECT=$WANDB_project \ 23 | torchrun --nproc_per_node=2 run.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/sample_and_report.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import json 3 | import numpy as np 4 | import os 5 | import torch 6 | from tqdm import tqdm 7 | from typing import List, Union, Tuple, Optional, Dict 8 | from eval import eval_from_dict, TOP_K 9 | from datasets import load_dataset 10 | 11 | 12 | class AverageMeter(object): 13 | """Computes and stores the average and current value""" 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | def __repr__(self): 30 | return str(self.avg) 31 | 32 | 33 | class BinaryRetriever(object): 34 | 35 | @staticmethod 36 | def retrieve_from_file_random( 37 | source_embed_file, 38 | target_embed_file, 39 | source_id_file, 40 | target_id_file, 41 | pool_size, 42 | top_k, 43 | seed=42, 44 | round_num=5 45 | ): 46 | with open(source_id_file, 'r') as f: 47 | source_id_map = {} 48 | source_func_id2idx = {} 49 | for idx, line in enumerate(f.readlines()): 50 | source_id_map[idx] = line.strip() 51 | source_func_id2idx[line.strip()] = idx 52 | 53 | with open(target_id_file, 'r') as f: 54 | target_id_map = {} 55 | target_func_id2idx = {} 56 | for idx, line in enumerate(f.readlines()): 57 | target_id_map[idx] = line.strip() 58 | target_func_id2idx[line.strip()] = idx 59 | 60 | overlapped_funcs = set(source_func_id2idx.keys()) & set(target_func_id2idx.keys()) 61 | print(f'Number of overlapped functions: {len(overlapped_funcs)}') 62 | source_embed_all = np.load(source_embed_file + '.npy') 63 | target_embed_all = np.load(target_embed_file + '.npy') 64 | # random sample 65 | np.random.seed(seed) 66 | pool_size = min(pool_size, len(overlapped_funcs)) 67 | 68 | overlapped_funcs_list = sorted(list(overlapped_funcs)) 69 | 70 | results_all = { 71 | 'recall': {k: AverageMeter() for k in TOP_K}, 72 | 'mrr': AverageMeter() 73 | } 74 | for rnd in range(round_num): 75 | selected_overlapped_funcs = np.random.choice(overlapped_funcs_list, pool_size, replace=False) 76 | print(f'Number of selected overlapped functions: {len(selected_overlapped_funcs)}') 77 | selected_indices = [source_func_id2idx[x] for x in selected_overlapped_funcs] 78 | 79 | 80 | source_embed = source_embed_all[selected_indices] 81 | target_embed = target_embed_all[selected_indices] 82 | 83 | print(f'source embedding shape: {source_embed.shape}, target embedding shape: {target_embed.shape}') 84 | indexer = faiss.IndexFlatIP(target_embed.shape[1]) 85 | indexer.add(target_embed) 86 | D, I = indexer.search(source_embed, top_k) 87 | 88 | results = {} 89 | for source_idx, (dist, retrieved_index) in enumerate(zip(D, I)): 90 | source_id = source_id_map[source_idx] 91 | results[source_id] = {} 92 | retrieved_target_id = [target_id_map[x] for x in retrieved_index] 93 | results[source_id]['retrieved'] = retrieved_target_id 94 | results[source_id]['score'] = dist.tolist() 95 | ret = eval_from_dict(results) 96 | results_all['mrr'].update(ret['mrr']) 97 | for k in TOP_K: 98 | results_all['recall'][k].update(ret['recall'][k]) 99 | 100 | return results_all 101 | 102 | 103 | 104 | @staticmethod 105 | def retrieve_from_file( 106 | source_embed_file, 107 | target_embed_file, 108 | source_id_file, 109 | target_id_file, 110 | pool_size, 111 | top_k, 112 | save_file, 113 | ): 114 | with open(source_id_file, 'r') as f: 115 | source_id_map = {} 116 | for idx, line in enumerate(f.readlines()[:pool_size]): 117 | source_id_map[idx] = line.strip() 118 | 119 | with open(target_id_file, 'r') as f: 120 | target_id_map = {} 121 | for idx, line in enumerate(f.readlines()[:pool_size]): 122 | target_id_map[idx] = line.strip() 123 | 124 | source_embed = np.load(source_embed_file + '.npy') 125 | target_embed = np.load(target_embed_file + '.npy') 126 | assert (len(source_id_map) == source_embed.shape[0]) 127 | assert (len(target_id_map) == target_embed.shape[0]) 128 | indexer = faiss.IndexFlatIP(target_embed.shape[1]) 129 | indexer.add(target_embed) 130 | print(f'source embedding shape: {source_embed.shape}, target embedding shape: {target_embed.shape}') 131 | D, I = indexer.search(source_embed, top_k) 132 | 133 | results = {} 134 | for source_idx, (dist, retrieved_index) in enumerate(zip(D, I)): 135 | source_id = source_id_map[source_idx] 136 | results[source_id] = {} 137 | retrieved_target_id = [target_id_map[x] for x in retrieved_index] 138 | results[source_id]['retrieved'] = retrieved_target_id 139 | results[source_id]['score'] = dist.tolist() 140 | 141 | with open(save_file, 'w+') as f: 142 | json.dump(results, f, indent=2) 143 | 144 | return results 145 | 146 | 147 | if __name__ == '__main__': 148 | import sys 149 | sys.path.append('../../code/') 150 | import argparse 151 | 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument( 154 | '--source_file', type=str, default='output/src_embedding', 155 | help='source embedding file' 156 | ) 157 | parser.add_argument( 158 | '--target_file', type=str, default='output/tgt_embedding', 159 | help='target embedding file' 160 | ) 161 | parser.add_argument( 162 | '--source_id_file', type=str, default='cache/binary_clone_detection/query.id', 163 | help='source id file' 164 | ) 165 | parser.add_argument( 166 | '--target_id_file', type=str, default='cache/binary_clone_detection/pool.id', 167 | help='target id file' 168 | ) 169 | parser.add_argument( 170 | '--pool_size', type=int, default=100, 171 | help='pool size' 172 | ) 173 | parser.add_argument( 174 | '--top_k', type=int, default=10, 175 | help='top k' 176 | ) 177 | parser.add_argument( 178 | '--save_file', type=str, default='output/retrieval_results.json', 179 | help='save file' 180 | ) 181 | 182 | args = parser.parse_args() 183 | 184 | # dump_files( 185 | # 'sheepy928/binkit-O0-raw', 186 | # '../../data/binkit-O3.jsonl', 187 | # '../../save/.cache', 188 | # swap=True 189 | # ) 190 | # dump_files_jsonl( 191 | # '../../data/coreutils.gcc.O3.jsonl', 192 | # '../../data/coreutils.clang.O0.jsonl', 193 | # cache_dir='../../save/.cache', 194 | # swap=False 195 | # ) 196 | ret = BinaryRetriever.retrieve_from_file_random( 197 | source_embed_file=args.source_file, 198 | target_embed_file=args.target_file, 199 | source_id_file=args.source_id_file, 200 | target_id_file=args.target_id_file, 201 | pool_size=args.pool_size, 202 | top_k=args.top_k 203 | ) 204 | 205 | print(ret) 206 | print("Final-PR@1: ", ret['recall'][1]) 207 | print("Final-MRR: ", ret['mrr']) 208 | 209 | -------------------------------------------------------------------------------- /codeart/evaluation/binary-similarity/sample_and_report.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -x 2 | 3 | function sample_and_report_one(){ 4 | proj_name=$1 5 | ckpt_name=$2 6 | pool_size=$3 7 | python3 sample_and_report.py \ 8 | --source_file output/$proj_name-src_$ckpt_name \ 9 | --target_file output/$proj_name-tgt_$ckpt_name \ 10 | --source_id_file cache/binary_clone_detection/$proj_name-query.id \ 11 | --target_id_file cache/binary_clone_detection/$proj_name-pool.id \ 12 | --pool_size $pool_size |tee report_$ckpt_name-$proj_name-pool$pool_size.txt 13 | } 14 | 15 | function sample_and_report_pools(){ 16 | proj_name=$1 17 | ckpt_name=$2 18 | sample_and_report_one $proj_name $ckpt_name 32 19 | sample_and_report_one $proj_name $ckpt_name 50 20 | sample_and_report_one $proj_name $ckpt_name 100 21 | sample_and_report_one $proj_name $ckpt_name 200 22 | sample_and_report_one $proj_name $ckpt_name 300 23 | sample_and_report_one $proj_name $ckpt_name 500 24 | } 25 | 26 | function sample_and_report_all(){ 27 | ckpt_name=$1 28 | sample_and_report_pools coreutilsh $ckpt_name 29 | sample_and_report_pools binutilsh $ckpt_name 30 | sample_and_report_pools libcurlh $ckpt_name 31 | sample_and_report_pools libmagickh $ckpt_name 32 | sample_and_report_pools opensslh $ckpt_name 33 | sample_and_report_pools libsqlh $ckpt_name 34 | sample_and_report_pools puttyh $ckpt_name 35 | } 36 | 37 | 38 | sample_and_report_all ckpt-4k 39 | 40 | -------------------------------------------------------------------------------- /codeart/evaluation/malware-family-classification/config/eval-2f-100c.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "../../checkpoints/codeart-26m-mfc-2f-100c", 3 | 4 | "dataset_name": "PurCL/malware-top-100", 5 | "output_dir": "../save/codeart-26m-mfc-2f-100c/", 6 | "max_functions": 2, 7 | 8 | "masking_enable_global_memory_patterns": true, 9 | "masking_enable_bridge_patterns": false, 10 | "masking_enable_graph_patterns": true, 11 | "masking_enable_local_patterns": true, 12 | "with_transitive_closure": true, 13 | 14 | "position_embedding_type": "mixed", 15 | "max_relative_position_embeddings": 8, 16 | "max_seq_length": 512, 17 | 18 | "use_auth_token": true, 19 | "dataloader_num_workers": 2, 20 | "remove_unused_columns": false, 21 | 22 | "do_train": true, 23 | "do_eval": true, 24 | "do_predict": true, 25 | 26 | "per_device_train_batch_size": 4, 27 | "gradient_accumulation_steps": 1, 28 | "per_device_eval_batch_size": 8, 29 | 30 | "num_train_epochs": 5, 31 | "learning_rate": 5e-5, 32 | "evaluation_strategy": "steps", 33 | "eval_steps": 100, 34 | "save_steps": 1000, 35 | "logging_steps": 10, 36 | 37 | "report_to": "tensorboard", 38 | "cache_dir": "../save/.cache", 39 | 40 | "overwrite_output_dir": true 41 | } -------------------------------------------------------------------------------- /codeart/evaluation/malware-family-classification/config/train-2f-100c.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "../../checkpoints/codeart-26m", 3 | 4 | "dataset_name": "PurCL/malware-top-100", 5 | "output_dir": "../save/codeart-26m-mfc-2f-100c/", 6 | "overwrite_cache": true, 7 | 8 | "masking_enable_global_memory_patterns": true, 9 | "masking_enable_bridge_patterns": false, 10 | "masking_enable_graph_patterns": true, 11 | "masking_enable_local_patterns": true, 12 | "with_transitive_closure": true, 13 | 14 | "position_embedding_type": "mixed", 15 | "max_relative_position_embeddings": 8, 16 | "max_seq_length": 512, 17 | "max_functions": 2, 18 | 19 | "use_auth_token": true, 20 | "dataloader_num_workers": 2, 21 | "remove_unused_columns": false, 22 | 23 | "do_train": true, 24 | "do_eval": true, 25 | "do_predict": true, 26 | 27 | "per_device_train_batch_size": 4, 28 | "gradient_accumulation_steps": 1, 29 | "per_device_eval_batch_size": 8, 30 | 31 | "num_train_epochs": 10, 32 | "learning_rate": 1e-4, 33 | "evaluation_strategy": "steps", 34 | "eval_steps": 100, 35 | "save_steps": 500, 36 | "logging_steps": 10, 37 | "load_best_model_at_end": true, 38 | "metric_for_best_model": "eval_roc_auc_score", 39 | 40 | "report_to": "tensorboard", 41 | "cache_dir": "../save/.cache", 42 | 43 | "overwrite_output_dir": true 44 | } -------------------------------------------------------------------------------- /codeart/evaluation/malware-family-classification/eval_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=your_org 5 | WANDB_project=malware-family-classification 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | python evaluate_multilabel.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation/malware-family-classification/run_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=your_org 5 | WANDB_project=malware-family-classification 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | torchrun --nproc_per_node=2 run_multilabel.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation/malware-family-classification/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class DataCollatorForCodeArt(object): 6 | 7 | def __init__(self, tokenizer, label2id, max_functions) -> None: 8 | self.tokenizer = tokenizer 9 | self.label2id = label2id 10 | self.max_functions = max_functions 11 | 12 | def __call__( 13 | self, 14 | examples 15 | ): 16 | batch = { 17 | 'input_ids': [], 18 | 'attention_mask': [], 19 | 'relative_position_matrix': [], 20 | 'labels': [], 21 | 'sequence_mask': [], 22 | 'all_labels': [] 23 | } 24 | 25 | for example in examples: 26 | num_functions = 0 27 | sequence_mask = [] 28 | for function in eval(example['functions'])[:self.max_functions]: 29 | encoded = self.tokenizer.inst_encode(function['code'], function['data_dep']) 30 | batch['input_ids'].append(encoded['input_ids']) 31 | batch['attention_mask'].append(encoded['attention_mask']) 32 | batch['relative_position_matrix'].append(encoded['relative_position_matrix']) 33 | num_functions += 1 34 | sequence_mask.append(1) 35 | for _ in range(num_functions, self.max_functions): # pad to max_functions 36 | encoded = self.tokenizer.inst_encode([], []) 37 | batch['input_ids'].append(encoded['input_ids']) 38 | batch['attention_mask'].append(encoded['attention_mask']) 39 | batch['relative_position_matrix'].append(encoded['relative_position_matrix']) 40 | sequence_mask.append(0) 41 | batch["labels"].append(self.label2id[random.sample(example['labels'], 1)[0]]) 42 | batch["sequence_mask"].append(sequence_mask) 43 | batch["all_labels"].append(example['labels']) 44 | return { 45 | 'input_ids': torch.stack(batch['input_ids']), 46 | 'attention_mask': torch.stack(batch['attention_mask']), 47 | 'relative_position_matrix': torch.stack(batch['relative_position_matrix']), 48 | 'labels': torch.tensor(batch['labels'], dtype=torch.long).unsqueeze(1), 49 | 'sequence_mask': torch.tensor(batch['sequence_mask'], dtype=torch.long), 50 | 'all_labels': batch['all_labels'] 51 | } 52 | 53 | 54 | class DataCollatorForCodeArtMultilabel(object): 55 | 56 | def __init__(self, tokenizer, label2id, max_functions) -> None: 57 | self.tokenizer = tokenizer 58 | self.label2id = label2id 59 | self.max_functions = max_functions 60 | 61 | def __call__( 62 | self, 63 | examples 64 | ): 65 | batch = { 66 | 'input_ids': [], 67 | 'attention_mask': [], 68 | 'relative_position_matrix': [], 69 | 'labels': [], 70 | 'sequence_mask': [], 71 | } 72 | 73 | for example in examples: 74 | num_functions = 0 75 | sequence_mask = [] 76 | for function in eval(example['functions'])[:self.max_functions]: 77 | encoded = self.tokenizer.inst_encode(function['code'], function['data_dep']) 78 | batch['input_ids'].append(encoded['input_ids']) 79 | batch['attention_mask'].append(encoded['attention_mask']) 80 | batch['relative_position_matrix'].append(encoded['relative_position_matrix']) 81 | num_functions += 1 82 | sequence_mask.append(1) 83 | for _ in range(num_functions, self.max_functions): # pad to max_functions 84 | encoded = self.tokenizer.inst_encode([], []) 85 | batch['input_ids'].append(encoded['input_ids']) 86 | batch['attention_mask'].append(encoded['attention_mask']) 87 | batch['relative_position_matrix'].append(encoded['relative_position_matrix']) 88 | sequence_mask.append(0) 89 | labels = torch.zeros(len(self.label2id)) 90 | for l in example['labels']: 91 | labels[self.label2id[l]] = 1 92 | batch['labels'].append(labels) 93 | batch["sequence_mask"].append(sequence_mask) 94 | return { 95 | 'input_ids': torch.stack(batch['input_ids']), 96 | 'attention_mask': torch.stack(batch['attention_mask']), 97 | 'relative_position_matrix': torch.stack(batch['relative_position_matrix']), 98 | 'labels': torch.stack(batch['labels']), 99 | 'sequence_mask': torch.tensor(batch['sequence_mask'], dtype=torch.long), 100 | } -------------------------------------------------------------------------------- /codeart/evaluation/type-inference/config/eval-O0.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "../../checkpoints/codeart-26m-ti-O0", 3 | "dataset_name": "PurCL/marinda-type-inference-debuginfo-only-O0-shuffle", 4 | "output_dir": "../save/codeart-ti-O0", 5 | 6 | "masking_enable_global_memory_patterns": true, 7 | "masking_enable_bridge_patterns": false, 8 | "masking_enable_graph_patterns": true, 9 | "masking_enable_local_patterns": true, 10 | "with_transitive_closure": true, 11 | 12 | "position_embedding_type": "mixed", 13 | "max_relative_position_embeddings": 8, 14 | "max_seq_length": 512, 15 | 16 | "seed": 42, 17 | 18 | "label_file": "labels.json", 19 | 20 | "use_auth_token": true, 21 | "dataloader_num_workers": 2, 22 | "remove_unused_columns": false, 23 | 24 | "do_predict": true, 25 | 26 | "per_device_train_batch_size": 48, 27 | "gradient_accumulation_steps": 2, 28 | "per_device_eval_batch_size": 128, 29 | 30 | "num_train_epochs": 5, 31 | "learning_rate": 5e-5, 32 | "evaluation_strategy": "steps", 33 | "eval_steps": 10, 34 | "save_steps": 100, 35 | "logging_steps": 2, 36 | 37 | "report_to": "tensorboard", 38 | "cache_dir": "../save/.cache", 39 | 40 | "save_total_limit": 10, 41 | "load_best_model_at_end": true, 42 | "metric_for_best_model": "eval_f1", 43 | 44 | "overwrite_output_dir": true, 45 | "no_cuda": false 46 | } -------------------------------------------------------------------------------- /codeart/evaluation/type-inference/config/train-O0.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_name_or_path": "../../checkpoints/codeart-26m", 3 | "dataset_name": "PurCL/marinda-type-inference-debuginfo-only-O0-shuffle", 4 | "output_dir": "../save/codeart-ti-O0", 5 | 6 | "masking_enable_global_memory_patterns": true, 7 | "masking_enable_bridge_patterns": false, 8 | "masking_enable_graph_patterns": true, 9 | "masking_enable_local_patterns": true, 10 | "with_transitive_closure": true, 11 | 12 | "position_embedding_type": "mixed", 13 | "max_relative_position_embeddings": 8, 14 | "max_seq_length": 512, 15 | 16 | "seed": 42, 17 | 18 | "label_file": "labels.json", 19 | 20 | "use_auth_token": true, 21 | "dataloader_num_workers": 2, 22 | "remove_unused_columns": false, 23 | 24 | "do_train": true, 25 | "do_eval": true, 26 | "do_predict": true, 27 | 28 | "per_device_train_batch_size": 48, 29 | "gradient_accumulation_steps": 2, 30 | "per_device_eval_batch_size": 128, 31 | 32 | "num_train_epochs": 5, 33 | "learning_rate": 1e-4, 34 | "evaluation_strategy": "steps", 35 | "eval_steps": 50, 36 | "save_steps": 100, 37 | "logging_steps": 2, 38 | 39 | "report_to": "tensorboard", 40 | "cache_dir": "../save/.cache", 41 | 42 | "save_total_limit": 10, 43 | "load_best_model_at_end": true, 44 | "metric_for_best_model": "eval_f1", 45 | 46 | "overwrite_output_dir": true, 47 | "no_cuda": false 48 | } -------------------------------------------------------------------------------- /codeart/evaluation/type-inference/eval_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=your_org 5 | WANDB_project=type-inference 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | python run.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation/type-inference/labels.json: -------------------------------------------------------------------------------- 1 | ["noacc", "base(int)", "base(char)**", "base(char)*", "base(_Bool)", "base(unsigned int)", "base(long unsigned int)", "struct*", "void*", "base(unsigned char)*", "enum", "base(unsigned char)", "base(char)", "base(unsigned int)*", "base(long int)", "base(short int)*", "union*", "subroutine*", "base(long int)*", "base(int)*", "array*", "struct**", "base(long unsigned int)*", "base(double)*", "void**", "base(short unsigned int)", "struct", "base(_Bool)*", "base(char)***", "base(long unsigned int)**", "base(long double)", "struct***", "enum*", "base(unsigned int)**", "base(long double)*", "base(float)*", "base(long long unsigned int)*", "base(short unsigned int)*", "base(signed char)*", "base(long long int)", "array", "base(double)", "base(unsigned char)**", "union", "base(short int)", "base(long long unsigned int)", "subroutine**", "array**", "base(signed char)", "base(float)"] -------------------------------------------------------------------------------- /codeart/evaluation/type-inference/run_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=your_org 5 | WANDB_project=type-inference 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | torchrun --nproc_per_node=2 --master_port=$2 run.py $CURRENT_DIR/$CONFIG -------------------------------------------------------------------------------- /codeart/evaluation/type-inference/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DataCollatorForCodeArt(object): 5 | 6 | def __init__(self, tokenizer, label2id) -> None: 7 | self.tokenizer = tokenizer 8 | self.label2id = label2id 9 | self.comma_id = tokenizer.convert_tokens_to_ids(',') 10 | self.ignore = [ 11 | tokenizer.convert_tokens_to_ids('u'), 12 | tokenizer.convert_tokens_to_ids('cpu'), 13 | tokenizer.convert_tokens_to_ids('m'), 14 | tokenizer.convert_tokens_to_ids('x') 15 | ] 16 | 17 | def __call__( 18 | self, 19 | examples 20 | ): 21 | batch = { 22 | 'input_ids': [], 23 | 'attention_mask': [], 24 | 'relative_position_matrix': [], 25 | 'labels': [] 26 | } 27 | 28 | for example in examples: 29 | encoded = self.tokenizer.inst_encode( 30 | eval(example['code']), 31 | eval(example['data_dep']), 32 | return_extra_info=True 33 | ) 34 | batch['input_ids'].append(encoded['input_ids']) 35 | batch['attention_mask'].append(encoded['attention_mask']) 36 | batch['relative_position_matrix'].append(encoded['relative_position_matrix']) 37 | 38 | instruction_node_positions = encoded['instruction_node_positions'] 39 | instruction_labels = eval(example['code_w_type']) 40 | token_labels = [-100] # ignore cls 41 | for i, inst_id in enumerate(instruction_node_positions): 42 | if 'base(char)' in instruction_labels[i]: 43 | print(eval(example['code'])[i]) 44 | start = inst_id 45 | if i + 1 == len(instruction_node_positions): 46 | end = encoded['input_ids'].tolist().index(self.tokenizer.sep_token_id) 47 | else: 48 | end = instruction_node_positions[i + 1] 49 | 50 | # ignore 51 | token_labels.append(-100) # NOTE: may surpass max_length 52 | # get operator type 53 | token_labels.append(self.label2id[instruction_labels[i][0]]) # NOTE: may surpass max_length 54 | # get operands by `,` 55 | cur_id = start + 2 56 | if start < 511 and encoded['input_ids'][start + 1] in self.ignore: 57 | cur_type_id = 0 58 | else: 59 | cur_type_id = 1 60 | while cur_id < end: 61 | if encoded['input_ids'][cur_id] == self.comma_id: 62 | token_labels.append(-100) 63 | cur_type_id += 1 64 | else: 65 | try: 66 | token_labels.append(self.label2id[instruction_labels[i][cur_type_id]]) 67 | except IndexError: 68 | print(example['metadata']) 69 | print(self.tokenizer.convert_ids_to_tokens(encoded['input_ids'][start: end])) 70 | print(end, cur_type_id, self.tokenizer.sep_token_id, self.tokenizer.convert_ids_to_tokens(self.tokenizer.sep_token_id)) 71 | # raise IndexError 72 | 73 | # set `cur_type_id` to 0 74 | cur_type_id = 0 75 | token_labels.append(self.label2id[instruction_labels[i][cur_type_id]]) 76 | cur_id += 1 77 | 78 | # brute force 79 | if len(token_labels) < self.tokenizer.model_max_length: 80 | token_labels += [-100] * (self.tokenizer.model_max_length - len(token_labels)) 81 | 82 | if len(token_labels) > self.tokenizer.model_max_length: 83 | # print([(t, l) for l, t in zip(token_labels, self.tokenizer.convert_ids_to_tokens(encoded['input_ids']))]) 84 | # print(token_labels[-1]) 85 | token_labels = token_labels[:self.tokenizer.model_max_length - 1] + [-100] # force [SEP] at end 86 | assert len(token_labels) == self.tokenizer.model_max_length 87 | 88 | batch["labels"].append(token_labels) 89 | 90 | return { 91 | 'input_ids': torch.stack(batch['input_ids']), 92 | 'attention_mask': torch.stack(batch['attention_mask']), 93 | 'relative_position_matrix': torch.stack(batch['relative_position_matrix']), 94 | 'labels': torch.tensor(batch['labels'], dtype=torch.long) 95 | } -------------------------------------------------------------------------------- /codeart/preprocess/README.md: -------------------------------------------------------------------------------- 1 | # Preprocess 2 | 3 | Our preprocess pipeline is derived from the code of [jTrans](https://github.com/vul337/jTrans). 4 | Our preprocess has three steps: 5 | 6 | 1. We use IDA Pro to disassemble the binary program. This script is largely based on the preprocessing script of jTrans. 7 | 2. We collect the preprocessed results from individual binary programs and merge them to a single file. 8 | 3. We perform a conservative dependence analysis to extract the program dependence. 9 | 10 | This file introduces the two key components of our preprocess pipeline. 11 | 12 | ## Disassemble 13 | 14 | Given an input binary program, we use IDA Pro to disassemble it. 15 | Specifically, we use IDA Pro to obtain the control flow graph (CFG) of the binary program, 16 | and further use IDA Pro to decode binary code to assembly instructions for each basic block. 17 | 18 | The IDA script we use is `disassemble.py`. 19 | It assumes the following directory structure: 20 | 21 | ``` 22 | example-project 23 | |--unstrip 24 | | |--example-binary0.elf 25 | | |--... 26 | |--extracted-bins 27 | | |--(empty) 28 | |--example-binary0.elf 29 | |--... 30 | ``` 31 | 32 | `example-project` is a directory denoting the project name (e.g., Coreutils). 33 | Suppose that the example project contains a list of binary programs (e.g., example-bianry0.elf, example-binary1.elf, ...). 34 | The directory `unstrip` contains the unstripped binary programs. 35 | They are used to obtain the function names of the binary programs. 36 | 37 | **The names are used exclusively for generating ground truth labels for the binary 38 | similarity task and is not used in CodeArt.** 39 | 40 | The directory `extracted-bins` is empty at the beginning. The IDA script will 41 | store the intermediate results in this directory in the format of `.pickle` files. 42 | 43 | The file `example-project/example-binary0.elf` is the binary program we want to disassemble. 44 | It is stripped. 45 | 46 | The IDA script runs as follow: 47 | 48 | ```shell 49 | $PATH_TO_IDA/idat64 -A -S"$PWD/disassemble.py" example-project/example-binary0.elf 50 | ``` 51 | 52 | The script will generate the following file(s): 53 | 54 | ``` 55 | example-project 56 | |--extracted-bins 57 | | |--example-binary0.elf_extract.pkl 58 | |--... 59 | ``` 60 | 61 | ## Collect Preprocessed Results 62 | 63 | This step aims to merge the preprocessed results from individual binary programs to a single file. The script is `collect.py`. 64 | It takes as input a file that contains paths to the preprocessed results of individual binary programs and outputs a single pickle file that contains all the preprocessed results. 65 | For example, the input file is similar to the following: 66 | 67 | ``` 68 | > cat example-list.txt 69 | example-project/extracted-bins/example-binary0.elf_extract.pkl 70 | example-project/extracted-bins/example-binary1.elf_extract.pkl 71 | ... 72 | ``` 73 | 74 | ## Dependence Analysis 75 | 76 | Then we perform a conservative dependence analysis to extract the program dependence. 77 | The entry point of our analysis is in `analyze.py`. 78 | It takes as input a pickle file obtained from the previous "collect" step, iterates over all functions in the disassembled binary program, 79 | and outputs the input to CodeArt in a `.jsonl` file. 80 | 81 | The main logic of our analysis is in `ExprLangAnalyzer` of `preprocess/analysis/expr_lang_analyzer.py`. 82 | 83 | To use the resulting `.jsonl` file for evaluation, please refer to the file `codeart/evaluation/binary-similarity/dump_files.py` 84 | This file takes as input three arguments: the paths to the query `.jsonl` and the pool `.jsonl` files, and the output path. The `README.md` file under `codeart/evaluation/binary-similarity` provides a link to Google Drive, containing the preprocessed test data used by CodeArt paper. The results of `codeart/evaluation/binary-similarity/dump_files.py` are expected to have the same format as our example data on Google Drive. 85 | 86 | For more details on the binary similarity task evaluation, please refer to the `README.md` under `codeart/evaluation/binary-similarity`. 87 | -------------------------------------------------------------------------------- /codeart/preprocess/analysis/expr_lang_analyzer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Set 3 | from analysis.prog_model import Instruction, BasicBlock, Function, ReachDefinitionAnalysis, PostDominatorAnalysis, ControlDependenceAnalysis 4 | import networkx as nx 5 | import json 6 | 7 | class ExprLangAnalyzer: 8 | 9 | def __init__(self, cfg:nx.classes.DiGraph): 10 | self.func = Function(cfg) 11 | self.reach_def = ReachDefinitionAnalysis(self.func) 12 | self.reach_def.run() 13 | self.post_dominator = PostDominatorAnalysis(self.func) 14 | self.post_dominator.run() 15 | self.control_dependence = ControlDependenceAnalysis(self.func, self.post_dominator) 16 | self.control_dependence.run() 17 | self.addr2our_bb = self.func.addr2our_bb 18 | self.dep = [] 19 | # Note: we need to precompute intra-block dependencies 20 | # because we might recurse into a block and we need to know 21 | # what the intra-block dependencies are for that block 22 | self.instr_to_intra_block_dep = {} 23 | for bb_addr in sorted(self.addr2our_bb.keys()): 24 | intra_block_dep = {} 25 | bb = self.addr2our_bb[bb_addr] 26 | for instr in bb.instrs: 27 | self.instr_to_intra_block_dep[instr.id] = dict(intra_block_dep) 28 | for current_def in instr.defs: 29 | intra_block_dep[current_def] = instr 30 | for bb_addr in sorted(self.addr2our_bb.keys()): 31 | self._print_dep_for_bb(self.addr2our_bb[bb_addr]) 32 | 33 | def print_func_to_jsonl(self, fout, metadata={}): 34 | instr_strs = [] 35 | for bb_addr in sorted(self.addr2our_bb.keys()): 36 | bb = self.addr2our_bb[bb_addr] 37 | for instr in bb.instrs: 38 | current_instr_id = instr.id 39 | current_instr_str = instr.code.split(';')[0] 40 | instr_strs.append((current_instr_id, current_instr_str)) 41 | deps_strs = [] 42 | distinct_deps = set(self.dep) 43 | # sort by first element 44 | distinct_deps = sorted(distinct_deps, key=lambda x: x[0]) 45 | for dep in distinct_deps: 46 | deps_strs.append((dep[0], dep[1])) 47 | 48 | data_out = { 49 | 'metadata': metadata, 50 | 'code': instr_strs, 51 | 'data_dep': deps_strs, 52 | } 53 | fout.write(json.dumps(data_out)) 54 | fout.write('\n') 55 | fout.flush() 56 | 57 | 58 | def _print_dep_for_instr(self, visited : Set, indent: int, instr:Instruction): 59 | intra_block_dep = self.instr_to_intra_block_dep[instr.id] 60 | inter_block_dep = self.reach_def.bb_in[instr.basic_block.addr] 61 | INDENT = ' ' 62 | current_indent = INDENT * indent 63 | if instr in visited: 64 | # print(current_indent, end='') 65 | # print("Cyclic dependency detected, skipping instr %d" % instr.id) 66 | return 67 | visited.add(instr.id) 68 | # print(current_indent, end='') 69 | # print('instr: %d: %s' % (instr.id, instr)) 70 | current_instr_id = instr.id 71 | deps = [] 72 | for use in instr.uses: 73 | # print(current_indent, end='') 74 | # print(' ;use: %s' % use) 75 | # print(current_indent, end=' ') 76 | if use in intra_block_dep: 77 | self.dep.append((current_instr_id, intra_block_dep[use].id)) 78 | # print(';intra_block_dep: %d' % intra_block_dep[use].id) 79 | # self._print_dep_for_instr(visited, indent+1, intra_block_dep[use]) 80 | elif use in inter_block_dep: 81 | # print(';inter_block_dep:[', end=' ') 82 | for def_instr in inter_block_dep[use]: 83 | self.dep.append((current_instr_id, def_instr.id)) 84 | # print('%d,' % def_instr.id, end=' ') 85 | # print(']', end=' ') 86 | # if len(inter_block_dep[use]) > 1: 87 | # print("**phi node here**") 88 | # else: 89 | # print() 90 | # for def_instr in inter_block_dep[use]: 91 | # self._print_dep_for_instr(visited, indent+1, def_instr) 92 | else: 93 | # print(';definition may be outside of current function') 94 | pass 95 | visited.remove(instr.id) 96 | 97 | 98 | 99 | 100 | 101 | # used for dbg 102 | def _print_dep_for_bb(self, bb:BasicBlock): 103 | def_in = self.reach_def.bb_in[bb.addr] 104 | # print() 105 | # print(";BB: %x" % bb.addr) 106 | # for k,v in def_in.items(): 107 | # print(";var: %s, def: [" % k, end='') 108 | # for instr in v: 109 | # print("%d, " % instr.id, end='') 110 | # print("]") 111 | 112 | for instr in bb.instrs: 113 | self._print_dep_for_instr(set(), 0, instr) -------------------------------------------------------------------------------- /codeart/preprocess/analyze.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import argparse 4 | import json 5 | import glob 6 | from utils.data_utils import dump_cfg, parse_cfg 7 | from analysis import prog_model 8 | import networkx as nx 9 | from analysis.expr_lang_analyzer import ExprLangAnalyzer 10 | from tqdm import tqdm 11 | import signal 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data-in', type=str, default='') 16 | parser.add_argument('--fout', type=str, default='') 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def timeout_handler(signum, frame): 22 | raise Exception("timeout") 23 | 24 | def main(): 25 | args = parse_args() 26 | 27 | data_in = pickle.load(open(args.data_in, 'rb')) 28 | fout = open(args.fout, 'w') 29 | signal.signal(signal.SIGALRM, timeout_handler) 30 | for function in tqdm(data_in): 31 | meta = { 32 | 'project_name': function['project_name'], 33 | 'function_name': function['funcname'], 34 | 'function_addr': function['funcaddr'], 35 | "binary_name": function['binname'], 36 | } 37 | # maximum analysis time is 10 seconds 38 | signal.alarm(10) 39 | try: 40 | expr_lang_analyzer = ExprLangAnalyzer(function['cfg']) 41 | signal.alarm(10) 42 | expr_lang_analyzer.print_func_to_jsonl(fout, metadata=meta) 43 | except Exception as e: 44 | # if e is not time out 45 | if str(e) != "timeout": 46 | print("Error in function: ") 47 | print(meta) 48 | # raise e 49 | else: 50 | print(meta) 51 | print(e) 52 | 53 | fout.close() 54 | print(args) 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /codeart/preprocess/binary_base.py: -------------------------------------------------------------------------------- 1 | from elftools.elf.elffile import ELFFile 2 | import elftools.elf.elffile as elffile 3 | from elftools.elf.sections import SymbolTableSection 4 | from collections import defaultdict 5 | import os 6 | 7 | class Binarybase(object): 8 | def __init__(self, unstrip_path): 9 | self.unstrip_path = unstrip_path 10 | assert(os.path.exists(unstrip_path)) 11 | self.addr2name = self.extract_addr2name(self.unstrip_path) 12 | 13 | def get_func_name(self, name, functions): 14 | 15 | if name not in functions: 16 | return name 17 | 18 | i = 0 19 | while True: 20 | 21 | new_name = name+'_'+str(i) 22 | if new_name not in functions: 23 | return new_name 24 | 25 | i += 1 26 | 27 | def scan_section(self, functions, section): 28 | """ 29 | Function to extract function names from a shared library file. 30 | """ 31 | if not section or not isinstance(section, SymbolTableSection) or section['sh_entsize'] == 0: 32 | return 0 33 | 34 | count = 0 35 | for nsym, symbol in enumerate(section.iter_symbols()): 36 | 37 | if symbol['st_info']['type'] == 'STT_FUNC' and symbol['st_shndx'] != 'SHN_UNDEF': 38 | 39 | func = symbol.name 40 | 41 | name = self.get_func_name(func, functions) 42 | 43 | if not name in functions: 44 | 45 | functions[name] = {} 46 | 47 | functions[name]['begin'] = symbol.entry['st_value'] 48 | 49 | 50 | def extract_addr2name(self, path): 51 | ''' 52 | return: 53 | ''' 54 | functions = {} 55 | with open(path, 'rb') as stream: 56 | 57 | elffile = ELFFile(stream) 58 | 59 | self.scan_section(functions, elffile.get_section_by_name('.symtab')) 60 | 61 | self.scan_section(functions, elffile.get_section_by_name('.dynsym')) 62 | 63 | addr2name = {func['begin']: name for (name, func) in functions.items()} 64 | return defaultdict(lambda:-1, addr2name) 65 | -------------------------------------------------------------------------------- /codeart/preprocess/collect.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from tqdm import tqdm 4 | 5 | import argparse 6 | 7 | ADDR_IDX = 0 8 | ASM_IDX = 1 9 | RAW_IDX = 2 10 | CFG_IDX = 3 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(description="Collect the preprocess results") 15 | parser.add_argument( 16 | "--binary_list_file", 17 | type=str, 18 | default="", 19 | help="This file contains the file names to be loaded", 20 | ) 21 | parser.add_argument( 22 | "--fout", type=str, default="", help="The output file name, in pickle format" 23 | ) 24 | args = parser.parse_args() 25 | return args 26 | 27 | 28 | if __name__ == "__main__": 29 | MAX_LEN = 512 30 | args = get_args() 31 | print(args) 32 | # load picked binary 33 | fin = open(args.binary_list_file, "r") 34 | binary_list = fin.readlines() 35 | fin.close() 36 | binfolder_binary_entries = [] 37 | all_binary_len = len(binary_list) 38 | print("Loading binaries ...") 39 | 40 | for b in tqdm(binary_list): 41 | project_name = os.path.basename(os.path.dirname(b)) 42 | bin_fin = open(b.strip(), "rb") 43 | binary = pickle.load(bin_fin) 44 | bin_fin.close() 45 | addr2function = {} 46 | for name, entry in binary.items(): 47 | addr2function[entry[0]] = entry 48 | 49 | for name, entry in binary.items(): 50 | my_cfg = entry[CFG_IDX] 51 | # this is to support legacy code 52 | # there is only one version of CFG now 53 | new_cfg = my_cfg 54 | func_addr = entry[ADDR_IDX] 55 | new_cfg.nodes[func_addr]["num"] = -1 56 | logical_order = [(n, my_cfg.nodes[n]) for n in my_cfg.nodes()] 57 | 58 | binfolder_binary_entries.append( 59 | { 60 | "project_name": project_name, 61 | "funcname": name, 62 | "binname": os.path.basename(b), 63 | "funcaddr": entry[ADDR_IDX], 64 | "cfg": new_cfg, 65 | "dbg_logical_order": logical_order, 66 | } 67 | ) 68 | 69 | fout = open(args.fout, "wb") 70 | pickle.dump(binfolder_binary_entries, fout) 71 | fout.close() 72 | exit(0) 73 | -------------------------------------------------------------------------------- /codeart/preprocess/disassemble.py: -------------------------------------------------------------------------------- 1 | import idc 2 | import idautils 3 | import idaapi 4 | import pickle 5 | import sys 6 | 7 | import networkx as nx 8 | from binary_base import Binarybase 9 | 10 | 11 | SAVEROOT = "./extracted-bins" # dir of pickle files saved by IDA 12 | DATAROOT = "./unstrip" # dir of binaries (not stripped) 13 | 14 | 15 | class BinaryData(Binarybase): 16 | def __init__(self, unstrip_path): 17 | super(BinaryData, self).__init__(unstrip_path) 18 | self.fix_up() 19 | 20 | def fix_up(self): 21 | for addr in self.addr2name: 22 | # incase some functions' instructions are not recognized by IDA 23 | idc.create_insn(addr) 24 | idc.add_func(addr) 25 | 26 | def get_asm(self, func): 27 | instGenerator = idautils.FuncItems(func) 28 | asm_list = [] 29 | for inst in instGenerator: 30 | asm_list.append(idc.GetDisasm(inst)) 31 | return asm_list 32 | 33 | def get_rawbytes(self, func): 34 | instGenerator = idautils.FuncItems(func) 35 | rawbytes_list = b"" 36 | for inst in instGenerator: 37 | rawbytes_list += idc.get_bytes(inst, idc.get_item_size(inst)) 38 | return rawbytes_list 39 | 40 | def get_cfg(self, func): 41 | 42 | def get_attr(block, func_addr_set): 43 | asm, raw = [], b"" 44 | curr_addr = block.start_ea 45 | if curr_addr not in func_addr_set: 46 | return -1 47 | # print(f"[*] cur: {hex(curr_addr)}, block_end: {hex(block.end_ea)}") 48 | while curr_addr <= block.end_ea: 49 | asm.append(idc.GetDisasm(curr_addr)) 50 | raw += idc.get_bytes(curr_addr, idc.get_item_size(curr_addr)) 51 | curr_addr = idc.next_head(curr_addr, block.end_ea) 52 | return asm, raw 53 | 54 | nx_graph = nx.DiGraph() 55 | flowchart = idaapi.FlowChart( 56 | idaapi.get_func(func), flags=idaapi.FC_PREDS) 57 | func_addr_set = set([addr for addr in idautils.FuncItems(func)]) 58 | for block in flowchart: 59 | # Make sure all nodes are added (including edge-less nodes) 60 | attr = get_attr(block, func_addr_set) 61 | if attr == -1: 62 | continue 63 | nx_graph.add_node(block.start_ea, asm=attr[0], raw=attr[1]) 64 | # print(f"[*] bb: {hex(block.start_ea)}, asm: {attr[0]}") 65 | for pred in block.preds(): 66 | if pred.start_ea not in func_addr_set: 67 | continue 68 | nx_graph.add_edge(pred.start_ea, block.start_ea) 69 | for succ in block.succs(): 70 | if succ.start_ea not in func_addr_set: 71 | continue 72 | nx_graph.add_edge(block.start_ea, succ.start_ea) 73 | return nx_graph 74 | 75 | 76 | 77 | def extract_all(self): 78 | for func in idautils.Functions(): 79 | if idc.get_segm_name(func) in ['.plt', 'extern', '.init', '.fini']: 80 | continue 81 | print("[+] %s" % idc.get_func_name(func)) 82 | asm_list = self.get_asm(func) 83 | rawbytes_list = self.get_rawbytes(func) 84 | cfg = self.get_cfg(func) 85 | unstrip_name = self.addr2name[func] 86 | if unstrip_name == -1: 87 | name = idc.get_func_name(func) 88 | else: 89 | name = unstrip_name 90 | yield (name, func, asm_list, rawbytes_list, cfg) 91 | 92 | 93 | if __name__ == "__main__": 94 | import os 95 | from collections import defaultdict 96 | 97 | print(DATAROOT) 98 | print(os.getcwd()) 99 | assert os.path.exists(DATAROOT), "DATAROOT does not exist" 100 | assert os.path.exists(SAVEROOT), "SAVEROOT does not exist" 101 | print("Current filename: %s" % idc.get_input_file_path()) 102 | binary_abs_path = idc.get_input_file_path() 103 | # filename = binary_abs_path.split('/')[-1][:-6] 104 | filename = binary_abs_path.split('/')[-1] 105 | unstrip_path = os.path.join(DATAROOT, filename) 106 | # unstrip_path = binary_abs_path 107 | idc.auto_wait() 108 | binary_data = BinaryData(unstrip_path) 109 | 110 | saved_dict = defaultdict(lambda: list) 111 | saved_path = os.path.join( 112 | SAVEROOT, filename + "_extract.pkl") # unpair data 113 | with open(saved_path, 'wb') as f: 114 | for func_name, func, asm_list, rawbytes_list, cfg in binary_data.extract_all(): 115 | saved_dict[func_name] = [func, asm_list, 116 | rawbytes_list, cfg, None] 117 | pickle.dump(dict(saved_dict), f) 118 | idc.qexit(0) # exit IDA -------------------------------------------------------------------------------- /codeart/preprocess/type_inference/die_globals.py: -------------------------------------------------------------------------------- 1 | all_dies = {} 2 | gof = 0 3 | range_list = None 4 | location_list = None 5 | 6 | import elftools 7 | 8 | 9 | def resolve_ofs(ofs): 10 | return ofs + gof 11 | 12 | 13 | cu_code_base_addr = 0 14 | 15 | 16 | def resolve_code_ofs(ofs): 17 | return ofs + cu_code_base_addr 18 | 19 | 20 | _dbg_cfi_entries = {} 21 | addr2cfi_entries = {} 22 | 23 | 24 | def parse_cfi_entires(cfi_entries): 25 | for entry in cfi_entries: 26 | if "header" not in dir(entry): 27 | continue 28 | if "initial_location" not in entry.header: 29 | continue 30 | addr = entry.header["initial_location"] 31 | cfi_table = entry._decode_CFI_table() 32 | _dbg_cfi_entries[addr] = cfi_table 33 | if len(cfi_table) != 2: 34 | raise Exception("CFI table length is not 2") 35 | current_cfa_entries = cfi_table[0] 36 | # the register ids mentioned in the CFI table 37 | # cols = cfi_table[1] 38 | has_rbp = False 39 | for cfa_entry in current_cfa_entries: 40 | cfa_info = cfa_entry["cfa"] 41 | if type(cfa_info) != elftools.dwarf.callframe.CFARule: 42 | raise Exception("CFI table is not CFARule") 43 | if cfa_info.reg == 6: 44 | has_rbp = True 45 | break 46 | 47 | if has_rbp: 48 | addr2cfi_entries[addr] = "rbp" 49 | else: 50 | addr2cfi_entries[addr] = "rsp" 51 | 52 | # heuristics to identify whether stack frame is rsp or rbp based 53 | -------------------------------------------------------------------------------- /codeart/preprocess/type_inference/gen_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import json 5 | import utils 6 | from tqdm import tqdm 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--jsonl-data-in", type=str, help="path to codeart jsonl data") 12 | parser.add_argument( 13 | "--raw-data-in", 14 | type=str, 15 | help="path to the .pkl file corresponding to the jsonl data", 16 | ) 17 | parser.add_argument( 18 | "--bin-dir-root", type=str, help="root dir containing the binaries" 19 | ) 20 | parser.add_argument("--fout", type=str, default="") 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | raw_data_in = pickle.load(open(args.raw_data_in, "rb")) 28 | jsonl_data_in = open(args.jsonl_data_in, "r") 29 | name2raw = {} 30 | binname2debug_info = {} 31 | for function in raw_data_in: 32 | proj_name = function["project_name"] 33 | binname = function["binname"] 34 | func_addr = function["funcaddr"] 35 | name2raw[(proj_name, binname, func_addr)] = function 36 | 37 | del proj_name 38 | del binname 39 | del func_addr 40 | del function 41 | if args.fout == "": 42 | args.fout = args.jsonl_data_in + ".type_data.jsonl" 43 | fout = open(args.fout, "w") 44 | for line in tqdm(jsonl_data_in): 45 | data = json.loads(line) 46 | project_name = data["metadata"]["project_name"] 47 | function_addr = data["metadata"]["function_addr"] 48 | binary_name = data["metadata"]["binary_name"] 49 | if (project_name, binary_name, function_addr) not in name2raw: 50 | raise Exception( 51 | "Function %s not found" 52 | % str((project_name, binary_name, function_addr)) 53 | ) 54 | 55 | normalized_bin_name = binary_name.strip().replace(".elf_extract.pkl", "") 56 | if (project_name, normalized_bin_name) not in binname2debug_info: 57 | # binname2debug_info[(proj_name, normalized_bin_name)] = None 58 | bin_path = os.path.join( 59 | args.bin_dir_root, 60 | project_name, 61 | normalized_bin_name + ".type_info.jsonl", 62 | ) 63 | dbg_info_searcher = utils.BinaryDebugInfoSearcher(bin_path) 64 | binname2debug_info[(project_name, normalized_bin_name)] = dbg_info_searcher 65 | 66 | current_dbg_info_searcher = binname2debug_info[ 67 | (project_name, normalized_bin_name) 68 | ] 69 | marinda_insns = data["code"] 70 | function_raw = name2raw[(project_name, binary_name, function_addr)] 71 | cfg = function_raw["cfg"] 72 | sorted_node_ids = sorted(cfg.nodes) 73 | nodes = [cfg.nodes[node_id] for node_id in sorted_node_ids] 74 | raw_insn_list = [] 75 | current_insns = 0 76 | for node in nodes: 77 | if len(node["asm"]) != len(node["addr_list"]): 78 | raise Exception("Length of asm and addr_list not equal") 79 | for asm, addr in zip(node["asm"], node["addr_list"]): 80 | if current_insns == len(marinda_insns): 81 | break 82 | raw_insn_list.append((current_insns, asm, addr)) 83 | current_insns += 1 84 | 85 | instrs_w_type = [] 86 | for i, marinda_insn in enumerate(marinda_insns): 87 | raw_insn = raw_insn_list[i] 88 | raw_insn_str = raw_insn[1] 89 | insn_addr = raw_insn[2] 90 | parse_ret = utils.parse_insn_possible_op(raw_insn_str) 91 | current_insn_w_type = [] 92 | for item in parse_ret: 93 | if item is None: 94 | continue 95 | if type(item) == str: 96 | current_insn_w_type.append((item, None)) 97 | else: 98 | ties = current_dbg_info_searcher.query(insn_addr, item[1]) 99 | if ties is None: 100 | current_insn_w_type.append((item[0], None)) 101 | else: 102 | current_insn_w_type.append((item[0], ties[-1].type)) 103 | 104 | instrs_w_type.append((marinda_insn, current_insn_w_type)) 105 | to_print = { 106 | "metadata": data["metadata"], 107 | "code": marinda_insns, 108 | "data_dep": data["data_dep"], 109 | "code_w_type": [i[1] for i in instrs_w_type], 110 | } 111 | fout.write(json.dumps(to_print)) 112 | fout.write("\n") 113 | 114 | fout.close() 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /codeart/preprocess/type_inference/parse_dwarf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from elftools.elf.elffile import ELFFile 3 | import sys 4 | from tqdm import tqdm 5 | import die_globals 6 | import base_calculator 7 | import json 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--fin", type=str, default="type-inference/diffutils-3.4-O2/diff" 14 | ) 15 | parser.add_argument("--fout", type=str, default="") 16 | 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | def collect_dies(die): 22 | die_globals.all_dies[die.offset] = die 23 | for child in die.iter_children(): 24 | collect_dies(child) 25 | 26 | 27 | def parse_die(die): 28 | if ( 29 | die.tag == "DW_TAG_subprogram" 30 | or die.tag == "DW_TAG_inlined_subroutine" 31 | or die.tag == "DW_TAG_lexical_block" 32 | ): 33 | sub_prog_range_list = [] 34 | if "DW_AT_low_pc" in die.attributes: 35 | low_pc = die.attributes["DW_AT_low_pc"].value 36 | high_pc = die.attributes["DW_AT_high_pc"].value 37 | if die.attributes["DW_AT_high_pc"].form.startswith("DW_FORM_addr"): 38 | pass 39 | elif die.attributes["DW_AT_high_pc"].form.startswith("DW_FORM_data"): 40 | high_pc += low_pc 41 | else: 42 | raise Exception("Invalid DW_AT_high_pc form") 43 | sub_prog_range_list.append((low_pc, high_pc)) 44 | elif "DW_AT_ranges" in die.attributes: 45 | offset = die.attributes["DW_AT_ranges"].value 46 | ranges = die_globals.range_list.get_range_list_at_offset(offset) 47 | for entry in ranges: 48 | low_pc = die_globals.resolve_code_ofs(entry.begin_offset) 49 | high_pc = die_globals.resolve_code_ofs(entry.end_offset) 50 | sub_prog_range_list.append((low_pc, high_pc)) 51 | if len(sub_prog_range_list) > 0: 52 | # for entry in sub_prog_range_list: 53 | # print("[%x, %x)" % (entry[0], entry[1])) 54 | # all_structure_list = [] 55 | all_type_list = [] 56 | to_visit = [(die, sub_prog_range_list)] 57 | while to_visit: 58 | entry = to_visit.pop() 59 | current_die = entry[0] 60 | my_range_list = entry[1] 61 | for child in current_die.iter_children(): 62 | if ( 63 | child.tag == "DW_TAG_variable" 64 | or child.tag == "DW_TAG_formal_parameter" 65 | ): 66 | # if base_calculator.complex_data_type(child): 67 | base_list = base_calculator.calculate_base_addr( 68 | child, my_range_list 69 | ) 70 | all_type_list.extend(base_list) 71 | elif ( 72 | child.tag == "DW_TAG_subprogram" 73 | or child.tag == "DW_TAG_inlined_subroutine" 74 | or child.tag == "DW_TAG_lexical_block" 75 | ): 76 | types_in_child = parse_die(child) 77 | all_type_list.extend(types_in_child) 78 | else: 79 | to_visit.append((child, my_range_list)) 80 | 81 | all_type_list_pretty = [] 82 | if "DW_AT_frame_base" in die.attributes: 83 | frame_base = die.attributes["DW_AT_frame_base"].value 84 | fbase = base_calculator.parse_exprloc(frame_base) 85 | if fbase[0] == "cfa": 86 | # resolve cfa 87 | my_addr = sub_prog_range_list[0][0] 88 | if my_addr not in die_globals.addr2cfi_entries: 89 | raise Exception("CFI entry not found") 90 | fbase_str = die_globals.addr2cfi_entries[my_addr] 91 | else: 92 | fbase_str = fbase[0] 93 | 94 | for entry in all_type_list: 95 | if "fbreg" in entry[4][0]: 96 | pretty_fbreg = entry[4][0].replace( 97 | "fbreg", "fbreg(%s)" % fbase_str 98 | ) 99 | pretty_loc = (pretty_fbreg, entry[4][1]) 100 | all_type_list_pretty.append( 101 | (entry[0], entry[1], entry[2], entry[3], pretty_loc) 102 | ) 103 | else: 104 | all_type_list_pretty.append(entry) 105 | else: 106 | all_type_list_pretty = all_type_list 107 | 108 | # for entry in all_type_list_pretty: 109 | # print( 110 | # "%s, %s, [%x, %x), %s" 111 | # % (entry[0], entry[1], entry[2], entry[3], entry[4]) 112 | # ) 113 | return all_type_list_pretty 114 | else: 115 | # print("No range list for %s" % die) 116 | # function declarations do not have range list 117 | return [] 118 | else: 119 | ret = [] 120 | if die.has_children: 121 | for child in die.iter_children(): 122 | types_in_child = parse_die(child) 123 | ret.extend(types_in_child) 124 | return ret 125 | 126 | 127 | def main(): 128 | args = parse_args() 129 | fin = open(args.fin, "rb") 130 | elf = ELFFile(fin) 131 | dwarf = elf.get_dwarf_info() 132 | die_globals.range_list = dwarf.range_lists() 133 | die_globals.location_list = dwarf.location_lists() 134 | eh_cfi = dwarf.EH_CFI_entries() 135 | die_globals.parse_cfi_entires(eh_cfi) 136 | 137 | types_all = [] 138 | types_all_pretty = [] 139 | for cu in dwarf.iter_CUs(): 140 | die_globals.gof = cu.cu_offset 141 | die = cu.get_top_DIE() 142 | if "DW_AT_low_pc" not in die.attributes: 143 | continue 144 | die_globals.cu_code_base_addr = die.attributes["DW_AT_low_pc"].value 145 | collect_dies(die) 146 | type_info = parse_die(die) 147 | type_info_pretty = [] 148 | for info in type_info: 149 | pretty_info = (info[0], info[1], "%x" % info[2], "%x" % info[3], info[4]) 150 | type_info_pretty.append(pretty_info) 151 | types_all.extend(type_info) 152 | types_all_pretty.extend(type_info_pretty) 153 | 154 | if args.fout == "": 155 | args.fout = args.fin + ".type_info.jsonl" 156 | with open(args.fout, "w") as fout: 157 | for info in types_all_pretty: 158 | if info[1] is None: 159 | continue 160 | entry = { 161 | "varname": info[0], 162 | "type": info[1], 163 | "low_pc": info[2], 164 | "high_pc": info[3], 165 | "loc": info[4], 166 | } 167 | fout.write(json.dumps(entry) + "\n") 168 | 169 | print(base_calculator.skip_cnt) 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /codeart/preprocess/type_inference/upload_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import pandas as pd 5 | from datasets import Dataset, concatenate_datasets, Features 6 | from tqdm import tqdm 7 | import numpy as np 8 | import glob 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--opt", type=str, default="O3") 12 | 13 | args = parser.parse_args() 14 | 15 | shards = glob.glob("dataset/*.type_data.jsonl") 16 | OPT = args.opt 17 | out_ds_name = "type-inference-all-%s" % OPT 18 | 19 | print("OPT: %s" % OPT) 20 | 21 | feature_dict = { 22 | "metadata": { 23 | "binary_name": {"dtype": "string", "_type": "Value"}, 24 | "function_addr": {"dtype": "int64", "_type": "Value"}, 25 | "function_name": {"dtype": "string", "_type": "Value"}, 26 | "project_name": {"dtype": "string", "_type": "Value"}, 27 | }, 28 | "code_w_type": {"dtype": "string", "_type": "Value"}, 29 | "code": {"dtype": "string", "_type": "Value"}, 30 | "data_dep": {"dtype": "string", "_type": "Value"}, 31 | } 32 | features = Features.from_dict(feature_dict) 33 | 34 | valid_projs = set( 35 | [ 36 | "coreutils-5.93-O0", 37 | "coreutils-5.93-O1", 38 | "coreutils-5.93-O2", 39 | "coreutils-5.93-O3", 40 | "coreutils-6.4-O0", 41 | "coreutils-6.4-O1", 42 | "coreutils-6.4-O2", 43 | "coreutils-6.4-O3", 44 | "coreutils-7.6-O0", 45 | "coreutils-7.6-O1", 46 | "coreutils-7.6-O2", 47 | "coreutils-7.6-O3", 48 | "coreutils-8.1-O0", 49 | "coreutils-8.1-O1", 50 | "coreutils-8.1-O2", 51 | "coreutils-8.1-O3", 52 | "coreutils-8.30-O0", 53 | "coreutils-8.30-O1", 54 | "coreutils-8.30-O2", 55 | "coreutils-8.30-O3", 56 | "diffutils-2.8-O0", 57 | "diffutils-2.8-O1", 58 | "diffutils-2.8-O2", 59 | "diffutils-2.8-O3", 60 | "diffutils-3.1-O0", 61 | "diffutils-3.1-O1", 62 | "diffutils-3.1-O2", 63 | "diffutils-3.1-O3", 64 | "diffutils-3.3-O0", 65 | "diffutils-3.3-O1", 66 | "diffutils-3.3-O2", 67 | "diffutils-3.3-O3", 68 | "diffutils-3.4-O0", 69 | "diffutils-3.4-O1", 70 | "diffutils-3.4-O2", 71 | "diffutils-3.4-O3", 72 | "findutils-4.233-O0", 73 | "findutils-4.233-O1", 74 | "findutils-4.233-O2", 75 | "findutils-4.233-O3", 76 | "findutils-4.41-O0", 77 | "findutils-4.41-O1", 78 | "findutils-4.41-O2", 79 | "findutils-4.41-O3", 80 | "findutils-4.6-O0", 81 | "findutils-4.6-O1", 82 | "findutils-4.6-O2", 83 | "findutils-4.6-O3", 84 | ] 85 | ) 86 | 87 | valid_projs = set([x for x in valid_projs if x.endswith(OPT)]) 88 | 89 | 90 | def gen(shards): 91 | for shard in shards: 92 | with open(shard, "r") as f: 93 | for line in f: 94 | data = json.loads(line) 95 | proj_name = data["metadata"]["project_name"] 96 | # if proj_name not in valid_projs: 97 | # continue 98 | if not proj_name.endswith(OPT): 99 | continue 100 | code_w_type = [] 101 | for types in data["code_w_type"]: 102 | code_w_type.append([i[1] if i[1] else "noacc" for i in types]) 103 | out_data_entry = { 104 | "metadata": data["metadata"], 105 | "code_w_type": str(code_w_type), 106 | "code": str(data["code"]), 107 | "data_dep": str(data["data_dep"]), 108 | } 109 | yield out_data_entry 110 | 111 | 112 | ds = Dataset.from_generator( 113 | gen, 114 | features=features, 115 | gen_kwargs={"shards": shards}, 116 | num_proc=8, 117 | cache_dir="./ds-cache", 118 | ) 119 | 120 | # split train/valid/test 121 | dataset = ds.shuffle(seed=42) 122 | ret_data_dict = dataset.train_test_split(test_size=0.1, seed=42) 123 | ret_data_dict.push_to_hub(out_ds_name + "-shuffle", private=True) 124 | 125 | print() 126 | -------------------------------------------------------------------------------- /codeart/preprocess/type_inference/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | 5 | 6 | class TypeInfoEntry: 7 | def __init__(self, raw_entry): 8 | self.type = raw_entry["type"] 9 | self.low_pc = int(raw_entry["low_pc"], 16) 10 | self.high_pc = int(raw_entry["high_pc"], 16) 11 | self.loc = self._normalize_loc(raw_entry["loc"]) 12 | 13 | def _normalize_loc(self, loc): 14 | if loc[0].startswith("fbreg"): 15 | # fbreg(rbp) --> RBP 16 | # fbreg(rsp) --> RSP 17 | normalized_reg_str = ( 18 | loc[0].replace("fbreg", "").replace("(", "").replace(")", "") 19 | ) 20 | normalized_reg_str = normalized_reg_str.upper() 21 | return "[%s%x]" % (normalized_reg_str, loc[1]) 22 | elif loc[0].startswith("reg"): 23 | reg_name = loc[0].split("_")[1] 24 | return reg_name 25 | elif loc[0].startswith("breg"): 26 | reg_name = loc[0].split("_")[1] 27 | return "[%s%x]" % (reg_name, loc[1]) 28 | else: 29 | return None 30 | 31 | def __repr__(self) -> str: 32 | return "TIE-[%x, %x)@%s: %s" % (self.low_pc, self.high_pc, self.loc, self.type) 33 | 34 | 35 | class DebugInfoIntervalTreeNode: 36 | def __init__(self, med, parent): 37 | self.med = med 38 | self.parent = parent 39 | self.left = None 40 | self.right = None 41 | 42 | def __repr__(self) -> str: 43 | return "Node-%x" % (self.med) 44 | 45 | 46 | def _build_tree(sorted_points, ties, parent): 47 | if not sorted_points: 48 | return None 49 | 50 | mid = len(sorted_points) // 2 51 | median = sorted_points[mid] 52 | # intervals that are on the left 53 | left_ties = [tie for tie in ties if tie.high_pc < median] 54 | # intervals that are on the right 55 | right_ties = [tie for tie in ties if tie.low_pc > median] 56 | # intervals that contain the median 57 | median_ties = [ 58 | tie for tie in ties if tie.low_pc <= median and tie.high_pc >= median 59 | ] 60 | node = DebugInfoIntervalTreeNode(median, parent) 61 | if len(left_ties) > 0: 62 | node.left = _build_tree(sorted_points[:mid], left_ties, node) 63 | if len(right_ties) > 0: 64 | node.right = _build_tree(sorted_points[mid:], right_ties, node) 65 | my_tie_sorted_by_low = sorted(median_ties, key=lambda x: x.low_pc) 66 | node.ties_sorted_by_low = my_tie_sorted_by_low 67 | my_tie_sorted_by_high = sorted(median_ties, key=lambda x: x.high_pc, reverse=True) 68 | node.ties_sorted_by_high = my_tie_sorted_by_high 69 | return node 70 | 71 | 72 | def _precompute_loc_info(node): 73 | if not node: 74 | return 75 | my_locs = set() 76 | for tie in node.ties_sorted_by_low: 77 | my_locs.add(tie.loc) 78 | node.my_locs = my_locs 79 | my_children_locs = set() 80 | if node.left: 81 | _precompute_loc_info(node.left) 82 | my_children_locs.update(node.left.locs_all) 83 | if node.right: 84 | _precompute_loc_info(node.right) 85 | my_children_locs.update(node.right.locs_all) 86 | 87 | node.locs_all = my_locs | my_children_locs 88 | 89 | 90 | def _query_tree(addr, loc, node): 91 | match_all = loc == "**" 92 | if loc not in node.locs_all and not match_all: 93 | return None 94 | related_ties = [] 95 | if addr <= node.med: 96 | if loc in node.my_locs or match_all: 97 | for tie in node.ties_sorted_by_low: 98 | if (tie.loc == loc or match_all) and tie.low_pc <= addr: 99 | related_ties.append(tie) 100 | if tie.low_pc > addr: 101 | break 102 | if node.left: 103 | left_ret = _query_tree(addr, loc, node.left) 104 | if left_ret: 105 | related_ties.extend(left_ret) 106 | else: 107 | if loc in node.my_locs or match_all: 108 | for tie in node.ties_sorted_by_high: 109 | if (tie.loc == loc or match_all) and tie.high_pc > addr: 110 | related_ties.append(tie) 111 | if tie.high_pc <= addr: 112 | break 113 | if node.right: 114 | right_ret = _query_tree(addr, loc, node.right) 115 | if right_ret: 116 | related_ties.extend(right_ret) 117 | if len(related_ties) == 0: 118 | return None 119 | return related_ties 120 | 121 | 122 | class BinaryDebugInfoSearcher: 123 | 124 | def __init__(self, file_name): 125 | self.file_name = file_name 126 | self.dbg_info_list_raw = [] 127 | if os.path.exists(file_name): 128 | self._load() 129 | self.interval_entries = [] 130 | for entry in self.dbg_info_list_raw: 131 | typeinfo_entry = TypeInfoEntry(entry) 132 | if typeinfo_entry.loc is not None: 133 | self.interval_entries.append(typeinfo_entry) 134 | interval_points = set() 135 | for entry in self.interval_entries: 136 | interval_points.add(entry.low_pc) 137 | interval_points.add(entry.high_pc) 138 | interval_points_sorted = sorted(list(interval_points)) 139 | # create interval tree 140 | self.root = _build_tree(interval_points_sorted, self.interval_entries, None) 141 | # also pre-compute the location information for each tree node 142 | _precompute_loc_info(self.root) 143 | 144 | def _load(self): 145 | with open(self.file_name, "r") as fin: 146 | for line in fin: 147 | data = json.loads(line) 148 | self.dbg_info_list_raw.append(data) 149 | 150 | def query(self, addr, loc): 151 | if len(self.interval_entries) == 0: 152 | return None 153 | return _query_tree(addr, loc, self.root) 154 | 155 | 156 | reg_normalize = { 157 | "al": "RAX", 158 | "ah": "RAX", 159 | "ax": "RAX", 160 | "eax": "RAX", 161 | "rax": "RAX", 162 | "bl": "RBX", 163 | "bh": "RBX", 164 | "bx": "RBX", 165 | "ebx": "RBX", 166 | "rbx": "RBX", 167 | "cl": "RCX", 168 | "ch": "RCX", 169 | "cx": "RCX", 170 | "ecx": "RCX", 171 | "rcx": "RCX", 172 | "dl": "RDX", 173 | "dh": "RDX", 174 | "dx": "RDX", 175 | "edx": "RDX", 176 | "rdx": "RDX", 177 | "sil": "RSI", 178 | "si": "RSI", 179 | "esi": "RSI", 180 | "rsi": "RSI", 181 | "dil": "RDI", 182 | "di": "RDI", 183 | "edi": "RDI", 184 | "rdi": "RDI", 185 | "bpl": "RBP", 186 | "bp": "RBP", 187 | "ebp": "RBP", 188 | "rbp": "RBP", 189 | "spl": "RSP", 190 | "sp": "RSP", 191 | "esp": "RSP", 192 | "rsp": "RSP", 193 | "r8b": "R8", 194 | "r8w": "R8", 195 | "r8d": "R8", 196 | "r8": "R8", 197 | "r9b": "R9", 198 | "r9w": "R9", 199 | "r9d": "R9", 200 | "r9": "R9", 201 | "r10b": "R10", 202 | "r10w": "R10", 203 | "r10d": "R10", 204 | "r10": "R10", 205 | "r11b": "R11", 206 | "r11w": "R11", 207 | "r11d": "R11", 208 | "r11": "R11", 209 | "r12b": "R12", 210 | "r12w": "R12", 211 | "r12d": "R12", 212 | "r12": "R12", 213 | "r13b": "R13", 214 | "r13w": "R13", 215 | "r13d": "R13", 216 | "r13": "R13", 217 | "r14b": "R14", 218 | "r14w": "R14", 219 | "r14d": "R14", 220 | "r14": "R14", 221 | "r15b": "R15", 222 | "r15w": "R15", 223 | "r15d": "R15", 224 | "r15": "R15", 225 | } 226 | 227 | NUM_PAT = re.compile(r"[0-9]+") 228 | 229 | 230 | def _eval_addr_str(addr_str): 231 | if "+" in addr_str: 232 | plus_idx = addr_str.find("+") 233 | base_str = addr_str[:plus_idx] 234 | offset_str = addr_str[plus_idx + 1 :] 235 | base_expr = _eval_addr_str(base_str) 236 | offset_expr = _eval_addr_str(offset_str) 237 | if type(base_expr) == int and type(offset_expr) == int: 238 | return base_expr + offset_expr 239 | elif type(offset_expr) == int and offset_expr > 0: 240 | return "%s+%x" % (base_expr, offset_expr) 241 | elif type(offset_expr) == int and offset_expr < 0: 242 | return "%s-%x" % (base_expr, -offset_expr) 243 | elif type(offset_expr) == int and offset_expr == 0: 244 | return base_expr 245 | elif "-" in addr_str: 246 | minus_idx = addr_str.find("-") 247 | base_str = addr_str[:minus_idx] 248 | offset_str = addr_str[minus_idx + 1 :] 249 | base_expr = _eval_addr_str(base_str) 250 | offset_expr = _eval_addr_str(offset_str) 251 | if type(base_expr) == int and type(offset_expr) == int: 252 | return base_expr - offset_expr 253 | if type(base_expr) == int and type(offset_expr) == int: 254 | return base_expr - offset_expr 255 | elif type(offset_expr) == int and offset_expr > 0: 256 | return "%s-%s" % (base_expr, offset_expr) 257 | elif type(offset_expr) == int and offset_expr < 0: 258 | return "%s+%x" % (base_expr, -offset_expr) 259 | elif type(offset_expr) == int and offset_expr == 0: 260 | return base_expr 261 | elif "*" in addr_str: 262 | mul_idx = addr_str.find("*") 263 | base_str = addr_str[:mul_idx] 264 | offset_str = addr_str[mul_idx + 1 :] 265 | base_expr = _eval_addr_str(base_str) 266 | offset_expr = _eval_addr_str(offset_str) 267 | if type(base_expr) == int and type(offset_expr) == int: 268 | return base_expr * offset_expr 269 | elif type(offset_expr) == int and offset_expr > 0: 270 | return "%s*%s" % (base_expr, offset_expr) 271 | elif type(offset_expr) == int and offset_expr < 0: 272 | return "%s*%x" % (base_expr, -offset_expr) 273 | elif type(offset_expr) == int and offset_expr == 0: 274 | return 0 275 | elif addr_str.strip() in reg_normalize: 276 | return reg_normalize[addr_str.strip()] 277 | elif addr_str.startswith("var_"): 278 | # try to parse it as a hex number 279 | hex_num = addr_str[4:] 280 | try: 281 | return -int(hex_num, 16) 282 | except: 283 | pass 284 | return addr_str 285 | elif addr_str.endswith("h"): 286 | hex_num = addr_str[:-1] 287 | try: 288 | return int(hex_num, 16) 289 | except: 290 | pass 291 | return addr_str 292 | elif NUM_PAT.match(addr_str): 293 | return int(addr_str) 294 | else: 295 | return addr_str 296 | 297 | 298 | ADDR_PAT = re.compile(r"\[.*\]") 299 | 300 | 301 | def _normalize_op(op_str): 302 | # if it's a register, normalize it 303 | if op_str in reg_normalize: 304 | return reg_normalize[op_str] 305 | # if it's [reg+xxxh], normalize it 306 | if "[" in op_str and "]" in op_str: 307 | addr_str = ADDR_PAT.findall(op_str)[0] 308 | # remove "[", "]" 309 | addr_str_expr = addr_str[1:-1] 310 | evaluated_expr = _eval_addr_str(addr_str_expr) 311 | if type(evaluated_expr) == int: 312 | return "[%x]" % evaluated_expr 313 | else: 314 | return "[%s]" % evaluated_expr 315 | return op_str 316 | 317 | 318 | def parse_insn_possible_op(insn): 319 | operator = None 320 | operand = None 321 | comments = None 322 | operand_strs = [] 323 | operands = [] 324 | insn_ori = insn 325 | insn = insn.strip() 326 | if insn.find(";") != -1: 327 | comments = insn[insn.find(";") + 1 :] 328 | insn = insn[: insn.find(";")] 329 | insn = insn.strip() 330 | if insn.find(" ") != -1: 331 | operator = insn[: insn.find(" ")] 332 | operand = insn[insn.find(" ") + 1 :] 333 | else: 334 | operator = insn 335 | 336 | if operand is not None: 337 | operand_strs = operand.split(",") 338 | for op in operand_strs: 339 | op = op.strip() 340 | op_parsed = _normalize_op(op) 341 | operands.append(op_parsed) 342 | 343 | ret = [] 344 | ret.append(operator) 345 | for op, op_str in zip(operands, operand_strs): 346 | ret.append((op_str, op)) 347 | return ret 348 | 349 | pass 350 | -------------------------------------------------------------------------------- /codeart/preprocess/utils/asm_parser.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def ispurenumber(number): 4 | if number[0] == '+' or number[0] == '-': 5 | number = number[1:] 6 | # whether every char is digit 7 | for i in range(len(number)): 8 | if str.isdigit(number[i]): 9 | continue 10 | else: 11 | return False 12 | return True 13 | 14 | 15 | def isaddr(number): 16 | return number[0] == '[' and number[-1] == ']' 17 | 18 | 19 | def ishexnumber(number): 20 | if number[0] == '+' or number[0] == '-': 21 | number = number[1:] 22 | if number[-1] == 'h': 23 | for i in range(len(number)-1): 24 | if str.isdigit(number[i]) or (number[i] >= 'A' and number[i] <= 'F'): 25 | continue 26 | else: 27 | return False 28 | else: 29 | return False 30 | return True 31 | -------------------------------------------------------------------------------- /codeart/preprocess/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | 2 | def dump_cfg(cfg, fout_name): 3 | """ 4 | dump cfg to text form 5 | """ 6 | node_ids = [n for n in cfg.nodes] 7 | node_ids.sort(key=lambda x: (cfg.nodes[x]['num'] if 'num' in cfg.nodes[x] else 0, x)) 8 | with open(fout_name, 'w') as f: 9 | for n in node_ids: 10 | f.write('node: %s\n' % n) 11 | if 'num' in cfg.nodes[n]: 12 | f.write('num: %x\n' % cfg.nodes[n]['num']) 13 | else: 14 | f.write('num: %x\n' % 0) 15 | f.write('asm: %x\n'%n) 16 | for a in cfg.nodes[n]['asm']: 17 | f.write(' %s\n' % a) 18 | f.write('edges:\n') 19 | for e in cfg.edges(n): 20 | f.write(' %s\n' % str(e)) 21 | f.write('endedges\n') 22 | f.flush() 23 | f.close() 24 | 25 | def parse_cfg(fin_name): 26 | """ 27 | parse cfg from text form 28 | """ 29 | cfg = nx.DiGraph() 30 | fin = open(fin_name, 'r') 31 | lines = fin.readlines() 32 | fin.close() 33 | PARSE_NODE = 0 34 | PARSE_ASM = 1 35 | PARSE_EDGE = 2 36 | PARSE_NUM = 3 37 | state = PARSE_NODE 38 | for line in lines: 39 | line = line.strip() 40 | lines = line.split(';') 41 | if len(lines) > 0: 42 | line = lines[0] 43 | if len(line) == 0: 44 | continue 45 | if state == PARSE_NODE: 46 | if line.startswith('node:'): 47 | node_id = int(line.split(':')[1].strip()) 48 | cfg.add_node(node_id) 49 | state = PARSE_NUM 50 | cfg.nodes[node_id]['asm'] = [] 51 | elif state == PARSE_NUM: 52 | if line.startswith('num:'): 53 | cfg.nodes[node_id]['num'] = int(line.split(':')[1].strip(),16) 54 | state = PARSE_ASM 55 | elif state == PARSE_ASM: 56 | if line.startswith('asm:'): 57 | continue 58 | elif line.startswith('edges:'): 59 | state = PARSE_EDGE 60 | else: 61 | cfg.nodes[node_id]['asm'].append(line.strip()) 62 | elif state == PARSE_EDGE: 63 | if line.startswith('edges:'): 64 | continue 65 | else: 66 | if line.startswith('endedges'): 67 | state = PARSE_NODE 68 | else: 69 | edge = eval(line.strip()) 70 | cfg.add_edge(edge[0], edge[1]) 71 | return cfg 72 | -------------------------------------------------------------------------------- /codeart/scripts/config/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_name": "PurCL/bincorp-26m-all", 3 | "run_name": "codeart", 4 | "output_dir": "../save/codeart", 5 | 6 | 7 | "masking_enable_global_memory_patterns": true, 8 | "masking_enable_bridge_patterns": false, 9 | "masking_enable_graph_patterns": true, 10 | "masking_enable_local_patterns": true, 11 | "with_transitive_closure": true, 12 | 13 | 14 | "position_embedding_type": "mixed", 15 | "max_relative_position_embeddings": 8, 16 | "ep_add_linear_projection": true, 17 | "max_seq_length": 512, 18 | "mlm_probability": 0.15, 19 | "ep_probability": 0.4, 20 | 21 | 22 | "use_auth_token": true, 23 | "dataloader_num_workers": 2, 24 | "remove_unused_columns": false, 25 | "do_train": true, 26 | "do_eval": true, 27 | 28 | 29 | "per_device_train_batch_size": 64, 30 | "gradient_accumulation_steps": 2, 31 | "per_device_eval_batch_size": 256, 32 | 33 | 34 | "num_train_epochs": 20, 35 | "learning_rate": 5e-4, 36 | "warmup_steps": 10000, 37 | "weight_decay": 0.01, 38 | "evaluation_strategy": "steps", 39 | "eval_steps": 100, 40 | "save_steps": 1000, 41 | "logging_steps": 10, 42 | 43 | 44 | "report_to": "wandb", 45 | "cache_dir": "../save/.cache", 46 | 47 | 48 | "overwrite_output_dir": true, 49 | "torch_compile": true, 50 | "fp16": true 51 | } -------------------------------------------------------------------------------- /codeart/scripts/train_config.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CURRENT_DIR=$(pwd) 3 | 4 | WANDB_org=your_org 5 | WANDB_project=codeart 6 | 7 | # Check if the CONFIG file is provided 8 | if [ -z "$CONFIG" ]; then 9 | echo "Please provide a config file" 10 | exit 1 11 | fi 12 | 13 | # Check if the CONFIG file exists 14 | if [ ! -f "$CURRENT_DIR/$CONFIG" ]; then 15 | echo "Config file $CURRENT_DIR/$CONFIG does not exist" 16 | exit 1 17 | fi 18 | 19 | echo "Config: $CONFIG" 20 | 21 | 22 | cd ../code/ 23 | 24 | 25 | WANDB_ENTITY=$WANDB_org \ 26 | WANDB_PROJECT=$WANDB_project \ 27 | torchrun --nproc_per_node=8 run.py $CURRENT_DIR/$CONFIG 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | transformers==4.30.2 3 | datasets==2.14.4 4 | networkx==3.1 5 | scikit-learn==1.3.0 6 | gradio 7 | faiss-gpu --------------------------------------------------------------------------------