├── .dockerignore ├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── README.md ├── data └── .gitignore ├── docs └── adding_cwes.md ├── environment.yml ├── iris_arch.png ├── output └── .keep ├── results ├── CodeQL.csv ├── IRIS+DeepSeekCoder-33B.csv ├── IRIS+DeepSeekCoder-7B.csv ├── IRIS+GPT-3.5.csv ├── IRIS+GPT-4.csv ├── IRIS+Gemma2-27B.csv ├── IRIS+Llama3-70B.csv ├── IRIS+Llama3-8B.csv └── IRIS+Qwen2.5-32B.csv ├── scripts ├── build_codeql_dbs.py ├── codeql-queries │ ├── packages.ql │ └── qlpack.yml ├── get_packages_codeql.py └── setup_environment.sh └── src ├── __init__.py ├── codeql_queries.py ├── codeql_vul.py ├── codeql_vul_for_query.py ├── config.py ├── cwe-queries ├── GeneralQuerywLLM.ql ├── cwe-022 │ ├── MyTaintedPathQuery.qll │ ├── cwe-022wLLM.ql │ ├── cwe-022wLLMAugmented.ql │ ├── cwe-022wLLMSinksOnly.ql │ └── cwe-022wLLMSourcesOnly.ql ├── cwe-078 │ ├── CommandInjectionRuntimeExecLocalwLLM.ql │ ├── CommandInjectionRuntimeExecwLLM.ql │ ├── CommandInjectionRuntimeExecwLLMSinksOnly.ql │ ├── CommandInjectionRuntimeExecwLLMSourcesOnly.ql │ ├── ExecRelativewLLM.ql │ ├── ExecTaintedLocalwLLM.ql │ ├── ExecTaintedwLLM.ql │ ├── ExecUnescapedwLLM.ql │ ├── MyCommandArguments.qll │ ├── MyCommandInjectionRuntimeExec.qll │ ├── MyCommandLineQuery.qll │ └── MyExternalProcess.qll ├── cwe-079 │ ├── MyXSS.qll │ ├── MyXssLocalQuery.qll │ ├── MyXssQuery.qll │ ├── XSS.ql │ ├── XSSLocal.ql │ ├── XSSSinksOnly.ql │ └── XSSSourcesOnly.ql ├── cwe-094 │ ├── MySpelInjection.qll │ ├── MySpelInjectionQuery.qll │ ├── MyTemplateInjection.qll │ ├── MyTemplateInjectionQuery.qll │ ├── SpelInjection.ql │ ├── SpelInjectionSinksOnly.ql │ ├── SpelInjectionSourcesOnly.ql │ └── TemplateInjection.ql ├── cwe-295 │ ├── InsecureTrustManager.ql │ ├── MyInsecureTrustManager.qll │ └── MyInsecureTrustManagerQuery.qll ├── cwe-502 │ ├── UnsafeDeserialization.ql │ └── UnsafeDeserializationQuery.qll ├── cwe-611 │ ├── MyXxe.qll │ ├── MyXxeQuery.qll │ ├── MyXxeRemoteQuery.qll │ └── XXE.ql └── cwe-general │ └── GeneralSpecs.qll ├── evaluate_spec_against_codeql.py ├── logger.py ├── models ├── __init__.py ├── codegen.py ├── codellama.py ├── codet5.py ├── config.py ├── deepseek.py ├── gemini.py ├── google.py ├── gpt.py ├── llama.py ├── llm.py ├── mistral.py ├── ollama.py ├── openaimodels.py ├── qwen.py ├── starcoder.py └── wizarcoder.py ├── modules ├── codeql_query_runner.py ├── contextual_analysis_pipeline.py ├── evaluation_pipeline.py └── postprocess_cwe_query.py ├── neusym_vul.py ├── neusym_vul_for_query.py ├── prompts.py ├── queries.py ├── queries ├── fetch_class_locs.ql ├── fetch_external_apis.ql ├── fetch_field_reads.ql ├── fetch_func_locs.ql ├── fetch_func_params.ql ├── fetch_sinks.ql ├── fetch_sources.ql └── getpackages.ql └── utils ├── __init__.py ├── cwe_top_25.txt ├── cwenames.txt ├── cwenames_top25.txt ├── cweparser.py ├── metrics_table.py ├── metrics_table_cwe.py ├── metrics_test.py ├── mylogger.py ├── prompt_utils.py ├── sample_spec.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # we instead clone cwe-bench-java within the docker container 2 | data/cwe-bench-java/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | codeql-all-dbs 2 | codeql-all-dbs-gh 3 | patch_info 4 | projects 5 | projects-gh 6 | commit_data 7 | old 8 | outputs* 9 | output/ 10 | 11 | **/__pycache__ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "data/cwe-bench-java"] 2 | path = data/cwe-bench-java 3 | url = https://github.com/iris-sast/cwe-bench-java.git 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Start with Ubuntu base image 2 | FROM ubuntu:22.04 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # Install dependencies 7 | RUN apt-get update && apt-get install -y \ 8 | git wget curl python3 python3-pip unzip tar \ 9 | && rm -rf /var/lib/apt/lists/* 10 | 11 | # Install Miniconda based on architecture 12 | RUN arch=$(uname -m) && \ 13 | if [ "$arch" = "x86_64" ]; then \ 14 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh"; \ 15 | elif [ "$arch" = "aarch64" ]; then \ 16 | MINICONDA_URL="https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh"; \ 17 | else \ 18 | echo "Unsupported architecture: $arch"; \ 19 | exit 1; \ 20 | fi && \ 21 | wget $MINICONDA_URL -O miniconda.sh && \ 22 | bash miniconda.sh -b -p /opt/conda && \ 23 | rm miniconda.sh 24 | 25 | ENV PATH=/opt/conda/bin:$PATH 26 | 27 | WORKDIR /iris 28 | COPY . /iris/ 29 | RUN git clone https://github.com/iris-sast/cwe-bench-java.git data/cwe-bench-java 30 | 31 | RUN chmod +x scripts/setup_environment.sh 32 | RUN bash ./scripts/setup_environment.sh 33 | 34 | # Set up shell 35 | SHELL ["/bin/bash", "-c"] 36 | #RUN echo "conda activate $(head -1 environment.yml | cut -d' ' -f2)" >> ~/.bashrc 37 | RUN ENV_NAME=$(head -1 environment.yml | cut -d' ' -f2) && \ 38 | conda init bash && \ 39 | echo ". /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 40 | echo "conda activate $ENV_NAME" >> ~/.bashrc 41 | # Copy JDKs from build context 42 | COPY jdk-7u80-linux-x64.tar.gz jdk-8u202-linux-x64.tar.gz jdk-17_linux-x64_bin.tar.gz /iris/data/cwe-bench-java/java-env/ 43 | RUN cd /iris/data/cwe-bench-java/java-env/ && \ 44 | tar xzf jdk-8u202-linux-x64.tar.gz --no-same-owner && \ 45 | tar xzf jdk-7u80-linux-x64.tar.gz --no-same-owner && \ 46 | tar xzf jdk-17_linux-x64_bin.tar.gz --no-same-owner && \ 47 | chmod -R 755 */bin */lib && \ 48 | chmod -R 755 */jre/bin */jre/lib && \ 49 | ls -la jdk-17/lib/libjli.so # Verify the library exists and permissions 50 | CMD ["/bin/bash"] 51 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ziyang Li 4 | Copyright (c) 2024 Saikat Dutta 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | codeql-dbs -------------------------------------------------------------------------------- /docs/adding_cwes.md: -------------------------------------------------------------------------------- 1 | # Adding CWEs 2 | We are always open to supporting new CWEs. We recommend any of the CWEs in the [OWASP top 25](https://cwe.mitre.org/top25/) that we don't currently support. 3 | 4 | To add a CWE, you will need to provide the CodeQL queries and add the CWE queries to `queries.py`. 5 | 6 | Typically the structure of the queries would be 7 | ``` 8 | cwe-* 9 | ├── cwe-*wLLM.ql 10 | │   11 | └── My[CodeQLCWEQueryModuleName].qll 12 | ``` 13 | `cwe-*wLLM.ql` is the wrapper query that imports the module `*.qll` file. The `*.qll` file is the module library - this is where the logic for the sources and sinks is implemented. 14 | 15 | 1. Find the CWE definition on the [Mitre CWE site](https://cwe.mitre.org/data/definitions/502.html). A strong understanding of the CWE will help you in the following steps. 16 | 2. We recommend using CodeQL's CWE queries for examples. You can find CodeQL's CWE queries in the [CodeQL github repository](https://github.com/github/codeql). In `java/ql/src/Security/CWE`, locate the CWE you're interested in adding. Within each CWE directory, locate the `.ql` file. Often there are multiple `.ql` files - a quick heuristic is to pick the `.ql` file with the most general name, and most similar to the CWE name. 17 | 18 | For example - [CWE-022](https://cwe.mitre.org/data/definitions/22.html) has `TaintedPath.ql` and `ZipSlip.ql`. We used `TaintedPath.ql`. 19 | 3. Once you've found the corresponding `.ql` file for the CWE - make note of this file. This will be the wrapper query. Within the file, there should be an import statement that refers to the module related to the CWE. Often it will be prefixed with `semmle.code.java.security` and end with `Query`. Within the CodeQL repository, find the module in `codeql/java/ql/lib/semmle/code/java/security`. 20 | 4. Within the `cwe-queries` directory of iris, create a new folder titled `cwe-[CWE number]`. Within the folder copy the `.ql` and the `.qll` files. Rename them with the prefix `My`. Within the `.qll` file - there may be multiple modules suffixed with `Config`. Find the Config that includes the `.qll` name in it - - this is where the source and sink predicates are defined. 21 | 22 | Within the module, replace the predicates with the following 23 | ``` 24 | predicate isSource(DataFlow::Node source) { 25 | isGPTDetectedSource(source) 26 | } 27 | 28 | predicate isSink(DataFlow::Node sink) { 29 | isGPTDetectedSink(sink) 30 | } 31 | 32 | predicate isBarrier(DataFlow::Node sanitizer) { 33 | sanitizer.getType() instanceof BoxedType or 34 | sanitizer.getType() instanceof PrimitiveType or 35 | sanitizer.getType() instanceof NumberType 36 | } 37 | 38 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 39 | isGPTDetectedStep(n1, n2) 40 | } 41 | ``` 42 | 43 | Also add the following imports: 44 | ``` 45 | import MySources 46 | import MySinks 47 | import MySummaries 48 | ``` 49 | 50 | Remove the former predicate definitions and anything else in the file related to the former predicates. Now in the `.ql` file, update the imports to refer to the renamed `.qll` module. 51 | 52 | 5. Now within the [`queries.py`](../src/queries.py) file, add the CWE and its queries to the `QUERIES` dictionary. Note - if the CWE is double digits - for the id use 0[number]. For example - CWE 22 would be `cwe-022`. Use the following format - we use CWE-22 as an example: 53 | ``` 54 | "cwe-[number]wLLM": { 55 | "name": "cwe-[number]wLLM", 56 | "type": "cwe-query", 57 | "cwe_id": "022", 58 | "cwe_id_short": "22", 59 | "cwe_id_tag": "CWE-22", 60 | "desc": "Path Traversal or Zip Slip", 61 | "queries": [ 62 | "cwe-queries/cwe-022/cwe-022wLLM.ql", 63 | "cwe-queries/cwe-022/MyTaintedPathQuery.qll", 64 | ], 65 | "prompts": { 66 | "cwe_id": "CWE-022", 67 | "desc": "Path Traversal or Zip Slip", 68 | "long_desc": """\ 69 | A path traversal vulnerability allows an attacker to access files \ 70 | on your web server to which they should not have access. They do this by tricking either \ 71 | the web server or the web application running on it into returning files that exist outside \ 72 | of the web root folder. Another attack pattern is that users can pass in malicious Zip file \ 73 | which may contain directories like "../". Typical sources of this vulnerability involves \ 74 | obtaining information from untrusted user input through web requests, getting entry directory \ 75 | from Zip files. Sinks will relate to file system manipulation, such as creating file, listing \ 76 | directories, and etc.""", 77 | "examples": [ 78 | { 79 | "package": "java.util.zip", 80 | "class": "ZipEntry", 81 | "method": "getName", 82 | "signature": "String getName()", 83 | "sink_args": [], 84 | "type": "source", 85 | }, 86 | { 87 | "package": "java.io", 88 | "class": "FileInputStream", 89 | "method": "FileInputStream", 90 | "signature": "FileInputStream(File file)", 91 | "sink_args" : ["file"], 92 | "type": "sink", 93 | }, 94 | { 95 | "package": "java.net", 96 | "class": "URL", 97 | "method": "URL", 98 | "signature": "URL(String url)", 99 | "sink_args": [], 100 | "type": "taint-propagator", 101 | }, 102 | { 103 | "package": "java.io", 104 | "class": "File", 105 | "method": "File", 106 | "signature": "File(String path)", 107 | "sink_args": [], 108 | "type": "taint-propagator", 109 | }, 110 | ] 111 | } 112 | }, 113 | ``` 114 | 115 | For the `long_desc` key - look up definitions of the CWE and find a clear description that summarizes what the CWE is and how it's exploited. 116 | 117 | For the examples, you will need to provide sources and sinks. A CodeQL source is a value that an attacker can use for malicious operations within a system. A CodeQL sink is a program point that accepts a malicious source, and ends up using the malicious data. You can use the [Github Advisory Database](https://github.com/advisories) to find examples of the CWE. Or the definition may provide common abstractions which you can then search for Java's most used libraries for the related abstraction. 118 | 119 | 6. Add a hint related to CWE for contextual analysis prompt in [`prompts.py`](../src/prompts.py). Hints are stored in `POSTHOC_FILTER_HINTS`. The key should be the CWE number and the value include sentences that describe extra details to look out for when detecting the CWE. Sites that have definitions for the CWE will often have more specific guidance on the CWE. 120 | 121 | 6. Test out the query. You can provide the --test-run parameter when running `neusym_vul.py` to see if the CodeQL queries compile. Afterwards, you can try a test run with a small model on one of the Java projects associated with the CWE. The [GitHub Advisory Database](https://github.com/advisories) is an easy way to find a vulnerable project given the CWE. -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: iris 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | dependencies: 7 | - python=3.10 8 | - pandas 9 | - requests 10 | - tqdm 11 | - openai 12 | - accelerate 13 | - transformers 14 | - pytorch=2.5 15 | -------------------------------------------------------------------------------- /iris_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iris-sast/iris/0355004de110ecc425b8ca45024c6b4465ca1c2e/iris_arch.png -------------------------------------------------------------------------------- /output/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iris-sast/iris/0355004de110ecc425b8ca45024c6b4465ca1c2e/output/.keep -------------------------------------------------------------------------------- /scripts/build_codeql_dbs.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import argparse 4 | import subprocess 5 | from pathlib import Path 6 | import sys 7 | sys.path.append(str(Path(__file__).parent.parent)) 8 | 9 | from src.config import CWE_BENCH_JAVA_DIR, CODEQL_DB_PATH, PROJECT_SOURCE_CODE_DIR 10 | 11 | def verify_java_installation(java_home): 12 | if not os.path.exists(java_home): 13 | raise Exception(f"JAVA_HOME directory does not exist: {java_home}") 14 | 15 | java_exe = os.path.join(java_home, 'bin', 'java') 16 | if not os.path.exists(java_exe): 17 | raise Exception(f"Java executable not found at: {java_exe}") 18 | 19 | javac_exe = os.path.join(java_home, 'bin', 'javac') 20 | if not os.path.exists(javac_exe): 21 | raise Exception(f"Javac executable not found at: {javac_exe}") 22 | 23 | def verify_maven_installation(maven_path): 24 | if not os.path.exists(maven_path): 25 | raise Exception(f"Maven directory does not exist: {maven_path}") 26 | 27 | mvn_exe = os.path.join(maven_path, 'mvn') 28 | if not os.path.exists(mvn_exe): 29 | raise Exception(f"Maven executable not found at: {mvn_exe}") 30 | 31 | def find_java_home(java_version, java_env_path): 32 | """ 33 | Find the appropriate Java home directory based on version and available installations. 34 | 35 | Args: 36 | java_version (str): Version string from build_info.csv 37 | java_env_path (str): Base path for Java installations 38 | 39 | Returns: 40 | str: Path to the appropriate Java installation 41 | """ 42 | if 'u' in java_version: 43 | # Handle Java 7 and 8 style versions (e.g., 8u202 -> jdk1.8.0_202) 44 | main_ver = java_version.split('u')[0] 45 | update_ver = java_version.split('u')[1] 46 | java_home = os.path.abspath(os.path.join(java_env_path, f"jdk1.{main_ver}.0_{update_ver}")) 47 | else: 48 | # Handle Java 9+ style versions 49 | # First try exact match (e.g., jdk-17) 50 | java_home = os.path.abspath(os.path.join(java_env_path, f"jdk-{java_version}")) 51 | 52 | if not os.path.exists(java_home): 53 | # Try finding a matching directory with a more specific version 54 | possible_dirs = [d for d in os.listdir(java_env_path) 55 | if d.startswith(f"jdk-{java_version}")] 56 | if possible_dirs: 57 | # Use the first matching directory 58 | java_home = os.path.abspath(os.path.join(java_env_path, possible_dirs[0])) 59 | 60 | return java_home 61 | 62 | def setup_environment(row, java_env_path): 63 | env = os.environ.copy() 64 | 65 | # Set Maven path if available 66 | if row['mvn_version'] != 'n/a': 67 | maven_path = os.path.abspath(os.path.join(java_env_path, f"apache-maven-{row['mvn_version']}/bin")) 68 | verify_maven_installation(maven_path) 69 | env['PATH'] = f"{maven_path}:{env.get('PATH', '')}" 70 | print(f"Maven path set to: {maven_path}") 71 | 72 | # Find and set Java home 73 | java_version = row['jdk_version'] 74 | java_home = find_java_home(java_version, java_env_path) 75 | 76 | verify_java_installation(java_home) 77 | env['JAVA_HOME'] = java_home 78 | print(f"JAVA_HOME set to: {java_home}") 79 | 80 | # Add Java binary to PATH 81 | env['PATH'] = f"{os.path.join(java_home, 'bin')}:{env.get('PATH', '')}" 82 | 83 | return env 84 | 85 | def create_codeql_database(project_slug, env, db_base_path, sources_base_path): 86 | print("\nEnvironment variables for CodeQL database creation:") 87 | print(f"PATH: {env.get('PATH', 'Not set')}") 88 | print(f"JAVA_HOME: {env.get('JAVA_HOME', 'Not set')}") 89 | 90 | try: 91 | java_version = subprocess.check_output(['java', '-version'], 92 | stderr=subprocess.STDOUT, 93 | env=env).decode() 94 | print(f"\nJava version check:\n{java_version}") 95 | except subprocess.CalledProcessError as e: 96 | print(f"Error checking Java version: {e}") 97 | raise 98 | 99 | database_path = os.path.abspath(os.path.join(db_base_path, project_slug)) 100 | source_path = os.path.abspath(os.path.join(sources_base_path, project_slug)) 101 | 102 | Path(database_path).parent.mkdir(parents=True, exist_ok=True) 103 | 104 | command = [ 105 | "codeql", "database", "create", 106 | database_path, 107 | "--source-root", source_path, 108 | "--language", "java", 109 | "--overwrite" 110 | ] 111 | 112 | try: 113 | print(f"Creating database at: {database_path}") 114 | print(f"Using source path: {source_path}") 115 | print(f"Using JAVA_HOME: {env.get('JAVA_HOME', 'Not set')}") 116 | subprocess.run(command, env=env, check=True) 117 | print(f"Successfully created CodeQL database for {project_slug}") 118 | except subprocess.CalledProcessError as e: 119 | print(f"Error creating CodeQL database for {project_slug}: {e}") 120 | raise 121 | 122 | def main(): 123 | parser = argparse.ArgumentParser(description='Create CodeQL databases for cwe-bench-java projects') 124 | parser.add_argument('--project', help='Specific project slug', default=None) 125 | parser.add_argument('--db-path', help='Base path for storing CodeQL databases', default=CODEQL_DB_PATH) 126 | parser.add_argument('--sources-path', help='Base path for project sources', default=PROJECT_SOURCE_CODE_DIR) 127 | parser.add_argument('--cwe-bench-java-path', help='Base path to cwe-bench-java', default=CWE_BENCH_JAVA_DIR) 128 | args = parser.parse_args() 129 | 130 | cwe_bench_java_path = os.path.abspath(args.cwe_bench_java_path) 131 | csv_path = os.path.join(cwe_bench_java_path, "data", "build_info.csv") 132 | java_env_path = os.path.join(cwe_bench_java_path, "java-env") 133 | 134 | with open(csv_path, 'r') as f: 135 | reader = csv.DictReader(f) 136 | projects = list(reader) 137 | 138 | if args.project: 139 | project = next((p for p in projects if p['project_slug'] == args.project), None) 140 | if project: 141 | env = setup_environment(project, java_env_path) 142 | create_codeql_database(project['project_slug'], env, args.db_path, args.sources_path) 143 | else: 144 | print(f"Project {args.project} not found in CSV file") 145 | else: 146 | for project in projects: 147 | env = setup_environment(project, java_env_path) 148 | create_codeql_database(project['project_slug'], env, args.db_path, args.sources_path) 149 | 150 | if __name__ == "__main__": 151 | main() -------------------------------------------------------------------------------- /scripts/codeql-queries/packages.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Find All Packages 3 | * @description Lists all packages in the Java project along with their class counts 4 | * @kind table 5 | * @id java/all-packages 6 | * @tags packages 7 | * inventory 8 | */ 9 | 10 | import java 11 | 12 | from Package p 13 | where p.fromSource() 14 | select p.getName() 15 | 16 | -------------------------------------------------------------------------------- /scripts/codeql-queries/qlpack.yml: -------------------------------------------------------------------------------- 1 | name: custom-queries 2 | version: 0.0.1 3 | dependencies: 4 | codeql/java-all: "*" 5 | -------------------------------------------------------------------------------- /scripts/get_packages_codeql.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import re 4 | import argparse 5 | from pathlib import Path 6 | import xml.etree.ElementTree as ET 7 | import sys 8 | sys.path.append(str(Path(__file__).parent.parent)) 9 | from src.config import PROJECT_SOURCE_CODE_DIR 10 | 11 | def run_codeql_query(db_path, query_path): 12 | try: 13 | subprocess.run( 14 | ["codeql", "query", "run", "--database", str(db_path), 15 | "--output=results.bqrs", str(query_path)], 16 | check=True, capture_output=True, text=True 17 | ) 18 | result = subprocess.run( 19 | ["codeql", "bqrs", "decode", "--format=csv", "results.bqrs"], 20 | capture_output=True, text=True, check=True 21 | ) 22 | packages = set() 23 | rows = result.stdout.strip().split("\n")[1:] 24 | for line in rows: 25 | parts = line.split(",") 26 | if len(parts) >= 1: 27 | pkg_name = parts[0].strip().strip('"') 28 | packages.add(pkg_name) 29 | return packages 30 | except subprocess.CalledProcessError as e: 31 | print(f"Error running CodeQL command: {e}") 32 | if e.stderr: 33 | print(f"Error output: {e.stderr}") 34 | return set() 35 | 36 | def find_maven_group_id(project_dir): 37 | pom_paths = [ 38 | Path(project_dir) / "pom.xml", 39 | ] 40 | 41 | for pom_path in pom_paths: 42 | if pom_path.exists(): 43 | try: 44 | tree = ET.parse(pom_path) 45 | root = tree.getroot() 46 | 47 | ns = {'mvn': re.findall(r'{(.*)}', root.tag)[0]} if '{' in root.tag else {} 48 | 49 | if ns: 50 | group_id = root.find('./mvn:groupId', ns) 51 | else: 52 | group_id = root.find('./groupId') 53 | 54 | if group_id is not None: 55 | return group_id.text 56 | 57 | if ns: 58 | parent = root.find('./mvn:parent', ns) 59 | if parent is not None: 60 | group_id = parent.find('./mvn:groupId', ns) 61 | else: 62 | parent = root.find('./parent') 63 | if parent is not None: 64 | group_id = parent.find('./groupId') 65 | 66 | if group_id is not None: 67 | return group_id.text 68 | 69 | except Exception as e: 70 | print(f"Error parsing pom.xml: {e}") 71 | 72 | return None 73 | 74 | def find_gradle_group_id(project_dir): 75 | gradle_paths = [ 76 | Path(project_dir) / "build.gradle", 77 | Path(project_dir) / "build.gradle.kts", 78 | ] 79 | 80 | for gradle_path in gradle_paths: 81 | if gradle_path.exists(): 82 | try: 83 | with open(gradle_path, 'r') as f: 84 | content = f.read() 85 | 86 | match = re.search(r'group\s*=\s*[\'"]([^\'"]+)[\'"]', content) 87 | if match: 88 | return match.group(1) 89 | 90 | match = re.search(r'group\s*=\s*"([^"]+)"', content) 91 | if match: 92 | return match.group(1) 93 | 94 | except Exception as e: 95 | print(f"Error parsing Gradle file: {e}") 96 | 97 | return None 98 | 99 | def filter_internal_packages(packages, internal_package): 100 | internal_packages = [] 101 | 102 | for pkg_name in packages: 103 | if pkg_name.startswith(internal_package): 104 | internal_packages += [pkg_name] 105 | 106 | return internal_packages 107 | 108 | def main(): 109 | parser = argparse.ArgumentParser(description="Extract internal packages from a Java project") 110 | parser.add_argument("project_name", help="Name of the project") 111 | parser.add_argument("--internal-package", help="Base package name for internal packages (e.g., 'org.keycloak')") 112 | 113 | args = parser.parse_args() 114 | project_name = args.project_name 115 | internal_package = args.internal_package 116 | 117 | iris_root = Path(__file__).parent.parent 118 | output_file = iris_root / "data" / "cwe-bench-java" / "package-names" / f"{project_name}.txt" 119 | query_path = iris_root / "scripts" / "codeql-queries" / "packages.ql" 120 | db_path = iris_root / "data" / "codeql-dbs" / project_name 121 | project_path = Path(PROJECT_SOURCE_CODE_DIR) / project_name 122 | 123 | if not internal_package: 124 | print("Internal package not specified, trying to detect from build files...") 125 | 126 | internal_package = find_maven_group_id(project_path) 127 | if internal_package: 128 | print(f"Found Maven groupId: {internal_package}") 129 | else: 130 | internal_package = find_gradle_group_id(project_path) 131 | if internal_package: 132 | print(f"Found Gradle group: {internal_package}") 133 | 134 | if not internal_package: 135 | print("Error: Could not detect internal package name.") 136 | print("Please specify it with --internal-package (e.g., --internal-package org.keycloak)") 137 | sys.exit(1) 138 | 139 | print(f"Running CodeQL query for all packages in {project_name}...") 140 | all_packages = run_codeql_query(db_path, query_path) 141 | if not all_packages: 142 | print("No packages found or CodeQL query failed.") 143 | return 144 | print(f"Found {len(all_packages)} total packages.") 145 | 146 | internal_packages = filter_internal_packages(all_packages, internal_package) 147 | excluded_packages = [pkg for pkg in all_packages if pkg not in internal_packages] 148 | print("Excluded packages:", excluded_packages) 149 | output_file.parent.mkdir(parents=True, exist_ok=True) 150 | with open(output_file, "w") as f: 151 | for package in sorted(internal_packages): 152 | f.write(f"{package}\n") 153 | 154 | print(f"Results written to {output_file}") 155 | 156 | Path("results.bqrs").unlink(missing_ok=True) 157 | 158 | if __name__ == "__main__": 159 | main() 160 | -------------------------------------------------------------------------------- /scripts/setup_environment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: This script is designed to be executed from the 'iris-dev' root directory. 4 | # To run: ./scripts/setup_environment.sh 5 | 6 | # Exit on error 7 | set -e 8 | 9 | echo "Starting setup process..." 10 | 11 | # Check if conda is installed 12 | if ! command -v conda &> /dev/null; then 13 | echo "Error: conda is not installed. Please install conda first." 14 | exit 1 15 | fi 16 | 17 | # Create conda environment from environment.yml 18 | echo "Creating conda environment 'iris' from environment.yml..." 19 | if [ ! -f "environment.yml" ]; then 20 | echo "Error: environment.yml not found in current directory" 21 | exit 1 22 | fi 23 | 24 | # Remove existing environment if it exists 25 | conda env remove -n iris 2>/dev/null || true 26 | 27 | # Create new environment 28 | conda env create -f environment.yml 29 | 30 | # Create necessary directories 31 | echo "Creating directories..." 32 | PROJECT_ROOT=$(pwd) 33 | CODEQL_DIR="$PROJECT_ROOT/codeql" 34 | mkdir -p "$CODEQL_DIR" 35 | mkdir -p "$PROJECT_ROOT/data/codeql-dbs" 36 | 37 | echo "Downloading patched CodeQL..." 38 | CODEQL_URL="https://github.com/iris-sast/iris/releases/download/codeql-0.8.3-patched/codeql.zip" 39 | CODEQL_ZIP="codeql.zip" 40 | if ! curl -L -o "$CODEQL_ZIP" "$CODEQL_URL"; then 41 | echo "Error: Failed to download CodeQL" 42 | exit 1 43 | fi 44 | 45 | echo "Extracting CodeQL..." 46 | TEMP_DIR="$PROJECT_ROOT/temp_codeql_extract" 47 | mkdir -p "$TEMP_DIR" 48 | 49 | if ! unzip -qo "$CODEQL_ZIP" -d "$TEMP_DIR"; then 50 | echo "Error: Failed to extract CodeQL" 51 | rm -f "$CODEQL_ZIP" 52 | rm -rf "$TEMP_DIR" 53 | exit 1 54 | fi 55 | 56 | mv "$TEMP_DIR/codeql"/* "$CODEQL_DIR" 57 | rm -rf "$TEMP_DIR" 58 | rm -f "$CODEQL_ZIP" 59 | 60 | echo "export PATH=\"$CODEQL_DIR:$PATH\"" >> ~/.bashrc 61 | export PATH="$CODEQL_DIR:$PATH" 62 | 63 | echo "Setup completed successfully!" 64 | echo "- Conda environment 'iris' has been created" 65 | echo "- CodeQL has been downloaded and extracted to $CODEQL_DIR" 66 | echo "- Created '$PROJECT_ROOT/data/codeql-dbs' directory" 67 | echo "- Added CodeQL to PATH in ~/.bashrc" 68 | echo "" 69 | echo "To activate the environment, run: conda activate iris" 70 | echo "You may need to restart your terminal or run 'source ~/.bashrc' for PATH changes to take effect" 71 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iris-sast/iris/0355004de110ecc425b8ca45024c6b4465ca1c2e/src/__init__.py -------------------------------------------------------------------------------- /src/codeql_queries.py: -------------------------------------------------------------------------------- 1 | QL_SOURCE_PREDICATE = """\ 2 | import java 3 | import semmle.code.java.dataflow.DataFlow 4 | private import semmle.code.java.dataflow.ExternalFlow 5 | 6 | predicate isGPTDetectedSource(DataFlow::Node src) {{ 7 | {body} 8 | }} 9 | 10 | {additional} 11 | """ 12 | 13 | QL_SINK_PREDICATE = """\ 14 | import java 15 | import semmle.code.java.dataflow.DataFlow 16 | private import semmle.code.java.dataflow.ExternalFlow 17 | 18 | predicate isGPTDetectedSink(DataFlow::Node snk) {{ 19 | {body} 20 | }} 21 | 22 | {additional} 23 | """ 24 | 25 | QL_SUBSET_PREDICATE = """\ 26 | predicate isGPTDetected{kind}Part{part_id}(DataFlow::Node {node}) {{ 27 | {body} 28 | }} 29 | """ 30 | 31 | CALL_QL_SUBSET_PREDICATE = " isGPTDetected{kind}Part{part_id}({node})" 32 | 33 | QL_STEP_PREDICATE = """\ 34 | import java 35 | import semmle.code.java.dataflow.DataFlow 36 | private import semmle.code.java.dataflow.ExternalFlow 37 | 38 | predicate isGPTDetectedStep(DataFlow::Node prev, DataFlow::Node next) {{ 39 | {body} 40 | }} 41 | """ 42 | 43 | QL_METHOD_CALL_SOURCE_BODY_ENTRY = """\ 44 | ( 45 | src.asExpr().(Call).getCallee().getName() = "{method}" and 46 | src.asExpr().(Call).getCallee().getDeclaringType().getSourceDeclaration().hasQualifiedName("{package}", "{clazz}") 47 | )\ 48 | """ 49 | 50 | QL_FUNC_PARAM_SOURCE_ENTRY = """\ 51 | exists(Parameter p | 52 | src.asParameter() = p and 53 | p.getCallable().getName() = "{method}" and 54 | p.getCallable().getDeclaringType().getSourceDeclaration().hasQualifiedName("{package}", "{clazz}") and 55 | ({params}) 56 | )\ 57 | """ 58 | 59 | QL_FUNC_PARAM_NAME_ENTRY = """ p.getName() = "{arg_name}" """ 60 | 61 | QL_SUMMARY_BODY_ENTRY = """\ 62 | exists(Call c | 63 | (c.getArgument(_) = prev.asExpr() or c.getQualifier() = prev.asExpr()) 64 | and c.getCallee().getDeclaringType().hasQualifiedName("{package}", "{clazz}") 65 | and c.getCallee().getName() = "{method}" 66 | and c = next.asExpr() 67 | )\ 68 | """ 69 | 70 | QL_SINK_BODY_ENTRY = """\ 71 | exists(Call c | 72 | c.getCallee().getName() = "{method}" and 73 | c.getCallee().getDeclaringType().getSourceDeclaration().hasQualifiedName("{package}", "{clazz}") and 74 | ({args}) 75 | )\ 76 | """ 77 | 78 | QL_SINK_ARG_NAME_ENTRY = """ c.getArgument({arg_id}) = snk.asExpr().(Argument) """ 79 | 80 | QL_SINK_ARG_THIS_ENTRY = """ c.getQualifier() = snk.asExpr() """ 81 | 82 | QL_BODY_OR_SEPARATOR = "\n or\n" 83 | 84 | EXTENSION_YML_TEMPLATE = """\ 85 | extensions: 86 | - addsTo: 87 | pack: codeql/java-all 88 | extensible: sinkModel 89 | data: 90 | {sinks} 91 | - addsTo: 92 | pack: codeql/java-all 93 | extensible: sourceModel 94 | data: 95 | {sources} 96 | """ 97 | 98 | EXTENSION_SRC_SINK_YML_ENTRY = """\ 99 | - ["{package}", "{clazz}", True, "{method}", "", "", "{access}", "{tag}", "manual"]\ 100 | """ 101 | 102 | EXTENSION_SUMMARY_YML_ENTRY = """\ 103 | - ["{package}", "{clazz}", True, "{method}", "", "", "{access_in}", "{access_out}", "{tag}", "manual"]\ 104 | """ 105 | -------------------------------------------------------------------------------- /src/codeql_vul.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess as sp 4 | import pandas as pd 5 | import shutil 6 | import json 7 | import re 8 | import argparse 9 | import numpy as np 10 | import copy 11 | import math 12 | import random 13 | 14 | import requests 15 | from tqdm import tqdm 16 | from tqdm.contrib.concurrent import thread_map 17 | 18 | THIS_SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) 19 | NEUROSYMSA_ROOT_DIR = os.path.abspath(f"{THIS_SCRIPT_DIR}/../../") 20 | sys.path.append(NEUROSYMSA_ROOT_DIR) 21 | 22 | from src.config import CODEQL_DIR, CODEQL_DB_PATH, OUTPUT_DIR, ALL_METHOD_INFO_DIR, PROJECT_SOURCE_CODE_DIR, CVES_MAPPED_W_COMMITS_DIR 23 | 24 | 25 | from src.logger import Logger 26 | from src.queries import QUERIES 27 | from src.prompts import API_LABELLING_SYSTEM_PROMPT, API_LABELLING_USER_PROMPT 28 | from src.prompts import FUNC_PARAM_LABELLING_SYSTEM_PROMPT, FUNC_PARAM_LABELLING_USER_PROMPT 29 | from src.prompts import POSTHOC_FILTER_SYSTEM_PROMPT, POSTHOC_FILTER_USER_PROMPT, POSTHOC_FILTER_HINTS, SNIPPET_CONTEXT_SIZE 30 | from src.codeql_queries import QL_SOURCE_PREDICATE, QL_STEP_PREDICATE, QL_SINK_PREDICATE 31 | from src.codeql_queries import EXTENSION_YML_TEMPLATE, EXTENSION_SRC_SINK_YML_ENTRY, EXTENSION_SUMMARY_YML_ENTRY 32 | from src.codeql_queries import QL_METHOD_CALL_SOURCE_BODY_ENTRY, QL_FUNC_PARAM_SOURCE_ENTRY, QL_FUNC_PARAM_NAME_ENTRY 33 | from src.codeql_queries import QL_SUMMARY_BODY_ENTRY, QL_BODY_OR_SEPARATOR 34 | from src.codeql_queries import QL_SINK_BODY_ENTRY, QL_SINK_ARG_NAME_ENTRY, QL_SINK_ARG_THIS_ENTRY 35 | 36 | from src.modules.codeql_query_runner import CodeQLQueryRunner 37 | from src.modules.evaluation_pipeline import EvaluationPipeline 38 | 39 | 40 | class CodeQLSAPipeline: 41 | def __init__( 42 | self, 43 | project_name: str, 44 | query: str, 45 | evaluation_only: bool = False, 46 | overwrite: bool = False 47 | ): 48 | # Store basic information 49 | self.project_name = project_name 50 | self.query = query 51 | self.evaluation_only = evaluation_only 52 | self.overwrite = overwrite 53 | 54 | # Setup logger 55 | self.master_logger = Logger(f"{NEUROSYMSA_ROOT_DIR}/log") 56 | 57 | # Check if the query is valid 58 | if self.query in QUERIES: 59 | if "cwe_id" not in QUERIES[self.query]: 60 | self.master_logger.info(f"Processing {self.project_name} (Query: {self.query}, Trial: {self.run_id})...") 61 | self.master_logger.error(f"==> Query `{self.query}` is not a query for detecting CWE; aborting"); exit(1) 62 | else: 63 | self.master_logger.info(f"Processing {self.project_name} (Query: {self.query}, Trial: {self.run_id})...") 64 | self.master_logger.error(f"==> Unknown query `{self.query}`; aborting"); exit(1) 65 | self.cwe_id = QUERIES[self.query]["cwe_id"] 66 | self.experimental = QUERIES[self.query]["experimental"] 67 | self.cve_id = project_name.split("_")[3] 68 | 69 | # Load some basic information, such as commits and fixes related to the CVE 70 | self.all_cves_with_commit = pd.read_csv(CVES_MAPPED_W_COMMITS_DIR) 71 | self.project_cve_with_commit_info = self.all_cves_with_commit[self.all_cves_with_commit["cve"] == self.cve_id].iloc[0] 72 | self.cve_fixing_commits = self.project_cve_with_commit_info["commits"].split(";") 73 | self.fixed_methods = pd.read_csv(ALL_METHOD_INFO_DIR) 74 | self.project_fixed_methods = self.fixed_methods[self.fixed_methods["db_name"] == self.project_name] 75 | self.project_source_code_dir = f"{PROJECT_SOURCE_CODE_DIR}/{self.project_name}" 76 | 77 | # Basic path information 78 | self.project_output_path = f"{OUTPUT_DIR}/{self.project_name}/common" 79 | 80 | # Setup codeql database path 81 | self.project_codeql_db_path = f"{CODEQL_DB_PATH}/{self.project_name}" 82 | if not os.path.exists(f"{self.project_codeql_db_path}/db-java"): 83 | self.master_logger.info(f"Processing {self.project_name} (Query: {self.query}...") 84 | self.master_logger.error(f"==> Cannot find CodeQL database for {self.project_name}; aborting"); exit(1) 85 | 86 | # Setup query output path 87 | self.query_output_path = f"{self.project_output_path}/{self.query}" 88 | os.makedirs(self.query_output_path, exist_ok=True) 89 | self.query_output_result_sarif_path = f"{self.query_output_path}/results.sarif" 90 | self.query_output_result_csv_path = f"{self.query_output_path}/results.csv" 91 | self.final_output_json_path = f"{self.query_output_path}/results.json" 92 | 93 | # Function and Class locations 94 | self.func_locs_path = f"{self.project_output_path}/fetch_func_locs/results.csv" 95 | self.class_locs_path = f"{self.project_output_path}/fetch_class_locs/results.csv" 96 | 97 | def run_codeql_query(self): 98 | self.master_logger.info("==> Stage 1: Running CodeQL queries...") 99 | 100 | exp = "experimental/" if self.experimental else "" 101 | 102 | cmd = [ 103 | "codeql", 104 | "database", 105 | "analyze", 106 | self.project_codeql_db_path, 107 | f"--output={self.query_output_result_sarif_path}", 108 | f"{CODEQL_DIR}/qlpacks/codeql/java-queries/0.8.3/{exp}Security/CWE/CWE-{self.cwe_id}/" 109 | ] 110 | 111 | if self.overwrite: 112 | cmd += ["--rerun"] 113 | 114 | sp.run(cmd + ["--format=sarif-latest"]) 115 | 116 | sp.run(cmd + ["--format=csv"]) 117 | 118 | def run_simple_codeql_query(self, query, target_csv_path=None, suffix=None, dyn_queries={}): 119 | runner = CodeQLQueryRunner(self.project_name, self.project_output_path, self.project_codeql_db_path, self.master_logger) 120 | runner.run(query, target_csv_path, suffix, dyn_queries) 121 | 122 | def extract_class_locations(self): 123 | if not os.path.exists(self.class_locs_path): 124 | self.master_logger.info(f" ==> Class locations not found; running CodeQL query to extract...") 125 | self.run_simple_codeql_query("fetch_class_locs") 126 | 127 | def extract_func_locations(self): 128 | if not os.path.exists(self.func_locs_path): 129 | self.master_logger.info(f" ==> Function locations not found; running CodeQL query to extract...") 130 | self.run_simple_codeql_query("fetch_func_locs") 131 | 132 | def build_evaluation_pipeline(self): 133 | return EvaluationPipeline( 134 | self.project_fixed_methods, 135 | self.class_locs_path, 136 | self.func_locs_path, 137 | self.project_source_code_dir, 138 | query_output_result_sarif_path=self.query_output_result_sarif_path, 139 | final_output_json_path=self.final_output_json_path, 140 | overwrite=self.overwrite, 141 | project_logger=self.master_logger, 142 | ) 143 | 144 | def evaluate_result(self): 145 | self.master_logger.info("==> Stage 2: Evaluating results...") 146 | 147 | # 1. Extract class and function locations 148 | self.master_logger.info(" ==> Extracting function and class locations...") 149 | self.extract_class_locations() 150 | self.extract_func_locations() 151 | 152 | # 2. Build 153 | self.master_logger.info(" ==> Evaluating results...") 154 | eval_pipeline = self.build_evaluation_pipeline() 155 | eval_pipeline.run_vanilla_only() 156 | 157 | def run(self): 158 | if self.evaluation_only: 159 | self.evaluate_result() 160 | else: 161 | self.run_codeql_query() 162 | self.evaluate_result() 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("project", type=str) 168 | parser.add_argument("--query", type=str, default="cwe-022wCodeQL", required=True) 169 | parser.add_argument("--overwrite", action="store_true") 170 | parser.add_argument("--evaluation-only", action="store_true") 171 | args = parser.parse_args() 172 | 173 | pipeline = CodeQLSAPipeline( 174 | args.project, 175 | args.query, 176 | evaluation_only=args.evaluation_only, 177 | overwrite=args.overwrite, 178 | ) 179 | pipeline.run() 180 | -------------------------------------------------------------------------------- /src/codeql_vul_for_query.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import subprocess as sp 5 | import pandas as pd 6 | 7 | THIS_SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | NEUROSYMSA_ROOT_DIR = os.path.abspath(f"{THIS_SCRIPT_DIR}/../../") 9 | sys.path.append(NEUROSYMSA_ROOT_DIR) 10 | 11 | from src.config import CVES_MAPPED_W_COMMITS_DIR, CVE_REPO_TAGS_DIR 12 | 13 | 14 | from src.queries import QUERIES 15 | 16 | def collect_projects_for_query(query, cwe_id, all_cves_with_commit, all_project_tags): 17 | for (_, proj_row) in all_cves_with_commit.iterrows(): 18 | # Check relevance 19 | if cwe_id not in proj_row["cwe"].split(";"): 20 | continue 21 | cve_id = proj_row["cve"] 22 | relevant_project_tag = all_project_tags[all_project_tags["cve"] == cve_id] 23 | if len(relevant_project_tag) == 0: 24 | continue 25 | project_name = relevant_project_tag.iloc[0]["project"] 26 | yield project_name 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("query", type=str, default="cwe-022wCodeQL") 31 | parser.add_argument("--evaluation-only", action="store_true") 32 | parser.add_argument("--overwrite", action="store_true") 33 | args = parser.parse_args() 34 | 35 | query = args.query 36 | if query not in QUERIES: 37 | print(f"Unknown query {query}") 38 | if "cwe_id_tag" not in QUERIES[query]: 39 | print(f"Not a CWE related query: {query}") 40 | cwe_id = QUERIES[query]["cwe_id_tag"] 41 | 42 | all_cves_with_commit = pd.read_csv(CVES_MAPPED_W_COMMITS_DIR).dropna(subset=["cwe", "cve", "commits"]) 43 | all_project_tags = pd.read_csv(CVE_REPO_TAGS_DIR).dropna(subset=["project", "cve", "tag"]) 44 | 45 | relevant_projects = list(collect_projects_for_query(query, cwe_id, all_cves_with_commit, all_project_tags)) 46 | 47 | for (i, project) in enumerate(relevant_projects): 48 | print("===========================================") 49 | print(f"[{i + 1}/{len(relevant_projects)}] STARTING RUNNING ON PROJECT: {project}") 50 | 51 | # Generate the command 52 | command = ["python", f"{THIS_SCRIPT_DIR}/codeql_vul.py", project, "--query", query] 53 | if args.evaluation_only: command += ["--evaluation-only"] 54 | if args.overwrite: command += ["--overwrite"] 55 | 56 | # Run the command 57 | sp.run(command) 58 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | IRIS_ROOT_DIR = os.path.join(os.path.dirname(__file__), "..") 4 | 5 | # CODEQL_DIR should be the path of the patched version of CodeQL provided as a download in the releases section for Iris. 6 | CODEQL_DIR = f"{IRIS_ROOT_DIR}/codeql" 7 | 8 | # CODEQL_DB_PATH is the path to the directory that contains CodeQL databases. 9 | CODEQL_DB_PATH = f"{IRIS_ROOT_DIR}/data/codeql-dbs" 10 | 11 | # PROJECT_SOURCE_CODE_DIR contains the Java projects. 12 | PROJECT_SOURCE_CODE_DIR = f"{IRIS_ROOT_DIR}/data/cwe-bench-java/project-sources" 13 | 14 | # PACKAGE_MODULES_PATH contains each project's internal modules. 15 | PACKAGE_MODULES_PATH = f"{IRIS_ROOT_DIR}/data/cwe-bench-java/package-names" 16 | 17 | # OUTPUT_DIR is where the results from running Iris are stored. 18 | OUTPUT_DIR = f"{IRIS_ROOT_DIR}/output" 19 | 20 | # ALL_METHOD_INFO_DIR 21 | ALL_METHOD_INFO_DIR = f"{IRIS_ROOT_DIR}/data/cwe-bench-java/data/fix_info.csv" 22 | 23 | # CVES_MAPPED_W_COMMITS_DIR is the path to project_info.csv, which contains the mapping of vulnerabilities to projects in cwe-bench-java. 24 | CVES_MAPPED_W_COMMITS_DIR = f"{IRIS_ROOT_DIR}/data/cwe-bench-java/data/project_info.csv" 25 | 26 | # Path to cwe-bench-java directory submodule. 27 | CWE_BENCH_JAVA_DIR = f"{IRIS_ROOT_DIR}/data/cwe-bench-java" -------------------------------------------------------------------------------- /src/cwe-queries/GeneralQuerywLLM.ql: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iris-sast/iris/0355004de110ecc425b8ca45024c6b4465ca1c2e/src/cwe-queries/GeneralQuerywLLM.ql -------------------------------------------------------------------------------- /src/cwe-queries/cwe-022/MyTaintedPathQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides dataflow configurations for tainted path queries. */ 2 | 3 | import java 4 | import semmle.code.java.frameworks.Networking 5 | import semmle.code.java.dataflow.DataFlow 6 | import semmle.code.java.dataflow.FlowSources 7 | private import semmle.code.java.dataflow.ExternalFlow 8 | import semmle.code.java.security.PathSanitizer 9 | 10 | import semmle.code.java.security.TaintedPathQuery 11 | import MySources 12 | import MySinks 13 | import MySummaries 14 | 15 | /** 16 | * A taint-tracking configuration for tracking flow from remote sources to the creation of a path. 17 | */ 18 | module MyTaintedPathConfig implements DataFlow::ConfigSig { 19 | predicate isSource(DataFlow::Node source) { 20 | isGPTDetectedSource(source) 21 | } 22 | 23 | predicate isSink(DataFlow::Node sink) { 24 | isGPTDetectedSink(sink) 25 | } 26 | 27 | predicate isBarrier(DataFlow::Node sanitizer) { 28 | sanitizer.getType() instanceof BoxedType or 29 | sanitizer.getType() instanceof PrimitiveType or 30 | sanitizer.getType() instanceof NumberType or 31 | sanitizer instanceof PathInjectionSanitizer 32 | } 33 | 34 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 35 | isGPTDetectedStep(n1, n2) 36 | } 37 | } 38 | 39 | /** Tracks flow from remote sources to the creation of a path. */ 40 | module MyTaintedPathFlow = TaintTracking::Global; 41 | 42 | 43 | /** 44 | * A taint-tracking configuration for tracking flow from remote sources to the creation of a path. 45 | */ 46 | module MyTaintedPathSinksOnlyConfig implements DataFlow::ConfigSig { 47 | predicate isSource(DataFlow::Node source) { 48 | source instanceof ThreatModelFlowSource 49 | } 50 | 51 | predicate isSink(DataFlow::Node sink) { 52 | isGPTDetectedSink(sink) 53 | } 54 | 55 | predicate isBarrier(DataFlow::Node sanitizer) { 56 | sanitizer.getType() instanceof BoxedType or 57 | sanitizer.getType() instanceof PrimitiveType or 58 | sanitizer.getType() instanceof NumberType or 59 | sanitizer instanceof PathInjectionSanitizer 60 | } 61 | 62 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 63 | any(TaintedPathAdditionalTaintStep s).step(n1, n2) 64 | } 65 | } 66 | 67 | /** Tracks flow from remote sources to the creation of a path. */ 68 | module MyTaintedPathFlowSinksOnly = TaintTracking::Global; 69 | 70 | 71 | /** 72 | * A taint-tracking configuration for tracking flow from remote sources to the creation of a path. 73 | */ 74 | module MyTaintedPathSourcesOnlyConfig implements DataFlow::ConfigSig { 75 | predicate isSource(DataFlow::Node source) { 76 | isGPTDetectedSource(source) 77 | } 78 | 79 | predicate isSink(DataFlow::Node sink) { 80 | sinkNode(sink, "path-injection") 81 | } 82 | 83 | predicate isBarrier(DataFlow::Node sanitizer) { 84 | sanitizer.getType() instanceof BoxedType or 85 | sanitizer.getType() instanceof PrimitiveType or 86 | sanitizer.getType() instanceof NumberType or 87 | sanitizer instanceof PathInjectionSanitizer 88 | } 89 | 90 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 91 | any(TaintedPathAdditionalTaintStep s).step(n1, n2) 92 | } 93 | } 94 | 95 | /** Tracks flow from remote sources to the creation of a path. */ 96 | module MyTaintedPathFlowSourcesOnly = TaintTracking::Global; 97 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-022/cwe-022wLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Uncontrolled data used in path expression 3 | * @description Accessing paths influenced by users can allow an attacker to access unexpected resources. 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 7.5 7 | * @precision high 8 | * @id java/my-path-injection 9 | * @tags security 10 | * external/cwe/cwe-022 11 | * external/cwe/cwe-023 12 | * external/cwe/cwe-036 13 | * external/cwe/cwe-073 14 | */ 15 | 16 | import java 17 | import semmle.code.java.security.PathCreation 18 | import MyTaintedPathQuery 19 | import MyTaintedPathFlow::PathGraph 20 | 21 | /** 22 | * Gets the data-flow node at which to report a path ending at `sink`. 23 | * 24 | * Previously this query flagged alerts exclusively at `PathCreation` sites, 25 | * so to avoid perturbing existing alerts, where a `PathCreation` exists we 26 | * continue to report there; otherwise we report directly at `sink`. 27 | */ 28 | DataFlow::Node getReportingNode(DataFlow::Node sink) { 29 | MyTaintedPathFlow::flowTo(sink) and 30 | if exists(PathCreation pc | pc.getAnInput() = sink.asExpr()) 31 | then result.asExpr() = any(PathCreation pc | pc.getAnInput() = sink.asExpr()) 32 | else result = sink 33 | } 34 | 35 | bindingset[src] 36 | string sourceType(DataFlow::Node src) { 37 | if exists(Parameter p | src.asParameter() = p) 38 | then result = "user-provided value as public function parameter" 39 | else result = "user-provided value from external api return value" 40 | } 41 | 42 | from 43 | MyTaintedPathFlow::PathNode source, MyTaintedPathFlow::PathNode sink 44 | where 45 | MyTaintedPathFlow::flowPath(source, sink) 46 | select 47 | getReportingNode(sink.getNode()), 48 | source, 49 | sink, 50 | "This path depends on a $@.", 51 | source.getNode(), 52 | sourceType(source.getNode()) 53 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-022/cwe-022wLLMAugmented.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Uncontrolled data used in path expression 3 | * @description Accessing paths influenced by users can allow an attacker to access unexpected resources. 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 7.5 7 | * @precision high 8 | * @id java/my-path-injection 9 | * @tags security 10 | * external/cwe/cwe-022 11 | * external/cwe/cwe-023 12 | * external/cwe/cwe-036 13 | * external/cwe/cwe-073 14 | */ 15 | 16 | import java 17 | import semmle.code.java.security.PathCreation 18 | 19 | //from semmle.code.java.security.TaintedPathQuery 20 | import semmle.code.java.frameworks.Networking 21 | import semmle.code.java.dataflow.DataFlow 22 | import semmle.code.java.dataflow.FlowSources 23 | private import semmle.code.java.dataflow.ExternalFlow 24 | import semmle.code.java.security.PathSanitizer 25 | import MySources 26 | import MySinks 27 | 28 | // predicate isGPTDetectedSourceMethod(Method m) { 29 | // ( 30 | // m.getName() = "getHeader" and 31 | // m.getDeclaringType().getAnAncestor().hasQualifiedName("javax.servlet.http", "HttpServletRequest") 32 | // ) or 33 | // ( 34 | // m.getName() = "getPathInfo" and 35 | // m.getDeclaringType().getAnAncestor().hasQualifiedName("javax.servlet.http", "HttpServletRequest") 36 | // ) 37 | // } 38 | 39 | // predicate isGPTDetectedSourceField(Field f) { 40 | // ( 41 | // f.getName() = "Form" and 42 | // f.getDeclaringType().getAnAncestor().hasQualifiedName("javax.servlet.http", "HttpServletRequest") 43 | // ) 44 | // } 45 | 46 | // predicate isGPTDetectedSinkMethodCall(Call c) { 47 | // ( 48 | // c.getCallee().getDeclaringType().getAnAncestor().hasQualifiedName("java.net", "URL") and 49 | // c.getCallee().getName() = "getFile" 50 | // ) 51 | // } 52 | 53 | // predicate isGPTDetectedSinkArgument(Argument a) { 54 | // ( 55 | // a.getCall().getCallee().getDeclaringType().getAnAncestor().hasQualifiedName("java.lang", "Runtime") and 56 | // a.getCall().getCallee().getName() = "exec" and 57 | // a.getPosition() = 0 58 | // ) 59 | // } 60 | 61 | // predicate isGPTDetectedTaintPropArgument(Argument a) { 62 | // ( 63 | // a.getCall().getCallee().getDeclaringType().getAnAncestor().hasQualifiedName("java.net", "URL") and 64 | // a.getCall().getCallee().getName() = "URL" 65 | // ) 66 | // } 67 | 68 | /** 69 | * A unit class for adding additional taint steps. 70 | * 71 | * Extend this class to add additional taint steps that should apply to tainted path flow configurations. 72 | */ 73 | class TaintedPathAdditionalTaintStep extends Unit { 74 | abstract predicate step(DataFlow::Node n1, DataFlow::Node n2); 75 | } 76 | 77 | private class MyTaintedPathAdditionalTaintStep extends TaintedPathAdditionalTaintStep { 78 | override predicate step(DataFlow::Node src, Dataflow::Node sink) { 79 | exists(Argument arg | 80 | arg = src.asExpr() and 81 | arg.getCall() = sink.asExpr() and 82 | isGPTDetectedTaintPropArgument(arg) 83 | ) 84 | } 85 | } 86 | 87 | private class DefaultTaintedPathAdditionalTaintStep extends TaintedPathAdditionalTaintStep { 88 | override predicate step(DataFlow::Node n1, DataFlow::Node n2) { 89 | exists(Argument a | 90 | a = n1.asExpr() and 91 | a.getCall() = n2.asExpr() and 92 | a = any(TaintPreservingUriCtorParam tpp).getAnArgument() 93 | ) 94 | } 95 | } 96 | 97 | private class TaintPreservingUriCtorParam extends Parameter { 98 | TaintPreservingUriCtorParam() { 99 | exists(Constructor ctor, int idx, int nParams | 100 | ctor.getDeclaringType() instanceof TypeUri and 101 | this = ctor.getParameter(idx) and 102 | nParams = ctor.getNumberOfParameters() 103 | | 104 | // URI(String scheme, String ssp, String fragment) 105 | idx = 1 and nParams = 3 106 | or 107 | // URI(String scheme, String host, String path, String fragment) 108 | idx = [1, 2] and nParams = 4 109 | or 110 | // URI(String scheme, String authority, String path, String query, String fragment) 111 | idx = 2 and nParams = 5 112 | or 113 | // URI(String scheme, String userInfo, String host, int port, String path, String query, String fragment) 114 | idx = 4 and nParams = 7 115 | ) 116 | } 117 | } 118 | 119 | /** 120 | * A taint-tracking configuration for tracking flow from remote sources to the creation of a path. 121 | */ 122 | module MyTaintedPathConfig implements DataFlow::ConfigSig { 123 | predicate isSource(DataFlow::Node source) { 124 | // return value of method call 125 | isGPTDetectedSourceMethod(source.asExpr().(MethodCall).getMethod()) or 126 | source instanceof ThreatModelFlowSource 127 | 128 | // field read 129 | //isGPTDetectedSourceField(source.asExpr().(FieldAccess).getField()) 130 | } 131 | 132 | predicate isSink(DataFlow::Node sink) { 133 | // callee of a method call 134 | //isGPTDetectedSinkMethodCall(sink.asExpr().(Call)) or 135 | 136 | // an argument to a method call 137 | isGPTDetectedSinkArgument(sink.asExpr().(Argument)) or 138 | sinkNode(sink, "path-injection") 139 | } 140 | 141 | predicate isBarrier(DataFlow::Node sanitizer) { 142 | sanitizer.getType() instanceof BoxedType or 143 | sanitizer.getType() instanceof PrimitiveType or 144 | sanitizer.getType() instanceof NumberType or 145 | sanitizer instanceof PathInjectionSanitizer 146 | } 147 | 148 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 149 | any(TaintedPathAdditionalTaintStep s).step(n1, n2) 150 | } 151 | } 152 | 153 | /** Tracks flow from remote sources to the creation of a path. */ 154 | module MyTaintedPathFlow = TaintTracking::Global; 155 | 156 | 157 | /** 158 | * Gets the data-flow node at which to report a path ending at `sink`. 159 | * 160 | * Previously this query flagged alerts exclusively at `PathCreation` sites, 161 | * so to avoid perturbing existing alerts, where a `PathCreation` exists we 162 | * continue to report there; otherwise we report directly at `sink`. 163 | */ 164 | DataFlow::Node getReportingNode(DataFlow::Node sink) { 165 | MyTaintedPathFlow::flowTo(sink) and 166 | if exists(PathCreation pc | pc.getAnInput() = sink.asExpr()) 167 | then result.asExpr() = any(PathCreation pc | pc.getAnInput() = sink.asExpr()) 168 | else result = sink 169 | } 170 | 171 | from MyTaintedPathFlow::PathNode source, MyTaintedPathFlow::PathNode sink 172 | where MyTaintedPathFlow::flowPath(source, sink) 173 | select getReportingNode(sink.getNode()), source, sink, "This path depends on a $@.", 174 | source.getNode(), "user-provided value" 175 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-022/cwe-022wLLMSinksOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Uncontrolled data used in path expression 3 | * @description Accessing paths influenced by users can allow an attacker to access unexpected resources. 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 7.5 7 | * @precision high 8 | * @id java/my-path-injection-sinks-only 9 | * @tags security 10 | * external/cwe/cwe-022 11 | * external/cwe/cwe-023 12 | * external/cwe/cwe-036 13 | * external/cwe/cwe-073 14 | */ 15 | 16 | import java 17 | import semmle.code.java.security.PathCreation 18 | import MyTaintedPathQuery 19 | import MyTaintedPathFlowSinksOnly::PathGraph 20 | /** 21 | * Gets the data-flow node at which to report a path ending at `sink`. 22 | * 23 | * Previously this query flagged alerts exclusively at `PathCreation` sites, 24 | * so to avoid perturbing existing alerts, where a `PathCreation` exists we 25 | * continue to report there; otherwise we report directly at `sink`. 26 | */ 27 | DataFlow::Node getReportingNode(DataFlow::Node sink) { 28 | MyTaintedPathFlowSinksOnly::flowTo(sink) and 29 | if exists(PathCreation pc | pc.getAnInput() = sink.asExpr()) 30 | then result.asExpr() = any(PathCreation pc | pc.getAnInput() = sink.asExpr()) 31 | else result = sink 32 | } 33 | 34 | from MyTaintedPathFlowSinksOnly::PathNode source, MyTaintedPathFlowSinksOnly::PathNode sink 35 | where MyTaintedPathFlowSinksOnly::flowPath(source, sink) 36 | select getReportingNode(sink.getNode()), source, sink, "This path depends on a $@.", 37 | source.getNode(), "user-provided value" 38 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-022/cwe-022wLLMSourcesOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Uncontrolled data used in path expression 3 | * @description Accessing paths influenced by users can allow an attacker to access unexpected resources. 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 7.5 7 | * @precision high 8 | * @id java/my-path-injection-sinks-only 9 | * @tags security 10 | * external/cwe/cwe-022 11 | * external/cwe/cwe-023 12 | * external/cwe/cwe-036 13 | * external/cwe/cwe-073 14 | */ 15 | 16 | import java 17 | import semmle.code.java.security.PathCreation 18 | import MyTaintedPathQuery 19 | import MyTaintedPathFlowSourcesOnly::PathGraph 20 | /** 21 | * Gets the data-flow node at which to report a path ending at `sink`. 22 | * 23 | * Previously this query flagged alerts exclusively at `PathCreation` sites, 24 | * so to avoid perturbing existing alerts, where a `PathCreation` exists we 25 | * continue to report there; otherwise we report directly at `sink`. 26 | */ 27 | DataFlow::Node getReportingNode(DataFlow::Node sink) { 28 | MyTaintedPathFlowSourcesOnly::flowTo(sink) and 29 | if exists(PathCreation pc | pc.getAnInput() = sink.asExpr()) 30 | then result.asExpr() = any(PathCreation pc | pc.getAnInput() = sink.asExpr()) 31 | else result = sink 32 | } 33 | 34 | from MyTaintedPathFlowSourcesOnly::PathNode source, MyTaintedPathFlowSourcesOnly::PathNode sink 35 | where MyTaintedPathFlowSourcesOnly::flowPath(source, sink) 36 | select getReportingNode(sink.getNode()), source, sink, "This path depends on a $@.", 37 | source.getNode(), "user-provided value" 38 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/CommandInjectionRuntimeExecLocalwLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Command Injection into Runtime.exec() with dangerous command 3 | * @description High sensitvity and precision version of java/command-line-injection, designed to find more cases of command injection in rare cases that the default query does not find 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 6.1 7 | * @precision high 8 | * @id java/my-command-line-injection-extra-local 9 | * @tags security 10 | * experimental 11 | * local 12 | * external/cwe/cwe-078 13 | */ 14 | 15 | import MyCommandInjectionRuntimeExec 16 | import MyExecUserFlow::PathGraph 17 | 18 | class LocalSource extends Source instanceof LocalUserInput { } 19 | 20 | bindingset[src] 21 | string sourceType(DataFlow::Node src) { 22 | if exists(Parameter p | src.asParameter() = p) 23 | then result = "user-provided value as public function parameter" 24 | else result = "user-provided value from external api return value" 25 | } 26 | 27 | from 28 | MyExecUserFlow::PathNode source, 29 | MyExecUserFlow::PathNode sink 30 | //, DataFlow::Node sourceCmd, DataFlow::Node sinkCmd 31 | where 32 | MyExecUserFlow::flowPath(source, sink) 33 | // where mycallIsTaintedByUserInputAndDangerousCommand(source, sink, sourceCmd, sinkCmd) 34 | select sink, source, sink, 35 | "Call to dangerous java.lang.Runtime.exec() with command '$@' with arg from untrusted input", 36 | source.getNode(), 37 | sourceType(source.getNode()) 38 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/CommandInjectionRuntimeExecwLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Command Injection into Runtime.exec() with dangerous command 3 | * @description High sensitvity and precision version of java/command-line-injection, designed to find more cases of command injection in rare cases that the default query does not find 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 6.1 7 | * @precision high 8 | * @id java/my-command-line-injection-extra 9 | * @tags security 10 | * experimental 11 | * external/cwe/cwe-078 12 | */ 13 | 14 | import MyCommandInjectionRuntimeExec 15 | import MyExecUserFlow::PathGraph 16 | 17 | 18 | from 19 | MyExecUserFlow::PathNode source, MyExecUserFlow::PathNode sink 20 | //, DataFlow::Node sourceCmd, DataFlow::Node sinkCmd 21 | where MyExecUserFlow::flowPath(source, sink) 22 | // where mycallIsTaintedByUserInputAndDangerousCommand(source, sink, sourceCmd, sinkCmd) 23 | select sink, source, sink, 24 | "Call to dangerous java.lang.Runtime.exec() with command '$@' with arg from untrusted input", 25 | source.getNode(), source.toString() 26 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/CommandInjectionRuntimeExecwLLMSinksOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Command Injection into Runtime.exec() with dangerous command 3 | * @description High sensitvity and precision version of java/command-line-injection, designed to find more cases of command injection in rare cases that the default query does not find 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 6.1 7 | * @precision high 8 | * @id java/my-command-line-injection-extra 9 | * @tags security 10 | * experimental 11 | * external/cwe/cwe-078 12 | */ 13 | 14 | import MyCommandInjectionRuntimeExec 15 | import MyExecUserFlowSinksOnly::PathGraph 16 | 17 | 18 | from 19 | MyExecUserFlowSinksOnly::PathNode source, MyExecUserFlowSinksOnly::PathNode sink 20 | //, DataFlow::Node sourceCmd, DataFlow::Node sinkCmd 21 | where MyExecUserFlowSinksOnly::flowPath(source, sink) 22 | // where mycallIsTaintedByUserInputAndDangerousCommand(source, sink, sourceCmd, sinkCmd) 23 | select sink, source, sink, 24 | "Call to dangerous java.lang.Runtime.exec() with command '$@' with arg from untrusted input", 25 | source.getNode(), source.toString() 26 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/CommandInjectionRuntimeExecwLLMSourcesOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Command Injection into Runtime.exec() with dangerous command 3 | * @description High sensitvity and precision version of java/command-line-injection, designed to find more cases of command injection in rare cases that the default query does not find 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 6.1 7 | * @precision high 8 | * @id java/my-command-line-injection-extra-sources-only 9 | * @tags security 10 | * experimental 11 | * external/cwe/cwe-078 12 | */ 13 | 14 | import MyCommandInjectionRuntimeExec 15 | import MyExecUserFlowSourcesOnly::PathGraph 16 | 17 | 18 | from 19 | MyExecUserFlowSourcesOnly::PathNode source, MyExecUserFlowSourcesOnly::PathNode sink 20 | //, DataFlow::Node sourceCmd, DataFlow::Node sinkCmd 21 | where MyExecUserFlowSourcesOnly::flowPath(source, sink) 22 | // where mycallIsTaintedByUserInputAndDangerousCommand(source, sink, sourceCmd, sinkCmd) 23 | select sink, source, sink, 24 | "Call to dangerous java.lang.Runtime.exec() with command '$@' with arg from untrusted input", 25 | source.getNode(), source.toString() 26 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/ExecRelativewLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Executing a command with a relative path 3 | * @description Executing a command with a relative path is vulnerable to 4 | * malicious changes in the PATH environment variable. 5 | * @kind problem 6 | * @problem.severity warning 7 | * @security-severity 9.8 8 | * @precision medium 9 | * @id java/myrelative-path-command 10 | * @tags security 11 | * external/cwe/cwe-078 12 | * external/cwe/cwe-088 13 | */ 14 | 15 | import semmle.code.java.Expr 16 | import semmle.code.java.security.RelativePaths 17 | import semmle.code.java.security.ExternalProcess 18 | 19 | from ArgumentToExec argument, string command 20 | where 21 | ( 22 | relativePath(argument, command) or 23 | arrayStartingWithRelative(argument, command) 24 | ) and 25 | not shellBuiltin(command) 26 | select argument, "Command with a relative path '" + command + "' is executed." 27 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/ExecTaintedLocalwLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Local-user-controlled command line 3 | * @description Using externally controlled strings in a command line is vulnerable to malicious 4 | * changes in the strings. 5 | * @kind path-problem 6 | * @problem.severity recommendation 7 | * @security-severity 9.8 8 | * @precision medium 9 | * @id java/mycommand-line-injection-local 10 | * @tags security 11 | * external/cwe/cwe-078 12 | * external/cwe/cwe-088 13 | */ 14 | 15 | import java 16 | import MyCommandLineQuery 17 | import MyExternalProcess 18 | import MyLocalUserInputToArgumentToExecFlow::PathGraph 19 | 20 | from 21 | MyLocalUserInputToArgumentToExecFlow::PathNode source, 22 | MyLocalUserInputToArgumentToExecFlow::PathNode sink, Expr e 23 | where 24 | MyLocalUserInputToArgumentToExecFlow::flowPath(source, sink) and 25 | myargumentToExec(e, sink.getNode()) 26 | select e, source, sink, "This command line depends on a $@.", source.getNode(), 27 | "user-provided value" 28 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/ExecTaintedwLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Uncontrolled command line 3 | * @description Using externally controlled strings in a command line is vulnerable to malicious 4 | * changes in the strings. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 9.8 8 | * @precision high 9 | * @id java/mycommand-line-injection 10 | * @tags security 11 | * external/cwe/cwe-078 12 | * external/cwe/cwe-088 13 | */ 14 | 15 | import java 16 | import MyCommandLineQuery 17 | import MyRemoteUserInputToArgumentToExecFlow::PathGraph 18 | 19 | from 20 | MyRemoteUserInputToArgumentToExecFlow::PathNode source, 21 | MyRemoteUserInputToArgumentToExecFlow::PathNode sink, Expr execArg 22 | where execIsTainted(source, sink, execArg) 23 | select execArg, source, sink, "This command line depends on a $@.", source.getNode(), 24 | "user-provided value" 25 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/ExecUnescapedwLLM.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Building a command line with string concatenation 3 | * @description Using concatenated strings in a command line is vulnerable to malicious 4 | * insertion of special characters in the strings. 5 | * @kind problem 6 | * @problem.severity error 7 | * @security-severity 9.8 8 | * @precision high 9 | * @id java/myconcatenated-command-line 10 | * @tags security 11 | * external/cwe/cwe-078 12 | * external/cwe/cwe-088 13 | */ 14 | 15 | import java 16 | import semmle.code.java.security.CommandLineQuery 17 | import semmle.code.java.security.ExternalProcess 18 | 19 | /** 20 | * Strings that are known to be sane by some simple local analysis. Such strings 21 | * do not need to be escaped, because the programmer can predict what the string 22 | * has in it. 23 | */ 24 | predicate saneString(Expr expr) { 25 | expr instanceof StringLiteral 26 | or 27 | expr instanceof NullLiteral 28 | or 29 | exists(Variable var | var.getAnAccess() = expr and exists(var.getAnAssignedValue()) | 30 | forall(Expr other | var.getAnAssignedValue() = other | saneString(other)) 31 | ) 32 | } 33 | 34 | predicate builtFromUncontrolledConcat(Expr expr) { 35 | exists(AddExpr concatExpr | concatExpr = expr | 36 | builtFromUncontrolledConcat(concatExpr.getAnOperand()) 37 | ) 38 | or 39 | exists(AddExpr concatExpr | concatExpr = expr | 40 | exists(Expr arg | arg = concatExpr.getAnOperand() | not saneString(arg)) 41 | ) 42 | or 43 | exists(Expr other | builtFromUncontrolledConcat(other) | 44 | exists(Variable var | var.getAnAssignedValue() = other and var.getAnAccess() = expr) 45 | ) 46 | } 47 | 48 | from StringArgumentToExec argument 49 | where 50 | builtFromUncontrolledConcat(argument) and 51 | not execIsTainted(_, _, argument) 52 | select argument, "Command line is built with string concatenation." 53 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/MyCommandArguments.qll: -------------------------------------------------------------------------------- 1 | /** 2 | * Definitions for reasoning about lists and arrays that are to be used as arguments to an external process. 3 | */ 4 | 5 | import java 6 | import semmle.code.java.dataflow.SSA 7 | import semmle.code.java.Collections 8 | 9 | /** 10 | * Holds if `ex` is used safely as an argument to a command; 11 | * i.e. it's not in the first position and it's not a shell command. 12 | */ 13 | predicate isSafeCommandArgument(Expr ex) { 14 | exists(ArrayInit ai, int i | 15 | ex = ai.getInit(i) and 16 | i > 0 and 17 | not isShell(ai.getInit(0)) 18 | ) 19 | or 20 | exists(CommandArgumentList cal | 21 | not cal.isShell() and 22 | ex = cal.getASubsequentAdd().getArgument(0) 23 | ) 24 | or 25 | exists(CommandArgArrayImmutableFirst caa | 26 | not caa.isShell() and 27 | ex = caa.getAWrite(any(int i | i > 0)) 28 | ) 29 | } 30 | 31 | /** 32 | * Holds if the given expression is the name of a shell command such as bash or python 33 | */ 34 | private predicate isShell(Expr ex) { 35 | exists(string cmd | cmd = ex.(StringLiteral).getValue() | 36 | cmd.regexpMatch(".*(sh|javac?|python[23]?|osascript|cmd)(\\.exe)?$") 37 | ) 38 | or 39 | exists(SsaVariable ssa | 40 | ex = ssa.getAUse() and 41 | isShell(ssa.getAnUltimateDefinition().(SsaExplicitUpdate).getDefiningExpr()) 42 | ) 43 | or 44 | isShell(ex.(Assignment).getRhs()) 45 | or 46 | isShell(ex.(LocalVariableDeclExpr).getInit()) 47 | } 48 | 49 | /** 50 | * A type that could be a list of strings. Includes raw `List` types. 51 | */ 52 | private class ListOfStringType extends CollectionType { 53 | ListOfStringType() { 54 | this.getSourceDeclaration().getASourceSupertype*().hasQualifiedName("java.util", "List") and 55 | this.getElementType().getADescendant() instanceof TypeString 56 | } 57 | } 58 | 59 | /** 60 | * A variable that could be used as a list of arguments to a command. 61 | */ 62 | private class CommandArgumentList extends SsaExplicitUpdate { 63 | CommandArgumentList() { 64 | this.getSourceVariable().getType() instanceof ListOfStringType and 65 | forex(CollectionMutation ma | ma.getQualifier() = this.getAUse() | 66 | ma.getMethod().getName().matches("add%") 67 | ) 68 | } 69 | 70 | /** Gets a use of the variable for which the list could be empty. */ 71 | private VarRead getAUseBeforeFirstAdd() { 72 | result = this.getAFirstUse() 73 | or 74 | exists(VarRead mid | 75 | mid = this.getAUseBeforeFirstAdd() and 76 | adjacentUseUse(mid, result) and 77 | not exists(MethodCall ma | 78 | mid = ma.getQualifier() and 79 | ma.getMethod().hasName("add") 80 | ) 81 | ) 82 | } 83 | 84 | /** 85 | * Gets an addition to this list, i.e. a call to an `add` or `addAll` method. 86 | */ 87 | MethodCall getAnAdd() { 88 | result.getQualifier() = this.getAUse() and 89 | result.getMethod().getName().matches("add%") 90 | } 91 | 92 | /** Gets an addition to this list which could be its first element. */ 93 | MethodCall getAFirstAdd() { 94 | result = this.getAnAdd() and 95 | result.getQualifier() = this.getAUseBeforeFirstAdd() 96 | } 97 | 98 | /** Gets an addition to this list which is not the first element. */ 99 | MethodCall getASubsequentAdd() { 100 | result = this.getAnAdd() and 101 | not result = this.getAFirstAdd() 102 | } 103 | 104 | /** Holds if the first element of this list is a shell command. */ 105 | predicate isShell() { 106 | exists(MethodCall ma | ma = this.getAFirstAdd() and isShell(ma.getArgument(0))) 107 | } 108 | } 109 | 110 | /** 111 | * The type `String[]`. 112 | */ 113 | private class ArrayOfStringType extends Array { 114 | ArrayOfStringType() { this.getElementType() instanceof TypeString } 115 | } 116 | 117 | private predicate arrayVarWrite(ArrayAccess acc) { exists(Assignment a | a.getDest() = acc) } 118 | 119 | /** 120 | * A variable that could be an array of arguments to a command. 121 | */ 122 | private class CommandArgumentArray extends SsaExplicitUpdate { 123 | CommandArgumentArray() { 124 | this.getSourceVariable().getType() instanceof ArrayOfStringType and 125 | forall(ArrayAccess a | a.getArray() = this.getAUse() and arrayVarWrite(a) | 126 | a.getIndexExpr() instanceof CompileTimeConstantExpr 127 | ) 128 | } 129 | 130 | /** Gets an expression that is written to the given index of this array at the given use. */ 131 | Expr getAWrite(int index, VarRead use) { 132 | exists(Assignment a, ArrayAccess acc | 133 | acc.getArray() = use and 134 | use = this.getAUse() and 135 | index = acc.getIndexExpr().(CompileTimeConstantExpr).getIntValue() and 136 | acc = a.getDest() and 137 | result = a.getRhs() 138 | ) 139 | } 140 | 141 | /** Gets an expression that is written to the given index of this array. */ 142 | Expr getAWrite(int index) { result = this.getAWrite(index, _) } 143 | } 144 | 145 | /** 146 | * A `CommandArgArray` whose element at index 0 is never written to, except possibly once to initialise it. 147 | */ 148 | private class CommandArgArrayImmutableFirst extends CommandArgumentArray { 149 | CommandArgArrayImmutableFirst() { 150 | (exists(this.getAWrite(0)) or exists(firstElementOf(this.getDefiningExpr()))) and 151 | forall(VarRead use | exists(this.getAWrite(0, use)) | use = this.getAFirstUse()) 152 | } 153 | 154 | /** Gets the first element of this array. */ 155 | Expr getFirstElement() { 156 | result = this.getAWrite(0) 157 | or 158 | not exists(this.getAWrite(0)) and 159 | result = firstElementOf(this.getDefiningExpr()) 160 | } 161 | 162 | /** Holds if the first element of this array is a shell command. */ 163 | predicate isShell() { isShell(this.getFirstElement()) } 164 | } 165 | 166 | /** Gets the first element of an imutable array of strings */ 167 | private Expr firstElementOf(Expr arr) { 168 | arr.getType() instanceof ArrayOfStringType and 169 | ( 170 | result = firstElementOf(arr.(Assignment).getRhs()) 171 | or 172 | result = firstElementOf(arr.(LocalVariableDeclExpr).getInit()) 173 | or 174 | exists(CommandArgArrayImmutableFirst caa | arr = caa.getAUse() | result = caa.getFirstElement()) 175 | or 176 | exists(MethodCall ma, Method m | 177 | arr = ma and 178 | ma.getMethod() = m and 179 | m.getDeclaringType().hasQualifiedName("java.util", "Arrays") and 180 | m.hasName("copyOf") and 181 | result = firstElementOf(ma.getArgument(0)) 182 | ) 183 | or 184 | exists(Field f | 185 | f.isStatic() and 186 | arr.(FieldRead).getField() = f and 187 | result = firstElementOf(f.getInitializer()) 188 | ) 189 | or 190 | result = arr.(ArrayInit).getInit(0) 191 | or 192 | result = arr.(ArrayCreationExpr).getInit().getInit(0) 193 | ) 194 | } 195 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/MyCommandInjectionRuntimeExec.qll: -------------------------------------------------------------------------------- 1 | import java 2 | import semmle.code.java.frameworks.javaee.ejb.EJBRestrictions 3 | import semmle.code.java.dataflow.DataFlow 4 | private import semmle.code.java.dataflow.FlowSources 5 | private import semmle.code.java.dataflow.ExternalFlow 6 | import MySources 7 | import MySinks 8 | import MySummaries 9 | 10 | module MyExecUserFlowConfig implements DataFlow::ConfigSig { 11 | //predicate isSource(DataFlow::Node source) { source instanceof Source } 12 | predicate isSource(DataFlow::Node src) { 13 | isGPTDetectedSource(src) 14 | } 15 | 16 | predicate isSink(DataFlow::Node sink) { 17 | isGPTDetectedSink(sink) 18 | } 19 | 20 | predicate isBarrier(DataFlow::Node node) { 21 | node.getType() instanceof PrimitiveType or 22 | node.getType() instanceof BoxedType or 23 | node.getType() instanceof NumberType 24 | } 25 | 26 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 27 | isGPTDetectedStep(n1, n2) 28 | } 29 | } 30 | 31 | /** Tracks flow of unvalidated user input that is used in Runtime.Exec */ 32 | module MyExecUserFlow = TaintTracking::Global; 33 | 34 | 35 | module MyExecUserFlowConfigSinksOnly implements DataFlow::ConfigSig { 36 | //predicate isSource(DataFlow::Node source) { source instanceof Source } 37 | predicate isSource(DataFlow::Node src) { 38 | src instanceof ThreatModelFlowSource 39 | } 40 | 41 | predicate isSink(DataFlow::Node sink) { 42 | isGPTDetectedSink(sink) 43 | } 44 | 45 | 46 | 47 | predicate isBarrier(DataFlow::Node node) { 48 | node.getType() instanceof PrimitiveType or 49 | node.getType() instanceof BoxedType or 50 | node.getType() instanceof NumberType 51 | } 52 | } 53 | 54 | /** Tracks flow of unvalidated user input that is used in Runtime.Exec */ 55 | module MyExecUserFlowSinksOnly = TaintTracking::Global; 56 | 57 | 58 | 59 | module MyExecUserFlowConfigSourcesOnly implements DataFlow::ConfigSig { 60 | predicate isSource(DataFlow::Node src) { 61 | isGPTDetectedSource(src) 62 | } 63 | 64 | predicate isSink(DataFlow::Node sink) { 65 | sinkNode(sink, "command-injection") 66 | or 67 | exists(MethodCall call | 68 | call.getMethod() instanceof RuntimeExecMethod and 69 | sink.asExpr() = call.getArgument(_) and 70 | sink.asExpr().getType() instanceof Array 71 | ) 72 | } 73 | 74 | // predicate isSink(DataFlow::Node sink) { 75 | // exists(MethodCall call | 76 | // call.getMethod() instanceof RuntimeExecMethod and 77 | // sink.asExpr() = call.getArgument(_) and 78 | // sink.asExpr().getType() instanceof Array 79 | // ) 80 | // } 81 | 82 | 83 | predicate isBarrier(DataFlow::Node node) { 84 | node.getType() instanceof PrimitiveType or 85 | node.getType() instanceof BoxedType or 86 | node.getType() instanceof NumberType 87 | } 88 | } 89 | 90 | /** Tracks flow of unvalidated user input that is used in Runtime.Exec */ 91 | module MyExecUserFlowSourcesOnly = TaintTracking::Global; 92 | 93 | // array[3] = node 94 | class AssignToNonZeroIndex extends DataFlow::Node { 95 | AssignToNonZeroIndex() { 96 | exists(AssignExpr assign, ArrayAccess access | 97 | assign.getDest() = access and 98 | access.getIndexExpr().(IntegerLiteral).getValue().toInt() != 0 and 99 | assign.getSource() = this.asExpr() 100 | ) 101 | } 102 | } 103 | 104 | // String[] array = {"a", "b, "c"}; 105 | class ArrayInitAtNonZeroIndex extends DataFlow::Node { 106 | ArrayInitAtNonZeroIndex() { 107 | exists(ArrayInit init, int index | 108 | init.getInit(index) = this.asExpr() and 109 | index != 0 110 | ) 111 | } 112 | } 113 | 114 | // Stream.concat(Arrays.stream(array_1), Arrays.stream(array_2)) 115 | class StreamConcatAtNonZeroIndex extends DataFlow::Node { 116 | StreamConcatAtNonZeroIndex() { 117 | exists(MethodCall call, int index | 118 | call.getMethod().getQualifiedName() = "java.util.stream.Stream.concat" and 119 | call.getArgument(index) = this.asExpr() and 120 | index != 0 121 | ) 122 | } 123 | } 124 | 125 | // list of executables that execute their arguments 126 | // TODO: extend with data extensions 127 | class UnSafeExecutable extends string { 128 | bindingset[this] 129 | UnSafeExecutable() { 130 | this.regexpMatch("^(|.*/)([a-z]*sh|javac?|python.*|perl|[Pp]ower[Ss]hell|php|node|deno|bun|ruby|osascript|cmd|Rscript|groovy)(\\.exe)?$") and 131 | not this = "netsh.exe" 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/MyCommandLineQuery.qll: -------------------------------------------------------------------------------- 1 | /** 2 | * Provides classes and methods common to queries `java/command-line-injection`, `java/command-line-concatenation` 3 | * and their experimental derivatives. 4 | * 5 | * Do not import this from a library file, in order to reduce the risk of 6 | * unintentionally bringing a TaintTracking::Configuration into scope in an unrelated 7 | * query. 8 | */ 9 | 10 | import java 11 | private import semmle.code.java.dataflow.FlowSources 12 | private import semmle.code.java.dataflow.ExternalFlow 13 | private import MyCommandArguments 14 | private import MyExternalProcess 15 | import MySources 16 | import MySinks 17 | 18 | /** A sink for command injection vulnerabilities. */ 19 | abstract class CommandInjectionSink extends DataFlow::Node { } 20 | 21 | /** A sanitizer for command injection vulnerabilities. */ 22 | abstract class CommandInjectionSanitizer extends DataFlow::Node { } 23 | 24 | /** 25 | * A unit class for adding additional taint steps. 26 | * 27 | * Extend this class to add additional taint steps that should apply to configurations related to command injection. 28 | */ 29 | class CommandInjectionAdditionalTaintStep extends Unit { 30 | /** 31 | * Holds if the step from `node1` to `node2` should be considered a taint 32 | * step for configurations related to command injection. 33 | */ 34 | abstract predicate step(DataFlow::Node node1, DataFlow::Node node2); 35 | } 36 | 37 | private class DefaultCommandInjectionSink extends CommandInjectionSink { 38 | DefaultCommandInjectionSink() { sinkNode(this, "command-injection") } 39 | } 40 | 41 | private class DefaultCommandInjectionSanitizer extends CommandInjectionSanitizer { 42 | DefaultCommandInjectionSanitizer() { 43 | this.getType() instanceof PrimitiveType 44 | or 45 | this.getType() instanceof BoxedType 46 | or 47 | this.getType() instanceof NumberType 48 | or 49 | isSafeCommandArgument(this.asExpr()) 50 | } 51 | } 52 | 53 | /** 54 | * A taint-tracking configuration for unvalidated user input that is used to run an external process. 55 | */ 56 | module MyRemoteUserInputToArgumentToExecFlowConfig implements DataFlow::ConfigSig { 57 | predicate isSource(DataFlow::Node src) { 58 | isGPTDetectedSourceMethod(src.asExpr().(MethodCall).getMethod()) 59 | // src instanceof ThreatModelFlowSource 60 | } 61 | 62 | predicate isSink(DataFlow::Node sink) { 63 | isGPTDetectedSinkMethodCall(sink.asExpr().(Call)) or 64 | 65 | // an argument to a method call 66 | isGPTDetectedSinkArgument(sink.asExpr().(Argument)) 67 | //sink instanceof CommandInjectionSink 68 | } 69 | 70 | predicate isBarrier(DataFlow::Node node) { node instanceof CommandInjectionSanitizer } 71 | 72 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 73 | any(CommandInjectionAdditionalTaintStep s).step(n1, n2) 74 | } 75 | } 76 | 77 | /** 78 | * Taint-tracking flow for unvalidated user input that is used to run an external process. 79 | */ 80 | module MyRemoteUserInputToArgumentToExecFlow = 81 | TaintTracking::Global; 82 | 83 | /** 84 | * A taint-tracking configuration for unvalidated local user input that is used to run an external process. 85 | */ 86 | module MyLocalUserInputToArgumentToExecFlowConfig implements DataFlow::ConfigSig { 87 | predicate isSource(DataFlow::Node src) { 88 | isGPTDetectedSourceMethod(src.asExpr().(MethodCall).getMethod()) 89 | //src instanceof LocalUserInput 90 | } 91 | 92 | predicate isSink(DataFlow::Node sink) { 93 | isGPTDetectedSinkMethodCall(sink.asExpr().(Call)) or 94 | // an argument to a method call 95 | isGPTDetectedSinkArgument(sink.asExpr().(Argument)) 96 | //sink instanceof CommandInjectionSink 97 | } 98 | 99 | predicate isBarrier(DataFlow::Node node) { node instanceof CommandInjectionSanitizer } 100 | 101 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 102 | any(CommandInjectionAdditionalTaintStep s).step(n1, n2) 103 | } 104 | } 105 | 106 | /** 107 | * Taint-tracking flow for unvalidated local user input that is used to run an external process. 108 | */ 109 | module MyLocalUserInputToArgumentToExecFlow = 110 | TaintTracking::Global; 111 | 112 | /** 113 | * Implementation of `ExecTainted.ql`. It is extracted to a QLL 114 | * so that it can be excluded from `ExecUnescaped.ql` to avoid 115 | * reporting overlapping results. 116 | */ 117 | predicate execIsTainted( 118 | MyRemoteUserInputToArgumentToExecFlow::PathNode source, 119 | MyRemoteUserInputToArgumentToExecFlow::PathNode sink, Expr execArg 120 | ) { 121 | MyRemoteUserInputToArgumentToExecFlow::flowPath(source, sink) and 122 | myargumentToExec(execArg, sink.getNode()) 123 | } 124 | 125 | /** 126 | * DEPRECATED: Use `execIsTainted` instead. 127 | * 128 | * Implementation of `ExecTainted.ql`. It is extracted to a QLL 129 | * so that it can be excluded from `ExecUnescaped.ql` to avoid 130 | * reporting overlapping results. 131 | */ 132 | deprecated predicate execTainted(DataFlow::PathNode source, DataFlow::PathNode sink, Expr execArg) { 133 | exists(MyRemoteUserInputToArgumentToExecFlowConfig conf | 134 | conf.hasFlowPath(source, sink) and myargumentToExec(execArg, sink.getNode()) 135 | ) 136 | } 137 | 138 | /** 139 | * DEPRECATED: Use `RemoteUserInputToArgumentToExecFlow` instead. 140 | * 141 | * A taint-tracking configuration for unvalidated user input that is used to run an external process. 142 | */ 143 | deprecated class MyRemoteUserInputToArgumentToExecFlowConfig extends TaintTracking::Configuration { 144 | MyRemoteUserInputToArgumentToExecFlowConfig() { 145 | this = "ExecCommon::RemoteUserInputToArgumentToExecFlowConfig" 146 | } 147 | 148 | override predicate isSource(DataFlow::Node src) { src instanceof RemoteFlowSource } 149 | 150 | override predicate isSink(DataFlow::Node sink) { sink instanceof CommandInjectionSink } 151 | 152 | override predicate isSanitizer(DataFlow::Node node) { node instanceof CommandInjectionSanitizer } 153 | 154 | override predicate isAdditionalTaintStep(DataFlow::Node n1, DataFlow::Node n2) { 155 | any(CommandInjectionAdditionalTaintStep s).step(n1, n2) 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-078/MyExternalProcess.qll: -------------------------------------------------------------------------------- 1 | /** Definitions related to external processes. */ 2 | 3 | import semmle.code.java.Member 4 | private import semmle.code.java.dataflow.DataFlow 5 | private import MyCommandLineQuery 6 | 7 | /** 8 | * DEPRECATED: A callable that executes a command. 9 | */ 10 | abstract deprecated class MyExecCallable extends Callable { 11 | /** 12 | * Gets the index of an argument that will be part of the command that is executed. 13 | */ 14 | abstract int getAnExecutedArgument(); 15 | } 16 | 17 | /** 18 | * An expression used as an argument to a call that executes an external command. For calls to 19 | * varargs method calls, this only includes the first argument, which will be the command 20 | * to be executed. 21 | */ 22 | class MyArgumentToExec extends Expr { 23 | MyArgumentToExec() { myargumentToExec(this, _) } 24 | } 25 | 26 | /** 27 | * Holds if `e` is an expression used as an argument to a call that executes an external command. 28 | * For calls to varargs method calls, this only includes the first argument, which will be the command 29 | * to be executed. 30 | */ 31 | predicate myargumentToExec(Expr e, CommandInjectionSink s) { 32 | s.asExpr() = e 33 | or 34 | e.(Argument).isNthVararg(0) and 35 | s.(DataFlow::ImplicitVarargsArray).getCall() = e.(Argument).getCall() 36 | } 37 | 38 | /** 39 | * An `ArgumentToExec` of type `String`. 40 | */ 41 | class MyStringArgumentToExec extends MyArgumentToExec { 42 | MyStringArgumentToExec() { this.getType() instanceof TypeString } 43 | } 44 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/MyXSS.qll: -------------------------------------------------------------------------------- 1 | /** Provides classes to reason about Cross-site scripting (XSS) vulnerabilities. */ 2 | 3 | import java 4 | import semmle.code.java.frameworks.Servlets 5 | import semmle.code.java.frameworks.android.WebView 6 | import semmle.code.java.frameworks.spring.SpringController 7 | import semmle.code.java.frameworks.spring.SpringHttp 8 | import semmle.code.java.frameworks.javaee.jsf.JSFRenderer 9 | private import semmle.code.java.frameworks.hudson.Hudson 10 | import semmle.code.java.dataflow.DataFlow 11 | import semmle.code.java.dataflow.TaintTracking 12 | private import semmle.code.java.dataflow.ExternalFlow 13 | 14 | /** A sink that represent a method that outputs data without applying contextual output encoding. */ 15 | abstract class XssSink extends DataFlow::Node { } 16 | 17 | /** A sanitizer that neutralizes dangerous characters that can be used to perform a XSS attack. */ 18 | abstract class XssSanitizer extends DataFlow::Node { } 19 | 20 | /** 21 | * A sink that represent a method that outputs data without applying contextual output encoding, 22 | * and which should truncate flow paths such that downstream sinks are not flagged as well. 23 | */ 24 | abstract class XssSinkBarrier extends XssSink { } 25 | 26 | /** 27 | * A unit class for adding additional taint steps. 28 | * 29 | * Extend this class to add additional taint steps that should apply to the XSS 30 | * taint configuration. 31 | */ 32 | class XssAdditionalTaintStep extends Unit { 33 | /** 34 | * Holds if the step from `node1` to `node2` should be considered a taint 35 | * step for XSS taint configurations. 36 | */ 37 | abstract predicate step(DataFlow::Node node1, DataFlow::Node node2); 38 | } 39 | 40 | /** A default sink representing methods susceptible to XSS attacks. */ 41 | private class DefaultXssSink extends XssSink { 42 | DefaultXssSink() { 43 | sinkNode(this, ["html-injection", "js-injection"]) 44 | or 45 | exists(MethodCall ma | 46 | ma.getMethod() instanceof WritingMethod and 47 | MyXssVulnerableWriterSourceToWritingMethodFlow::flowToExpr(ma.getQualifier()) and 48 | this.asExpr() = ma.getArgument(_) 49 | ) 50 | } 51 | } 52 | 53 | /** A default sanitizer that considers numeric and boolean typed data safe for writing to output. */ 54 | private class DefaultXssSanitizer extends XssSanitizer { 55 | DefaultXssSanitizer() { 56 | this.getType() instanceof NumericType or 57 | this.getType() instanceof BooleanType or 58 | // Match `org.springframework.web.util.HtmlUtils.htmlEscape` and possibly other methods like it. 59 | this.asExpr().(MethodCall).getMethod().getName().regexpMatch("(?i)html_?escape.*") 60 | } 61 | } 62 | 63 | /** A configuration that tracks data from a servlet writer to an output method. */ 64 | private module MyXssVulnerableWriterSourceToWritingMethodFlowConfig implements DataFlow::ConfigSig { 65 | predicate isSource(DataFlow::Node src) { src.asExpr() instanceof MyXssVulnerableWriterSource } 66 | 67 | predicate isSink(DataFlow::Node sink) { 68 | exists(MethodCall ma | 69 | sink.asExpr() = ma.getQualifier() and ma.getMethod() instanceof WritingMethod 70 | ) 71 | } 72 | } 73 | 74 | private module MyXssVulnerableWriterSourceToWritingMethodFlow = 75 | TaintTracking::Global; 76 | 77 | /** A method that can be used to output data to an output stream or writer. */ 78 | private class WritingMethod extends Method { 79 | WritingMethod() { 80 | this.getDeclaringType().getAnAncestor().hasQualifiedName("java.io", _) and 81 | ( 82 | this.getName().matches("print%") or 83 | this.getName() = "append" or 84 | this.getName() = "format" or 85 | this.getName() = "write" 86 | ) 87 | } 88 | } 89 | 90 | /** An output stream or writer that writes to a servlet, JSP or JSF response. */ 91 | class MyXssVulnerableWriterSource extends MethodCall { 92 | MyXssVulnerableWriterSource() { 93 | this.getMethod() instanceof ServletResponseGetWriterMethod 94 | or 95 | this.getMethod() instanceof ServletResponseGetOutputStreamMethod 96 | or 97 | exists(Method m | m = this.getMethod() | 98 | m.getDeclaringType().getQualifiedName() = "javax.servlet.jsp.JspContext" and 99 | m.getName() = "getOut" 100 | ) 101 | or 102 | this.getMethod() instanceof FacesGetResponseWriterMethod 103 | or 104 | this.getMethod() instanceof FacesGetResponseStreamMethod 105 | } 106 | } 107 | 108 | /** 109 | * Holds if `s` is an HTTP Content-Type vulnerable to XSS. 110 | */ 111 | bindingset[s] 112 | predicate isXssVulnerableContentType(string s) { 113 | s.regexpMatch("(?i)text/(html|xml|xsl|rdf|vtt|cache-manifest).*") or 114 | s.regexpMatch("(?i)application/(.*\\+)?xml.*") or 115 | s.regexpMatch("(?i)cache-manifest.*") or 116 | s.regexpMatch("(?i)image/svg\\+xml.*") 117 | } 118 | 119 | /** 120 | * Holds if `s` is an HTTP Content-Type that is not vulnerable to XSS. 121 | */ 122 | bindingset[s] 123 | predicate isXssSafeContentType(string s) { not isXssVulnerableContentType(s) } 124 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/MyXssLocalQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides a taint-tracking configuration to reason about cross-site scripting from a local source. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.FlowSources 5 | private import semmle.code.java.dataflow.TaintTracking 6 | private import MyXSS 7 | 8 | import MySources 9 | import MySinks 10 | import MySummaries 11 | 12 | /** 13 | * A taint-tracking configuration for reasoning about cross-site scripting vulnerabilities from a local source. 14 | */ 15 | module MyXssLocalConfig implements DataFlow::ConfigSig { 16 | predicate isSource(DataFlow::Node source) { 17 | source instanceof LocalUserInput 18 | } 19 | 20 | predicate isSink(DataFlow::Node sink) { 21 | isGPTDetectedSink(sink) 22 | } 23 | 24 | predicate isBarrier(DataFlow::Node node) { node instanceof XssSanitizer } 25 | 26 | predicate isBarrierOut(DataFlow::Node node) { node instanceof XssSinkBarrier } 27 | 28 | predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 29 | any(XssAdditionalTaintStep s).step(node1, node2) or 30 | isGPTDetectedStep(node1, node2) 31 | } 32 | } 33 | 34 | /** 35 | * Taint-tracking flow for cross-site scripting vulnerabilities from a local source. 36 | */ 37 | module MyXssLocalFlow = TaintTracking::Global; 38 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/MyXssQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides a taint tracking configuration to track cross site scripting. */ 2 | 3 | import java 4 | import semmle.code.java.dataflow.FlowSources 5 | import semmle.code.java.dataflow.TaintTracking 6 | import MyXSS 7 | 8 | import MySources 9 | import MySinks 10 | import MySummaries 11 | 12 | /** 13 | * A taint-tracking configuration for cross site scripting vulnerabilities. 14 | */ 15 | module MyXssConfig implements DataFlow::ConfigSig { 16 | predicate isSource(DataFlow::Node source) { 17 | isGPTDetectedSource(source) 18 | } 19 | 20 | predicate isSink(DataFlow::Node sink) { 21 | isGPTDetectedSink(sink) 22 | } 23 | 24 | predicate isBarrier(DataFlow::Node node) { node instanceof XssSanitizer } 25 | 26 | predicate isBarrierOut(DataFlow::Node node) { node instanceof XssSinkBarrier } 27 | 28 | predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 29 | any(XssAdditionalTaintStep s).step(node1, node2) or 30 | isGPTDetectedStep(node1, node2) 31 | } 32 | } 33 | 34 | /** Tracks flow from remote sources to cross site scripting vulnerabilities. */ 35 | module MyXssFlow = TaintTracking::Global; 36 | 37 | 38 | module MyXssConfigSinksOnly implements DataFlow::ConfigSig { 39 | predicate isSource(DataFlow::Node source) { 40 | source instanceof ThreatModelFlowSource 41 | } 42 | 43 | predicate isSink(DataFlow::Node sink) { 44 | isGPTDetectedSink(sink) 45 | } 46 | 47 | predicate isBarrier(DataFlow::Node node) { node instanceof XssSanitizer } 48 | 49 | predicate isBarrierOut(DataFlow::Node node) { node instanceof XssSinkBarrier } 50 | 51 | predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 52 | any(XssAdditionalTaintStep s).step(node1, node2) 53 | } 54 | } 55 | 56 | /** Tracks flow from remote sources to cross site scripting vulnerabilities. */ 57 | module MyXssFlowSinksOnly = TaintTracking::Global; 58 | 59 | module MyXssConfigSourcesOnly implements DataFlow::ConfigSig { 60 | predicate isSource(DataFlow::Node source) { 61 | isGPTDetectedSource(source) 62 | } 63 | 64 | predicate isSink(DataFlow::Node sink) { 65 | sink instanceof XssSink 66 | } 67 | 68 | predicate isBarrier(DataFlow::Node node) { node instanceof XssSanitizer } 69 | 70 | predicate isBarrierOut(DataFlow::Node node) { node instanceof XssSinkBarrier } 71 | 72 | predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 73 | any(XssAdditionalTaintStep s).step(node1, node2) 74 | } 75 | } 76 | 77 | /** Tracks flow from remote sources to cross site scripting vulnerabilities. */ 78 | module MyXssFlowSourcesOnly = TaintTracking::Global; 79 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/XSS.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Cross-site scripting 3 | * @description Writing user input directly to a web page 4 | * allows for a cross-site scripting vulnerability. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 6.1 8 | * @precision high 9 | * @id java/myxss 10 | * @tags security 11 | * external/cwe/cwe-079 12 | */ 13 | 14 | import java 15 | import MyXssQuery 16 | import MyXssFlow::PathGraph 17 | 18 | bindingset[src] 19 | string sourceType(DataFlow::Node src) { 20 | if exists(Parameter p | src.asParameter() = p) 21 | then result = "user-provided value as public function parameter" 22 | else result = "user-provided value from external api return value" 23 | } 24 | 25 | from MyXssFlow::PathNode source, MyXssFlow::PathNode sink 26 | where MyXssFlow::flowPath(source, sink) 27 | select sink.getNode(), source, sink, "Cross-site scripting vulnerability due to a $@.", 28 | source.getNode(), 29 | sourceType(source.getNode()) 30 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/XSSLocal.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Cross-site scripting from local source 3 | * @description Writing user input directly to a web page 4 | * allows for a cross-site scripting vulnerability. 5 | * @kind path-problem 6 | * @problem.severity recommendation 7 | * @security-severity 6.1 8 | * @precision medium 9 | * @id java/myxss-local 10 | * @tags security 11 | * external/cwe/cwe-079 12 | */ 13 | 14 | import java 15 | import semmle.code.java.security.XssLocalQuery 16 | import MyXssLocalFlow::PathGraph 17 | 18 | from XssLocalFlow::PathNode source, XssLocalFlow::PathNode sink 19 | where XssLocalFlow::flowPath(source, sink) 20 | select sink.getNode(), source, sink, "Cross-site scripting vulnerability due to $@.", 21 | source.getNode(), "user-provided value" 22 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/XSSSinksOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Cross-site scripting 3 | * @description Writing user input directly to a web page 4 | * allows for a cross-site scripting vulnerability. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 6.1 8 | * @precision high 9 | * @id java/myxss-sinks-only 10 | * @tags security 11 | * external/cwe/cwe-079 12 | */ 13 | 14 | import java 15 | import MyXssQuery 16 | import MyXssFlowSinksOnly::PathGraph 17 | 18 | from MyXssFlowSinksOnly::PathNode source, MyXssFlowSinksOnly::PathNode sink 19 | where MyXssFlowSinksOnly::flowPath(source, sink) 20 | select sink.getNode(), source, sink, "Cross-site scripting vulnerability due to a $@.", 21 | source.getNode(), "user-provided value" 22 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-079/XSSSourcesOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Cross-site scripting 3 | * @description Writing user input directly to a web page 4 | * allows for a cross-site scripting vulnerability. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 6.1 8 | * @precision high 9 | * @id java/myxss-sources-only 10 | * @tags security 11 | * external/cwe/cwe-079 12 | */ 13 | 14 | import java 15 | import MyXssQuery 16 | import MyXssFlowSourcesOnly::PathGraph 17 | 18 | from MyXssFlowSourcesOnly::PathNode source, MyXssFlowSourcesOnly::PathNode sink 19 | where MyXssFlowSourcesOnly::flowPath(source, sink) 20 | select sink.getNode(), source, sink, "Cross-site scripting vulnerability due to a $@.", 21 | source.getNode(), "user-provided value" 22 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/MySpelInjection.qll: -------------------------------------------------------------------------------- 1 | /** Provides classes to reason about SpEL injection attacks. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.DataFlow 5 | private import semmle.code.java.frameworks.spring.SpringExpression 6 | 7 | /** A data flow sink for unvalidated user input that is used to construct SpEL expressions. */ 8 | abstract class SpelExpressionEvaluationSink extends DataFlow::ExprNode { } 9 | 10 | /** 11 | * A unit class for adding additional taint steps. 12 | * 13 | * Extend this class to add additional taint steps that should apply to the `SpELInjectionConfig`. 14 | */ 15 | class SpelExpressionInjectionAdditionalTaintStep extends Unit { 16 | /** 17 | * Holds if the step from `node1` to `node2` should be considered a taint 18 | * step for the `SpELInjectionConfig` configuration. 19 | */ 20 | abstract predicate step(DataFlow::Node node1, DataFlow::Node node2); 21 | } 22 | 23 | /** A set of additional taint steps to consider when taint tracking SpEL related data flows. */ 24 | private class DefaultSpelExpressionInjectionAdditionalTaintStep extends SpelExpressionInjectionAdditionalTaintStep 25 | { 26 | override predicate step(DataFlow::Node node1, DataFlow::Node node2) { 27 | expressionParsingStep(node1, node2) 28 | } 29 | } 30 | 31 | /** 32 | * Holds if `node1` to `node2` is a dataflow step that parses a SpEL expression, 33 | * by calling `parser.parseExpression(tainted)`. 34 | */ 35 | private predicate expressionParsingStep(DataFlow::Node node1, DataFlow::Node node2) { 36 | exists(MethodCall ma, Method m | ma.getMethod() = m | 37 | m.getDeclaringType().getAnAncestor() instanceof ExpressionParser and 38 | m.hasName(["parseExpression", "parseRaw"]) and 39 | ma.getAnArgument() = node1.asExpr() and 40 | node2.asExpr() = ma 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/MySpelInjectionQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides taint tracking and dataflow configurations to be used in SpEL injection queries. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.FlowSources 5 | private import semmle.code.java.dataflow.TaintTracking 6 | private import semmle.code.java.frameworks.spring.SpringExpression 7 | private import MySpelInjection 8 | 9 | import MySources 10 | import MySinks 11 | import MySummaries 12 | 13 | /** 14 | * A taint-tracking configuration for unsafe user input 15 | * that is used to construct and evaluate a SpEL expression. 16 | */ 17 | module MySpelInjectionConfig implements DataFlow::ConfigSig { 18 | predicate isSource(DataFlow::Node source) { 19 | //source instanceof ThreatModelFlowSource 20 | isGPTDetectedSource(source) 21 | } 22 | 23 | predicate isSink(DataFlow::Node sink) { 24 | //sink instanceof SpelExpressionEvaluationSink 25 | isGPTDetectedSink(sink) 26 | } 27 | 28 | predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 29 | any(SpelExpressionInjectionAdditionalTaintStep c).step(node1, node2) or 30 | isGPTDetectedStep(node1, node2) 31 | } 32 | } 33 | 34 | /** Tracks flow of unsafe user input that is used to construct and evaluate a SpEL expression. */ 35 | module MySpelInjectionFlow = TaintTracking::Global; 36 | 37 | 38 | module MySpelInjectionConfigSinksOnly implements DataFlow::ConfigSig { 39 | predicate isSource(DataFlow::Node source) { 40 | source instanceof ThreatModelFlowSource 41 | } 42 | 43 | predicate isSink(DataFlow::Node sink) { 44 | //sink instanceof SpelExpressionEvaluationSink 45 | isGPTDetectedSink(sink) 46 | } 47 | 48 | // predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 49 | // any(SpelExpressionInjectionAdditionalTaintStep c).step(node1, node2) 50 | // } 51 | } 52 | 53 | /** Tracks flow of unsafe user input that is used to construct and evaluate a SpEL expression. */ 54 | module MySpelInjectionFlowSinksOnly = TaintTracking::Global; 55 | 56 | module MySpelInjectionConfigSourcesOnly implements DataFlow::ConfigSig { 57 | predicate isSource(DataFlow::Node source) { 58 | //source instanceof ThreatModelFlowSource 59 | isGPTDetectedSource(source) 60 | } 61 | 62 | predicate isSink(DataFlow::Node sink) { 63 | sink instanceof SpelExpressionEvaluationSink 64 | } 65 | 66 | // predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 67 | // any(SpelExpressionInjectionAdditionalTaintStep c).step(node1, node2) 68 | // } 69 | } 70 | 71 | /** Tracks flow of unsafe user input that is used to construct and evaluate a SpEL expression. */ 72 | module MySpelInjectionFlowSourcesOnly = TaintTracking::Global; 73 | 74 | 75 | /** 76 | * A configuration for safe evaluation context that may be used in expression evaluation. 77 | */ 78 | private module SafeEvaluationContextFlowConfig implements DataFlow::ConfigSig { 79 | predicate isSource(DataFlow::Node source) { source instanceof SafeContextSource } 80 | 81 | predicate isSink(DataFlow::Node sink) { 82 | exists(MethodCall ma | 83 | ma.getMethod() instanceof ExpressionEvaluationMethod and 84 | ma.getArgument(0) = sink.asExpr() 85 | ) 86 | } 87 | 88 | int fieldFlowBranchLimit() { result = 0 } 89 | } 90 | 91 | private module SafeEvaluationContextFlow = DataFlow::Global; 92 | 93 | /** 94 | * A `ContextSource` that is safe from SpEL injection. 95 | */ 96 | private class SafeContextSource extends DataFlow::ExprNode { 97 | SafeContextSource() { 98 | isSimpleEvaluationContextConstructorCall(this.getExpr()) or 99 | isSimpleEvaluationContextBuilderCall(this.getExpr()) 100 | } 101 | } 102 | 103 | /** 104 | * Holds if `expr` constructs `SimpleEvaluationContext`. 105 | */ 106 | private predicate isSimpleEvaluationContextConstructorCall(Expr expr) { 107 | exists(ConstructorCall cc | 108 | cc.getConstructedType() instanceof SimpleEvaluationContext and 109 | cc = expr 110 | ) 111 | } 112 | 113 | /** 114 | * Holds if `expr` builds `SimpleEvaluationContext` via `SimpleEvaluationContext.Builder`, 115 | * for instance, `SimpleEvaluationContext.forReadWriteDataBinding().build()`. 116 | */ 117 | private predicate isSimpleEvaluationContextBuilderCall(Expr expr) { 118 | exists(MethodCall ma, Method m | ma.getMethod() = m | 119 | m.getDeclaringType() instanceof SimpleEvaluationContextBuilder and 120 | m.hasName("build") and 121 | ma = expr 122 | ) 123 | } 124 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/MyTemplateInjection.qll: -------------------------------------------------------------------------------- 1 | /** Definitions related to the server-side template injection (SST) query. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.FlowSources 5 | private import semmle.code.java.dataflow.ExternalFlow 6 | private import semmle.code.java.dataflow.TaintTracking 7 | 8 | /** 9 | * A source for server-side template injection (SST) vulnerabilities. 10 | */ 11 | abstract class TemplateInjectionSource extends DataFlow::Node { 12 | /** Holds if this source has the specified `state`. */ 13 | predicate hasState(DataFlow::FlowState state) { state instanceof DataFlow::FlowStateEmpty } 14 | } 15 | 16 | /** 17 | * A sink for server-side template injection (SST) vulnerabilities. 18 | */ 19 | abstract class TemplateInjectionSink extends DataFlow::Node { 20 | /** Holds if this sink has the specified `state`. */ 21 | predicate hasState(DataFlow::FlowState state) { state instanceof DataFlow::FlowStateEmpty } 22 | } 23 | 24 | /** 25 | * A unit class for adding additional taint steps. 26 | * 27 | * Extend this class to add additional taint steps that should apply to flows related to 28 | * server-side template injection (SST) vulnerabilities. 29 | */ 30 | class TemplateInjectionAdditionalTaintStep extends Unit { 31 | /** 32 | * Holds if the step from `node1` to `node2` should be considered a taint 33 | * step for flows related to server-side template injection (SST) vulnerabilities. 34 | */ 35 | predicate isAdditionalTaintStep(DataFlow::Node node1, DataFlow::Node node2) { none() } 36 | 37 | /** 38 | * Holds if the step from `node1` to `node2` should be considered a taint 39 | * step for flows related toserver-side template injection (SST) vulnerabilities. 40 | * This step is only applicable in `state1` and updates the flow state to `state2`. 41 | */ 42 | predicate isAdditionalTaintStep( 43 | DataFlow::Node node1, DataFlow::FlowState state1, DataFlow::Node node2, 44 | DataFlow::FlowState state2 45 | ) { 46 | none() 47 | } 48 | } 49 | 50 | /** 51 | * A sanitizer for server-side template injection (SST) vulnerabilities. 52 | */ 53 | abstract class TemplateInjectionSanitizer extends DataFlow::Node { } 54 | 55 | /** 56 | * A sanitizer for server-side template injection (SST) vulnerabilities. 57 | * This sanitizer is only applicable when `TemplateInjectionSanitizerWithState::hasState` 58 | * holds for the flow state. 59 | */ 60 | abstract class TemplateInjectionSanitizerWithState extends DataFlow::Node { 61 | /** Holds if this sanitizer has the specified `state`. */ 62 | abstract predicate hasState(DataFlow::FlowState state); 63 | } 64 | 65 | private class DefaultTemplateInjectionSource extends TemplateInjectionSource instanceof ThreatModelFlowSource 66 | { } 67 | 68 | private class DefaultTemplateInjectionSink extends TemplateInjectionSink { 69 | DefaultTemplateInjectionSink() { sinkNode(this, "template-injection") } 70 | } 71 | 72 | private class DefaultTemplateInjectionSanitizer extends TemplateInjectionSanitizer { 73 | DefaultTemplateInjectionSanitizer() { 74 | this.getType() instanceof PrimitiveType or 75 | this.getType() instanceof BoxedType or 76 | this.getType() instanceof NumericType 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/MyTemplateInjectionQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides a taint tracking configuration for server-side template injection (SST) vulnerabilities */ 2 | 3 | import java 4 | import semmle.code.java.dataflow.TaintTracking 5 | import semmle.code.java.dataflow.FlowSources 6 | import MyTemplateInjection 7 | import MySources 8 | import MySinks 9 | 10 | /** A taint tracking configuration to reason about server-side template injection (SST) vulnerabilities */ 11 | module MyTemplateInjectionFlowConfig implements DataFlow::StateConfigSig { 12 | class FlowState = DataFlow::FlowState; 13 | 14 | predicate isSource(DataFlow::Node source, FlowState state) { 15 | source.(TemplateInjectionSource).hasState(state) 16 | } 17 | 18 | predicate isSink(DataFlow::Node sink, FlowState state) { 19 | sink.(TemplateInjectionSink).hasState(state) 20 | } 21 | 22 | predicate isBarrier(DataFlow::Node sanitizer) { sanitizer instanceof TemplateInjectionSanitizer } 23 | 24 | predicate isBarrier(DataFlow::Node sanitizer, FlowState state) { 25 | sanitizer.(TemplateInjectionSanitizerWithState).hasState(state) 26 | } 27 | 28 | predicate isAdditionalFlowStep(DataFlow::Node node1, DataFlow::Node node2) { 29 | any(TemplateInjectionAdditionalTaintStep a).isAdditionalTaintStep(node1, node2) 30 | } 31 | 32 | predicate isAdditionalFlowStep( 33 | DataFlow::Node node1, FlowState state1, DataFlow::Node node2, FlowState state2 34 | ) { 35 | any(TemplateInjectionAdditionalTaintStep a).isAdditionalTaintStep(node1, state1, node2, state2) 36 | } 37 | } 38 | 39 | /** Tracks server-side template injection (SST) vulnerabilities */ 40 | module MyTemplateInjectionFlow = TaintTracking::GlobalWithState; 41 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/SpelInjection.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Expression language injection (Spring) 3 | * @description Evaluation of a user-controlled Spring Expression Language (SpEL) expression 4 | * may lead to remote code execution. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 9.3 8 | * @precision high 9 | * @id java/my-spel-expression-injection 10 | * @tags security 11 | * external/cwe/cwe-094 12 | */ 13 | 14 | import java 15 | import MySpelInjectionQuery 16 | import semmle.code.java.dataflow.DataFlow 17 | import MySpelInjectionFlow::PathGraph 18 | 19 | bindingset[src] 20 | string sourceType(DataFlow::Node src) { 21 | if exists(Parameter p | src.asParameter() = p) 22 | then result = "user-provided value as public function parameter" 23 | else result = "user-provided value from external api return value" 24 | } 25 | 26 | from 27 | MySpelInjectionFlow::PathNode source, 28 | MySpelInjectionFlow::PathNode sink 29 | where 30 | MySpelInjectionFlow::flowPath(source, sink) 31 | select 32 | sink.getNode(), source, sink, 33 | "SpEL expression depends on a $@.", source.getNode(), 34 | sourceType(source.getNode()) 35 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/SpelInjectionSinksOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Expression language injection (Spring) 3 | * @description Evaluation of a user-controlled Spring Expression Language (SpEL) expression 4 | * may lead to remote code execution. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 9.3 8 | * @precision high 9 | * @id java/my-spel-expression-injection-sinks-only 10 | * @tags security 11 | * external/cwe/cwe-094 12 | */ 13 | 14 | import java 15 | import MySpelInjectionQuery 16 | import semmle.code.java.dataflow.DataFlow 17 | import MySpelInjectionFlowSinksOnly::PathGraph 18 | 19 | from MySpelInjectionFlowSinksOnly::PathNode source, MySpelInjectionFlowSinksOnly::PathNode sink 20 | where MySpelInjectionFlowSinksOnly::flowPath(source, sink) 21 | select sink.getNode(), source, sink, "SpEL expression depends on a $@.", source.getNode(), 22 | "user-provided value" 23 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/SpelInjectionSourcesOnly.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Expression language injection (Spring) 3 | * @description Evaluation of a user-controlled Spring Expression Language (SpEL) expression 4 | * may lead to remote code execution. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 9.3 8 | * @precision high 9 | * @id java/my-spel-expression-injection-sinks-only 10 | * @tags security 11 | * external/cwe/cwe-094 12 | */ 13 | 14 | import java 15 | import MySpelInjectionQuery 16 | import semmle.code.java.dataflow.DataFlow 17 | import MySpelInjectionFlowSourcesOnly::PathGraph 18 | 19 | from MySpelInjectionFlowSourcesOnly::PathNode source, MySpelInjectionFlowSourcesOnly::PathNode sink 20 | where MySpelInjectionFlowSourcesOnly::flowPath(source, sink) 21 | select sink.getNode(), source, sink, "SpEL expression depends on a $@.", source.getNode(), 22 | "user-provided value" 23 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-094/TemplateInjection.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Server-side template injection 3 | * @description Untrusted input interpreted as a template can lead to remote code execution. 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 9.3 7 | * @precision high 8 | * @id java/my-server-side-template-injection 9 | * @tags security 10 | * external/cwe/cwe-1336 11 | * external/cwe/cwe-094 12 | */ 13 | 14 | import java 15 | import MyTemplateInjectionQuery 16 | import MyTemplateInjectionFlow::PathGraph 17 | 18 | from MyTemplateInjectionFlow::PathNode source, MyTemplateInjectionFlow::PathNode sink 19 | where MyTemplateInjectionFlow::flowPath(source, sink) 20 | select sink.getNode(), source, sink, "Template, which may contain code, depends on a $@.", 21 | source.getNode(), "user-provided value" 22 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-295/InsecureTrustManager.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name `TrustManager` that accepts all certificates 3 | * @description Trusting all certificates allows an attacker to perform a machine-in-the-middle attack. 4 | * @kind path-problem 5 | * @problem.severity error 6 | * @security-severity 7.5 7 | * @precision high 8 | * @id java/my-insecure-trustmanager 9 | * @tags security 10 | * external/cwe/cwe-295 11 | */ 12 | 13 | import java 14 | import semmle.code.java.dataflow.DataFlow 15 | import MyInsecureTrustManagerQuery 16 | import MyInsecureTrustManagerFlow::PathGraph 17 | 18 | from MyInsecureTrustManagerFlow::PathNode source, MyInsecureTrustManagerFlow::PathNode sink 19 | where MyInsecureTrustManagerFlow::flowPath(source, sink) 20 | select sink, source, sink, "This uses $@, which is defined in $@ and trusts any certificate.", 21 | source, "TrustManager", 22 | source.getNode().asExpr().(ClassInstanceExpr).getConstructedType() as type, type.nestedName() 23 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-295/MyInsecureTrustManager.qll: -------------------------------------------------------------------------------- 1 | /** Provides classes and predicates to reason about insecure `TrustManager`s. */ 2 | 3 | import java 4 | private import semmle.code.java.controlflow.Guards 5 | private import semmle.code.java.security.Encryption 6 | private import semmle.code.java.security.SecurityFlag 7 | 8 | 9 | 10 | /** Holds if `node` is guarded by a flag that suggests an intentionally insecure use. */ 11 | private predicate isGuardedByInsecureFlag(DataFlow::Node node) { 12 | exists(Guard g | g.controls(node.asExpr().getBasicBlock(), _) | 13 | g = getASecurityFeatureFlagGuard() or g = getAnInsecureTrustManagerFlagGuard() 14 | ) 15 | } 16 | 17 | /** The `java.security.cert.CertificateException` class. */ 18 | private class CertificateException extends RefType { 19 | CertificateException() { this.hasQualifiedName("java.security.cert", "CertificateException") } 20 | } 21 | 22 | /** 23 | * Holds if: 24 | * - `m` may `throw` a `CertificateException`, or 25 | * - `m` calls another method that may throw, or 26 | * - `m` calls a method declared to throw a `CertificateException`, but for which no source is available 27 | */ 28 | private predicate mayThrowCertificateException(Method m) { 29 | exists(ThrowStmt throwStmt | 30 | throwStmt.getThrownExceptionType().getAnAncestor() instanceof CertificateException 31 | | 32 | throwStmt.getEnclosingCallable() = m 33 | ) 34 | or 35 | exists(Method otherMethod | m.polyCalls(otherMethod) | 36 | mayThrowCertificateException(otherMethod) 37 | or 38 | not otherMethod.fromSource() and 39 | otherMethod.getAnException().getType().getAnAncestor() instanceof CertificateException 40 | ) 41 | } 42 | 43 | /** 44 | * Flags suggesting a deliberately insecure `TrustManager` usage. 45 | */ 46 | private class MyInsecureTrustManagerFlag extends FlagKind { 47 | MyInsecureTrustManagerFlag() { this = "InsecureTrustManagerFlag" } 48 | 49 | bindingset[result] 50 | override string getAFlagName() { 51 | result 52 | .regexpMatch("(?i).*(secure|disable|selfCert|selfSign|validat|verif|trust|ignore|nocertificatecheck).*") and 53 | result != "equalsIgnoreCase" 54 | } 55 | } 56 | 57 | /** Gets a guard that represents a (likely) flag controlling an insecure `TrustManager` use. */ 58 | private Guard getAnInsecureTrustManagerFlagGuard() { 59 | result = any(MyInsecureTrustManagerFlag flag).getAFlag().asExpr() 60 | } 61 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-295/MyInsecureTrustManagerQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides taint tracking configurations to be used in Trust Manager queries. */ 2 | 3 | import java 4 | import semmle.code.java.dataflow.FlowSources 5 | import MyInsecureTrustManager 6 | import MySources 7 | import MySinks 8 | 9 | /** 10 | * A configuration to model the flow of an insecure `TrustManager` 11 | * to the initialization of an SSL context. 12 | */ 13 | module MyInsecureTrustManagerConfig implements DataFlow::ConfigSig { 14 | predicate isSource(DataFlow::Node source) { 15 | //source instanceof InsecureTrustManagerSource 16 | isGPTDetectedSourceMethod(source.asExpr().(MethodCall).getMethod()) 17 | } 18 | 19 | predicate isSink(DataFlow::Node sink) { 20 | //sink instanceof InsecureTrustManagerSink 21 | (isGPTDetectedSinkMethodCall(sink.asExpr().(Call)) or 22 | isGPTDetectedSinkArgument(sink.asExpr().(Argument)) ) 23 | and not isGuardedByInsecureFlag(this) 24 | } 25 | 26 | predicate allowImplicitRead(DataFlow::Node node, DataFlow::ContentSet c) { 27 | (isSink(node) or isAdditionalFlowStep(node, _)) and 28 | node.getType() instanceof Array and 29 | c instanceof DataFlow::ArrayContent 30 | } 31 | } 32 | 33 | module MyInsecureTrustManagerFlow = DataFlow::Global; 34 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-502/UnsafeDeserialization.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Deserialization of user-controlled data 3 | * @description Deserializing user-controlled data may allow attackers to 4 | * execute arbitrary code. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 9.8 8 | * @precision high 9 | * @id java/unsafe-deserialization 10 | * @tags security 11 | * external/cwe/cwe-502 12 | */ 13 | 14 | import java 15 | import semmle.code.java.security.UnsafeDeserializationQuery 16 | import UnsafeDeserializationFlow::PathGraph 17 | 18 | from UnsafeDeserializationFlow::PathNode source, UnsafeDeserializationFlow::PathNode sink 19 | where UnsafeDeserializationFlow::flowPath(source, sink) 20 | select sink.getNode().(UnsafeDeserializationSink).getMethodCall(), source, sink, 21 | "Unsafe deserialization depends on a $@.", source.getNode(), "user-provided value" 22 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-611/MyXxe.qll: -------------------------------------------------------------------------------- 1 | /** Provides classes to reason about XML eXternal Entity (XXE) vulnerabilities. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.DataFlow 5 | 6 | /** A node where insecure XML parsing takes place. */ 7 | abstract class XxeSink extends DataFlow::Node { } 8 | 9 | /** A node that acts as a sanitizer in configurations realted to XXE vulnerabilities. */ 10 | abstract class XxeSanitizer extends DataFlow::Node { } 11 | 12 | /** 13 | * A unit class for adding additional taint steps. 14 | * 15 | * Extend this class to add additional taint steps that should apply to flows related to 16 | * XXE vulnerabilities. 17 | */ 18 | class XxeAdditionalTaintStep extends Unit { 19 | /** 20 | * Holds if the step from `node1` to `node2` should be considered a taint 21 | * step for flows related to XXE vulnerabilities. 22 | */ 23 | abstract predicate step(DataFlow::Node n1, DataFlow::Node n2); 24 | } 25 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-611/MyXxeQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides default definitions to be used in XXE queries. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.TaintTracking 5 | private import semmle.code.java.security.XmlParsers 6 | import semmle.code.java.security.Xxe 7 | 8 | /** 9 | * The default implementation of a XXE sink. 10 | * The argument of a parse call on an insecurely configured XML parser. 11 | */ 12 | private class DefaultXxeSink extends XxeSink { 13 | DefaultXxeSink() { 14 | not SafeSaxSourceFlow::flowTo(this) and 15 | exists(XmlParserCall parse | 16 | parse.getSink() = this.asExpr() and 17 | not parse.isSafe() 18 | ) 19 | } 20 | } 21 | 22 | /** 23 | * A taint-tracking configuration for safe XML readers used to parse XML documents. 24 | */ 25 | private module SafeSaxSourceFlowConfig implements DataFlow::ConfigSig { 26 | predicate isSource(DataFlow::Node src) { src.asExpr() instanceof SafeSaxSource } 27 | 28 | predicate isSink(DataFlow::Node sink) { sink.asExpr() = any(XmlParserCall parse).getSink() } 29 | 30 | int fieldFlowBranchLimit() { result = 0 } 31 | } 32 | 33 | private module SafeSaxSourceFlow = TaintTracking::Global; 34 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-611/MyXxeRemoteQuery.qll: -------------------------------------------------------------------------------- 1 | /** Provides taint tracking configurations to be used in remote XXE queries. */ 2 | 3 | import java 4 | private import semmle.code.java.dataflow.FlowSources 5 | private import semmle.code.java.dataflow.TaintTracking 6 | private import MyXxeQuery 7 | import MySources 8 | import MySinks 9 | 10 | /** 11 | * A taint-tracking configuration for unvalidated remote user input that is used in XML external entity expansion. 12 | */ 13 | module MyXxeConfig implements DataFlow::ConfigSig { 14 | predicate isSource(DataFlow::Node src) { 15 | //src instanceof ThreatModelFlowSource 16 | isGPTDetectedSourceMethod(source.asExpr().(MethodCall).getMethod()) 17 | } 18 | 19 | predicate isSink(DataFlow::Node sink) { 20 | //sink instanceof XxeSink 21 | isGPTDetectedSinkMethodCall(sink.asExpr().(Call)) or 22 | 23 | // an argument to a method call 24 | isGPTDetectedSinkArgument(sink.asExpr().(Argument)) 25 | } 26 | 27 | predicate isBarrier(DataFlow::Node sanitizer) { 28 | sanitizer instanceof XxeSanitizer 29 | } 30 | 31 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 32 | any(XxeAdditionalTaintStep s).step(n1, n2) 33 | } 34 | } 35 | 36 | /** 37 | * Detect taint flow of unvalidated remote user input that is used in XML external entity expansion. 38 | */ 39 | module MyXxeFlow = TaintTracking::Global; 40 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-611/XXE.ql: -------------------------------------------------------------------------------- 1 | /** 2 | * @name Resolving XML external entity in user-controlled data 3 | * @description Parsing user-controlled XML documents and allowing expansion of external entity 4 | * references may lead to disclosure of confidential data or denial of service. 5 | * @kind path-problem 6 | * @problem.severity error 7 | * @security-severity 9.1 8 | * @precision high 9 | * @id java/my-xxe 10 | * @tags security 11 | * external/cwe/cwe-611 12 | * external/cwe/cwe-776 13 | * external/cwe/cwe-827 14 | */ 15 | 16 | import java 17 | import semmle.code.java.dataflow.DataFlow 18 | import MyXxeRemoteQuery 19 | import MyXxeFlow::PathGraph 20 | 21 | from MyXxeFlow::PathNode source, MyXxeFlow::PathNode sink 22 | where MyXxeFlow::flowPath(source, sink) 23 | select sink.getNode(), source, sink, 24 | "XML parsing depends on a $@ without guarding against external entity expansion.", 25 | source.getNode(), "user-provided value" 26 | -------------------------------------------------------------------------------- /src/cwe-queries/cwe-general/GeneralSpecs.qll: -------------------------------------------------------------------------------- 1 | /** Provides dataflow configurations for tainted path queries. */ 2 | 3 | import java 4 | import semmle.code.java.frameworks.Networking 5 | import semmle.code.java.dataflow.DataFlow 6 | import semmle.code.java.dataflow.FlowSources 7 | private import semmle.code.java.dataflow.ExternalFlow 8 | 9 | import MySources 10 | import MySinks 11 | 12 | 13 | /** 14 | * A taint-tracking configuration for tracking flow from remote sources to the creation of a path. 15 | */ 16 | module MyPathConfig implements DataFlow::ConfigSig { 17 | predicate isSource(DataFlow::Node source) { 18 | isGPTDetectedSourceMethod(source.asExpr().(MethodCall).getMethod()) 19 | } 20 | 21 | predicate isSink(DataFlow::Node sink) { 22 | isGPTDetectedSinkMethodCall(sink.asExpr().(Call)) or 23 | 24 | // an argument to a method call 25 | isGPTDetectedSinkArgument(sink.asExpr().(Argument)) 26 | } 27 | 28 | predicate isBarrier(DataFlow::Node sanitizer) { 29 | sanitizer.getType() instanceof BoxedType or 30 | sanitizer.getType() instanceof PrimitiveType or 31 | sanitizer.getType() instanceof NumberType or 32 | sanitizer instanceof PathInjectionSanitizer 33 | } 34 | 35 | predicate isAdditionalFlowStep(DataFlow::Node n1, DataFlow::Node n2) { 36 | any(TaintedPathAdditionalTaintStep s).step(n1, n2) 37 | } 38 | } 39 | 40 | /** Tracks flow from remote sources to the creation of a path. */ 41 | module MyTaintedPathFlow = TaintTracking::Global; 42 | 43 | -------------------------------------------------------------------------------- /src/evaluate_spec_against_codeql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | import yaml 5 | import csv 6 | import json 7 | from tqdm import tqdm 8 | 9 | THIS_SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) 10 | NEUROSYMSA_ROOT_DIR = os.path.abspath(f"{THIS_SCRIPT_DIR}/../../") 11 | sys.path.append(NEUROSYMSA_ROOT_DIR) 12 | 13 | from src.config import CODEQL_DIR, OUTPUT_DIR 14 | from src.queries import QUERIES 15 | 16 | YAML_DIR = f"{CODEQL_DIR}/qlpacks/codeql/java-all/0.8.3/ext" 17 | 18 | SINK_KIND = { 19 | "022": ["path-injection"], 20 | "078": ["command-injection"], 21 | "079": ["html-injection", "js-injection"], 22 | "094": ["template-injection"] 23 | } 24 | 25 | def extensible_model(model): 26 | if model == "sinkModel": return "sink" 27 | elif model == "sourceModel": return "source" 28 | else: return "none" 29 | 30 | def get_all_codeql_specs(query): 31 | cwe_id = QUERIES[query]["cwe_id"] 32 | storage = {} 33 | 34 | for yml_file_name in tqdm(list(os.listdir(YAML_DIR)), desc="Loading CodeQL Yamls"): 35 | if not yml_file_name.endswith(".model.yml"): 36 | continue 37 | package = ".".join(yml_file_name.split(".")[:-2]) 38 | content = yaml.safe_load(open(f"{YAML_DIR}/{yml_file_name}")) 39 | extensions = content["extensions"] 40 | for extension in extensions: 41 | ext = extension["addsTo"]["extensible"] 42 | kind = extensible_model(ext) 43 | data = extension["data"] 44 | for api in data: 45 | if ext == "summaryModel" or ext == "neutralModel": 46 | api_package = api[0] 47 | clazz = api[1] 48 | method = api[2] 49 | elif ext == "sinkModel": 50 | api_package = api[0] 51 | clazz = api[1] 52 | method = api[3] 53 | sink_kind = api[7] 54 | elif ext == "sourceModel": 55 | api_package = api[0] 56 | clazz = api[1] 57 | method = api[3] 58 | item = (api_package, clazz, method) 59 | if ext == "sinkModel": 60 | if sink_kind in SINK_KIND[cwe_id]: 61 | storage[item] = "sink" 62 | else: 63 | storage[item] = "none" 64 | else: 65 | storage[item] = kind 66 | 67 | return storage 68 | 69 | def load_all_llm_specs(query, run_id, llm): 70 | cwe_id = QUERIES[query]["cwe_id"] 71 | labels_dir = f"{OUTPUT_DIR}/common/{run_id}/cwe-{cwe_id}/api_labels_{llm}.json" 72 | labels_json = json.load(open(labels_dir)) 73 | labels = {} 74 | for item in labels_json: 75 | package = item["package"] 76 | clazz = item["class"] 77 | method = item["method"] 78 | llm_label = item["type"] 79 | labels[(package, clazz, method)] = llm_label 80 | return labels 81 | 82 | def find_intersection(codeql_specs, llm_labels): 83 | intersection = {} 84 | for (sig, label) in codeql_specs.items(): 85 | if sig in llm_labels: 86 | intersection[sig] = { 87 | "codeql_label": label, 88 | "llm_label": llm_labels[sig], 89 | } 90 | return intersection 91 | 92 | def evaluate(intersection): 93 | kind_id = {"none": 0, "source": 1, "sink": 2, "taint-propagator": 0, "propagator": 0, "unknown": 0, "other": 0} 94 | array = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] 95 | results = [[[], [], []], [[], [], []], [[], [], []]] 96 | for (sig, labels) in intersection.items(): 97 | codeql_label = labels["codeql_label"] 98 | llm_label = labels["llm_label"] 99 | array[kind_id[codeql_label]][kind_id[llm_label]] += 1 100 | results[kind_id[codeql_label]][kind_id[llm_label]].append(sig) 101 | 102 | print(array[0]) 103 | print(array[1]) 104 | print(array[2]) 105 | 106 | total = sum([sum(row) for row in array]) 107 | diagonal = sum([array[i][i] for i in range(3)]) 108 | accuracy = diagonal / total 109 | source_recall = 1 if sum(array[1]) == 0 else array[1][1] / sum(array[1]) 110 | sink_recall = 1 if sum(array[2]) == 0 else array[2][2] / sum(array[2]) 111 | print(f"Total: {total}, Accuracy: {accuracy:.4f}, Source Recall: {source_recall:.4f}, Sink Recall: {sink_recall:.4f}") 112 | 113 | return results 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("query") 118 | parser.add_argument("run_id") 119 | parser.add_argument("llm") 120 | args = parser.parse_args() 121 | 122 | codeql_specs = get_all_codeql_specs(args.query) 123 | # print(f"#specs: {len(codeql_specs)}") 124 | llm_specs = load_all_llm_specs(args.query, args.run_id, args.llm) 125 | # print(f"#llm_specs: {len(llm_specs)}") 126 | intersection = find_intersection(codeql_specs, llm_specs) 127 | # print(f"#intersections: {len(intersection)}") 128 | results = evaluate(intersection) 129 | 130 | print("false negative sources", results[1][0]) 131 | print("false negative sinks", results[2][0]) 132 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | class Logger: 5 | def __init__(self, logdir): 6 | self.logdir = logdir 7 | os.makedirs(self.logdir, exist_ok=True) 8 | t = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 9 | self._logfile=f"{self.logdir}/log{t}.txt" 10 | 11 | def log(self, message, logtype="info", phase="", no_new_line=False, printonly=False): 12 | message=str(message) 13 | if len(phase) > 0: 14 | phase=f" [{phase}]" 15 | t=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 16 | s=f"[{logtype.upper()}] [{t}]{phase} {message}" 17 | if no_new_line: 18 | print(s, end="") 19 | else: 20 | print(s) 21 | if not printonly: 22 | with open(self._logfile, 'a') as f: 23 | f.write(s) 24 | f.write("\n") 25 | 26 | def info(self, message, phase="", no_new_line=False): 27 | self.log(message, "info", phase, no_new_line=no_new_line) 28 | 29 | def error(self, message, phase=""): 30 | self.log(message, "error", phase) 31 | 32 | def print(self, message, end=None): 33 | print(message, end=end) 34 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iris-sast/iris/0355004de110ecc425b8ca45024c6b4465ca1c2e/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/codegen.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" 9 | 10 | _model_name_map = { 11 | "codegen-16b-multi": "Salesforce/codegen-16B-multi", 12 | "codegen25-7b-instruct": "Salesforce/codegen25-7b-instruct", 13 | "codegen25-7b-multi": "Salesforce/codegen25-7b-multi" 14 | } 15 | 16 | class CodegenModel(LLM): 17 | def __init__(self, model_name, logger: MyLogger, **kwargs): 18 | super().__init__(model_name, logger, _model_name_map, **kwargs) 19 | 20 | def predict(self, main_prompt): 21 | # assuming 0 is system and 1 is user 22 | system_prompt = main_prompt[0]['content'] 23 | user_prompt = main_prompt[1]['content'] 24 | if 'instruct' in self.model_name: 25 | prompt = f"Instruction: {system_prompt}\\n Input: \\n {user_prompt}\\n Output:\\n" 26 | else: 27 | prompt = f"Input: \\n {user_prompt}\\n Output:\\n" 28 | #prompt = f"{user_prompt}" 29 | #inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") 30 | return self.predict_main(prompt) 31 | -------------------------------------------------------------------------------- /src/models/codellama.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" 9 | 10 | _model_name_map = { 11 | "codellama-70b-instruct": 'codellama/CodeLlama-70b-Instruct-hf', 12 | "codellama-34b": 'codellama/CodeLlama-34b-hf', 13 | "codellama-34b-python": 'codellama/CodeLlama-34b-Python-hf', 14 | "codellama-34b-instruct": 'codellama/CodeLlama-34b-Instruct-hf', 15 | "codellama-13b-instruct": 'codellama/CodeLlama-13b-Instruct-hf', 16 | "codellama-7b-instruct": 'codellama/CodeLlama-7b-Instruct-hf', 17 | } 18 | 19 | class CodeLlamaModel(LLM): 20 | def __init__(self, model_name, logger: MyLogger, **kwargs): 21 | super().__init__(model_name, logger, _model_name_map, **kwargs) 22 | self.terminators = [ 23 | self.pipe.tokenizer.eos_token_id, 24 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 25 | ] 26 | 27 | 28 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 29 | if batch_size > 0: 30 | prompts = [self.pipe.tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=True) for p in main_prompt] 31 | #print(prompts[0]) 32 | self.model_hyperparams['temperature']=0.01 33 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 34 | else: 35 | prompt = self.pipe.tokenizer.apply_chat_template( 36 | main_prompt, 37 | tokenize=False, 38 | add_generation_prompt=True 39 | ) 40 | l=len(self.tokenizer.tokenize(prompt)) 41 | self.log("Prompt length:" +str(l)) 42 | limit=16000 if self.kwargs["max_input_tokens"] is None else self.kwargs["max_input_tokens"] 43 | if l > limit: 44 | return "Too long, skipping: "+str(l) 45 | if 'dataflow' in self.kwargs['system_prompt_type']: 46 | print(">Setting max tokens to ", 2048) 47 | self.model_hyperparams['max_new_tokens']=2048 48 | self.model_hyperparams['temperature']=0.01 49 | #print(prompt) 50 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 51 | 52 | 53 | # assuming 0 is system and 1 is user 54 | system_prompt = main_prompt[0]['content'] 55 | user_prompt = main_prompt[1]['content'] 56 | prompt = f"[INST] <>\\n{system_prompt}\\n<>\\n\\n{user_prompt}[/INST]" 57 | l=len(self.tokenizer.tokenize(prompt)) 58 | self.log("Prompt length:" +str(l)) 59 | limit=16000 if self.kwargs["max_input_tokens"] is None else self.kwargs["max_input_tokens"] 60 | if l > limit: 61 | return prompt, "Too long, skipping: "+str(l) 62 | if 'dataflow' in self.kwargs['system_prompt_type']: 63 | print(">Setting max tokens to ", 1024) 64 | self.model_hyperparams['max_new_tokens']=1024 65 | return self.predict_main(prompt) 66 | 67 | if __name__ == '__main__': 68 | system="You are a security researcher, expert in detecting vulnerabilities. Provide response in following format: 'vulnerability: | vulnerability type: | lines of code: " 69 | from data.bigvul import BigVul 70 | bigvul = BigVul(os.path.join(config.config['DATA_DIR_PATH'] ,"MSR_20_Code_vulnerability_CSV_Dataset")) 71 | id, row = bigvul.get_next() 72 | 73 | codellama_model = CodeLlamaModel(None) 74 | print(">>>Running CodeLlama") 75 | print(">>>ID:", str(id)) 76 | print(codellama_model.predict(system, 77 | f"Can you find any vulnerability in this code? ```{row}```") 78 | ) 79 | -------------------------------------------------------------------------------- /src/models/codet5.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" 9 | 10 | _model_name_map = { 11 | "codet5p-16b-instruct": "Salesforce/instructcodet5p-16b", 12 | "codet5p-16b": "Salesforce/codet5p-16b", 13 | "codet5p-6b": "Salesforce/codet5p-6b", 14 | "codet5p-2b": "Salesforce/codet5p-2b" 15 | } 16 | 17 | class CodeT5PlusModel(): 18 | def __init__(self, model_name, logger: MyLogger, **kwargs): 19 | #super().__init__(model_name, logger, _model_name_map, **kwargs) 20 | self.model_name=model_name 21 | self.tokenizer = AutoTokenizer.from_pretrained(_model_name_map[model_name]) 22 | #dmap={'encoder': 0, 'decoder.transformer.wte': 0, 'decoder.transformer.drop': 0, 'decoder.transformer.h.0': 0, 'decoder.transformer.h.1': 1, 'decoder.transformer.h.2': 1, 'decoder.transformer.h.3': 1, 'decoder.transformer.h.4': 1, 'decoder.transformer.h.5': 2, 'decoder.transformer.h.6': 2, 'decoder.transformer.h.7': 2, 'decoder.transformer.h.8': 2, 'decoder.transformer.h.9': 3, 'decoder.transformer.h.10': 3, 'decoder.transformer.h.11': 3, 'decoder.transformer.h.12': 3, 'decoder.transformer.h.13': 4, 'decoder.transformer.h.14': 4, 'decoder.transformer.h.15': 4, 'decoder.transformer.h.16': 4, 'decoder.transformer.h.17': 5, 'decoder.transformer.h.18': 5, 'decoder.transformer.h.19': 5, 'decoder.transformer.h.20': 5, 'decoder.transformer.h.21': 6, 'decoder.transformer.h.22': 6, 'decoder.transformer.h.23': 6, 'decoder.transformer.h.24': 6, 'decoder.transformer.h.25': 4, 'decoder.transformer.h.26': 5, 'decoder.transformer.h.27': 6, 'decoder.transformer.h.28': 1, 'decoder.transformer.h.29': 1, 'decoder.transformer.h.30': 1, 'decoder.transformer.h.31': 2, 'decoder.transformer.h.32': 2, 'decoder.transformer.h.33': 2, 'decoder.transformer.ln_f': 3, 'decoder.lm_head': 3, 'enc_to_dec_proj': 3} 23 | self.model = AutoModelForSeq2SeqLM.from_pretrained(_model_name_map[model_name], 24 | torch_dtype=torch.float16, 25 | low_cpu_mem_usage=True, 26 | trust_remote_code=True, 27 | device_map="auto" 28 | ) 29 | #print(self.model.hf_device_map) 30 | 31 | #self.model.to_bettertransformer() 32 | 33 | def predict(self, main_prompt): 34 | # assuming 0 is system and 1 is user 35 | system_prompt = main_prompt[0]['content'] 36 | user_prompt = main_prompt[1]['content'] 37 | if 'instruct' in self.model_name: 38 | prompt = f"Instruction: {system_prompt}\\n Input:\\n {user_prompt} \\n Output:\\n" 39 | else: 40 | prompt = f"Input:\\n {user_prompt} \\n Output:\\n" 41 | #prompt = f"{user_prompt}" 42 | #inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") 43 | #return self.predict_main(prompt) 44 | encoding = self.tokenizer(prompt, return_tensors="pt").to("cuda") 45 | if len(encoding) > 1000: 46 | return prompt, "Skipping, too long " + str(len(encoding)) 47 | #encoding=encoding.to('cuda:3') 48 | encoding['decoder_input_ids'] = encoding['input_ids'].clone() 49 | outputs = self.model.generate(**encoding, max_length=2000) 50 | return prompt, self.tokenizer.decode(outputs[0], skip_special_tokens=True) 51 | -------------------------------------------------------------------------------- /src/models/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | _MODEL_DIR_PATH = 'modeldirs/' 3 | _DATA_DIR_PATH = 'datasets/' 4 | 5 | # same params for all sizes 6 | _DEFAULT_PARAMS = { 7 | 'max_new_tokens': 1024, 8 | 'temperature': 0.0, 9 | 'top_p': 1.0 10 | } 11 | 12 | 13 | config = dict() 14 | config['MODEL_DIR_PATH']=_MODEL_DIR_PATH 15 | 16 | config['DATA_DIR_PATH']=_DATA_DIR_PATH 17 | config['DEFAULT_PARAMS']=_DEFAULT_PARAMS 18 | 19 | 20 | 21 | 22 | -------------------------------------------------------------------------------- /src/models/deepseek.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import models.config as config 3 | from utils.mylogger import MyLogger 4 | import os 5 | from models.llm import LLM 6 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" 7 | 8 | _model_name_map = { 9 | "deepseekcoder-33b": 'deepseek-ai/deepseek-coder-33b-instruct', 10 | "deepseekcoder-7b": 'deepseek-ai/deepseek-coder-7b-instruct-v1.5', 11 | "deepseekcoder-v2-15b": "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" 12 | } 13 | 14 | class DeepSeekModel(LLM): 15 | def __init__(self, model_name, logger: MyLogger, **kwargs): 16 | super().__init__(model_name, logger, _model_name_map, **kwargs) 17 | self.terminators = [ 18 | self.pipe.tokenizer.eos_token_id, 19 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 20 | ] 21 | 22 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 23 | def rename(d): 24 | newd = dict() 25 | newd["role"]="user" 26 | newd["content"]=d[0]['content'] + '\n'+ d[1]['content'] 27 | #print(d) 28 | #print(newd) 29 | return [newd] 30 | 31 | if batch_size > 0: 32 | prompts = [self.pipe.tokenizer.apply_chat_template(rename(p), tokenize=False, add_generation_prompt=True) for p in main_prompt] 33 | #print(prompts[0]) 34 | self.model_hyperparams['temperature']=0.0 35 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 36 | else: 37 | 38 | prompt = self.pipe.tokenizer.apply_chat_template( 39 | main_prompt, 40 | tokenize=False, 41 | add_generation_prompt=True 42 | ) 43 | l=len(self.tokenizer.tokenize(prompt)) 44 | self.log("Prompt length:" +str(l)) 45 | limit=16000 if self.kwargs["max_input_tokens"] is None else self.kwargs["max_input_tokens"] 46 | if l > limit: 47 | return "Too long, skipping: "+str(l) 48 | self.model_hyperparams['temperature']=0.01 49 | #print(prompt) 50 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 51 | 52 | -------------------------------------------------------------------------------- /src/models/gemini.py: -------------------------------------------------------------------------------- 1 | # pip install google-generativeai 2 | 3 | import models.config as config 4 | from utils.mylogger import MyLogger 5 | import os 6 | from models.llm import LLM 7 | import google.generativeai as genai 8 | from tqdm.contrib.concurrent import thread_map 9 | 10 | _model_name_map = { 11 | "gemini-1.5-pro": "gemini-1.5-pro-latest", 12 | "gemini-1.5-flash": "gemini-1.5-flash-latest", 13 | "gemini-pro": "gemini-pro", 14 | "gemini-pro-vision": "gemini-pro-vision", 15 | "gemini-1.0-pro-vision": "gemini-1.0-pro-vision-latest" 16 | } 17 | _GEMINI_DEFAULT_PARAMS = {"temperature": 0.4, "top_p": 1, "top_k": 32, "max_tokens": 2048 } 18 | 19 | class GeminiModel(LLM): 20 | def __init__(self, model_name, logger: MyLogger, **kwargs): 21 | # https://aistudio.google.com/app/apikey 22 | if ("google_api_key" in kwargs) and (kwargs["google_api_key"] is not None): 23 | api_key = kwargs["google_api_key"] 24 | else: 25 | api_key = os.getenv("GOOGLE_API_KEY") 26 | genai.configure(api_key=api_key) 27 | self.logprobs = None 28 | for k in _GEMINI_DEFAULT_PARAMS: 29 | if k in kwargs: 30 | #print(f"Setting {k}:{kwargs[k]}") 31 | _GEMINI_DEFAULT_PARAMS[k] = kwargs[k] 32 | genai.GenerationConfig(max_output_tokens=_GEMINI_DEFAULT_PARAMS["max_tokens"], 33 | temperature=_GEMINI_DEFAULT_PARAMS["temperature"], 34 | top_p=_GEMINI_DEFAULT_PARAMS["top_p"], 35 | top_k=_GEMINI_DEFAULT_PARAMS["top_k"]) 36 | self.client = genai.GenerativeModel(model_name=model_name) 37 | 38 | def predict(self, prompt, batch_size=0, no_progress_bar=False): 39 | if batch_size == 0: 40 | return self._predict(prompt) 41 | args = range(0, len(prompt)) 42 | responses = thread_map( 43 | lambda x: self._predict(prompt[x]), 44 | args, 45 | max_workers=batch_size, 46 | disable=no_progress_bar) 47 | return responses 48 | 49 | def _predict(self, main_prompt): 50 | # assuming 0 is system and 1 is user 51 | # https://www.googlecloudcommunity.com/gc/AI-ML/Gemini-Pro-Context-Option/m-p/684704/highlight/true#M4159 52 | # There is no direct way for 53 | history = [{"role": "user", "parts": [{"text": f"System prompt: {main_prompt[0]['content']}"}],}, 54 | {"role": "model", "parts": [{"text": "Understood."}],}, 55 | {"role": "user", "parts": [{"text": f"{main_prompt[1]['content']}"}],}] 56 | #print(_GEMINI_DEFAULT_PARAMS) 57 | response = self.client.generate_content(history) 58 | response = response.text 59 | #print(response) 60 | return response 61 | 62 | if __name__ == '__main__': 63 | gemini=GeminiModel('gemini-1.5-pro', None) 64 | system_prompt="" 65 | user_prompt="" 66 | gemini.predict([{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]) 67 | -------------------------------------------------------------------------------- /src/models/google.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" 9 | 10 | _model_name_map = { 11 | "gemma-7b": 'google/gemma-7b', 12 | "gemma-7b-it": 'google/gemma-1.1-7b-it', 13 | "gemma-2b": 'google/gemma-2b', 14 | "gemma-2b-it": 'google/gemma-1.1-2b-it', 15 | "codegemma-7b-it" : 'google/codegemma-7b-it', 16 | "gemma-2-27b" : 'google/gemma-2-27b-it', 17 | "gemma-2-9b": 'google/gemma-2-9b-it' 18 | } 19 | 20 | class GoogleModel(LLM): 21 | def __init__(self, model_name, logger: MyLogger, **kwargs): 22 | super().__init__(model_name, logger, _model_name_map, **kwargs) 23 | self.terminators = [ 24 | self.pipe.tokenizer.eos_token_id, 25 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 26 | ] 27 | 28 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 29 | def rename(d): 30 | newd = dict() 31 | newd["role"]="user" 32 | newd["content"]=d[0]['content'] + '\n'+ d[1]['content'] 33 | #print(d) 34 | #print(newd) 35 | return [newd] 36 | 37 | if batch_size > 0: 38 | prompts = [self.pipe.tokenizer.apply_chat_template(rename(p), tokenize=False, add_generation_prompt=True) for p in main_prompt] 39 | #print(prompts[0]) 40 | self.model_hyperparams['temperature']=0.0 41 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 42 | else: 43 | prompt = self.pipe.tokenizer.apply_chat_template( 44 | main_prompt, 45 | tokenize=False, 46 | add_generation_prompt=True 47 | ) 48 | self.model_hyperparams['temperature']=0.01 49 | #print(prompt) 50 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 51 | # assuming 0 is system and 1 is user 52 | system_prompt = main_prompt[0]['content'] 53 | user_prompt = main_prompt[1]['content'] 54 | prompt = f"{system_prompt}\\n{user_prompt}" 55 | l=len(self.tokenizer.tokenize(prompt)) 56 | self.log("Prompt length:" +str(l)) 57 | limit=16000 if self.kwargs["max_input_tokens"] is None else self.kwargs["max_input_tokens"] 58 | if l > limit: 59 | return prompt, "Too long, skipping: "+str(l) 60 | if 'dataflow' in self.kwargs['system_prompt_type']: 61 | print(">Setting max tokens to ", 1024) 62 | self.model_hyperparams['max_new_tokens']=1024 63 | return self.predict_main(prompt) 64 | 65 | if __name__ == '__main__': 66 | system="You are a security researcher, expert in detecting vulnerabilities. Provide response in following format: 'vulnerability: | vulnerability type: | lines of code: " 67 | from data.bigvul import BigVul 68 | from data.cvefixes import CVEFixes 69 | #bigvul = BigVul(os.path.join(config.config['DATA_DIR_PATH'] ,"MSR_20_Code_vulnerability_CSV_Dataset"), logger=None) 70 | cvefixes=CVEFixes("cvefixes-c-cpp-method", logger=None).df 71 | #id, row = bigvul.get_next() 72 | row=cvefixes.iloc[0]['code'] 73 | print(row) 74 | gemma_model = GoogleModel("gemma-7b-it", logger=None, max_input_tokens=1024, flash=False, system_prompt_type='') 75 | print(">>>Running Gemma") 76 | print(">>>ID:", str(id)) 77 | model_input = [{"role": "system", "content": system}, {"role": "user", "content": row}] 78 | print(gemma_model.predict(model_input)) 79 | -------------------------------------------------------------------------------- /src/models/gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm.contrib.concurrent import thread_map 3 | from openai import OpenAI 4 | 5 | import src.models.config as config 6 | from src.utils.mylogger import MyLogger 7 | from src.models.llm import LLM 8 | 9 | _model_name_map = { 10 | "gpt-4": "gpt-4-0125-preview", 11 | "gpt-3.5": "gpt-3.5-turbo-0125", 12 | "gpt-4-1106": "gpt-4-1106-preview", 13 | "gpt-4-0613": "gpt-4-0613" 14 | } 15 | _OPENAI_DEFAULT_PARAMS = {"temperature": 0, "n": 1, "max_tokens": 4096, "stop": "", "seed": 345 } 16 | 17 | class GPTModel(LLM): 18 | def __init__(self, model_name, logger: MyLogger, **kwargs): 19 | super().__init__(model_name, logger, _model_name_map, **kwargs) 20 | if ("openai_api_key" in kwargs) and (kwargs["openai_api_key"] is not None): 21 | api_key = kwargs["openai_api_key"] 22 | else: 23 | api_key = os.getenv("OPENAI_API_KEY") 24 | self.client = OpenAI(api_key=api_key) 25 | self.logprobs = None 26 | for k in _OPENAI_DEFAULT_PARAMS: 27 | if k in kwargs: 28 | #print(f"Setting {k}:{kwargs[k]}") 29 | _OPENAI_DEFAULT_PARAMS[k] = kwargs[k] 30 | 31 | def predict(self, prompt, expect_json=False, batch_size=0, no_progress_bar=False): 32 | if batch_size == 0: 33 | return self._predict(prompt, expect_json) 34 | args = range(0, len(prompt)) 35 | responses = thread_map( 36 | lambda x: self._predict(prompt[x], expect_json), 37 | args, 38 | max_workers=batch_size, 39 | disable=no_progress_bar) 40 | return responses 41 | 42 | def _predict(self, main_prompt, expect_json=False): 43 | # assuming 0 is system and 1 is user 44 | system_prompt = main_prompt[0]['content'] 45 | user_prompt = main_prompt[1]['content'] 46 | prompt = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] 47 | if 'logprobs' in self.kwargs: 48 | _OPENAI_DEFAULT_PARAMS['logprobs']=self.kwargs["logprobs"] 49 | if 'top_logprobs' in self.kwargs: 50 | _OPENAI_DEFAULT_PARAMS['top_logprobs']=self.kwargs["top_logprobs"] 51 | if expect_json: 52 | response = self.client.chat.completions.create( 53 | model=self.model_id, 54 | messages=prompt, 55 | response_format={"type": "json_object"}, 56 | **_OPENAI_DEFAULT_PARAMS) 57 | else: 58 | response = self.client.chat.completions.create( 59 | model=self.model_id, 60 | messages=prompt, 61 | **_OPENAI_DEFAULT_PARAMS) 62 | if response.choices[0].logprobs != None: 63 | self.logprobs=response.choices[0].logprobs.content 64 | else: 65 | self.logprobs=None 66 | response=response.choices[0].message.content 67 | 68 | return response 69 | 70 | 71 | if __name__ == '__main__': 72 | from src.prompts import SYSTEM_PROMPTS, USER_PROMPTS 73 | gpt=GPTModel('gpt-4', None) 74 | system_prompt=SYSTEM_PROMPTS['SINK'] 75 | user_prompt=USER_PROMPTS["SINK"].format(cwe_description="Command Injection", 76 | cwe_id="78", 77 | functions=""" 78 | "java.lang","RuntimeException","RuntimeException" 79 | "java.lang","Runtime","exec" 80 | "java.lang","Runtime","getRuntime" 81 | "java.lang","Runtime","exec" 82 | "java.lang","Runtime","getRuntime" 83 | "java.lang","Runtime","addShutdownHook" 84 | "java.lang","Runtime","getRuntime" 85 | "java.lang","Runtime","removeShutdownHook" 86 | "java.lang","Runtime","getRuntime" 87 | "java.lang","RuntimeException","RuntimeException" 88 | "java.lang","Runtime","freeMemory" 89 | "java.lang","Runtime","getRuntime" 90 | "java.lang","Runtime","removeShutdownHook" 91 | "java.lang","Runtime","getRuntime" 92 | "java.lang","Runtime","addShutdownHook" 93 | "java.lang","Runtime","getRuntime" 94 | "java.lang","RuntimeException","RuntimeException" 95 | "java.lang","RuntimeException","RuntimeException" 96 | "java.lang","RuntimeException","RuntimeException" 97 | "java.lang","RuntimeException","RuntimeException" 98 | "java.lang","RuntimeException","RuntimeException" 99 | "java.lang","RuntimeException","RuntimeException" 100 | """ 101 | ) 102 | 103 | gpt.predict([{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]) 104 | -------------------------------------------------------------------------------- /src/models/llama.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from tqdm.contrib.concurrent import thread_map 6 | from together import Together 7 | from utils.mylogger import MyLogger 8 | import os 9 | from models.llm import LLM 10 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" 11 | 12 | _model_name_map = { 13 | "llama-2-7b-chat": "meta-llama/Llama-2-7b-chat-hf", 14 | "llama-2-13b-chat": "meta-llama/Llama-2-13b-chat-hf", 15 | "llama-2-70b-chat": "meta-llama/Llama-2-70b-chat-hf", 16 | "llama-2-7b": "meta-llama/Llama-2-7b-hf", 17 | "llama-2-13b": "meta-llama/Llama-2-13b-hf", 18 | "llama-2-70b": "meta-llama/Llama-2-70b-hf", 19 | "llama-3-8b" : "meta-llama/Meta-Llama-3-8B-Instruct", 20 | "llama-3.1-8b" : "meta-llama/Meta-Llama-3.1-8B-Instruct", 21 | "llama-3-70b" : "meta-llama/Meta-Llama-3-70B-Instruct", 22 | "llama-3.1-70b" : "meta-llama/Meta-Llama-3.1-70B-Instruct", 23 | "llama-3-70b": "meta-llama/Meta-Llama-3-70B-Instruct", 24 | "llama-3-70b-tai": "meta-llama/Meta-Llama-3-70B-Instruct-Turbo" 25 | } 26 | 27 | class LlamaModel(LLM): 28 | def __init__(self, model_name, logger: MyLogger, **kwargs): 29 | super().__init__(model_name, logger, _model_name_map, **kwargs) 30 | if "-tai" in self.model_name: 31 | self.together_client = Together() 32 | else: 33 | self.terminators = [ 34 | self.pipe.tokenizer.eos_token_id, 35 | self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 36 | ] 37 | 38 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 39 | if "-tai" in self.model_name: 40 | return self.predict_with_together_ai(main_prompt, batch_size, no_progress_bar) 41 | else: 42 | return self.predict_local(main_prompt, batch_size, no_progress_bar) 43 | 44 | def predict_with_together_ai(self, main_prompt, batch_size, no_progress_bar): 45 | if batch_size == 0: 46 | return self.predict_one_with_together_ai(main_prompt) 47 | else: 48 | args = range(0, len(main_prompt)) 49 | responses = thread_map( 50 | lambda x: self.predict_one_with_together_ai(main_prompt[x]), 51 | args, 52 | max_workers=batch_size, 53 | disable=no_progress_bar, 54 | ) 55 | return responses 56 | 57 | 58 | def predict_one_with_together_ai(self, prompt): 59 | completion = self.together_client.chat.completions.create( 60 | model=_model_name_map[self.model_name], 61 | messages=prompt, 62 | # response_format={"type": "json_object"}, 63 | temperature=0) 64 | response = completion.choices[0].message.content 65 | return response 66 | 67 | def predict_local(self, main_prompt, batch_size=0, no_progress_bar=False): 68 | # assuming 0 is system and 1 is user 69 | 70 | #prompt = f"[INST] <>\\n{system_prompt}\\n<>\\n\\n{user_prompt}[/INST]" 71 | if batch_size > 0: 72 | prompts = [self.pipe.tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=True) for p in main_prompt] 73 | self.model_hyperparams['temperature']=0.01 74 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 75 | else: 76 | prompt = self.pipe.tokenizer.apply_chat_template( 77 | main_prompt, 78 | tokenize=False, 79 | add_generation_prompt=True 80 | ) 81 | 82 | #prompt = f"{user_prompt}" 83 | #inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") 84 | self.model_hyperparams['temperature']=0.01 85 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 86 | -------------------------------------------------------------------------------- /src/models/mistral.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256" 9 | 10 | _model_name_map = { 11 | "mistral-7b-instruct": 'mistralai/Mistral-7B-Instruct-v0.2', 12 | "mixtral-8x7b-instruct": 'mistralai/Mixtral-8x7B-Instruct-v0.1', 13 | "mixtral-8x7b": 'mistralai/Mixtral-8x7B-v0.1', 14 | "mixtral-8x22b" : "mistralai/Mixtral-8x22B-Instruct-v0.1", 15 | "mistral-codestral-22b": "mistralai/Codestral-22B-v0.1" 16 | } 17 | 18 | class MistralModel(LLM): 19 | def __init__(self, model_name, logger: MyLogger, **kwargs): 20 | super().__init__(model_name, logger, _model_name_map, **kwargs) 21 | self.terminators = [ 22 | self.pipe.tokenizer.eos_token_id, 23 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 24 | ] 25 | 26 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 27 | def rename(d): 28 | newd = dict() 29 | newd["role"]="user" 30 | newd["content"]=d[0]['content'] + '\n'+ d[1]['content'] 31 | #print(d) 32 | #print(newd) 33 | return [newd] 34 | 35 | if batch_size > 0: 36 | prompts = [self.pipe.tokenizer.apply_chat_template(rename(p), tokenize=False, add_generation_prompt=True) for p in main_prompt] 37 | #print(prompts[0]) 38 | self.model_hyperparams['temperature']=0.01 39 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 40 | else: 41 | prompt = self.pipe.tokenizer.apply_chat_template( 42 | main_prompt, 43 | tokenize=False, 44 | add_generation_prompt=True 45 | ) 46 | self.model_hyperparams['temperature']=0.01 47 | #print(prompt) 48 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 49 | 50 | 51 | # assuming 0 is system and 1 is user 52 | system_prompt = main_prompt[0]['content'] 53 | user_prompt = main_prompt[1]['content'] 54 | prompt = f"[INST] \\n{system_prompt}\\n{user_prompt}[/INST]" 55 | l=len(self.tokenizer.tokenize(prompt)) 56 | self.log("Prompt length:" +str(l)) 57 | limit=16000 if self.kwargs.get("max_input_tokens", None) is None else self.kwargs["max_input_tokens"] 58 | if l > limit: 59 | return prompt, "Too long, skipping: "+str(l) 60 | # if 'dataflow' in self.kwargs['system_prompt_type']: 61 | # print(">Setting max tokens to ", 1024) 62 | # self.model_hyperparams['max_new_tokens']=1024 63 | return self.predict_main(prompt) 64 | 65 | if __name__ == '__main__': 66 | system="You are a security researcher, expert in detecting vulnerabilities. Provide response in following format: 'vulnerability: | vulnerability type: | lines of code: " 67 | from data.bigvul import BigVul 68 | from data.cvefixes import CVEFixes 69 | #bigvul = BigVul(os.path.join(config.config['DATA_DIR_PATH'] ,"MSR_20_Code_vulnerability_CSV_Dataset"), logger=None) 70 | cvefixes=CVEFixes("cvefixes-c-cpp-method", logger=None).df 71 | #id, row = bigvul.get_next() 72 | row=cvefixes.iloc[0]['code'] 73 | print(row) 74 | mistral_model = MistralModel("mixtral-8x7b-instruct", logger=None, max_input_tokens=1024, flash=False, system_prompt_type='') 75 | print(">>>Running Mistral") 76 | print(">>>ID:", str(id)) 77 | model_input = [{"role": "system", "content": system}, {"role": "user", "content": row}] 78 | print(mistral_model.predict(model_input)) 79 | -------------------------------------------------------------------------------- /src/models/ollama.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ollama 3 | from tqdm.contrib.concurrent import thread_map 4 | 5 | from src.models.llm import LLM 6 | from src.utils.mylogger import MyLogger 7 | 8 | _model_name_map = { 9 | "ollama-qwen-coder": "qwen2.5-coder:latest", 10 | "ollama-qwen": "qwen2.5:32b", 11 | "ollama-llama3": "llama3.2:latest", 12 | "ollama-deepseek-32b": "deepseek-r1:32b", 13 | "ollama-deepseek-7b": "deepseek-r1:latest", 14 | } 15 | 16 | # default model parameters, add or modify according to your needs 17 | # see https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values 18 | _OLLAMA_DEFAULT_OPTIONS = { 19 | "temperature": 0.8, 20 | "num_predict": -1, 21 | "stop": None, 22 | "seed": 0, 23 | } 24 | 25 | 26 | class OllamaModel(LLM): 27 | def __init__(self, model_name, logger: MyLogger, **kwargs): 28 | super().__init__(model_name, logger, _model_name_map, **kwargs) 29 | if host := os.environ.get("OLLAMA_HOST"): 30 | self.client = ollama.Client(host=host) 31 | else: 32 | self.log.error("Please set OLLAMA_HOST environment variable") 33 | # TODO: https://github.com/ollama/ollama/issues/2415 34 | # self.logprobs = None 35 | for k in _OLLAMA_DEFAULT_OPTIONS: 36 | if k in kwargs: 37 | _OLLAMA_DEFAULT_OPTIONS[k] = kwargs[k] 38 | 39 | def predict(self, prompt, batch_size=0, no_progress_bar=False): 40 | if batch_size == 0: 41 | return self._predict(prompt) 42 | args = range(0, len(prompt)) 43 | responses = thread_map( 44 | lambda x: self._predict(prompt[x]), 45 | args, 46 | max_workers=batch_size, 47 | disable=no_progress_bar, 48 | ) 49 | return responses 50 | 51 | def _predict(self, main_prompt): 52 | # assuming 0 is system and 1 is user 53 | system_prompt = main_prompt[0]["content"] 54 | user_prompt = main_prompt[1]["content"] 55 | prompt = [ 56 | {"role": "system", "content": system_prompt}, 57 | {"role": "user", "content": user_prompt}, 58 | ] 59 | try: 60 | response = self.client.chat( 61 | model=self.model_id, 62 | messages=prompt, 63 | options=_OLLAMA_DEFAULT_OPTIONS, 64 | ) 65 | except ollama.ResponseError as e: 66 | print("Ollama Response Error:", e.error) 67 | return None 68 | 69 | return response.message.content 70 | -------------------------------------------------------------------------------- /src/models/openaimodels.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | from utils.mylogger import MyLogger 4 | from utils.prompt_utils import generate_message_list, generate_validation_message_list 5 | import time 6 | 7 | _OPENAI_DEFAULT_PARAMS = {"temperature": 0, "n": 1, "max_tokens": 1024, "stop": ""} 8 | _DELAY_SECS = 5 9 | 10 | class OpenAIModel: 11 | def __init__(self, logger: MyLogger, model_name="gpt-4", **kwargs): 12 | if logger is None: 13 | self.log = lambda x: print(x) 14 | else: 15 | self.log = lambda x: logger.log(x) 16 | self.model_id = model_name 17 | self.model_params = _OPENAI_DEFAULT_PARAMS 18 | self.kwargs = kwargs 19 | self.log(f"Model: {model_name}") 20 | self.log(f"Model params: {self.model_params}") 21 | 22 | if ("openai_api_key" in kwargs) and (kwargs["openai_api_key"] is not None): 23 | openai.api_key = kwargs["openai_api_key"] 24 | else: 25 | openai.api_key = os.getenv("OPENAI_API_KEY") 26 | 27 | #print(openai.api_key) 28 | 29 | def call_openai(self, prompt, n_tries=5): 30 | while n_tries > 0: 31 | try: 32 | output = openai.ChatCompletion.create( 33 | model=self.model_id, messages=prompt, **self.model_params 34 | ) 35 | # Only return the first response 36 | return output["choices"][0]["message"] 37 | except Exception as e: 38 | n_tries -= 1 39 | error_message = "OpenAI call failed with Exception: ", str(e) 40 | self.log(error_message) 41 | # Add a time delay to recover the rate limit, if any 42 | time.sleep(_DELAY_SECS) 43 | # Report this as an error 44 | if n_tries == 0: 45 | return {"role": "error", "content": error_message} 46 | 47 | def get_prompt(self, snippet, prompt_cwe): 48 | cwe_specific = ( 49 | "cwe_specific" in self.kwargs["prompting_technique"] 50 | or "cwe_specific" in self.kwargs['prompt_type'] 51 | or "cwe_specific" in self.kwargs['system_prompt_type'] 52 | ) 53 | # Prompt with item CWE if asked 54 | prompt_cwe = prompt_cwe if cwe_specific else -1 55 | self.log(f"Prompting technique: {self.kwargs['prompting_technique']}") 56 | self.log(f"User Prompt: {self.kwargs['prompt_type']}") 57 | self.log(f"System Prompt: {self.kwargs['system_prompt_type']}") 58 | self.log(f"Prompt CWE: {prompt_cwe}") 59 | 60 | return generate_message_list( 61 | prompting_technique=self.kwargs["prompting_technique"], 62 | snippet=snippet, 63 | prompt_cwe=prompt_cwe, 64 | user_prompt=self.kwargs['prompt_type'], 65 | system_prompt=self.kwargs['system_prompt_type'], 66 | ) 67 | 68 | def predict(self, message): 69 | # Self validate results (using responses from the previous run) 70 | # We don't want to call the model with the previous prompts again 71 | if "validate_results_from_dir" in self.kwargs and self.kwargs["validate_results_from_dir"] is not None: 72 | main_prompt = generate_validation_message_list(message["id"], self.kwargs["validate_results_from_dir"]) 73 | response = self.call_openai(main_prompt) 74 | pred = "" if response["role"] == "error" else response["content"] 75 | main_prompt.append({"role": "assistant", "content": pred}) 76 | return self._stringify_chat(main_prompt), pred 77 | 78 | snippet = message["snippet"] 79 | prompt_cwe = message["prompt_cwe"] 80 | main_prompt = self.get_prompt(snippet=snippet, prompt_cwe=prompt_cwe) 81 | # Maintain a running prompt with the chat history 82 | running_prompt = [] 83 | for prompt in main_prompt: 84 | running_prompt.append(prompt) 85 | if prompt["role"] == "user": 86 | # Predict when a user prompt is provided 87 | response = self.call_openai(running_prompt) 88 | if response["role"] == "error": 89 | # An empty string in the prediction will re-run the experiment on the sample after reload 90 | running_prompt.append({"role": "assistant", "content": ""}) 91 | break 92 | 93 | time.sleep(_DELAY_SECS) 94 | # Store the chat history 95 | running_prompt.append(response) 96 | # Do not continue the chat if an error occurred 97 | 98 | pred = running_prompt[-1]["content"] 99 | return self._stringify_chat(running_prompt), pred 100 | 101 | def _stringify_chat(self, chat_history): 102 | prompt_str = "" 103 | # Skip the final prediction 104 | for prompt in chat_history[:-1]: 105 | prompt_str += prompt["role"].upper() + "\n" 106 | prompt_str += prompt["content"] 107 | prompt_str += "\n-------------------\n" 108 | return prompt_str 109 | -------------------------------------------------------------------------------- /src/models/qwen.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" 9 | _model_name_map = { 10 | "qwen2.5-coder-7b": "Qwen/Qwen2.5-Coder-7B-Instruct", 11 | "qwen2.5-coder-1.5b": "Qwen/Qwen2.5-Coder-1.5B-Instruct", 12 | "qwen2.5-14b": "Qwen/Qwen2.5-14B-Instruct", 13 | "qwen2.5-32b": "Qwen/Qwen2.5-32B-Instruct", 14 | "qwen2.5-72b" : "Qwen/Qwen2.5-72B-Instruct" 15 | 16 | } 17 | 18 | class QwenModel(LLM): 19 | def __init__(self, model_name, logger: MyLogger, **kwargs): 20 | super().__init__(model_name, logger, _model_name_map, **kwargs) 21 | self.terminators = [ 22 | self.pipe.tokenizer.eos_token_id 23 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 24 | ] 25 | 26 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 27 | # assuming 0 is system and 1 is user 28 | 29 | #prompt = f"[INST] <>\\n{system_prompt}\\n<>\\n\\n{user_prompt}[/INST]" 30 | if batch_size > 0: 31 | prompts = [self.pipe.tokenizer.apply_chat_template(p, tokenize=False, add_generation_prompt=True) for p in main_prompt] 32 | self.model_hyperparams['temperature']=0.01 33 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 34 | else: 35 | prompt = self.pipe.tokenizer.apply_chat_template( 36 | main_prompt, 37 | tokenize=False, 38 | add_generation_prompt=True 39 | ) 40 | 41 | #prompt = f"{user_prompt}" 42 | #inputs = self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda") 43 | self.model_hyperparams['temperature']=0.01 44 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 45 | 46 | -------------------------------------------------------------------------------- /src/models/starcoder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" 9 | _model_name_map = { 10 | "starcoder" : "bigcode/starcoder", 11 | "starcoder2-15b": "bigcode/starcoder2-15b" 12 | 13 | } 14 | 15 | class StarCoderModel(LLM): 16 | def __init__(self, model_name, logger: MyLogger, **kwargs): 17 | super().__init__(model_name, logger, _model_name_map, **kwargs) 18 | self.terminators = [ 19 | self.pipe.tokenizer.eos_token_id, 20 | # self.pipe.tokenizer.convert_tokens_to_ids("<|eot_id|>") 21 | ] 22 | 23 | def predict(self, main_prompt, batch_size=0, no_progress_bar=False): 24 | # assuming 0 is system and 1 is user 25 | #system_prompt = main_prompt[0]['content' 26 | def rename(d): 27 | return d[0]['content'] + '\n'+ d[1]['content'] 28 | #print(d) 29 | #print(newd) 30 | 31 | 32 | if batch_size > 0: 33 | prompts = [rename(p) for p in main_prompt] 34 | #print(prompts[0]) 35 | self.model_hyperparams['temperature']=0.01 36 | return self.predict_main(prompts, batch_size=batch_size, no_progress_bar=no_progress_bar) 37 | else: 38 | prompt = self.pipe.tokenizer.apply_chat_template( 39 | main_prompt, 40 | tokenize=False, 41 | add_generation_prompt=True 42 | ) 43 | self.model_hyperparams['temperature']=0.01 44 | #print(prompt) 45 | return self.predict_main(prompt, no_progress_bar=no_progress_bar) 46 | -------------------------------------------------------------------------------- /src/models/wizarcoder.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForCausalLM 2 | import transformers 3 | import torch 4 | import models.config as config 5 | from utils.mylogger import MyLogger 6 | import os 7 | from models.llm import LLM 8 | 9 | #os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64" 10 | _model_name_map = { 11 | "wizardcoder-15b": "WizardLM/WizardCoder-15B-V1.0", 12 | "wizardcoder-34b-python": "WizardLM/WizardCoder-Python-34B-V1.0", 13 | "wizardcoder-13b-python": "WizardLM/WizardCoder-Python-13B-V1.0", 14 | "wizardlm-70b": "WizardLM/WizardLM-70B-V1.0", 15 | "wizardlm-13b": "WizardLM/WizardLM-13B-V1.2", 16 | "wizardlm-30b": "WizardLM/WizardLM-30B-V1.0" 17 | } 18 | class WizardCoderModel(LLM): 19 | def __init__(self, model_name, logger: MyLogger, **kwargs): 20 | super().__init__(model_name, logger, _model_name_map, **kwargs) 21 | 22 | def predict(self, main_prompt): 23 | # assuming 0 is system and 1 is user 24 | system_prompt = main_prompt[0]['content'] 25 | user_prompt = main_prompt[1]['content'] 26 | prompt = user_prompt 27 | prompt = f"Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{system_prompt} \n\n {user_prompt}\n\n### Response:" 28 | 29 | return self.predict_main(prompt) 30 | -------------------------------------------------------------------------------- /src/modules/codeql_query_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess as sp 4 | import pandas as pd 5 | import shutil 6 | import json 7 | import re 8 | import argparse 9 | import numpy as np 10 | import copy 11 | import math 12 | import random 13 | 14 | from src.config import CODEQL_DIR 15 | from src.queries import QUERIES 16 | 17 | CODEQL = f"{CODEQL_DIR}/codeql" 18 | CODEQL_CUSTOM_QUERY_DIR = f"{CODEQL_DIR}/qlpacks/codeql/java-queries/0.8.3/myqueries" 19 | 20 | ENTRY_SCRIPT_DIR = os.path.abspath(os.path.dirname(os.path.realpath(__file__)) + "/../") 21 | 22 | class CodeQLQueryRunner: 23 | def __init__(self, project_name, project_output_path, project_codeql_db_path, project_logger): 24 | self.project_name = project_name 25 | self.project_codeql_db_path = project_codeql_db_path 26 | self.project_output_path = project_output_path 27 | self.project_logger = project_logger 28 | 29 | def run(self, query, target_csv_path=None, suffix=None, dyn_queries={}): 30 | """ 31 | :param query, is a string that should be a key in the QUERIES dictionary 32 | :param target_csv_path, is a path where the result csv should be stored to 33 | :param suffix, ??? 34 | :param dyn_queries, is a dictionary {: } of dyanmically generated queries. 35 | The name needs to be ending with a `.ql` or `.qll` extension. 36 | """ 37 | # 0. Sanity check 38 | if query not in QUERIES: 39 | self.project_logger.error(f" ==> Unknown query `{query}`; aborting"); exit(1) 40 | 41 | # 1. Create the directory in CodeQL's queries path 42 | suffix_dir = "" if suffix is None else f"/{suffix}" 43 | codeql_query_dir = f"{CODEQL_CUSTOM_QUERY_DIR}/{self.project_name}/{query}{suffix_dir}" 44 | os.makedirs(codeql_query_dir, exist_ok=True) 45 | 46 | # 2. Copy the basic queries and supporting queries to the codeql directory 47 | for q in QUERIES[query]["queries"]: 48 | shutil.copy(f"{ENTRY_SCRIPT_DIR}/{q}", f"{codeql_query_dir}/") 49 | 50 | # 3. Write the dynamic queries 51 | for dyn_query_name, content in dyn_queries.items(): 52 | with open(f"{codeql_query_dir}/{dyn_query_name}", "w") as f: 53 | f.write(content) 54 | 55 | # 4. Setup the paths 56 | main_query = QUERIES[query]["queries"][0] 57 | main_query_name = main_query.split("/")[-1] 58 | codeql_query_path = f"{codeql_query_dir}/{main_query_name}" 59 | 60 | query_result_path = f"{self.project_output_path}/{query}{suffix_dir}" 61 | query_result_bqrs_path = f"{self.project_output_path}/{query}{suffix_dir}/results.bqrs" 62 | query_result_csv_path = f"{self.project_output_path}/{query}{suffix_dir}/results.csv" 63 | os.makedirs(query_result_path, exist_ok=True) 64 | 65 | # 5. Run the query and generate result bqrs 66 | sp.run([CODEQL, "query", "run", f"--database={self.project_codeql_db_path}", f"--output={query_result_bqrs_path}", "--", codeql_query_path]) 67 | if not os.path.exists(query_result_bqrs_path): 68 | self.project_logger.error(f" ==> Failed to run query `{query}`; aborting"); exit(1) 69 | 70 | # 6. Decode the query 71 | sp.run([CODEQL, "bqrs", "decode", query_result_bqrs_path, "--format=csv", f"--output={query_result_csv_path}"]) 72 | if not os.path.exists(query_result_csv_path): 73 | self.project_logger.error(f" ==> Failed to decode result bqrs from `{query}`; aborting"); exit(1) 74 | 75 | # 7. Copy the query out 76 | if target_csv_path is not None: 77 | shutil.copy(query_result_csv_path, target_csv_path) 78 | -------------------------------------------------------------------------------- /src/modules/postprocess_cwe_query.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess as sp 4 | import pandas as pd 5 | import shutil 6 | import json 7 | import re 8 | import argparse 9 | import numpy as np 10 | import copy 11 | import math 12 | import random 13 | 14 | from src.config import CODEQL_DIR 15 | 16 | CODEQL = f"{CODEQL_DIR}/codeql" 17 | CODEQL_CUSTOM_QUERY_DIR = f"{CODEQL_DIR}/qlpacks/codeql/java-queries/0.8.3/myqueries" 18 | 19 | class CWEQueryResultPostprocessor: 20 | def __init__(self): 21 | pass 22 | -------------------------------------------------------------------------------- /src/neusym_vul_for_query.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import subprocess as sp 5 | import pandas as pd 6 | 7 | THIS_SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) 8 | NEUROSYMSA_ROOT_DIR = os.path.abspath(f"{THIS_SCRIPT_DIR}/../") 9 | sys.path.append(NEUROSYMSA_ROOT_DIR) 10 | 11 | from src.config import CVES_MAPPED_W_COMMITS_DIR 12 | from src.queries import QUERIES 13 | 14 | def collect_projects_for_query(query, cwe_id, all_cves_with_commit, all_project_tags): 15 | for (_, proj_row) in all_cves_with_commit.iterrows(): 16 | # Check relevance 17 | if f"CWE-{cwe_id}" not in proj_row["cwe_id"].split(";"): 18 | continue 19 | cve_id = proj_row["cve_id"] 20 | relevant_project_tag = all_project_tags[all_project_tags["cve_id"] == cve_id] 21 | if len(relevant_project_tag) == 0: 22 | continue 23 | project_name = relevant_project_tag.iloc[0]["project_slug"] 24 | yield project_name 25 | 26 | if __name__ == '__main__': 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--query", type=str, default="cwe-022wLLM") 29 | parser.add_argument("--llm", type=str, default="gpt-4") 30 | parser.add_argument("--run-id", type=str, default="default") 31 | parser.add_argument("--seed", type=int, default=1234) 32 | parser.add_argument("--label-api-batch-size", type=int, default=30) 33 | parser.add_argument("--label-func-param-batch-size", type=int, default=20) 34 | parser.add_argument("--num-threads", type=int, default=3) 35 | parser.add_argument("--no-summary-model", action="store_true") 36 | parser.add_argument("--use-exhaustive-qll", action="store_true") 37 | parser.add_argument("--skip-huge-project", action="store_true") 38 | parser.add_argument("--skip-huge-project-num-apis-threshold", type=int, default=3000) 39 | parser.add_argument("--skip-project", type=str, nargs="+") 40 | parser.add_argument("--filter-project", type=str, nargs="+") 41 | parser.add_argument("--skip-posthoc-filter", action="store_true") 42 | parser.add_argument("--skip-evaluation", action="store_true") 43 | parser.add_argument("--filter-by-module", action="store_true") 44 | parser.add_argument("--posthoc-filtering-skip-fp", action="store_true") 45 | parser.add_argument("--posthoc-filtering-rerun-skipped-fp", action="store_true") 46 | parser.add_argument("--evaluation-only", action="store_true") 47 | parser.add_argument("--overwrite", action="store_true") 48 | parser.add_argument("--overwrite-api-candidates", action="store_true") 49 | parser.add_argument("--overwrite-func-param-candidates", action="store_true") 50 | parser.add_argument("--overwrite-labelled-apis", action="store_true") 51 | parser.add_argument("--overwrite-llm-cache", action="store_true") 52 | parser.add_argument("--overwrite-labelled-func-param", action="store_true") 53 | parser.add_argument("--overwrite-cwe-query-result", action="store_true") 54 | parser.add_argument("--overwrite-posthoc-filter", action="store_true") 55 | parser.add_argument("--overwrite-debug-info", action="store_true") 56 | parser.add_argument("--debug-source", action="store_true") 57 | parser.add_argument("--debug-sink", action="store_true") 58 | parser.add_argument("--test-run", action="store_true") 59 | args = parser.parse_args() 60 | 61 | query = args.query 62 | if query not in QUERIES: 63 | print(f"Unknown query {query}") 64 | if "cwe_id_tag" not in QUERIES[query]: 65 | print(f"Not a CWE related query: {query}") 66 | cwe_id = QUERIES[query]["cwe_id"] 67 | 68 | all_cves_with_commit = pd.read_csv(CVES_MAPPED_W_COMMITS_DIR) 69 | all_project_tags = all_cves_with_commit.dropna(subset=["project_slug", "cve_id", "github_tag"]) 70 | 71 | relevant_projects = list(collect_projects_for_query(query, cwe_id, all_cves_with_commit, all_project_tags)) 72 | 73 | for (i, project) in enumerate(relevant_projects): 74 | print("===========================================") 75 | print(f"[{i + 1}/{len(relevant_projects)}] STARTING RUNNING ON PROJECT: {project}") 76 | 77 | # Skip if not desired 78 | if args.skip_project is not None: 79 | need_skip = False 80 | for skip_project_filter in args.skip_project: 81 | if skip_project_filter in project: 82 | need_skip = True; break 83 | if need_skip: 84 | continue 85 | 86 | if args.filter_project is not None: 87 | need_skip = True 88 | for filter_project_filter in args.filter_project: 89 | if filter_project_filter in project: 90 | need_skip = False; break 91 | if need_skip: 92 | continue 93 | 94 | # Generate the command 95 | command = [ 96 | "python", f"{THIS_SCRIPT_DIR}/neusym_vul.py", 97 | project, 98 | "--query", query, 99 | "--llm", args.llm, 100 | "--run-id", args.run_id, 101 | "--seed", str(args.seed), 102 | "--label-api-batch-size", str(args.label_api_batch_size), 103 | "--num-threads", str(args.num_threads), 104 | ] 105 | 106 | # Adding store_true arguments 107 | if args.no_summary_model: command += ["--no-summary-model"] 108 | if args.use_exhaustive_qll: command += ["--use-exhaustive-qll"] 109 | if args.skip_huge_project: command += ["--skip-huge-project", "--skip-huge-project-num-apis-threshold", str(args.skip_huge_project_num_apis_threshold)] 110 | if args.skip_posthoc_filter: command += ["--skip-posthoc-filter"] 111 | if args.skip_evaluation: command += ["--skip-evaluation"] 112 | if args.posthoc_filtering_skip_fp: command += ["--posthoc-filtering-skip-fp"] 113 | if args.posthoc_filtering_rerun_skipped_fp: command += ["--posthoc-filtering-rerun-skipped-fp"] 114 | if args.evaluation_only: command += ["--evaluation-only"] 115 | if args.filter_by_module: command += ["--filter-by-module"] 116 | 117 | # Overwrites 118 | if args.overwrite: command += ["--overwrite"] 119 | if args.overwrite_api_candidates: command += ["--overwrite-api-candidates"] 120 | if args.overwrite_func_param_candidates: command += ["--overwrite-func-param-candidates"] 121 | if args.overwrite_labelled_apis: command += ["--overwrite-labelled-apis"] 122 | if args.overwrite_llm_cache: command += ["--overwrite-llm-cache"] 123 | if args.overwrite_labelled_func_param: command += ["--overwrite-labelled-func-param"] 124 | if args.overwrite_cwe_query_result: command += ["--overwrite-cwe-query-result"] 125 | if args.overwrite_posthoc_filter: command += ["--overwrite-posthoc-filter"] 126 | if args.debug_source: command += ["--debug-source"] 127 | if args.debug_sink: command += ["--debug-sink"] 128 | if args.test_run: command += ["--test-run"] 129 | 130 | # Run the command 131 | sp.run(command) 132 | -------------------------------------------------------------------------------- /src/prompts.py: -------------------------------------------------------------------------------- 1 | API_LABELLING_SYSTEM_PROMPT = """\ 2 | You are a security expert. \ 3 | You are given a list of APIs to be labeled as potential taint sources, sinks, or APIs that propagate taints. \ 4 | Taint sources are values that an attacker can use for unauthorized and malicious operations when interacting with the system. \ 5 | Taint source APIs usually return strings or custom object types. Setter methods are typically NOT taint sources. \ 6 | Taint sinks are program points that can use tainted data in an unsafe way, which directly exposes vulnerability under attack. \ 7 | Taint propagators carry tainted information from input to the output without sanitization, and typically have non-primitive input and outputs. \ 8 | Return the result as a json list with each object in the format: 9 | 10 | { "package": , 11 | "class": , 12 | "method": , 13 | "signature": , 14 | "sink_args": , 15 | "type": <"source", "sink", or "taint-propagator"> } 16 | 17 | DO NOT OUTPUT ANYTHING OTHER THAN JSON.\ 18 | """ 19 | 20 | API_LABELLING_USER_PROMPT = """\ 21 | {cwe_long_description} 22 | 23 | Some example source/sink/taint-propagator methods are: 24 | {cwe_examples} 25 | 26 | Among the following methods, \ 27 | assuming that the arguments passed to the given function is malicious, \ 28 | what are the functions that are potential source, sink, or taint-propagators to {cwe_description} attack (CWE-{cwe_id})? 29 | 30 | Package,Class,Method,Signature 31 | {methods} 32 | """ 33 | 34 | FUNC_PARAM_LABELLING_SYSTEM_PROMPT = """\ 35 | You are a security expert. \ 36 | You are given a list of APIs implemented in established Java libraries, \ 37 | and you need to identify whether some of these APIs could be potentially invoked by downstream libraries with malicious end-user (not programmer) inputs. \ 38 | For instance, functions that deserialize or parse inputs might be used by downstream libraries and would need to add sanitization for malicious user inputs. \ 39 | On the other hand, functions like HTTP request handlers are typically final and won't be called by a downstream package. \ 40 | Utility functions that are not related to the primary purpose of the package should also be ignored. \ 41 | Return the result as a json list with each object in the format: 42 | 43 | { "package": , 44 | "class": , 45 | "method": , 46 | "signature": , 47 | "tainted_input": } 48 | 49 | In the result list, only keep the functions that might be used by downstream libraries and is potentially invoked with malicious end-user inputs. \ 50 | Do not output anything other than JSON.\ 51 | """ 52 | 53 | FUNC_PARAM_LABELLING_USER_PROMPT = """\ 54 | You are analyzing the Java package {project_username}/{project_name}. \ 55 | Here is the package summary: 56 | 57 | {project_readme_summary} 58 | 59 | Please look at the following public methods in the library and their documentations (if present). \ 60 | What are the most important functions that look like can be invoked by a downstream Java package that is dependent on {project_name}, \ 61 | and that the function can be called with potentially malicious end-user inputs? \ 62 | If the package does not seem to be a library, just return empty list as the result. \ 63 | Utility functions that are not related to the primary purpose of the package should also be ignored 64 | 65 | Package,Class,Method,Doc 66 | {methods} 67 | """ 68 | 69 | POSTHOC_FILTER_SYSTEM_PROMPT = """\ 70 | You are an expert in detecting security vulnerabilities. \ 71 | You are given the starting point (source) and the ending point (sink) of a dataflow path in a Java project that may be a potential vulnerability. \ 72 | Analyze the given taint source and sink and predict whether the given dataflow can be part of a vulnerability or not, and store it as a boolean in "is_vulnerable". \ 73 | Note that, the source must be either a) the formal parameter of a public library function which might be invoked by a downstream package, or b) the result of a function call that returns tainted input from end-user. \ 74 | If the given source or sink do not satisfy the above criteria, mark the result as NOT VULNERABLE. \ 75 | Please provide a very short explanation associated with the verdict. \ 76 | Assume that the intermediate path has no sanitizer. 77 | 78 | Answer in JSON object with the following format: 79 | 80 | { "explanation": , 81 | "source_is_false_positive": , 82 | "sink_is_false_positive": , 83 | "is_vulnerable": } 84 | 85 | Do not include anything else in the response.\ 86 | """ 87 | 88 | POSTHOC_FILTER_USER_PROMPT = """\ 89 | Analyze the following dataflow path in a Java project and predict whether it contains a {cwe_description} vulnerability ({cwe_id}), or a relevant vulnerability. 90 | {hint} 91 | 92 | Source ({source_msg}): 93 | ``` 94 | {source} 95 | ``` 96 | 97 | Steps: 98 | {intermediate_steps} 99 | 100 | Sink ({sink_msg}): 101 | ``` 102 | {sink} 103 | ```\ 104 | """ 105 | 106 | POSTHOC_FILTER_USER_PROMPT_W_CONTEXT = """\ 107 | Analyze the following dataflow path in a Java project and predict whether it contains a {cwe_description} vulnerability ({cwe_id}), or a relevant vulnerability. 108 | {hint} 109 | 110 | Source ({source_msg}): 111 | ``` 112 | {source} 113 | ``` 114 | 115 | Steps: 116 | {intermediate_steps} 117 | 118 | Sink ({sink_msg}): 119 | ``` 120 | {sink} 121 | ``` 122 | 123 | {context}\ 124 | """ 125 | # The key should be the CWE number without any string prefixes. 126 | # The value should be sentences describing more specific details for detecting the CWE. 127 | POSTHOC_FILTER_HINTS = { 128 | "022": "Note: please be careful about defensing against absolute paths and \"..\" paths. Just canonicalizing paths might not be sufficient for the defense.", 129 | "078": "Note that other than typical Runtime.exec which is directly executing command, using Java Reflection to create dynamic objects with unsanitized inputs might also cause OS Command injection vulnerability. This includes deserializing objects from untrusted strings and similar functionalities. Writing to config files about library data may also induce unwanted execution of OS commands.", 130 | "079": "Please be careful about reading possibly tainted HTML input. During sanitization, do not assume the sanitization to be sufficient.", 131 | "094": "Please note that dubious error messages can sometimes be handled by downstream code for execution, resulting in CWE-094 vulnerability. Injection of malicious values might lead to arbitrary code execution as well.", 132 | } 133 | 134 | SNIPPET_CONTEXT_SIZE = 4 135 | -------------------------------------------------------------------------------- /src/queries/fetch_class_locs.ql: -------------------------------------------------------------------------------- 1 | import java 2 | 3 | from 4 | RefType c 5 | where 6 | c.fromSource() and 7 | c.getName() != "" 8 | select 9 | c.getName() as name, 10 | c.getFile().getRelativePath() as file, 11 | c.getLocation().getStartLine() as start_line, 12 | c.getLocation().getEndLine() + c.getTotalNumberOfLines() as end_line 13 | -------------------------------------------------------------------------------- /src/queries/fetch_external_apis.ql: -------------------------------------------------------------------------------- 1 | import java 2 | 3 | predicate isExternallCall(Call c) { 4 | ( 5 | not c.getCallee().getDeclaringType().getPackage().getName().matches("org.junit%") and 6 | not c.getCallee().getDeclaringType().getPackage().getName().matches("org.hamcrest%") and 7 | not c.getCallee().getDeclaringType().getPackage().getName().matches("org.mockito%") and 8 | not c.getCallee().getDeclaringType().getPackage().getName().matches("junit.framework%") 9 | ) 10 | } 11 | 12 | bindingset[m] 13 | string fullSignature(Callable m) { 14 | if m instanceof Constructor 15 | then 16 | result = m.getName() + "(" + concat(int i | i = [0 .. m.getNumberOfParameters()] | m.getParameter(i).getType().getName() + " " + m.getParameter(i).getName(), ", " order by i asc) + ")" 17 | else 18 | result = m.getReturnType().getName() + " " + m.getName() + "(" + concat(int i | i = [0 .. m.getNumberOfParameters()] | m.getParameter(i).getType().getName() + " " + m.getParameter(i).getName(), ", " order by i asc) + ")" 19 | } 20 | 21 | bindingset[m] 22 | string paramTypes(Callable m) { 23 | result = concat(int i | i = [0 .. m.getNumberOfParameters()] | m.getParameter(i).getType().getName(), ";" order by i asc) 24 | } 25 | 26 | 27 | string isStaticAsString(Callable m) { 28 | if m.isStatic() 29 | then result = "true" 30 | else result = "false" 31 | } 32 | 33 | bindingset[m] 34 | string getJavadocString(Callable m) { 35 | ( 36 | exists(Javadoc d | m.getDoc().getJavadoc() = d) and 37 | result = concat(int i | i = [0 .. m.getDoc().getJavadoc().getNumChild()] | m.getDoc().getJavadoc().getChild(i).getText(), " " order by i asc) 38 | ) 39 | or 40 | result = "" 41 | } 42 | 43 | from 44 | Call api 45 | where 46 | isExternallCall(api) and 47 | api.getCallee().getStringSignature() != "()" and 48 | api.getCallee().getDeclaringType().getSourceDeclaration().getName() != "Object" 49 | select 50 | api as callstr, 51 | api.getCallee().getDeclaringType().getSourceDeclaration().getPackage() as package, 52 | api.getCallee().getDeclaringType().getSourceDeclaration() as clazz, 53 | fullSignature(api.getCallee()) as full_signature, 54 | api.getCallee().getStringSignature() as internal_signature, 55 | api.getCallee() as func, 56 | isStaticAsString(api.getCallee()) as is_static, 57 | api.getFile() as file, 58 | api.getLocation().toString() as location, 59 | paramTypes(api.getCallee()) as parameter_types, 60 | api.getCallee().getReturnType().getName() as return_type, 61 | getJavadocString(api.getCallee()) as doc 62 | -------------------------------------------------------------------------------- /src/queries/fetch_field_reads.ql: -------------------------------------------------------------------------------- 1 | import java 2 | 3 | predicate isExternalClass(Class c){ 4 | //not c.getPackage().fromSource() and // eliminates internal packages 5 | not c.getPackage().getName().matches("org.junit%") and 6 | not c.getPackage().getName().matches("org.hamcrest%") and 7 | not c.getPackage().getName().matches("junit.framework%") 8 | } 9 | 10 | 11 | from FieldRead fr, Field f 12 | where isExternalClass(f.getDeclaringType()) and fr.getField() = f and fr.getCompilationUnit().getPackage().fromSource() 13 | select fr as fieldread, 14 | f as field, 15 | f.getDeclaringType() as clazz, 16 | f.getDeclaringType().getPackage() as package, 17 | fr.getFile() as file, 18 | fr.getLocation().toString() as location 19 | 20 | -------------------------------------------------------------------------------- /src/queries/fetch_func_locs.ql: -------------------------------------------------------------------------------- 1 | import java 2 | 3 | from 4 | Method c 5 | where 6 | c.fromSource() and 7 | c.getName() != "" 8 | select 9 | c.getName() as name, 10 | c.getFile().getRelativePath() as file, 11 | c.getLocation().getStartLine() as start_line, 12 | c.getLocation().getEndLine() + c.getTotalNumberOfLines() as end_line 13 | -------------------------------------------------------------------------------- /src/queries/fetch_func_params.ql: -------------------------------------------------------------------------------- 1 | import java 2 | 3 | bindingset[m] 4 | string fullSignature(Callable m) { 5 | if m instanceof Constructor 6 | then 7 | result = m.getName() + "(" + concat(int i | i = [0 .. m.getNumberOfParameters()] | m.getParameter(i).getType().getName() + " " + m.getParameter(i).getName(), ", " order by i asc) + ")" 8 | else 9 | result = m.getReturnType().getName() + " " + m.getName() + "(" + concat(int i | i = [0 .. m.getNumberOfParameters()] | m.getParameter(i).getType().getName() + " " + m.getParameter(i).getName(), ", " order by i asc) + ")" 10 | } 11 | 12 | predicate isTested(Callable m) { 13 | exists(Call c | 14 | c.getCallee() = m and 15 | c.getLocation().toString().indexOf("src/test") >= 0 16 | ) 17 | } 18 | 19 | predicate isNotInvokedByInternalFunction(Callable m) { 20 | not exists(Call c | 21 | c.getCallee() = m and 22 | c.getLocation().toString().indexOf("src/test") < 0 23 | ) 24 | } 25 | 26 | bindingset[m] 27 | string paramTypes(Callable m) { 28 | result = concat(int i | i = [0 .. m.getNumberOfParameters()] | m.getParameter(i).getType().getName(), ";" order by i asc) 29 | } 30 | 31 | bindingset[m] 32 | string getJavadocString(Callable m) { 33 | if (exists(Javadoc d | m.getDoc().getJavadoc() = d)) 34 | then 35 | result = concat(int i | i = [0 .. m.getDoc().getJavadoc().getNumChild()] | m.getDoc().getJavadoc().getChild(i).getText(), " " order by i asc) 36 | else 37 | result = "" 38 | } 39 | 40 | from 41 | Callable method 42 | where 43 | method.fromSource() and 44 | method.isPublic() and 45 | not method.hasNoParameters() and 46 | isTested(method) and 47 | isNotInvokedByInternalFunction(method) 48 | select 49 | method.getDeclaringType().getSourceDeclaration().getPackage() as package, 50 | method.getDeclaringType().getSourceDeclaration() as clazz, 51 | method.getName() as func, 52 | fullSignature(method) as full_signature, 53 | method.getStringSignature() as internal_signature, 54 | method.getLocation().toString() as location, 55 | paramTypes(method) as parameter_types, 56 | method.getReturnType().getName() as return_type, 57 | getJavadocString(method) as doc 58 | -------------------------------------------------------------------------------- /src/queries/fetch_sinks.ql: -------------------------------------------------------------------------------- 1 | import java 2 | import semmle.code.java.dataflow.DataFlow 3 | private import semmle.code.java.dataflow.ExternalFlow 4 | 5 | import MySinks 6 | 7 | from 8 | DataFlow::Node node 9 | where 10 | isGPTDetectedSink(node) 11 | select 12 | node.toString() as node_str, 13 | node.getLocation() as loc 14 | -------------------------------------------------------------------------------- /src/queries/fetch_sources.ql: -------------------------------------------------------------------------------- 1 | import java 2 | import semmle.code.java.dataflow.DataFlow 3 | private import semmle.code.java.dataflow.ExternalFlow 4 | 5 | import MySources 6 | 7 | from 8 | DataFlow::Node node 9 | where 10 | isGPTDetectedSource(node) 11 | select 12 | node.toString() as node_str, 13 | node.getLocation() as loc 14 | -------------------------------------------------------------------------------- /src/queries/getpackages.ql: -------------------------------------------------------------------------------- 1 | import java 2 | from Package p 3 | where p.fromSource() 4 | select p, p.getName() 5 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iris-sast/iris/0355004de110ecc425b8ca45024c6b4465ca1c2e/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/cwe_top_25.txt: -------------------------------------------------------------------------------- 1 | 787 2 | 79 3 | 89 4 | 416 5 | 78 6 | 20 7 | 125 8 | 22 9 | 352 10 | 434 11 | 862 12 | 476 13 | 287 14 | 190 15 | 502 16 | 77 17 | 119 18 | 798 19 | 918 20 | 306 21 | 362 22 | 269 23 | 94 24 | 863 25 | 276 26 | -------------------------------------------------------------------------------- /src/utils/cwenames.txt: -------------------------------------------------------------------------------- 1 | name,id 2 | OS Command Injection,78 3 | Cryptographic,327 4 | weak hash,328 5 | LDAP Injection,90 6 | Path Traversal,22 7 | Sensitive Cookie|Secure Cookie,614 8 | SQL Injection,89 9 | trust boundary violation,501 10 | Insufficiently Random Values,330 11 | XPath Injection,643 12 | xss|cross-site scripting|cross site scripting,79 13 | Out-of-bounds read|out of bounds read,125 14 | Race Condition|Concurrent Execution using Shared Resource with Improper Synchronization|Improper Synchronization,362 15 | Command Injection,77 16 | Out-of-bounds Write|Out of bounds Write,787 17 | use after free,416 18 | Improper Input Validation,20 19 | Improper Privilege Management,269 20 | NULL Pointer Dereference,476 21 | integer overflow,190 22 | Code Injection,94 23 | CSRF|cross site scripting|cross-site scripting,352 24 | Missing Authorization,862 25 | Server-Side Request Forgery|SSRF|Cross Site Port Attack|xspa,918 26 | Improper Restriction of Operations within the Bounds of a Memory Buffer,119 27 | Deserialization of Untrusted Data,502 28 | Improper Authentication,287 29 | Unrestricted Upload of File with Dangerous Type,434 30 | Use of Hard-coded Credentials,798 31 | Missing Authentication for Critical Function,306 32 | Incorrect Authorization,863 33 | Incorrect Default Permissions,276 34 | -------------------------------------------------------------------------------- /src/utils/cwenames_top25.txt: -------------------------------------------------------------------------------- 1 | name,id 2 | OS Command Injection,78 3 | Cryptographic,327 4 | weak hash,328 5 | LDAP Injection,90 6 | Path Traversal,22 7 | SQL Injection,89 8 | trust boundary violation,501 9 | Insufficiently Random Values,330 10 | XPath Injection,643 11 | Cross Site Scripting,79 12 | Out-of-bounds Read,125 13 | Race Condition or Concurrent Execution using Shared Resource with Improper Synchronization,362 14 | Command Injection,77 15 | Out-of-bounds Write,787 16 | Use After Free,416 17 | Improper Input Validation,20 18 | Improper Privilege Management,269 19 | NULL Pointer Dereference,476 20 | Integer Overflow,190 21 | Improper Control of Generation of Code or Code Injection,94 22 | Cross-Site Request Forgery,352 23 | Missing Authorization,862 24 | Server-Side Request Forgery (SSRF),918 25 | Improper Restriction of Operations within the Bounds of a Memory Buffer,119 26 | Deserialization of Untrusted Data,502 27 | Improper Authentication,287 28 | Unrestricted Upload of File with Dangerous Type,434 29 | Use of Hard-coded Credentials,798 30 | Missing Authentication for Critical Function,306 31 | Incorrect Authorization,863 32 | Incorrect Default Permissions,276 33 | Sensitive Cookie in HTTPS Session Without 'Secure' Attribute,614 34 | -------------------------------------------------------------------------------- /src/utils/cweparser.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def get_cwe_mappings(): 4 | cwefile=f"cwec_v4.12.xml" 5 | 6 | import xml.etree.ElementTree as ET 7 | import re 8 | tree = ET.parse(cwefile) 9 | root = tree.getroot() 10 | 11 | natures=dict() 12 | namemap=dict() 13 | for element in root[0]: 14 | #if re.match(r'Weakness', element.tag): 15 | #print(element.tag, element.attrib['ID']) 16 | #print(",".join([element.attrib['ID'], element.attrib['Name']])) 17 | 18 | for child in element: 19 | if "Related_Weakness" in child.tag: 20 | for w in child: 21 | if "Related_Weakness" in w.tag: 22 | natures[w.attrib['Nature']] = natures.get(w.attrib['Nature'], 0) + 1 23 | if w.attrib['Nature'] == 'ChildOf': 24 | print(";".join([element.attrib['ID'], element.attrib['Name'], w.attrib['Nature'], w.attrib['CWE_ID']])) 25 | #print(">", child.tag, child.text) 26 | # print(">", child.tag, child.attrib['ID']) 27 | #print(natures) 28 | 29 | #print(elements) 30 | 31 | def is_parent(parent, child, df): 32 | if parent == child: 33 | return True 34 | children = df[df['parentid'] == parent]['childid'].tolist() 35 | if len(children) == 0: 36 | return False 37 | for c in children: 38 | if is_parent(c, child, df): 39 | return True 40 | return False 41 | 42 | 43 | def check_cwe(true_id, predicted_id): 44 | import pandas as pd 45 | true_id = int(true_id) 46 | predicted_id = int(predicted_id) 47 | df = pd.read_csv("cwemappings.csv", delimiter=";") 48 | # if predicted id is equal to target id or a parent of target id, return true 49 | if true_id == predicted_id: 50 | return True 51 | else: 52 | # check if target id is a child of predicted id 53 | # if len(df[(df['childid'] == true_id) & (df['relation'] == 'ChildOf') & (df['parentid'] == predicted_id)]) > 0: 54 | # return True 55 | # else: 56 | # return False 57 | return is_parent(predicted_id, true_id, df) 58 | 59 | if __name__ == '__main__': 60 | import sys 61 | print(check_cwe(sys.argv[1], sys.argv[2])) 62 | #print(check_cwe(1004, 732)) 63 | #get_cwe_mappings() 64 | 65 | -------------------------------------------------------------------------------- /src/utils/metrics_test.py: -------------------------------------------------------------------------------- 1 | from utils import compute_results, compute_precision_recall_accuracy, compute_prec_recall_multiclass, group_metrics 2 | import pandas as pd 3 | import sys 4 | import os 5 | import tabulate 6 | import argparse 7 | 8 | 9 | def gen_table(output_folder, group_by_col=None, dataset_csv_path=None, dataset_index_col=None, top_cwe=False, indices=None, max_samples=None, do_cwe=False): 10 | results = compute_results(output_folder) 11 | table = [] 12 | df = pd.DataFrame.from_dict(results, orient="index") 13 | if top_cwe: 14 | top25=open('utils/cwe_top_25.txt').read().strip().split('\n') 15 | print("Filtering by top 25 cwes..") 16 | 17 | df=df[df['true_cwe'].isin(top25)] 18 | if indices: 19 | print("Filtering by indices") 20 | indices = open(indices).read().strip().split('\n') 21 | df = df[df.index.isin(indices)] 22 | if max_samples: 23 | df=df.iloc[:max_samples] 24 | 25 | if group_by_col: 26 | if not dataset_csv_path: 27 | raise Exception("Dataset CSV Path not provided for grouped metrics computation") 28 | df = group_metrics(df, group_by_col, dataset_csv_path, dataset_index_col) 29 | 30 | all = compute_precision_recall_accuracy(df, "true_label", "llm_label") 31 | table.append( 32 | [ 33 | "All", 34 | len(df), 35 | all["TP"], 36 | all["TN"], 37 | all["FP"], 38 | all["FN"], 39 | all["accuracy"], 40 | all["accuracy_balanced"], 41 | all["precision"], 42 | all["recall"], 43 | all["F1"] 44 | ] 45 | ) 46 | if do_cwe: 47 | cwes = list(df["true_cwe"].unique()) 48 | for cwe in cwes: 49 | cwe_df = df[df["true_cwe"] == cwe] 50 | prec_recall = compute_precision_recall_accuracy( 51 | cwe_df, "true_label", "llm_label" 52 | ) 53 | table.append( 54 | [ 55 | "CWE-" + str(cwe), 56 | len(cwe_df), 57 | int(prec_recall["TP"]), 58 | int(prec_recall["TN"]), 59 | int(prec_recall["FP"]), 60 | int(prec_recall["FN"]), 61 | prec_recall["accuracy"], 62 | prec_recall["precision"], 63 | prec_recall["recall"], 64 | prec_recall["F1"] 65 | 66 | ] 67 | ) 68 | print( 69 | tabulate.tabulate( 70 | table, 71 | headers=["CWE", "Count" "TP", "TN", "FP", "FN", "Accuracy", "AccBalanced", "Precision", "Recall", "F1"], 72 | tablefmt="orgtbl", 73 | ) 74 | ) 75 | print( 76 | tabulate.tabulate( 77 | table, 78 | headers=["CWE","Count", "TP", "TN", "FP", "FN", "Accuracy", "AccBalanced", "Precision", "Recall", "F1"], 79 | tablefmt="latex", 80 | floatfmt=(".0f", ".0f", ".0f", ".0f", ".0f", ".2f",".2f", ".2f", ".2f"), 81 | ) 82 | ) 83 | 84 | 85 | def get_results_from_folder(output_folder, logger=None): 86 | results = compute_results(output_folder) 87 | if logger is None: 88 | _log = lambda x: print(x) 89 | else: 90 | _log = _log 91 | df = pd.DataFrame.from_dict(results, orient="index") 92 | # df.to_csv(os.path.join(output_folder, "results.csv")) 93 | 94 | prec_recall = compute_precision_recall_accuracy(df, "true_label", "llm_label") 95 | # print results 96 | _log(">>Total samples: " + str(len(df))) 97 | _log(">>Total vulnerable: " + str(len(df[df["true_label"] == True]))) 98 | _log(">>Total not vulnerable: " + str(len(df[df["true_label"] == False]))) 99 | 100 | _log(">>Accuracy: " + str(prec_recall["accuracy"])) 101 | _log(">>Recall: " + str(prec_recall["recall"])) 102 | _log(">>Precision: " + str(prec_recall["precision"])) 103 | _log(">>F1: " + str(prec_recall["f1"])) 104 | 105 | _log(">>Total correct CWE: " + str(len(df[df["cwe_correct"] == True]))) 106 | _log( 107 | ">>Total correct CWE and Label: " 108 | + str(len(df[(df["cwe_correct"] == True) & (df["correct"] == True)])) 109 | ) 110 | 111 | # cwe specific results 112 | precision_dict, recall_dict, accuracy_dict = compute_prec_recall_multiclass( 113 | df, "true_cwe", "llm_cwe" 114 | ) 115 | for k in precision_dict.keys(): 116 | _log( 117 | ">>CWE: " 118 | + str(k) 119 | + ",Accuracy: " 120 | + str(accuracy_dict[k]) 121 | + ",Precision: " 122 | + str(precision_dict[k]) 123 | + ",Recall: " 124 | + str(recall_dict[k]) 125 | 126 | ) 127 | 128 | 129 | if __name__ == "__main__": 130 | # get_results_from_folder(sys.argv[1]) 131 | argparse = argparse.ArgumentParser() 132 | argparse.add_argument("--results_dir", type=str, default="Path to results directory") 133 | argparse.add_argument("--group_by", type=str, default=None, help="Column in the dataset to be used for grouping results (default: None)") 134 | argparse.add_argument("--dataset_csv_path", type=str, help="Path to dataset CSV. Required for grouped metrics") 135 | argparse.add_argument("--dataset_index_col", type=str, default=None, help="Column in the dataset that maps to the results indexes (used for group joins)") 136 | argparse.add_argument("--top_cwe", action='store_true') 137 | argparse.add_argument("--indices", type=str, default=None) 138 | argparse.add_argument("--max_samples", type=int, default=None) 139 | argparse.add_argument("--cwe", action='store_true') 140 | args = argparse.parse_args() 141 | 142 | gen_table(args.results_dir, args.group_by, args.dataset_csv_path, args.dataset_index_col, args.top_cwe, args.indices, args.max_samples, args.cwe) 143 | -------------------------------------------------------------------------------- /src/utils/mylogger.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | class MyLogger: 4 | def __init__(self, logfile): 5 | self.logfile=logfile 6 | os.makedirs(os.path.dirname(self.logfile), exist_ok=True) 7 | 8 | 9 | def log(self, text, do_print=True): 10 | if do_print: 11 | print(text) 12 | with open(self.logfile, 'a') as f: 13 | f.write(str(text)) 14 | f.write("\n") 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /src/utils/prompt_utils.py: -------------------------------------------------------------------------------- 1 | from data.prompt import PROMPTS, PROMPTS_SYSTEM 2 | import pandas as pd 3 | import os 4 | 5 | cwenames = pd.read_csv("utils/cwenames_top25.txt", index_col="id") 6 | 7 | def get_cwe_name_from_id(id): 8 | if int(id) == -1: 9 | return "any vulnerability" 10 | return f"CWE-{str(id)} ({cwenames.loc[int(id)]['name']})" 11 | 12 | def generate_message_list(prompting_technique, snippet, prompt_cwe=-1, user_prompt="generic", system_prompt="generic"): 13 | if "self_reflection" in prompting_technique: 14 | prompt_message_list = generate_self_reflection_message_list( 15 | snippet=snippet, prompt_cwe=prompt_cwe 16 | ) 17 | elif "instruction_cot" in prompting_technique: 18 | prompt_message_list = generate_system_heuristics_cot_message_list( 19 | snippet=snippet, prompt_cwe=prompt_cwe, validate=True 20 | ) 21 | elif "step_by_step_dataflow_analysis" in prompting_technique: 22 | prompt_message_list = generate_step_by_step_dataflow_analysis_message_list( 23 | snippet=snippet, prompt_cwe=prompt_cwe, system_prompt_type=None 24 | ) 25 | elif prompting_technique == "few_shot_cot": 26 | prompt_message_list = generate_few_shot_cot_message_list( 27 | snippet=snippet, prompt_cwe=prompt_cwe 28 | ) 29 | elif "basic" in prompting_technique: 30 | prompt_message_list = generate_basic_message_list( 31 | snippet=snippet, 32 | prompt_cwe=prompt_cwe, 33 | prompt_type=user_prompt, 34 | system_prompt_type=system_prompt 35 | ) 36 | else: 37 | raise Exception(f"Prompting technique: {prompting_technique} not found") 38 | return prompt_message_list 39 | 40 | def generate_basic_message_list(snippet, prompt_cwe, prompt_type, system_prompt_type, validate=False): 41 | """ 42 | Generate an OpenAI API-style message with a system and a user prompt 43 | """ 44 | query = PROMPTS[prompt_type].format(snippet, get_cwe_name_from_id(prompt_cwe)) 45 | system_prompt = PROMPTS_SYSTEM[system_prompt_type] 46 | message_list = [{"role": "system", "content": system_prompt}, {"role": "user", "content": query}] 47 | if validate: 48 | message_list.append({ 49 | "role": "user", 50 | "content": "Is this analysis correct?" 51 | }) 52 | 53 | return message_list 54 | 55 | def generate_self_reflection_message_list(snippet, prompt_cwe=-1): 56 | """ 57 | Generate an OpenAI API-stye message list with self reflection messages 58 | """ 59 | return [ 60 | { 61 | "role": "user", 62 | "content": PROMPTS["taint_analysis"].format(snippet) 63 | }, 64 | { 65 | "role": "user", 66 | "content": "Is this analysis correct?" 67 | }, 68 | { 69 | "role": "user", 70 | "content": f"Based on this analysis, is the given code snippet prone to {get_cwe_name_from_id(prompt_cwe)}? Provide response only in following format: '$$ vulnerability: | vulnerability type: | lines of code: | explanation: $$'." 71 | } 72 | ] 73 | 74 | def generate_system_heuristics_cot_message_list(snippet, prompt_cwe=-1, validate=False): 75 | """ 76 | Generate an OpenAI API-stye message list with well_crafted system prompt + cot messages 77 | Add a self reflection style prompt if validate = True 78 | """ 79 | messages = [ 80 | { 81 | "role": "system", 82 | "content": PROMPTS_SYSTEM["heuristics"] 83 | }, 84 | { 85 | "role": "user", 86 | "content": PROMPTS["zero_shot_cot"].format(snippet, get_cwe_name_from_id(prompt_cwe)) 87 | } 88 | ] 89 | if validate: 90 | messages.append( 91 | { 92 | "role": "user", 93 | "content": "Is this analysis correct?" 94 | } 95 | ) 96 | messages.append({ 97 | { 98 | "role": "user", 99 | "content": f"Based on this analysis, is the given code snippet prone to {get_cwe_name_from_id(prompt_cwe)}? Provide response only in following format: '$$ vulnerability: | vulnerability type: | lines of code: | explanation: $$'." 100 | } 101 | }) 102 | return messages 103 | 104 | def generate_few_shot_cot_message_list(snippet, prompt_cwe=-1, system_prompt_type=None): 105 | """ 106 | Generate an OpenAI API-stye message list with few shot messages 107 | Add a self reflection style prompt if validate = True 108 | """ 109 | messages = [] 110 | if system_prompt_type: 111 | messages.append({ 112 | "role": "system", 113 | "content": PROMPTS_SYSTEM[system_prompt_type] 114 | }) 115 | messages.append({ 116 | "role": "user", 117 | "content": PROMPTS["cpp_few_shot"].format(snippet) 118 | }) 119 | return messages 120 | 121 | def generate_step_by_step_dataflow_analysis_message_list(snippet, prompt_cwe=-1, system_prompt_type=None): 122 | """ 123 | Generate an OpenAI API-stye message list with few shot messages 124 | Add a self reflection style prompt if validate = True 125 | """ 126 | messages = [] 127 | if system_prompt_type: 128 | messages.append({ 129 | "role": "system", 130 | "content": PROMPTS_SYSTEM[system_prompt_type] 131 | }) 132 | messages.extend([ 133 | { 134 | "role": "user", 135 | "content": PROMPTS["identify_sources_sinks_sanitizers"].format(snippet) 136 | }, 137 | { 138 | "role": "user", 139 | "content": "Now find the flows between these identified sources and sinks that are not sanitized." 140 | }, 141 | { 142 | "role": "user", 143 | "content": f"Based on this analysis, is the given code snippet prone to {get_cwe_name_from_id(prompt_cwe)}? Provide response only in following format: '$$ vulnerability: | vulnerability type: | lines of code: | explanation: $$'." 144 | } 145 | ]) 146 | return messages 147 | 148 | def generate_validation_message_list(id, dataset_results_dir): 149 | existing_results_dir = os.path.join(dataset_results_dir, id) 150 | prompt_log = open(os.path.join(existing_results_dir, "query.txt")).read().strip() 151 | pred = open(os.path.join(existing_results_dir, "pred.txt")).read().strip() 152 | 153 | prompt_sep = "-------------------" 154 | prompts = prompt_log.split(prompt_sep) 155 | # Each prompt of the format 156 | # ROLE 157 | # 158 | # 159 | # for prompt in prompts: 160 | # print(prompt) 161 | # print("+++++++++++++++=") 162 | messages = [ 163 | { 164 | "role": prompt.strip().splitlines()[0].lower(), 165 | "content": "\n".join(prompt.strip().splitlines()[2:]) 166 | } for prompt in prompts if len(prompt.strip()) > 0] 167 | 168 | # Add the validation prompt 169 | messages.extend([ 170 | {"role": "assistant", "content": pred}, 171 | {"role": "user", "content": PROMPTS["validation"]}]) 172 | return messages 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /src/utils/sample_spec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json 4 | import os 5 | 6 | mapping = { 7 | 'llama-8b': 'test-llama', 8 | 'llama-70b': 'test-llama70-f', 9 | 'gemma-7b': 'test-gemma', 10 | 'deepseek-33b': 'test-deepseekcoder-33b', 11 | 'deepseek-7b': 'test-deepseekcoder-7b', 12 | 'mistral-7b': 'test-mistral-7b', 13 | 'gpt4': 'test0', 14 | 'gpt3.5': 'test-gpt35' 15 | } 16 | 17 | llm_name = { 18 | 'llama-8b': 'llama-3-8b', 19 | 'llama-70b': 'llama-3-70b', 20 | 'gemma-7b': 'gemma-7b-it', 21 | 'deepseek-33b': 'deepseekcoder-33b', 22 | 'deepseek-7b': 'deepseekcoder-7b', 23 | 'mistral-7b': 'mistral-7b-instruct', 24 | 'gpt4': 'gpt-4', 25 | 'gpt3.5': 'gpt-3.5', 26 | } 27 | 28 | def sample(ty, llm, cwe, output, amount): 29 | specs = json.load(open(f"shared/v2/outputs/common/{mapping[llm]}/cwe-{cwe}/api_labels_{llm_name[llm]}.json")) 30 | filtered_specs = [s for s in specs if "type" in s and s["type"] == ty] 31 | random.shuffle(filtered_specs) 32 | sampled_specs = filtered_specs[:amount] 33 | json.dump(sampled_specs, open(f"{output}/sampled_{ty}_{llm}_{cwe}.json", "w")) 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--output", type=str, default="shared/v2/sampled-specs") 37 | parser.add_argument("--amount", type=int, default=10) 38 | parser.add_argument("--seed", type=int, default=1234) 39 | args = parser.parse_args() 40 | 41 | random.seed(args.seed) 42 | 43 | os.makedirs(args.output, exist_ok=True) 44 | for ty in ["source", "sink"]: 45 | for llm in mapping.keys(): 46 | for cwe in ["022", "078", "079", "094"]: 47 | sample(ty, llm, cwe, args.output, args.amount) 48 | --------------------------------------------------------------------------------