├── .github └── workflows │ └── close_stale_ticket.yml ├── .gitmodules ├── CodeQueries_preparation ├── README.md ├── aggregate_file_level.py ├── data_ingestion │ ├── create_raw_codeql_queryset.py │ ├── create_raw_programs_dataset.py │ ├── download_and_serialize_dataset.sh │ ├── queries.csv │ ├── raw_codeql_queryset.proto │ ├── raw_programs_dataset.proto │ ├── run_create_raw_codeql_queryset.py │ ├── run_create_raw_programs_dataset.py │ ├── test_create_raw_codeql_queryset.py │ └── test_create_raw_programs_dataset.py └── data_preparation │ ├── commands.md │ ├── contexts │ ├── basecontexts.py │ ├── conflictingattributesinbaseclasses.py │ ├── defineequalswhenaddingattributes.py │ ├── distributable.py │ ├── equalsorhash.py │ ├── flaskdebug.py │ ├── get_builtin_stub.py │ ├── get_context.py │ ├── get_mro.py │ ├── incompleteordering.py │ ├── incorrectcomparisonusingis.py │ ├── initcallssubclassmethod.py │ ├── iterreturnsnoniterator.py │ ├── missingcalltoinit.py │ ├── my-languages.so │ ├── noncallablecalled.py │ ├── signatureoverriddenmethod.py │ ├── test_conflictingattributesinbaseclasses.py │ ├── test_data │ │ ├── test__aux_res.csv │ │ └── test_file_content.py │ ├── test_defineequalswhenaddingattributes.py │ ├── test_distributable.py │ ├── test_equalsorhash.py │ ├── test_flaskdebug.py │ ├── test_incompleteordering.py │ ├── test_incorrectcomparisonusingis.py │ ├── test_initcallssubclassmethod.py │ ├── test_iterreturnsnoniterator.py │ ├── test_missingcalltoinit.py │ ├── test_noncallablecalled.py │ ├── test_run.sh │ ├── test_signatureoverriddenmethod.py │ ├── test_unusedimport.py │ ├── test_useimplicitnonereturnvalue.py │ ├── test_wrongnumberargumentsincall.py │ ├── test_wrongnumberargumentsinclassinstantiation.py │ ├── unusedimport.py │ ├── useimplicitnonereturnvalue.py │ ├── wrongnumberargumentsincall.py │ └── wrongnumberargumentsinclassinstantiation.py │ ├── create_block_subtokens_labels.py │ ├── create_blocks_labels_dataset.py │ ├── create_blocks_relevance_labels_dataset.py │ ├── create_groupwise_prediction_dataset.py │ ├── create_query_result.py │ ├── create_relevance_prediction_examples.py │ ├── create_single_example.py │ ├── create_span_prediction_training_examples.py │ ├── create_tokenized_files_labels.py │ ├── dataset_with_context.proto │ ├── merge_negative_positive_examples.py │ ├── my-languages.so │ ├── run_create_block_subtokens_labels.py │ ├── run_create_blocks_labels_dataset.py │ ├── run_create_groupwise_prediction_dataset.py │ ├── run_create_query_result.py │ ├── run_create_relevance_prediction_examples.py │ ├── run_create_span_prediction_training_examples.py │ ├── run_create_tokenized_files_labels.py │ ├── test_create_block_labels_dataset.py │ ├── test_create_cubert_model_examples.py │ ├── test_create_query_result.py │ ├── test_create_subtokens_labels.py │ ├── test_create_tokenized_files_labels.py │ └── vocab.txt ├── Codequeries_Statistics.pdf ├── LICENSE ├── README.md ├── analyze_classified_spans.py ├── evaluate_generated_spans.py ├── evaluate_relevance.py ├── evaluate_spanprediction.py ├── figures └── QA_Task.png ├── generate_spans.py ├── get_sampled_data.py ├── pretrained_model_configs ├── README.md ├── config_1024.json ├── config_512.json └── vocab.txt ├── prompt_templates ├── ex_with_sf.j2 ├── ex_wo_sf.j2 ├── span_highlight_0shot.j2 ├── span_highlight_fewshot.j2 └── span_highlight_fewshot_sf.j2 ├── requirements.txt ├── resources ├── codequeries_meta.json ├── query_folderName_map.pkl ├── sampled_test_data.pkl └── sampled_train_all_data.pkl ├── setup.sh ├── train_relevanceprediction.py ├── train_spanprediction.py ├── using_CodeQueries.ipynb ├── utils.py └── utils_openai.py /.github/workflows/close_stale_ticket.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "14 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | days-before-issue-stale: 14 16 | days-before-issue-close: 7 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 14 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 7 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "CodeQueries_preparation/data_preparation/cubert"] 2 | path = CodeQueries_preparation/data_preparation/cubert 3 | url = https://github.com/thepurpleowl/cubert.git 4 | -------------------------------------------------------------------------------- /CodeQueries_preparation/README.md: -------------------------------------------------------------------------------- 1 | ### Steps for data preparation 2 | 1. Compile the proto definitions in `data_ingestion` and `data_preparation` dir. For a quick reference, follow this [blog](https://www.freecodecamp.org/news/googles-protocol-buffers-in-python/). 3 | `protoc -I=. --python_out=. ` 4 | 2. Run `download_and_serialize_dataset.sh`. This script downloads required files and runs the data ingestion phase (i.e., serializes query and file data into proto format). 5 | 3. Run files from `data_preparation` in the following order. Input for each file can be checked from corresponding test files. 6 | 3.1. `run_create_query_result.py` 7 | 3.2. `run_create_tokenized_files_labels.py` 8 | 3.3. `run_create_blocks_labels_dataset.py` 9 | 3.4. `run_create_block_subtokens_labels.py` 10 | 3.5. `run_create_groupwise_prediction_dataset.py`/`run_create_relevance_prediction_examples.py`/`run_create_span_prediction_training_examples.py` for twostep/relevance prediction/span prediction data. 11 | -------------------------------------------------------------------------------- /CodeQueries_preparation/aggregate_file_level.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from tqdm import tqdm 3 | import datasets 4 | 5 | examples_data = datasets.load_dataset("thepurpleowl/codequeries", "twostep", split=datasets.Split.VALIDATION) 6 | 7 | # %% 8 | # to get all blocks of a file, use indices of twostep_dict 9 | # in similar fashion, one can aggregate spans. 10 | twostep_dict = {} # dict(query_name, code_file_path) = indices of code blocks 11 | for i, example_instance in enumerate(tqdm(examples_data)): 12 | twostep_key = (example_instance['query_name'], example_instance['code_file_path']) 13 | if twostep_key not in twostep_dict: 14 | twostep_dict[twostep_key] = [i] 15 | else: 16 | twostep_dict[twostep_key].append(i) 17 | 18 | print(len(twostep_dict.keys())) 19 | # %% 20 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/create_raw_programs_dataset.py: -------------------------------------------------------------------------------- 1 | import raw_programs_dataset_pb2 as raw_programs_dataset_pb2 2 | 3 | 4 | def CreateRawProgramsDatasetGithub(repo_name: str, 5 | program_language: str, 6 | program_path_and_content: list): 7 | """ 8 | This function helps ingest data into a protobuf, to creating raw 9 | programs dataset. This dataset will eventually be used to run 10 | CodeQL queries on, and to get answers to those queries. This function 11 | is used when the source of the programs' dataset is GitHub. 12 | Args: 13 | repo_name: When the programs are taken from GitHub, we need 14 | to store the name of the repository here, as a string. 15 | program_language: Here we mention the programming language, in which 16 | the files in this dataset are written. 17 | program_path_and_content: This is a list of tuples. Each tuple contains 18 | the path to a program file, and the corresponding 19 | content of the file. Both stored as strings. 20 | Returns: 21 | A protobuf containing the entire programs dataset, containing all 22 | kinds of information about the dataset. 23 | 24 | """ 25 | # Protobuf for the entire dataset. 26 | entire_programs_dataset = raw_programs_dataset_pb2.RawProgramDataset() 27 | for i in range(len(program_path_and_content)): 28 | file_path_in_dataset = raw_programs_dataset_pb2.GitHubFilePath() 29 | 30 | file_path_in_dataset.repo = repo_name 31 | # This must be a unique path to the GitHub file. (blob) 32 | file_path_in_dataset.unique_path = program_path_and_content[i].path 33 | 34 | path_of_file = raw_programs_dataset_pb2.FilePath() 35 | path_of_file.dataset_file_path.CopyFrom(file_path_in_dataset) 36 | 37 | raw_program_file = raw_programs_dataset_pb2.RawProgramFile() 38 | raw_program_file.file_path.CopyFrom(path_of_file) 39 | raw_program_file.language = raw_programs_dataset_pb2.Languages.Value( 40 | program_language) 41 | 42 | # Storing the program contents. 43 | raw_program_file.file_content = program_path_and_content[i].content 44 | 45 | # Storing the entire dataset. 46 | entire_programs_dataset.raw_program_dataset.append(raw_program_file) 47 | 48 | return entire_programs_dataset 49 | 50 | 51 | def CreateRawProgramsDatasetNonGithub(dataset_name: str, split_name: str, 52 | program_language: str, 53 | program_content_files: list 54 | ): 55 | """ 56 | This function helps ingest data into a protobuf, to creating raw 57 | programs dataset. This dataset will eventually be used to run CodeQL 58 | queries on, and to get answers to those queries. This function is 59 | used when the source of the dataset is some place other than GitHub. 60 | Args: 61 | dataset_name: When the programs are taken from some place other than 62 | GitHub, we need to store the name of the dataset here, 63 | as a string. 64 | split_name: If the creators of the dataset have split the dataset 65 | into train/test/validation, then we pass this information 66 | here as a string. 67 | program_language: Here we mention the programming language, in which 68 | the files in this dataset are written. 69 | program_content_files: This is a list of tuples. Each tuple contains 70 | the path to a program file,and the corresponding 71 | content of the file. Both stored as strings. 72 | Returns: 73 | A protobuf containing the entire programs dataset, containing 74 | all kinds of information about the dataset. 75 | 76 | """ 77 | # Protobuf for the entire dataset. 78 | entire_programs_dataset = raw_programs_dataset_pb2.RawProgramDataset() 79 | for i in range(len(program_content_files)): 80 | # Protobuf to store information about program file paths. 81 | file_path_in_dataset = raw_programs_dataset_pb2.DatasetFilePath() 82 | file_path_in_dataset.source_name = dataset_name 83 | file_path_in_dataset.split = ( 84 | raw_programs_dataset_pb2.datasetsplit.Value( 85 | split_name)) 86 | file_path_in_dataset.unique_file_path = program_content_files[i].path 87 | 88 | path_of_file = raw_programs_dataset_pb2.FilePath() 89 | path_of_file.dataset_file_path.CopyFrom(file_path_in_dataset) 90 | 91 | raw_program_file = raw_programs_dataset_pb2.RawProgramFile() 92 | raw_program_file.file_path.CopyFrom(path_of_file) 93 | 94 | # Storing programming language information. 95 | raw_program_file.language = raw_programs_dataset_pb2.Languages.Value( 96 | program_language) 97 | 98 | # Storing the program content. 99 | raw_program_file.file_content = program_content_files[i].content 100 | 101 | entire_programs_dataset.raw_program_dataset.append(raw_program_file) 102 | 103 | return entire_programs_dataset 104 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/download_and_serialize_dataset.sh: -------------------------------------------------------------------------------- 1 | wget http://files.srl.inf.ethz.ch/data/py150_files.tar.gz 2 | wget https://github.com/google-research-datasets/eth_py150_open/raw/master/train__manifest.json 3 | wget https://github.com/google-research-datasets/eth_py150_open/raw/master/eval__manifest.json 4 | wget https://github.com/google-research-datasets/eth_py150_open/raw/master/dev__manifest.json 5 | 6 | touch dev_program_files_paths.txt 7 | 8 | jq -r ".[] | .filepath" dev__manifest.json | while read i; do 9 | echo $i >> dev_program_files_paths.txt 10 | done 11 | 12 | touch train_program_files_paths.txt 13 | 14 | jq -r ".[] | .filepath" train__manifest.json | while read i; do 15 | echo $i >> train_program_files_paths.txt 16 | done 17 | 18 | touch eval_program_files_paths.txt 19 | 20 | jq -r ".[] | .filepath" eval__manifest.json | while read i; do 21 | echo $i >> eval_program_files_paths.txt 22 | done 23 | 24 | tar -xvzf py150_files.tar.gz 25 | 26 | tar -xvzf data.tar.gz 27 | 28 | python run_create_raw_programs_dataset.py --data_source=other --source_name=eth_py150_open \ 29 | --split_name=TRAIN --programs_file_path=$(pwd)/train_program_files_paths.txt \ 30 | --dataset_programming_language=Python --downloaded_dataset_location=$(pwd)/data \ 31 | --save_dataset_location=$(pwd)/train_raw_programs_serialized 32 | 33 | python run_create_raw_programs_dataset.py --data_source=other --source_name=eth_py150_open \ 34 | --split_name=VALIDATION --programs_file_path=$(pwd)/dev_program_files_paths.txt \ 35 | --dataset_programming_language=Python --downloaded_dataset_location=$(pwd)/data \ 36 | --save_dataset_location=$(pwd)/dev_raw_programs_serialized 37 | 38 | python run_create_raw_programs_dataset.py --data_source=other --source_name=eth_py150_open \ 39 | --split_name=TEST --programs_file_path=$(pwd)/eval_program_files_paths.txt \ 40 | --dataset_programming_language=Python --downloaded_dataset_location=$(pwd)/data \ 41 | --save_dataset_location=$(pwd)/eval_raw_programs_serialized 42 | 43 | python run_create_raw_codeql_queryset.py --labeled_queries_file_path=$(pwd)/$3 \ 44 | --target_programming_language=Python --github_auth=$1:$2 --save_queryset_location=$(pwd)/raw_queries_serialized 45 | 46 | rm -rf dev_program_files_paths.txt train_program_files_paths.txt eval_program_files_paths.txt \ 47 | python50k_eval.txt python100k_train.txt \ 48 | dev__manifest.json data.tar.gz py150_files.tar.gz \ 49 | README.md github_repos.txt data train__manifest.json eval__manifest.json -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/raw_codeql_queryset.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package codeqlqueries; 4 | 5 | message GitHubFilePath { 6 | optional string repo = 1; 7 | optional string unique_path = 2; 8 | } 9 | 10 | message QueryMetadata { 11 | optional string name = 1; 12 | optional string description = 2; 13 | optional string severity = 3; 14 | optional string message = 4; 15 | optional string full_metadata = 5; 16 | } 17 | 18 | enum Hops { 19 | undefined_hop = 0; 20 | single_hop = 1; 21 | multiple_hops = 2; 22 | } 23 | 24 | enum Languages { 25 | Python = 0; 26 | C = 1; 27 | Cpp = 2; 28 | Java = 3; 29 | } 30 | 31 | enum Span { 32 | undefined_span = 0; 33 | single_span = 1; 34 | multiple_spans = 2; 35 | } 36 | 37 | message Query { 38 | optional GitHubFilePath query_path = 1; 39 | optional bytes queryID = 2; 40 | optional string content = 3; 41 | optional QueryMetadata metadata = 4; 42 | optional Languages language = 5; 43 | optional Hops hops = 6; 44 | optional Span span = 7; 45 | optional bool distributable = 8; 46 | } 47 | 48 | message RawQueryList { 49 | repeated Query raw_query_set = 1; 50 | } -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/raw_programs_dataset.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package rawprogramsdataset; 4 | 5 | message GitHubFilePath { 6 | // Repository name from where the raw program files are taken 7 | optional string repo = 1; 8 | // blob 9 | optional string unique_path = 2; 10 | } 11 | 12 | // Determining path when programs are taken from any arbitrary 13 | // dataset 14 | enum datasetsplit { 15 | UNDEFINED = 0; 16 | TRAIN = 1; 17 | TEST = 2; 18 | VALIDATION = 3; 19 | } 20 | 21 | message DatasetFilePath { 22 | optional string source_name = 1; 23 | optional datasetsplit split = 2; 24 | // Relative to the topmost level of the files. 25 | optional string unique_file_path = 3; 26 | } 27 | 28 | // Unique identification of program filepath irrespective of source 29 | message FilePath { 30 | oneof filepathunion { 31 | GitHubFilePath github_file_path = 1; 32 | DatasetFilePath dataset_file_path = 2; 33 | } 34 | } 35 | 36 | enum Languages { 37 | Python = 0; 38 | C = 1; 39 | Cpp = 2; 40 | Java = 3; 41 | } 42 | 43 | message RawProgramFile { 44 | optional FilePath file_path = 1; 45 | optional Languages language = 2; 46 | optional string file_content = 3; 47 | } 48 | 49 | message RawProgramDataset { repeated RawProgramFile raw_program_dataset = 1; } -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/run_create_raw_codeql_queryset.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | import sys 3 | import create_raw_codeql_queryset 4 | 5 | FLAGS = flags.FLAGS 6 | 7 | 8 | flags.DEFINE_string( 9 | "labeled_queries_file_path", 10 | None, 11 | "Path of the file which contains the paths of the CodeQL query files \ 12 | and manually labeled Hops and Span information." 13 | ) 14 | 15 | flags.DEFINE_string( 16 | "target_programming_language", 17 | None, 18 | "Programming langauge for which the CodeQL query files can be used \ 19 | to analyze the codebase." 20 | ) 21 | 22 | flags.DEFINE_string( 23 | "github_auth", 24 | None, 25 | "Github API Basic Authentication to access blobs. \ 26 | Value given as github_username:personal_access_token" 27 | ) 28 | 29 | flags.DEFINE_string( 30 | "save_queryset_location", 31 | None, 32 | "File path where CodeQL query set should be serialized to." 33 | ) 34 | 35 | 36 | if __name__ == "__main__": 37 | argv = FLAGS(sys.argv) 38 | 39 | create_raw_codeql_queryset.serialize_queryset( 40 | FLAGS.labeled_queries_file_path, 41 | FLAGS.target_programming_language, 42 | FLAGS.github_auth, 43 | FLAGS.save_queryset_location 44 | ) 45 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/run_create_raw_programs_dataset.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | import sys 3 | import create_raw_programs_dataset 4 | from collections import namedtuple 5 | 6 | FLAGS = flags.FLAGS 7 | 8 | 9 | flags.DEFINE_string( 10 | "data_source", 11 | None, 12 | "Specify if the programs' dataset is taken from github or from some\ 13 | other source. Write github or other." 14 | ) 15 | 16 | flags.DEFINE_string( 17 | "repository_name", 18 | None, 19 | "If the programs' dataset is taken from some GitHub repository, provide\ 20 | the name of the repository." 21 | ) 22 | 23 | flags.DEFINE_string( 24 | "file_with_unique_paths", 25 | None, 26 | "Provide the path of the file which stores the of unique paths of the\ 27 | program files if taken from GitHub. (blob)" 28 | ) 29 | 30 | flags.DEFINE_string( 31 | "source_name", 32 | None, 33 | "If the programs dataset is taken from some other resource, provide\ 34 | the name of the source." 35 | ) 36 | 37 | flags.DEFINE_string( 38 | "split_name", 39 | None, 40 | "If the programs dataset is taken from some other resource, provide\ 41 | the name of the split (TRAIN/TEST/VALIDATION/UNDEFINED)" 42 | ) 43 | 44 | flags.DEFINE_string( 45 | "programs_file_path", 46 | None, 47 | "If the programs' dataset is taken from some other resource provide\ 48 | the path of the file which contains the paths of the program files\ 49 | relative to the topmost level." 50 | ) 51 | 52 | flags.DEFINE_string( 53 | "dataset_programming_language", 54 | None, 55 | "Provide the programming langauge used in the programs in the dataset\ 56 | files." 57 | ) 58 | 59 | flags.DEFINE_string( 60 | "downloaded_dataset_location", 61 | None, 62 | "Provide the path to the folder where the data has been downloaded to." 63 | ) 64 | 65 | flags.DEFINE_string( 66 | "save_dataset_location", 67 | None, 68 | "File path where raw programs' dataset should be serialized to." 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | argv = FLAGS(sys.argv) 74 | File_Path_Content = namedtuple("File_Path_Content", "path content") 75 | 76 | if(FLAGS.data_source == "other"): 77 | files = [] 78 | with open(FLAGS.programs_file_path, "r") as f: 79 | for line in f: 80 | line = line.strip() 81 | files.append(line) 82 | 83 | file_content = [] 84 | for i in range(len(files)): 85 | with open(FLAGS.downloaded_dataset_location + "/" + files[i], 86 | "r") as f: 87 | temp = File_Path_Content(files[i], f.read()) 88 | file_content.append(temp) 89 | 90 | dataset = ( 91 | create_raw_programs_dataset.CreateRawProgramsDatasetNonGithub( 92 | FLAGS.source_name, 93 | FLAGS.split_name, 94 | FLAGS.dataset_programming_language, 95 | file_content 96 | )) 97 | 98 | elif(FLAGS.data_source == "github"): 99 | file_paths = [] 100 | with open(FLAGS.file_with_unique_paths, "r") as f: 101 | for line in f: 102 | line = line.strip() 103 | file_paths.append(line) 104 | 105 | file_content = [] 106 | for i in range(len(file_paths)): 107 | with open(FLAGS.downloaded_dataset_location + "/" + file_paths[i], 108 | "r") as f: 109 | temp = File_Path_Content(file_paths[i], f.read()) 110 | file_content.append(temp) 111 | 112 | dataset = create_raw_programs_dataset.CreateRawProgramsDatasetGithub( 113 | FLAGS.repository_name, 114 | FLAGS.dataset_programming_language, 115 | file_content 116 | ) 117 | 118 | with open(FLAGS.save_dataset_location, "wb") as fd: 119 | fd.write(dataset.SerializeToString()) 120 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_ingestion/test_create_raw_programs_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import create_raw_programs_dataset 3 | from collections import namedtuple 4 | 5 | LANGUAGE = 0 6 | CONTENTS = [[""], 7 | ["from __future__ import unicode_literals", 8 | "from ctypes import windll, Structure, c_uint", 9 | "from ctypes.wintypes import HWND, UINT, LPCWSTR, BOOL", "", 10 | "shell32 = windll.shell32", 11 | "SHFileOperationW = shell32.SHFileOperationW", "", "", 12 | "class SHFILEOPSTRUCTW(Structure):", 13 | " _fields_ = [", " ('hwnd, HWND),", 14 | " ('wFunc, UINT),", " ('pFrom, LPCWSTR),", 15 | " ('pTo, LPCWSTR),", " ('fFlags, c_uint),", 16 | " ('fAnyOperationsAborted, BOOL),", 17 | " ('hNameMappings, c_uint),", 18 | " ('lpszProgressTitle, LPCWSTR),", " ]", "", "", 19 | "FO_MOVE = 1", "FO_COPY = 2", "FO_DELETE = 3", "FO_RENAME = 4", 20 | ""]] 21 | 22 | SOURCE_NAME = "Test Dataset" 23 | SPLIT = 3 24 | PATH = ["empty_file.py", "plat_win.py"] 25 | 26 | 27 | class TestCreateRawProgramsDatasetNonGithub(unittest.TestCase): 28 | """ 29 | This tests all functionalities when creating dataset using 30 | Python files taken from sources other than GitHub. 31 | """ 32 | 33 | def setUp(self): 34 | File_Path_Content = namedtuple("File_Path_Content", "path content") 35 | files = ["empty_file.py", "plat_win.py"] 36 | 37 | file_content = [] 38 | for i in range(len(files)): 39 | temp = File_Path_Content(files[i], "\n".join( 40 | k for k in CONTENTS[i])) 41 | file_content.append(temp) 42 | 43 | self.returned_data = ( 44 | create_raw_programs_dataset.CreateRawProgramsDatasetNonGithub( 45 | "Test Dataset", 46 | "VALIDATION", 47 | "Python", 48 | file_content 49 | 50 | )) 51 | 52 | # Test if the programming language information is stored correctly 53 | def test_language(self): 54 | for i in range(len(self.returned_data.raw_program_dataset)): 55 | self.assertEqual( 56 | self.returned_data.raw_program_dataset[i].language, LANGUAGE) 57 | 58 | # Test if the file contents are stored correctly. 59 | def test_file_content(self): 60 | for i in range(len(self.returned_data.raw_program_dataset)): 61 | returned_content = ( 62 | self.returned_data.raw_program_dataset[i].file_content) 63 | actual_content = "\n".join(j for j in CONTENTS[i]) 64 | self.assertEqual(returned_content, actual_content) 65 | 66 | # Test if the file paths are stored correctly. 67 | def test_file_path(self): 68 | for i in range(len(self.returned_data.raw_program_dataset)): 69 | source = ( 70 | self.returned_data.raw_program_dataset[i] 71 | .file_path.dataset_file_path.source_name) 72 | split_name = self.returned_data.raw_program_dataset[i]\ 73 | .file_path.dataset_file_path.split 74 | unique_path = self.returned_data.raw_program_dataset[ 75 | i].file_path.dataset_file_path.unique_file_path 76 | 77 | self.assertEqual(source, SOURCE_NAME) 78 | self.assertEqual(split_name, SPLIT) 79 | self.assertEqual(unique_path, PATH[i]) 80 | 81 | 82 | if __name__ == "__main__": 83 | unittest.main() 84 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/commands.md: -------------------------------------------------------------------------------- 1 | python run_create_perturbed_blocks_labels_dataset.py --serialized_query_path=/home/suryaprakash/data/queryset_serialized --tokenized_block_label_protobuf_file=/home/suryaprakash/data/block_labels/block_eval --save_dataset_location=/home/suryaprakash/data/perturbs/nlalt_block_labels/block_eval --perturbation_type=NL-ALT --query_name_alt_csv=/home/suryaprakash/data/alt_query.csv 2 | 3 | python run_create_block_subtokens_labels.py --model_type=cubert --ordering_of_blocks=line_number --vocab_file=/home/suryaprakash/data/vocab.txt --tokenized_block_label_protobuf_file=/home/suryaprakash/data/perturbs/nlalt_block_labels/block_eval --save_dataset_location=/home/suryaprakash/data/perturbs/lineorderd_block_subtoken/eval_block_subtoken 4 | 5 | python run_create_single_model_baseline_examples.py --model_type=cubert --block_subtoken_label_protobuf_file=/home/suryaprakash/data/perturbs/lineorderd_block_subtoken/eval_block_subtoken --vocab_file=/home/suryaprakash/data/vocab.txt --save_dataset_location=/home/suryaprakash/perturbs/nl_alt_singlemodel_cubert_eval_examples -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/conflictingattributesinbaseclasses.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | import re 3 | 4 | 5 | class ConflictingAttributesInBaseClasses(BaseContexts): 6 | """ 7 | This module extracts module level context and corrsponding metadata from 8 | program content. 9 | Metadata contains start/end line number informations for corresponding context. 10 | """ 11 | def __init__(self, parser): 12 | """ 13 | Args: 14 | parser: Tree sitter parser object 15 | """ 16 | super().__init__(parser) 17 | 18 | 19 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 20 | """ 21 | This functions returns relevant Blocks as query specific context. 22 | Args: 23 | program_content: Program in string format from which we need to 24 | extract classes 25 | parser : tree_sitter_parser 26 | file_path : file of program_content 27 | message : CodeQL message 28 | result_span: CodeQL-treesitter adjusted namedtuple of 29 | (start_line, start_col, end_line, end_col) 30 | aux_result_df: auxiliary query results dataframe 31 | Returns: 32 | A list consisting relevant Blocks 33 | """ 34 | start_line = result_span.start_line 35 | end_line = result_span.end_line 36 | 37 | context_object = ConflictingAttributesInBaseClasses(parser) 38 | 39 | local_block = context_object.get_local_block(program_content, start_line, end_line) 40 | # local_block will be always a Block outside parent_class_blocks 41 | local_block.relevant = True 42 | required_blocks = [local_block] 43 | 44 | _, parent_class_blocks = context_object.get_local_and_super_class(program_content, start_line, end_line) 45 | all_blocks = context_object.get_all_blocks(program_content) 46 | for block in all_blocks: 47 | for p_block in parent_class_blocks: 48 | if (block.start_line >= p_block.start_line 49 | and block.end_line <= p_block.end_line): 50 | required_blocks.append(block) 51 | 52 | # MRO only for conflicting attribute 53 | conflicting_attributes = [] 54 | for msg in message.split('\n'): 55 | attr = re.findall(r"\'(.*)\':", msg) 56 | if attr: 57 | conflicting_attributes.append(attr[0].strip()) 58 | 59 | for class_block in parent_class_blocks: 60 | cls = class_block.metadata.split('.')[-1] 61 | for conflicting_attr in conflicting_attributes: 62 | mro_func_block = context_object.get_mro_function_block(conflicting_attr, 63 | cls, program_content) 64 | if(mro_func_block is not None): 65 | if(mro_func_block in required_blocks): 66 | for block in required_blocks: 67 | if(block == mro_func_block): 68 | block.relevant = True 69 | else: 70 | mro_func_block.relevant = True 71 | required_blocks.append(mro_func_block) 72 | 73 | # mark relevance 74 | lines = program_content.split('\n') 75 | for block in required_blocks: 76 | # function level check 77 | if(block.block_type == CLASS_FUNCTION): 78 | # check if function name is __init__ 79 | func_name = block.metadata.split('.')[-1] 80 | if(func_name == '__init__'): 81 | block.relevant = True 82 | 83 | # check for equal functions in all classes 84 | for o_block in required_blocks: 85 | if(o_block.block_type == CLASS_FUNCTION and block != o_block): 86 | o_func_name = o_block.metadata.split('.')[-1] 87 | if(func_name == o_func_name 88 | and func_name not in ['__init__', 'process_request']): 89 | block.relevant = True 90 | o_block.relevant = True 91 | # class field level checks 92 | else: 93 | # as these are class block only possible bock_types are CLASS_FUNCTION & CLASS_OTHER 94 | # hence check for other lines covers whole class body 95 | for i in block.other_lines: 96 | for conflicting_attr in conflicting_attributes: 97 | if(('self.' + conflicting_attr) in lines[i] 98 | or conflicting_attr in lines[i]): 99 | block.relevant = True 100 | break 101 | if(block.relevant): 102 | break 103 | 104 | return required_blocks 105 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/defineequalswhenaddingattributes.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | 3 | 4 | class DefineEqualsWhenAddingAttributes(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | 36 | context_object = DefineEqualsWhenAddingAttributes(parser) 37 | local_class_block, parent_class_blocks = context_object.get_local_and_super_class(program_content, start_line, end_line) 38 | all_blocks = context_object.get_all_blocks(program_content) 39 | 40 | required_blocks = [] 41 | local_block = context_object.get_local_block(program_content, start_line, end_line) 42 | # sanity check if 43 | if local_class_block is not None: 44 | for block in all_blocks: 45 | # if a block in base class 46 | if (block.start_line >= local_class_block.start_line 47 | and block.end_line <= local_class_block.end_line): 48 | if(block == local_block): 49 | block.relevant = True 50 | required_blocks.append(block) 51 | # if a block in any of the parent class 52 | else: 53 | for p_block in parent_class_blocks: 54 | if(block.start_line >= p_block.start_line 55 | and block.end_line <= p_block.end_line): 56 | required_blocks.append(block) 57 | 58 | local_class = local_class_block.metadata.split('.')[-1] 59 | mro_eq_block = context_object.get_mro_function_block('__eq__', 60 | local_class, 61 | program_content) 62 | 63 | # for __eq__ : with MRO 64 | # by definition of query, there has to be an mro_eq_block, i.e., __eq__() 65 | # from some super class, but not the case bcs of single file restriction 66 | if(mro_eq_block is not None): 67 | if(mro_eq_block in required_blocks): 68 | for block in required_blocks: 69 | if(block == mro_eq_block): 70 | block.relevant = True 71 | else: 72 | mro_eq_block.relevant = True 73 | required_blocks.append(mro_eq_block) 74 | 75 | # for __init__ : __init__ specific resolution - all __init__ in required_blocks are relevant 76 | # super_class = None 77 | # super_class = context_object.check_multiple_inheritance_super(program_content, 78 | # local_class_block) 79 | 80 | # if(super_class is None): 81 | # super_class = local_class_block.metadata.split('.')[-1] 82 | 83 | for block in required_blocks: 84 | if(block.block_type == CLASS_FUNCTION): 85 | func_name = block.metadata.split('.')[-1] 86 | 87 | if(func_name == '__init__'): 88 | block.relevant = True 89 | 90 | return required_blocks 91 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/distributable.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts 2 | 3 | 4 | class Distributable(BaseContexts): 5 | """ 6 | This module extracts Blocks and corrsponding metadata from program content for 7 | distributable queries. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | context_object = Distributable(parser) 36 | 37 | local_block = context_object.get_local_block(program_content, start_line, end_line) 38 | local_block.relevant = True 39 | required_blocks = [local_block] 40 | 41 | return required_blocks 42 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/equalsorhash.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | 3 | 4 | class EqualsOrHash(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | 36 | context_object = EqualsOrHash(parser) 37 | local_class_block = context_object.get_local_class(program_content, start_line, end_line) 38 | 39 | eq_or_hash = set(['__eq__', '__hash__']) 40 | not_implemented = set([message.split()[-1].split('.')[0].strip()]) 41 | implemented = next(iter(eq_or_hash - not_implemented)) 42 | not_implemented = next(iter(not_implemented)) 43 | 44 | all_blocks = context_object.get_all_blocks(program_content) 45 | required_blocks = [] 46 | local_block = context_object.get_local_block(program_content, start_line, end_line) 47 | if local_class_block is not None: 48 | for block in all_blocks: 49 | # if a block in the class 50 | if (block.start_line >= local_class_block.start_line 51 | and block.end_line <= local_class_block.end_line): 52 | if(block == local_block): 53 | block.relevant = True 54 | elif(block.block_type == CLASS_FUNCTION 55 | and block.metadata.split('.')[-1] == implemented): 56 | block.relevant = True 57 | required_blocks.append(block) 58 | 59 | # add `not_implemented` from super class with MRO 60 | current_class = local_class_block.metadata.split('.')[-1] 61 | mro_func_block = context_object.get_mro_function_block(not_implemented, 62 | current_class, 63 | program_content) 64 | if(mro_func_block is not None): 65 | mro_func_block.relevant = True 66 | if(mro_func_block not in required_blocks): 67 | required_blocks.append(mro_func_block) 68 | # else block shouldn't occur, as the class 69 | # hasn't implemented the `not_implemented` 70 | 71 | return required_blocks 72 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/flaskdebug.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts 2 | 3 | 4 | class FlaskDebug(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | self.postordered_nodes = [] 17 | 18 | def post_order_traverse(self, root_node, program_content): 19 | """ 20 | This functions returns postorder traversal of nodes and corresponding 21 | node literals. 22 | Args: 23 | root_node: root node of tree_sitter tree 24 | program_content: Program in string format from which we need to 25 | extract node literal 26 | Returns: 27 | None 28 | """ 29 | if(len(root_node.children) == 0): 30 | literal = bytes(program_content, "utf8")[ 31 | root_node.start_byte:root_node.end_byte 32 | ].decode("utf8") 33 | self.postordered_nodes.append((literal, root_node)) 34 | else: 35 | for ch in root_node.children: 36 | self.post_order_traverse(ch, program_content) 37 | 38 | def get_app_defination_line(self, root_node, program_content): 39 | """ 40 | This functions returns the line where node which contains Flask app defination. 41 | Args: 42 | root_node: root node of tree_sitter tree 43 | program_content: Program in string format from which we need to 44 | extract node literal. 45 | Returns: 46 | tree_sitter node containing start_line to end_line 47 | """ 48 | self.post_order_traverse(root_node, program_content) 49 | 50 | app_index = (-1, -1) 51 | for i, node in enumerate(self.postordered_nodes): 52 | # no need to check IndexError as CodeQL raises this flag 53 | # only if a flask app is running 54 | if(node[0] == 'Flask' 55 | and self.postordered_nodes[i + 1][0] == '(' 56 | and self.postordered_nodes[i + 2][0] == '__name__' 57 | and self.postordered_nodes[i + 3][0] == ')'): 58 | app_index = (node[1].start_point[0], node[1].end_point[0]) 59 | break 60 | 61 | return app_index 62 | 63 | 64 | def get_query_specific_context(program_content, parser, file_pah, message, result_span, aux_result_df=None): 65 | """ 66 | This functions returns relevant Blocks as query specific context. 67 | Args: 68 | program_content: Program in string format from which we need to 69 | extract classes 70 | parser : tree_sitter_parser 71 | file_path : file of program_content 72 | message : CodeQL message 73 | result_span: CodeQL-treesitter adjusted namedtuple of 74 | (start_line, start_col, end_line, end_col) 75 | aux_result_df: auxiliary query results dataframe 76 | Returns: 77 | A list consisting relevant Blocks 78 | """ 79 | start_line = result_span.start_line 80 | end_line = result_span.end_line 81 | 82 | context_object = FlaskDebug(parser) 83 | all_blocks = context_object.get_all_blocks(program_content) 84 | 85 | tree = parser.parse(bytes(program_content, "utf8")) 86 | root_node = tree.root_node 87 | 88 | # app definition indices 89 | required_blocks = [] 90 | local_block = context_object.get_local_block(program_content, start_line, end_line) 91 | 92 | flask_def_start, flask_def_end = context_object.get_app_defination_line(root_node, program_content) 93 | for block in all_blocks: 94 | block_start_line = block.start_line 95 | block_end_line = block.end_line 96 | if(len(block.other_lines) == 0): 97 | contains_def = (flask_def_start >= block_start_line and flask_def_end <= block_end_line) 98 | # start_line and end_line marks use of debug mode 99 | contains_debug = (start_line >= block_start_line and end_line <= block_end_line) 100 | if(contains_def or contains_debug): 101 | block.relevant = True 102 | else: 103 | block_specific_lines = block.other_lines 104 | contains_def = (flask_def_start in block_specific_lines and flask_def_end in block_specific_lines) 105 | contains_debug = (start_line in block_specific_lines and end_line in block_specific_lines) 106 | if(contains_def or contains_debug): 107 | block.relevant = True 108 | 109 | # to avoid adding same Block twice, in case 110 | # app defination and debug mode belongs to one Block 111 | if(block.relevant and (block not in required_blocks)): 112 | required_blocks.append(block) 113 | 114 | # Add local_block, if not present 115 | local_block.relevant = True 116 | if(local_block not in required_blocks): 117 | required_blocks.append(local_block) 118 | 119 | return required_blocks 120 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/get_builtin_stub.py: -------------------------------------------------------------------------------- 1 | # from basecontexts import Block, ROOT_BLOCK_TYPE, STUB 2 | 3 | 4 | call_stub = '''\ 5 | class (object): 6 | def __call__(self): 7 | raise NotImplementedError''' 8 | 9 | eq_stub = '''\ 10 | class (object): 11 | def __eq__(self, value): 12 | return self==value''' 13 | 14 | custom_function_stub = '''\ 15 | class (object): 16 | def (self): 17 | raise NotImplementedError''' 18 | 19 | 20 | def get_class_specific_call_stub(builtin_type): 21 | """ 22 | This function returns specifc builtin class stub. 23 | Args: 24 | builtin_type: builtin object type in python 25 | Returns: 26 | A Block with stub __call__ content 27 | """ 28 | class_specific_stub = call_stub.replace('', builtin_type) 29 | # stub_block = Block(-1, -1, [], class_specific_stub, ROOT_BLOCK_TYPE, STUB) 30 | 31 | return class_specific_stub 32 | 33 | 34 | def get_class_specific_eq_stub(builtin_type): 35 | """ 36 | This function returns specifc builtin class stub. 37 | Args: 38 | builtin_type: builtin object type in python 39 | Returns: 40 | A Block with stub __eq__ content 41 | """ 42 | class_specific_stub = eq_stub.replace('', builtin_type) 43 | # stub_block = Block(-1, -1, [], class_specific_stub, ROOT_BLOCK_TYPE, STUB) 44 | 45 | return class_specific_stub 46 | 47 | 48 | def get_custom_stub(builtin_type, function_name): 49 | """ 50 | This function returns specifc builtin class stub. 51 | Args: 52 | builtin_type: builtin object type in python 53 | function_name: required function name for STUB block 54 | Returns: 55 | A CLASS_FUNCTION Block with function_name and builtin_type 56 | """ 57 | custom_stub = custom_function_stub.replace('', builtin_type) 58 | custom_stub = custom_stub.replace('', function_name) 59 | # stub_block = Block(-1, -1, [], class_specific_stub, ROOT_BLOCK_TYPE, STUB) 60 | 61 | return custom_stub 62 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/get_mro.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from itertools import islice 3 | from typing import List, Dict, Tuple, Optional 4 | 5 | 6 | class LinearizationDeque(deque): 7 | """ 8 | A deque to represent linearization of a class 9 | """ 10 | @property 11 | def head(self) -> Optional[type]: 12 | """ 13 | Return head of deque 14 | """ 15 | try: 16 | return self[0] 17 | except IndexError: 18 | return None 19 | 20 | @property 21 | def tail(self) -> islice: # type: ignore 22 | """ 23 | Return islice object, which is suffice for iteration or calling `in` 24 | """ 25 | try: 26 | return islice(self, 1, self.__len__()) 27 | except (ValueError, IndexError): 28 | return islice([], 0, 0) 29 | 30 | 31 | class LinearizationDequeList: 32 | """ 33 | A class represents list of linearizations (dependencies) 34 | The last element of LinearizationDequeList is a list of parents. 35 | It's needed for the merge process preserves the local 36 | precedence order of direct parent classes. 37 | """ 38 | def __init__(self, *lists: Tuple[List[type]]) -> None: 39 | self._lists = [LinearizationDeque(i) for i in lists] 40 | 41 | def __contains__(self, item: type) -> bool: 42 | """ 43 | Return True if any linearization's tail contains an item 44 | """ 45 | return any([item in dep_list.tail for dep_list in self._lists]) 46 | 47 | def __len__(self): 48 | size = len(self._lists) 49 | return (size - 1) if size else 0 50 | 51 | def __repr__(self): 52 | return self._lists.__repr__() 53 | 54 | @property 55 | def heads(self) -> List[Optional[type]]: 56 | return [h.head for h in self._lists] 57 | 58 | @property 59 | def tails(self) -> 'LinearizationDequeList': 60 | """ 61 | Return self so that __contains__ could be called 62 | Used for readability reasons only 63 | """ 64 | return self 65 | 66 | @property 67 | def exhausted(self) -> bool: 68 | """ 69 | Return True if all elements of the lists are exhausted 70 | """ 71 | return all(map(lambda x: len(x) == 0, self._lists)) 72 | 73 | def remove(self, item: Optional[type]) -> None: 74 | """ 75 | Remove head from all LinearizationDeque 76 | """ 77 | for i in self._lists: 78 | if i and i.head == item: 79 | i.popleft() 80 | 81 | 82 | def _merge(*lists) -> list: 83 | """ 84 | Return self so that __contains__ could be called 85 | Used for readability reasons only 86 | Args: 87 | A list of LinearizationDeque 88 | Returns: 89 | A list of classes in order corresponding to Python's MRO 90 | """ 91 | result: List[Optional[type]] = [] 92 | linearizations = LinearizationDequeList(*lists) 93 | 94 | while True: 95 | if linearizations.exhausted: 96 | return result 97 | 98 | for head in linearizations.heads: 99 | if head and (head not in linearizations.tails): 100 | result.append(head) 101 | linearizations.remove(head) 102 | 103 | # Once candidate added to result, next 104 | # candidate selection iteration starts 105 | break 106 | else: 107 | # Loop never broke, no linearization could possibly be found 108 | raise ValueError('Cannot compute linearization, a cycle found') 109 | 110 | 111 | def mro(bases_dict: Dict, cls: str): 112 | """ 113 | Return mro as per c3 linearization 114 | Args: 115 | bases_dict: dict with classes as key and class.__bases__ as value 116 | cls: target class 117 | Returns: 118 | A list of classes in order corresponding to Python's MRO 119 | """ 120 | all_class_names = bases_dict.keys() 121 | result = [cls] 122 | 123 | if not bases_dict[cls]: 124 | return result 125 | else: 126 | return (result 127 | + _merge(*[mro(bases_dict, kls) for kls in bases_dict[cls] if kls in all_class_names], 128 | bases_dict[cls])) 129 | 130 | # #example 131 | # cc = {'A': [], 'B': ['A'],'L': [], 'C': ['A', 'L'], 132 | # 'D': ['C'],' E': ['B', 'C'], 'F': ['D', 'B']} 133 | 134 | # print(mro(cc, 'F')) 135 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/incompleteordering.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | 3 | 4 | class IncompleteOrdering(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | 36 | context_object = IncompleteOrdering(parser) 37 | local_class_block = context_object.get_local_class(program_content, start_line, end_line) 38 | 39 | all_blocks = context_object.get_all_blocks(program_content) 40 | required_blocks = [] 41 | local_block = context_object.get_local_block(program_content, start_line, end_line) 42 | if local_class_block is not None: 43 | order_func_names = set(['__lt__', '__gt__', '__le__', '__ge__']) 44 | local_order_func_names = set() 45 | for block in all_blocks: 46 | # if a block in base class 47 | if (block.start_line >= local_class_block.start_line 48 | and block.end_line <= local_class_block.end_line): 49 | if(block == local_block): 50 | block.relevant = True 51 | if(block.block_type == CLASS_FUNCTION): 52 | func_name = block.metadata.split('.')[-1] 53 | if(func_name in order_func_names): 54 | local_order_func_names.add(func_name) 55 | required_blocks.append(block) 56 | 57 | # add ordering functions from super class if 58 | # ordering fucntion not in local class 59 | current_class = local_class_block.metadata.split('.')[-1] 60 | remaining_order_func_names = order_func_names - local_order_func_names 61 | if(remaining_order_func_names): 62 | for func in sorted(remaining_order_func_names): 63 | mro_func_block = context_object.get_mro_function_block(func, 64 | current_class, 65 | program_content) 66 | if(mro_func_block is not None 67 | and mro_func_block not in required_blocks): 68 | required_blocks.append(mro_func_block) 69 | 70 | # mark relevance 71 | for block in required_blocks: 72 | if(block.block_type == CLASS_FUNCTION): 73 | func_name = block.metadata.split('.')[-1] 74 | if(func_name in order_func_names): 75 | block.relevant = True 76 | 77 | return required_blocks 78 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/initcallssubclassmethod.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | 3 | 4 | class InitCallsSubclassMethod(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | 36 | context_object = InitCallsSubclassMethod(parser) 37 | local_class_block, child_class_blocks = context_object.get_local_and_child_class(program_content, start_line, end_line) 38 | all_blocks = context_object.get_all_blocks(program_content) 39 | 40 | required_blocks = [] 41 | local_block = context_object.get_local_block(program_content, start_line, end_line) 42 | 43 | # for intersecting function check 44 | parent_class_functions = set() 45 | child_class_functions = dict() # dictionary of set 46 | 47 | required_parent_class_blocks = [] 48 | required_child_class_blocks = [] 49 | if local_class_block is not None: 50 | for block in all_blocks: 51 | # if a block in local class 52 | if (block.start_line >= local_class_block.start_line 53 | and block.end_line <= local_class_block.end_line): 54 | if(block == local_block): 55 | block.relevant = True 56 | required_parent_class_blocks.append(block) 57 | # for intersecting function check 58 | if(block.block_type == CLASS_FUNCTION): 59 | parent_class_functions.add(block.metadata.split('.')[-1]) 60 | # if a block in any of the child class 61 | else: 62 | for p_block in child_class_blocks: 63 | class_name = p_block.metadata.split('.')[-1] 64 | if (block.start_line >= p_block.start_line 65 | and block.end_line <= p_block.end_line): 66 | required_child_class_blocks.append(block) 67 | # for intersecting function check 68 | if(block.block_type == CLASS_FUNCTION): 69 | func_name = block.metadata.split('.')[-1] 70 | if(class_name in child_class_functions): 71 | child_class_functions[class_name].add(func_name) 72 | else: 73 | child_class_functions[class_name] = set([func_name]) 74 | 75 | # get intersecting functions 76 | intersecting_functions = dict() 77 | if(parent_class_functions): 78 | for cls, func_set in child_class_functions.items(): 79 | intersecting_functions[cls] = parent_class_functions.intersection(func_set) 80 | 81 | # mark relevance 82 | init_func = '__init__' 83 | for block in required_child_class_blocks: 84 | if(block.block_type == CLASS_FUNCTION): 85 | class_name = block.metadata.split('.')[-2] 86 | func_name = block.metadata.split('.')[-1] 87 | # check if function is not __init__ 88 | # and in intersecting functions 89 | if(func_name != init_func 90 | # for classes without functions/ intersecting functions 91 | and class_name in intersecting_functions.keys() 92 | and func_name in intersecting_functions[class_name]): 93 | block.relevant = True 94 | 95 | for block in required_parent_class_blocks: 96 | if(block.block_type == CLASS_FUNCTION): 97 | class_name = block.metadata.split('.')[-2] 98 | func_name = block.metadata.split('.')[-1] 99 | # check if function is __init__ 100 | # or in intersecting functions 101 | if(func_name == init_func): 102 | block.relevant = True 103 | break 104 | 105 | required_blocks.extend(required_child_class_blocks) 106 | required_blocks.extend(required_parent_class_blocks) 107 | 108 | return required_blocks 109 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/iterreturnsnoniterator.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | 3 | 4 | class IterReturnsNonIterator(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | 36 | context_object = IterReturnsNonIterator(parser) 37 | 38 | all_blocks = context_object.get_all_blocks(program_content) 39 | local_class_block = context_object.get_local_class(program_content, start_line, end_line) 40 | local_block = context_object.get_local_block(program_content, start_line, end_line) 41 | 42 | required_blocks = [] 43 | assert message.split()[0].strip() == 'Class' 44 | target_class = message.split()[1].strip() 45 | if local_class_block is not None: 46 | for block in all_blocks: 47 | # if a block in local class 48 | if (block.start_line >= local_class_block.start_line 49 | and block.end_line <= local_class_block.end_line): 50 | # local block is CLASS_OTHER 51 | if(block == local_block): 52 | block.relevant = True 53 | required_blocks.append(block) 54 | elif(block.block_type == CLASS_FUNCTION 55 | and block.metadata.split('.')[-1] == '__iter__'): 56 | block.relevant = True 57 | required_blocks.append(block) 58 | # if other class __iter__ returns target_class iterator object 59 | elif(block.block_type == CLASS_FUNCTION 60 | and block.metadata.split('.')[-1] == '__iter__' 61 | and ('return ' + target_class) in (' '.join(block.content.split()))): 62 | block.relevant = True 63 | required_blocks.append(block) 64 | 65 | return required_blocks 66 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/missingcalltoinit.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | 3 | 4 | class MissingCallToInit(BaseContexts): 5 | """ 6 | This module extracts module level context and corrsponding metadata from 7 | program content. 8 | Metadata contains start/end line number informations for corresponding context. 9 | """ 10 | def __init__(self, parser): 11 | """ 12 | Args: 13 | parser: Tree sitter parser object 14 | """ 15 | super().__init__(parser) 16 | 17 | 18 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | program_content: Program in string format from which we need to 23 | extract classes 24 | parser : tree_sitter_parser 25 | file_path : file of program_content 26 | message : CodeQL message 27 | result_span: CodeQL-treesitter adjusted namedtuple of 28 | (start_line, start_col, end_line, end_col) 29 | aux_result_df: auxiliary query results dataframe 30 | Returns: 31 | A list consisting relevant Blocks 32 | """ 33 | start_line = result_span.start_line 34 | end_line = result_span.end_line 35 | 36 | context_object = MissingCallToInit(parser) 37 | local_class_block, parent_class_blocks = context_object.get_local_and_super_class(program_content, start_line, end_line) 38 | all_blocks = context_object.get_all_blocks(program_content) 39 | local_block = context_object.get_local_block(program_content, start_line, end_line) 40 | 41 | required_blocks = [] 42 | # sanity check if 43 | if local_class_block is not None: 44 | for block in all_blocks: 45 | # if a block in base class 46 | if (block.start_line >= local_class_block.start_line 47 | and block.end_line <= local_class_block.end_line): 48 | # local block is CLASS_OTHER 49 | if(block == local_block): 50 | block.relevant = True 51 | required_blocks.append(block) 52 | elif(block.block_type == CLASS_FUNCTION 53 | and block.metadata.split('.')[-1] == '__init__'): 54 | block.relevant = True 55 | required_blocks.append(block) 56 | # if a block in any of the parent class 57 | else: 58 | for p_block in parent_class_blocks: 59 | if(block.start_line >= p_block.start_line 60 | and block.end_line <= p_block.end_line): 61 | if(block.block_type == CLASS_FUNCTION 62 | and block.metadata.split('.')[-1] == '__init__'): 63 | block.relevant = True 64 | required_blocks.append(block) 65 | 66 | return required_blocks 67 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/my-languages.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/CodeQueries_preparation/data_preparation/contexts/my-languages.so -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/noncallablecalled.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts 2 | from basecontexts import CLASS_FUNCTION, MODULE_OTHER, STUB, INBUILT_TYPES 3 | from get_builtin_stub import get_class_specific_call_stub 4 | 5 | INBUILT = 'builtin-class' 6 | INCODE = 'class' 7 | IMPORT = 'module-import' 8 | 9 | 10 | class NonCallableCalled(BaseContexts): 11 | """ 12 | This module extracts module level context and corrsponding metadata from 13 | program content. 14 | Metadata contains start/end line number informations for corresponding context. 15 | """ 16 | def __init__(self, parser): 17 | """ 18 | Args: 19 | parser: Tree sitter parser object 20 | """ 21 | super().__init__(parser) 22 | 23 | def get_target_class(self, message): 24 | """ 25 | This functions returns the class type which gave rise to the 26 | CodeQL flag. 27 | Args: 28 | message : CodeQL message 29 | Returns: 30 | class_type : one of INBUILT, INCODE, IMPORT 31 | class_name : class name 32 | """ 33 | class_type = message.split(':::')[0] 34 | if(class_type.startswith('class')): 35 | return INCODE, class_type.split()[-1].strip() 36 | elif(class_type.startswith('builtin-class') 37 | and class_type.split()[-1] in INBUILT_TYPES): 38 | return INBUILT, class_type.split()[-1].strip() 39 | else: 40 | return IMPORT, class_type.split()[-1].strip() 41 | 42 | 43 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 44 | """ 45 | This functions returns the class name which contains start_line to end_line. 46 | Args: 47 | program_content: Program in string format from which we need to 48 | extract classes 49 | parser : tree_sitter_parser 50 | file_path : file of program_content 51 | message : CodeQL message 52 | result_span: CodeQL-treesitter adjusted namedtuple of 53 | (start_line, start_col, end_line, end_col) 54 | aux_result_df: auxiliary query results dataframe 55 | Returns: 56 | A list consisting relevant Blocks 57 | """ 58 | start_line = result_span.start_line 59 | end_line = result_span.end_line 60 | 61 | context_object = NonCallableCalled(parser) 62 | 63 | local_block = context_object.get_local_block(program_content, start_line, end_line) 64 | required_blocks = [local_block] 65 | 66 | ''' 67 | Three class_types 68 | 1.user-defined class 69 | class DiffieHellman 70 | 2.builtin-class 71 | (list, dict, set, frozenset, tuple, bool, int, float, complex, str, NoneType/None) 72 | memoryview, bytearray, bytes 73 | 3.module imports 74 | built-in method: /babble/babble/include/jython/Lib/test/test_func_jy.py 75 | builtin-class _io.TextIOWrapper 76 | builtin-class module 77 | ''' 78 | class_type, class_name = context_object.get_target_class(message) 79 | if(class_type == INBUILT): 80 | class_specific_stub = get_class_specific_call_stub(class_name) 81 | stub_blocks = context_object.get_all_blocks(class_specific_stub, STUB) 82 | for block in stub_blocks: 83 | if(block.block_type == CLASS_FUNCTION 84 | and block.metadata.split('.')[-1] == '__call__'): 85 | block.relevant = True 86 | block.start_line = -1 87 | block.end_line = -1 88 | block.block_type = STUB 89 | # when deciding for STUB all blocks will 90 | # be in intermediate context 91 | required_blocks.append(block) 92 | else: 93 | all_blocks = context_object.get_all_blocks(program_content) 94 | if(class_type == INCODE): 95 | all_classes = context_object.get_all_classes(program_content) 96 | cls_start = -1 97 | cls_end = -1 98 | for cls in all_classes: 99 | if(cls.metadata.split('.')[-1] == class_name): 100 | cls_start = cls.start_line 101 | cls_end = cls.end_line 102 | break 103 | 104 | # if the class is inner class 105 | if(cls_start == -1 and cls_end == -1): 106 | for block in all_blocks: 107 | # check if class is defined in some other type block 108 | class_definition = class_type + ' ' + class_name 109 | if(class_definition in block.content): 110 | if(block in required_blocks): 111 | for added_block in required_blocks: 112 | if(added_block == block): 113 | added_block.relevant = True 114 | else: 115 | block.relevant = True 116 | required_blocks.append(block) 117 | # if the class is module level class 118 | else: 119 | mro_call_block = context_object.get_mro_function_block('__call__', 120 | class_name, 121 | program_content) 122 | 123 | for block in all_blocks: 124 | # as target blocks are inside some class, 125 | # no preprocessing with Block.other_lines is required 126 | if(block.start_line >= cls_start 127 | and block.end_line <= cls_end 128 | and block not in required_blocks): 129 | if(block.block_type == CLASS_FUNCTION 130 | and block == mro_call_block): 131 | block.relevant = True 132 | required_blocks.append(block) 133 | 134 | # if `__call__` is inherited 135 | if(mro_call_block is not None): 136 | mro_call_block.relevant = True # bcs of L129 137 | if(mro_call_block not in required_blocks): 138 | required_blocks.append(mro_call_block) 139 | elif(class_type == IMPORT): 140 | for block in all_blocks: 141 | # check if the single module_other is 142 | # also local_block and add to reuired_blocks 143 | if(block not in required_blocks 144 | and block.block_type == MODULE_OTHER): 145 | block.relevant = True 146 | required_blocks.append(block) 147 | 148 | # update local_block relevance 149 | # this updation done at the end for correct check of line 121, 147 and 153 150 | local_block.relevant = True 151 | 152 | return required_blocks 153 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/signatureoverriddenmethod.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION 2 | import math 3 | 4 | 5 | class SignatureOverriddenMethod(BaseContexts): 6 | """ 7 | This module extracts module level context and corrsponding metadata from 8 | program content. 9 | Metadata contains start/end line number informations for corresponding context. 10 | """ 11 | def __init__(self, parser): 12 | """ 13 | Args: 14 | parser: Tree sitter parser object 15 | """ 16 | super().__init__(parser) 17 | 18 | def get_mro_index(self, child_class_mro, class_name): 19 | """ 20 | This function checks if specific parent CLASS FUNCTION Block 21 | should be marked relevant 22 | Args: 23 | child_class_mro : mro of flagged class 24 | class_name : target class 25 | Returns: 26 | Index of class_name in child_class_mro 27 | """ 28 | class_mro_index = -1 29 | for i, cls in enumerate(child_class_mro): 30 | if(cls == class_name): 31 | class_mro_index = i 32 | return class_mro_index 33 | 34 | def is_parent_block_relevant(self, program_content, child_class, target_func_name, target_func_class, parent_class_blocks): 35 | """ 36 | This function checks if specific parent CLASS FUNCTION Block 37 | should be marked relevant 38 | Args: 39 | program_content: Program in string format from which we need to 40 | extract classes 41 | target_func_name : CLASS_FUNCTION Block function name 42 | target_class_name : CLASS_FUNCTION Block class name 43 | child_class : flagged class name 44 | parent_class_blocks : potential relevant parent class Blocks 45 | Returns: 46 | A Boolean value denoting whether to add or not 47 | """ 48 | child_class_mro = self.get_class_MRO(program_content, child_class) 49 | 50 | func_class_index = self.get_mro_index(child_class_mro, target_func_class) 51 | if(func_class_index == -1): 52 | return False 53 | 54 | relevant_class_index = math.inf 55 | mro_func_class = None 56 | for block in parent_class_blocks: 57 | if(block.block_type == CLASS_FUNCTION): 58 | func_name = block.metadata.split('.')[-1] 59 | class_name = block.metadata.split('.')[-2] 60 | class_index = self.get_mro_index(child_class_mro, class_name) 61 | if(func_name == target_func_name 62 | and (class_index != -1) 63 | and (class_index < relevant_class_index)): 64 | relevant_class_index = class_index 65 | mro_func_class = class_name 66 | return (mro_func_class == target_func_class) 67 | 68 | 69 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 70 | """ 71 | This functions returns relevant Blocks as query specific context. 72 | Args: 73 | program_content: Program in string format from which we need to 74 | extract classes 75 | parser : tree_sitter_parser 76 | file_path : file of program_content 77 | message : CodeQL message 78 | result_span: CodeQL-treesitter adjusted namedtuple of 79 | (start_line, start_col, end_line, end_col) 80 | aux_result_df: auxiliary query results dataframe 81 | Returns: 82 | A list consisting relevant Blocks 83 | """ 84 | start_line = result_span.start_line 85 | end_line = result_span.end_line 86 | 87 | context_object = SignatureOverriddenMethod(parser) 88 | local_class_block, parent_class_blocks = context_object.get_local_and_super_class(program_content, start_line, end_line) 89 | all_blocks = context_object.get_all_blocks(program_content) 90 | 91 | required_blocks = [] 92 | local_block = context_object.get_local_block(program_content, start_line, end_line) 93 | 94 | # for intersecting function check 95 | child_class_functions = set() 96 | parent_class_functions = dict() # dictionary of set 97 | 98 | required_child_class_blocks = [] 99 | required_parent_class_blocks = [] 100 | if local_class_block is not None: 101 | for block in all_blocks: 102 | # if a block in local class 103 | if (block.start_line >= local_class_block.start_line 104 | and block.end_line <= local_class_block.end_line): 105 | if(block == local_block): 106 | block.relevant = True 107 | required_child_class_blocks.append(block) 108 | # for intersecting function check 109 | if(block.block_type == CLASS_FUNCTION): 110 | child_class_functions.add(block.metadata.split('.')[-1]) 111 | # if a block in any of the parent class 112 | else: 113 | for p_block in parent_class_blocks: 114 | class_name = p_block.metadata.split('.')[-1] 115 | if (block.start_line >= p_block.start_line 116 | and block.end_line <= p_block.end_line): 117 | required_parent_class_blocks.append(block) 118 | # for intersecting function check 119 | if(block.block_type == CLASS_FUNCTION): 120 | func_name = block.metadata.split('.')[-1] 121 | if(class_name in parent_class_functions): 122 | parent_class_functions[class_name].add(func_name) 123 | else: 124 | parent_class_functions[class_name] = set([func_name]) 125 | 126 | # functions in parent blocks that would require MRO 127 | mro_functions = dict() 128 | if(child_class_functions): 129 | for cls, func_set in parent_class_functions.items(): 130 | mro_functions[cls] = child_class_functions - func_set 131 | 132 | for cls, func_set in mro_functions.items(): 133 | for func in func_set: 134 | mro_func_block = context_object.get_mro_function_block(func, cls, program_content) 135 | if(mro_func_block is not None 136 | and mro_func_block not in required_parent_class_blocks): 137 | # Blocks retrieved with MRO are not relevant 138 | # because of declaredAttribute 139 | required_parent_class_blocks.append(mro_func_block) 140 | 141 | # mark relevance 142 | for block in required_child_class_blocks: 143 | for p_block in required_parent_class_blocks: 144 | if(block.block_type == CLASS_FUNCTION and p_block.block_type == CLASS_FUNCTION): 145 | overriding_func_name = block.metadata.split('.')[-1] 146 | child_class = block.metadata.split('.')[-2] 147 | 148 | overridden_func_name = p_block.metadata.split('.')[-1] 149 | overridden_func_class = p_block.metadata.split('.')[-2] 150 | if(overriding_func_name == overridden_func_name 151 | and context_object.is_parent_block_relevant(program_content, child_class, 152 | overridden_func_name, overridden_func_class, 153 | required_parent_class_blocks)): 154 | block.relevant = True 155 | p_block.relevant = True 156 | 157 | required_blocks.extend(required_child_class_blocks) 158 | required_blocks.extend(required_parent_class_blocks) 159 | 160 | return required_blocks 161 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_data/test__aux_res.csv: -------------------------------------------------------------------------------- 1 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Employee.__init__""|""relative:///test_unused_imp.py:5:5:5:33""]]' is called","/test_unused_imp.py","11","9","11","34" 2 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""SalaryEmployee.calculate_payroll""|""relative:///test_unused_imp.py:15:5:15:32""]]' is called","/test_unused_imp.py","13","22","13","45" 3 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Employee.__init__""|""relative:///test_unused_imp.py:5:5:5:33""]]' is called","/test_unused_imp.py","20","9","20","34" 4 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""HourlyEmployee.calculate_payroll""|""relative:///test_unused_imp.py:25:5:25:32""]]' is called","/test_unused_imp.py","23","22","23","45" 5 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Employee.__init__""|""relative:///test_unused_imp.py:5:5:5:33""]]' is called","/test_unused_imp.py","30","9","30","34" 6 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""HourlyEmployee.__init__""|""relative:///test_unused_imp.py:19:5:19:58""]]' is called","/test_unused_imp.py","41","9","41","71" 7 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Intern.get_bonus""|""relative:///test_unused_imp.py:44:5:44:31""]]' is called","/test_unused_imp.py","55","22","55","54" 8 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""get_festive_bonus""|""relative:///test_unused_imp.py:61:1:61:34""]]' is called","/test_unused_imp.py","57","39","57","57" 9 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Intern.get_work_desc""|""relative:///test_unused_imp.py:50:5:50:48""]]' is called","/test_unused_imp.py","58","16","58","56" 10 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Intern.__init__""|""relative:///test_unused_imp.py:40:5:40:58""]]' is called","/test_unused_imp.py","64","10","64","34" 11 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""Intern.work""|""relative:///test_unused_imp.py:53:5:53:19""]]' is called","/test_unused_imp.py","65","1","65","13" 12 | "Function call map","Function call map for all function, `@classmethod` and `__init__`","warning","'[[""get_festive_bonus""|""relative:///test_unused_imp.py:61:1:61:34""]]' is called","/test_unused_imp.py","66","7","66","31" 13 | "Used function call map","Function call map for only used function and without `__init__`","warning","'[[""SalaryEmployee.calculate_payroll""|""relative:///test_unused_imp.py:15:5:15:32""]]' is called","/test_unused_imp.py","13","22","13","45" 14 | "Used function call map","Function call map for only used function and without `__init__`","warning","'[[""HourlyEmployee.calculate_payroll""|""relative:///test_unused_imp.py:25:5:25:32""]]' is called","/test_unused_imp.py","23","22","23","45" 15 | "Used function call map","Function call map for only used function and without `__init__`","warning","'[[""Intern.get_bonus""|""relative:///test_unused_imp.py:44:5:44:31""]]' is called","/test_unused_imp.py","55","22","55","54" 16 | "Used function call map","Function call map for only used function and without `__init__`","warning","'[[""Intern.get_work_desc""|""relative:///test_unused_imp.py:50:5:50:48""]]' is called","/test_unused_imp.py","58","16","58","56" 17 | "Used function call map","Function call map for only used function and without `__init__`","warning","'[[""get_festive_bonus""|""relative:///test_unused_imp.py:61:1:61:34""]]' is called","/test_unused_imp.py","66","7","66","31" 18 | "Used import","Expressions where import functionality used","warning","Global Variable np ::: is used.","/test_unused_imp.py","32","26","32","27" 19 | "Used import","Expressions where import functionality used","warning","Global Variable np ::: is used.","/test_unused_imp.py","34","26","34","27" 20 | "Used import","Expressions where import functionality used","warning","Global Variable np ::: is used.","/test_unused_imp.py","54","20","54","21" 21 | "Used import","Expressions where import functionality used","warning","Global Variable np ::: is used.","/test_unused_imp.py","54","42","54","43" 22 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_distributable.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, CLASS_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # train_data: mutated- /peterhudec/authomatic/tests/functional_tests/expected_values/tumblr.py 21 | # SOURCE_CODE = [ 22 | # '''class Person: 23 | # "This is a person class" 24 | # age = 10 25 | 26 | # def greet(self): 27 | # print('Hello') 28 | 29 | # birth_year = 2000 30 | # current_year = birth_year + age 31 | 32 | # print(Person.age) 33 | 34 | # def get_age(age): 35 | # print (age) 36 | 37 | # # create a new object of Person class 38 | # harry = Person() 39 | # # Calling object's greet() method 40 | # harry.greet()'''] 41 | 42 | SOURCE_CODE = ['''class Person:\n "This is a person class"\n age = 10\n\n def greet(self):\n print('Hello')\n\n birth_year = 2000\n current_year = birth_year + age\n\nprint(Person.age)\n\ndef get_age(age):\n print (age)\n\n# create a new object of Person class\nharry = Person()\n# Calling object's greet() method\nharry.greet()'''] 43 | SPAN = [Span(8, -1, 8, -1)] 44 | 45 | 46 | class TestDistributableQueryContext(unittest.TestCase): 47 | desired_block = [[Block(0, 48 | 8, 49 | [0, 1, 2, 3, 6, 7, 8], 50 | '''class Person:\n "This is a person class"\n age = 10\n\n\n birth_year = 2000\n current_year = birth_year + age''', 51 | 'root.Person', 52 | CLASS_OTHER, 53 | True, 54 | 'module', 55 | ('__', '__class__'))]] 56 | 57 | def test_relevant_block(self): 58 | for i, code in enumerate(SOURCE_CODE): 59 | span = SPAN[i] 60 | generated_block = get_span_context('Redundant assignment', 61 | code, tree_sitter_parser, '', 62 | '', span, None) 63 | 64 | for j, gen_block in enumerate(generated_block): 65 | self.assertEqual(self.desired_block[j][i], gen_block) 66 | 67 | 68 | if __name__ == "__main__": 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_flaskdebug.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import MODULE_FUNCTION, Block, MODULE_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # train_data: /peterhudec/authomatic/tests/functional_tests/expected_values/tumblr.py 21 | # '''from __future__ import absolute_import 22 | # from flask import Flask, request 23 | 24 | # app = Flask(__name__) 25 | 26 | # def setup(): 27 | # with open(SECRET_FILE, 'w') as f: 28 | # f.write(''.join(utils.random_string(size=42))) 29 | 30 | # setup() 31 | # with open(SECRET_FILE) as f: 32 | # app.secret_key = f.read() 33 | 34 | # @app.route('/'+app.secret_key, methods=['POST']) 35 | # def main(): 36 | # return "OK" 37 | 38 | # def start(app): 39 | # app.run('0.0.0.0', port=5678, debug=True) 40 | 41 | # if __name__ == "__main__": 42 | # start(app)''' 43 | 44 | SOURCE_CODE = ['''from __future__ import absolute_import\nfrom flask import Flask, request\n\napp = Flask(__name__)\n\ndef setup():\n with open(SECRET_FILE, \'w\') as f:\n f.write(\'\'.join(utils.random_string(size=42)))\n\nsetup()\nwith open(SECRET_FILE) as f:\n app.secret_key = f.read()\n\n@app.route(\'/\'+app.secret_key, methods=[\'POST\'])\ndef main():\n return "OK"\n\ndef start(app):\n app.run(\'0.0.0.0\', port=5678, debug=True)\n\nif __name__ == "__main__":\n start(app)'''] 45 | SPANS = [Span(18, 4, 18, 65)] 46 | 47 | 48 | class TestDistributableQueryContext(unittest.TestCase): 49 | desired_block = [[Block(17, 50 | 18, 51 | [], 52 | '''def start(app):\n app.run('0.0.0.0', port=5678, debug=True)''', 53 | 'root.start', 54 | MODULE_FUNCTION, 55 | True, 56 | 'module', 57 | ('__', '__class__')), 58 | Block(0, 59 | 21, 60 | [0, 1, 2, 3, 4, 8, 9, 10, 11, 12, 16, 19, 20, 21], 61 | '''from __future__ import absolute_import\nfrom flask import Flask, request\n\napp = Flask(__name__)\n\n\nsetup()\nwith open(SECRET_FILE) as f:\n app.secret_key = f.read()\n\n\n\nif __name__ == "__main__":\n start(app)''', 62 | 'root', 63 | MODULE_OTHER, 64 | True, 65 | 'module', 66 | ('__', '__class__'))]] 67 | 68 | def test_relevant_block(self): 69 | for i, code in enumerate(SOURCE_CODE): 70 | span = SPANS[i] 71 | generated_block = get_span_context('Flask app is run in debug mode', 72 | code, tree_sitter_parser, '', 73 | '', span, None) 74 | 75 | for j, gen_block in enumerate(generated_block): 76 | self.assertEqual(self.desired_block[i][j], gen_block) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_incompleteordering.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, CLASS_FUNCTION, CLASS_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | tree_sitter_parser = Parser() 19 | tree_sitter_parser.set_language(PY_LANGUAGE) 20 | 21 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 22 | 23 | # train_data: /clips/pattern/pattern/server/cherrypy/cherrypy/lib/httputil.py 24 | # import re 25 | # import urllib 26 | 27 | 28 | # def protocol_from_http(protocol_str): 29 | # return int(protocol_str[5]), int(protocol_str[7]) 30 | 31 | # class HeaderElement(object): 32 | # def __init__(self, value, params=None): 33 | # self.value = value 34 | # if params is None: 35 | # params = {} 36 | # self.params = params 37 | 38 | # def __cmp__(self, other): 39 | # return cmp(self.value, other.value) 40 | 41 | # def __lt__(self, other): 42 | # return self.value < other.value 43 | 44 | # def parse(elementstr): 45 | # atoms = [x.strip() for x in elementstr.split(";") if x.strip()] 46 | # if not atoms: 47 | # initial_value = '' 48 | # else: 49 | # initial_value = atoms.pop(0).strip() 50 | # return initial_value, params 51 | # parse = staticmethod(parse) 52 | 53 | SOURCE_CODE = ['''import os\n\ncurr_dir = os.cwd()\n\nclass Element(object):\n """\n Base Element class\n """\n def __init__(self, value):\n self.value = value\n\n def __le__(self, other):\n return self.value <= other.value\n\nclass HeaderElement(HHElement):\n def __init__(self, value, params=None):\n super().__init__(value)\n self.value = value\n if params is None:\n params = {}\n self.params = params\n\n def __cmp__(self, other):\n return cmp(self.value, other.value)\n\n def __lt__(self, other):\n return self.value < other.value\n\n def parse(elementstr):\n atoms = [x.strip() for x in elementstr.split(";") if x.strip()]\n if not atoms:\n initial_value = \'\'\n else:\n initial_value = atoms.pop(0).strip()\n return initial_value, params\n parse = staticmethod(parse)\n\nclass HHElement(Element):\n """\n Base Element class\n """\n def __init__(self, value):\n self.value = value\n\n def __ge__(self, other):\n return self.value <= other.value'''] 54 | SPAN = [Span(14, 0, 14, 31)] 55 | 56 | 57 | class TestDistributableQueryContext(unittest.TestCase): 58 | desired_block = [[Block(15, 59 | 20, 60 | [], 61 | '''def __init__(self, value, params=None):\n super().__init__(value)\n self.value = value\n if params is None:\n params = {}\n self.params = params''', 62 | 'root.HeaderElement.__init__', 63 | CLASS_FUNCTION, 64 | False, 65 | 'class HeaderElement(HHElement):', 66 | ('__', '__class__', 'HeaderElement', 'HHElement', 'Element')), 67 | Block(22, 68 | 23, 69 | [], 70 | '''def __cmp__(self, other):\n return cmp(self.value, other.value)''', 71 | 'root.HeaderElement.__cmp__', 72 | CLASS_FUNCTION, 73 | False, 74 | 'class HeaderElement(HHElement):', 75 | ('__', '__class__', 'HeaderElement', 'HHElement', 'Element')), 76 | Block(25, 77 | 26, 78 | [], 79 | '''def __lt__(self, other):\n return self.value < other.value''', 80 | 'root.HeaderElement.__lt__', 81 | CLASS_FUNCTION, 82 | True, 83 | 'class HeaderElement(HHElement):', 84 | ('__', '__class__', 'HeaderElement', 'HHElement', 'Element')), 85 | Block(28, 86 | 34, 87 | [], 88 | '''def parse(elementstr):\n atoms = [x.strip() for x in elementstr.split(";") if x.strip()]\n if not atoms:\n initial_value = \'\'\n else:\n initial_value = atoms.pop(0).strip()\n return initial_value, params''', 89 | 'root.HeaderElement.parse', 90 | CLASS_FUNCTION, 91 | False, 92 | 'class HeaderElement(HHElement):', 93 | ('__', '__class__', 'HeaderElement', 'HHElement', 'Element')), 94 | Block(14, 95 | 35, 96 | [14, 21, 24, 27, 35], 97 | '''class HeaderElement(HHElement):\n\n\n\n parse = staticmethod(parse)''', 98 | 'root.HeaderElement', 99 | CLASS_OTHER, 100 | True, 101 | 'module', 102 | ('__', '__class__', 'HHElement')), 103 | Block(44, 104 | 45, 105 | [], 106 | '''def __ge__(self, other):\n return self.value <= other.value''', 107 | 'root.HHElement.__ge__', 108 | CLASS_FUNCTION, 109 | True, 110 | 'class HHElement(Element):', 111 | ('__', '__class__', 'HHElement', 'Element')), 112 | Block(11, 113 | 12, 114 | [], 115 | '''def __le__(self, other):\n return self.value <= other.value''', 116 | 'root.Element.__le__', 117 | CLASS_FUNCTION, 118 | True, 119 | 'class Element(object):', 120 | ('__', '__class__', 'Element'))]] 121 | 122 | def test_relevant_block(self): 123 | for i, code in enumerate(SOURCE_CODE): 124 | span = SPAN[i] 125 | generated_block = get_span_context('Incomplete ordering', 126 | code, tree_sitter_parser, '', 127 | '', span, None) 128 | 129 | for j, gen_block in enumerate(generated_block): 130 | self.assertEqual(self.desired_block[i][j], gen_block) 131 | 132 | 133 | if __name__ == "__main__": 134 | unittest.main() 135 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_initcallssubclassmethod.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, CLASS_FUNCTION, CLASS_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # class Super(object): 21 | # def __init__(self, arg): 22 | # self._state = "Not OK" 23 | # self.set_up(arg) 24 | # self._state = "OK" 25 | 26 | # def set_up(self, arg): 27 | # "Do some set up" 28 | 29 | # class Sub(Super): 30 | # def __init__(self, arg): 31 | # Super.__init__(self, arg) 32 | # self.important_state = "OK" 33 | 34 | # def set_up(self, arg): 35 | # Super.set_up(self, arg) 36 | 37 | SOURCE_CODE = ['''class Super(object):\n def __init__(self, arg):\n self._state = "Not OK"\n self.set_up(arg)\n self._state = "OK"\n\n def set_up(self, arg):\n "Do some set up"\n\nclass Sub(Super):\n def __init__(self, arg):\n Super.__init__(self, arg)\n self.important_state = "OK"\n\n def set_up(self, arg):\n Super.set_up(self, arg)'''] 38 | SPAN = [Span(3, 8, 3, 24)] 39 | 40 | 41 | class TestDistributableQueryContext(unittest.TestCase): 42 | desired_block = [[Block(10, 43 | 12, 44 | [], 45 | '''def __init__(self, arg):\n Super.__init__(self, arg)\n self.important_state = "OK"''', 46 | 'root.Sub.__init__', 47 | CLASS_FUNCTION, 48 | False, 49 | 'class Sub(Super):', 50 | ('__', '__class__', 'Sub', 'Super')), 51 | Block(14, 52 | 15, 53 | [], 54 | '''def set_up(self, arg):\n Super.set_up(self, arg)''', 55 | 'root.Sub.set_up', 56 | CLASS_FUNCTION, 57 | True, 58 | 'class Sub(Super):', 59 | ('__', '__class__', 'Sub', 'Super')), 60 | Block(9, 61 | 15, 62 | [9, 13], 63 | '''class Sub(Super):\n''', 64 | 'root.Sub', 65 | CLASS_OTHER, 66 | False, 67 | 'module', 68 | ('__', '__class__', 'Super')), 69 | Block(1, 70 | 4, 71 | [], 72 | '''def __init__(self, arg):\n self._state = "Not OK"\n self.set_up(arg)\n self._state = "OK"''', 73 | 'root.Super.__init__', 74 | CLASS_FUNCTION, 75 | True, 76 | 'class Super(object):', 77 | ('__', '__class__', 'Super')), 78 | Block(6, 79 | 7, 80 | [], 81 | '''def set_up(self, arg):\n "Do some set up"''', 82 | 'root.Super.set_up', 83 | CLASS_FUNCTION, 84 | False, 85 | 'class Super(object):', 86 | ('__', '__class__', 'Super')), 87 | Block(0, 88 | 7, 89 | [0, 5], 90 | '''class Super(object):\n''', 91 | 'root.Super', 92 | CLASS_OTHER, 93 | False, 94 | 'module', 95 | ('__', '__class__', 'object'))]] 96 | 97 | def test_relevant_block(self): 98 | for i, code in enumerate(SOURCE_CODE): 99 | span = SPAN[i] 100 | generated_block = get_span_context('`__init__` method calls overridden method', 101 | code, tree_sitter_parser, '', 102 | '', span, None) 103 | 104 | for j, gen_block in enumerate(generated_block): 105 | self.assertEqual(self.desired_block[i][j], gen_block) 106 | 107 | 108 | if __name__ == "__main__": 109 | unittest.main() 110 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_iterreturnsnoniterator.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, CLASS_FUNCTION, CLASS_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # from gluon.tools import Crud 21 | 22 | # class MongoCursorWrapper: 23 | # def __init__ (self, cursor): 24 | # self.__cursor = cursor 25 | 26 | # def __iter__ (self): 27 | # return MongoWrapperIter (self.__cursor) 28 | 29 | # class MongoWrapper: 30 | # def __init__ (self, cursor): 31 | # self.__dict__['cursor'] = cursor 32 | 33 | # def __nonzero__ (self): 34 | # if self.cursor is None: 35 | # return False 36 | # return len (self.cursor) != 0 37 | 38 | # def __iter__ (self): 39 | # return MongoWrapperIter (self.cursor) 40 | 41 | # class MongoWrapperIter: 42 | # def __init__ (self, cursor): 43 | # self.__cursor = iter (cursor) 44 | 45 | # def __iter__ (self): 46 | # return self 47 | 48 | SOURCE_CODE = ['''from gluon.tools import Crud\n\nclass MongoCursorWrapper:\n def __init__ (self, cursor):\n self.__cursor = cursor\n \n def __iter__ (self):\n return MongoWrapperIter (self.__cursor)\n\nclass MongoWrapper:\n def __init__ (self, cursor):\n self.__dict__['cursor'] = cursor\n\n def __nonzero__ (self):\n if self.cursor is None:\n return False\n return len (self.cursor) != 0\n\n def __iter__ (self):\n return MongoWrapperIter (self.cursor)\n\nclass MongoWrapperIter:\n def __init__ (self, cursor):\n self.__cursor = iter (cursor)\n\n def __iter__ (self):\n return self'''] 49 | MESSAGE = ['''Class MongoWrapperIter is returned as an iterator (by [["__iter__"|"relative:///py_file_466.py:7:5:7:24"]]) but does not fully implement the iterator interface.\nClass MongoWrapperIter is returned as an iterator (by [["__iter__"|"relative:///py_file_466.py:19:5:19:24"]]) but does not fully implement the iterator interface.'''] 50 | SPAN = [Span(21, 0, 21, 23)] 51 | 52 | 53 | class TestDistributableQueryContext(unittest.TestCase): 54 | desired_block = [[Block(6, 55 | 7, 56 | [], 57 | '''def __iter__ (self):\n return MongoWrapperIter (self.__cursor)''', 58 | 'root.MongoCursorWrapper.__iter__', 59 | CLASS_FUNCTION, 60 | True, 61 | 'class MongoCursorWrapper:', 62 | ('__', '__class__', 'MongoCursorWrapper')), 63 | Block(18, 64 | 19, 65 | [], 66 | '''def __iter__ (self):\n return MongoWrapperIter (self.cursor)''', 67 | 'root.MongoWrapper.__iter__', 68 | CLASS_FUNCTION, 69 | True, 70 | 'class MongoWrapper:', 71 | ('__', '__class__', 'MongoWrapper')), 72 | Block(25, 73 | 26, 74 | [], 75 | '''def __iter__ (self):\n return self''', 76 | 'root.MongoWrapperIter.__iter__', 77 | CLASS_FUNCTION, 78 | True, 79 | 'class MongoWrapperIter:', 80 | ('__', '__class__', 'MongoWrapperIter')), 81 | Block(21, 82 | 26, 83 | [21, 24], 84 | '''class MongoWrapperIter:\n''', 85 | 'root.MongoWrapperIter', 86 | CLASS_OTHER, 87 | True, 88 | 'module', 89 | ('__', '__class__'))]] 90 | 91 | def test_relevant_block(self): 92 | for i, code in enumerate(SOURCE_CODE): 93 | span = SPAN[i] 94 | message = MESSAGE[i] 95 | generated_block = get_span_context('`__iter__` method returns a non-iterator', 96 | code, tree_sitter_parser, '', 97 | message, span, None) 98 | 99 | for j, gen_block in enumerate(generated_block): 100 | self.assertEqual(self.desired_block[i][j], gen_block) 101 | 102 | 103 | if __name__ == "__main__": 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_missingcalltoinit.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, CLASS_FUNCTION, CLASS_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # class Vehicle(object): 21 | 22 | # def __init__(self): 23 | # self.mobile = True 24 | 25 | # class Car(Vehicle): 26 | 27 | # def __init__(self): 28 | # Vehicle.__init__(self) 29 | # self.car_init() 30 | 31 | # #Car.__init__ is missed out. 32 | # class SportsCar(Car, Vehicle): 33 | 34 | # def __init__(self): 35 | # Vehicle.__init__(self) 36 | # self.sports_car_init() 37 | 38 | # #Fix SportsCar by calling Car.__init__ 39 | # class FixedSportsCar(Car, Vehicle): 40 | 41 | # def __init__(self): 42 | # Car.__init__(self) 43 | # self.sports_car_init() 44 | 45 | SOURCE_CODE = ['''class Vehicle(object):\n \n def __init__(self):\n self.mobile = True\n \nclass Car(Vehicle):\n \n def __init__(self):\n Vehicle.__init__(self)\n self.car_init()\n \n#Car.__init__ is missed out.\nclass SportsCar(Car, Vehicle):\n \n def __init__(self):\n Vehicle.__init__(self)\n self.sports_car_init()\n \n#Fix SportsCar by calling Car.__init__\nclass FixedSportsCar(Car, Vehicle):\n \n def __init__(self):\n Car.__init__(self)\n self.sports_car_init()'''] 46 | SPAN = [Span(12, 0, 12, 29)] 47 | 48 | 49 | class TestDistributableQueryContext(unittest.TestCase): 50 | desired_block = [[Block(2, 51 | 3, 52 | [], 53 | 'def __init__(self):\n self.mobile = True', 54 | 'root.Vehicle.__init__', 55 | CLASS_FUNCTION, 56 | True, 57 | 'class Vehicle(object):', 58 | ('__', '__class__', 'Vehicle')), 59 | Block(7, 60 | 9, 61 | [], 62 | 'def __init__(self):\n Vehicle.__init__(self)\n self.car_init()', 63 | 'root.Car.__init__', 64 | CLASS_FUNCTION, 65 | True, 66 | 'class Car(Vehicle):', 67 | ('__', '__class__', 'Car', 'Vehicle')), 68 | Block(14, 69 | 16, 70 | [], 71 | 'def __init__(self):\n Vehicle.__init__(self)\n self.sports_car_init()', 72 | 'root.SportsCar.__init__', 73 | CLASS_FUNCTION, 74 | True, 75 | 'class SportsCar(Car, Vehicle):', 76 | ('__', '__class__', 'SportsCar', 'Car', 'Vehicle')), 77 | Block(12, 78 | 16, 79 | [12, 13], 80 | 'class SportsCar(Car, Vehicle):\n ', 81 | 'root.SportsCar', 82 | CLASS_OTHER, 83 | True, 84 | 'module', 85 | ('__', '__class__', 'Car', 'Vehicle'))]] 86 | 87 | def test_relevant_block(self): 88 | for i, code in enumerate(SOURCE_CODE): 89 | span = SPAN[i] 90 | generated_block = get_span_context('Missing call to `__init__` during object initialization', 91 | code, tree_sitter_parser, '', 92 | '', span, None) 93 | 94 | for j, gen_block in enumerate(generated_block): 95 | self.assertEqual(self.desired_block[i][j], gen_block) 96 | 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # echo $(ls|grep test_) 3 | for python_file_name in $(ls|grep test_) 4 | do 5 | python $python_file_name 6 | done -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_signatureoverriddenmethod.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, CLASS_FUNCTION 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # import logging 21 | # log = logging.getLogger(__name__) 22 | 23 | # class Resource(): 24 | # def get(self, request): 25 | # self.instance = request 26 | # return self.instance 27 | 28 | # class NonFieldResource(Resource): 29 | # def __init__(self): 30 | # self.rsc = None 31 | 32 | # class FieldResource(): 33 | # def get(self, request, pk): 34 | # instance = self.get_object(request, pk=pk) 35 | # return self.prepare(request, instance) 36 | 37 | # class FieldsResource(FieldResource, NonFieldResource): 38 | # def is_not_found(self, request, response, *args, **kwargs): 39 | # return False 40 | 41 | # def get(self, request): 42 | # params = self.get_params(request) 43 | # queryset = self.get_queryset(request, params) 44 | 45 | SOURCE_CODE = ['''import logging\nlog = logging.getLogger(__name__)\n\nclass Resource():\n def get(self, request):\n self.instance = request\n return self.instance\n\nclass NonFieldResource(Resource):\n def __init__(self):\n self.rsc = None\n\nclass FieldResource():\n def get(self, request, pk):\n instance = self.get_object(request, pk=pk)\n return self.prepare(request, instance)\n\nclass FieldsResource(FieldResource, NonFieldResource):\n def is_not_found(self, request, response, *args, **kwargs):\n return False\n\n def get(self, request):\n params = self.get_params(request)\n queryset = self.get_queryset(request, params)'''] 46 | SPAN = [Span(21, 4, 21, 27)] 47 | 48 | 49 | class TestDistributableQueryContext(unittest.TestCase): 50 | desired_block = [[Block(18, 51 | 19, 52 | [], 53 | '''def is_not_found(self, request, response, *args, **kwargs):\n return False''', 54 | 'root.FieldsResource.is_not_found', 55 | CLASS_FUNCTION, 56 | False, 57 | 'class FieldsResource(FieldResource, NonFieldResource):', 58 | ('__', '__class__', 'FieldsResource', 'FieldResource', 'NonFieldResource', 'Resource')), 59 | Block(21, 60 | 23, 61 | [], 62 | 'def get(self, request):\n params = self.get_params(request)\n queryset = self.get_queryset(request, params)', 63 | 'root.FieldsResource.get', 64 | 'CLASS_FUNCTION', 65 | True, 66 | 'class FieldsResource(FieldResource, NonFieldResource):', 67 | ('__', '__class__', 'FieldsResource', 'FieldResource', 'NonFieldResource', 'Resource')), 68 | Block(17, 69 | 23, 70 | [17, 20], 71 | 'class FieldsResource(FieldResource, NonFieldResource):\n', 72 | 'root.FieldsResource', 73 | 'CLASS_OTHER', 74 | False, 75 | 'module', 76 | ('__', '__class__', 'FieldResource', 'NonFieldResource')), 77 | Block(4, 78 | 6, 79 | [], 80 | 'def get(self, request):\n self.instance = request\n return self.instance', 81 | 'root.Resource.get', 82 | 'CLASS_FUNCTION', 83 | False, 84 | 'class Resource():', 85 | ('__', '__class__', 'Resource')), 86 | Block(9, 87 | 10, 88 | [], 89 | 'def __init__(self):\n self.rsc = None', 90 | 'root.NonFieldResource.__init__', 91 | 'CLASS_FUNCTION', 92 | False, 93 | 'class NonFieldResource(Resource):', 94 | ('__', '__class__', 'NonFieldResource', 'Resource')), 95 | Block(13, 96 | 15, 97 | [], 98 | 'def get(self, request, pk):\n instance = self.get_object(request, pk=pk)\n return self.prepare(request, instance)', 99 | 'root.FieldResource.get', 100 | 'CLASS_FUNCTION', 101 | True, 102 | 'class FieldResource():', 103 | ('__', '__class__', 'FieldResource')), 104 | Block(3, 105 | 6, 106 | [3], 107 | 'class Resource():', 108 | 'root.Resource', 109 | 'CLASS_OTHER', 110 | False, 111 | 'module', 112 | ('__', '__class__')), 113 | Block(8, 114 | 10, 115 | [8], 116 | 'class NonFieldResource(Resource):', 117 | 'root.NonFieldResource', 118 | 'CLASS_OTHER', 119 | False, 120 | 'module', 121 | ('__', '__class__', 'Resource')), 122 | Block(12, 123 | 15, 124 | [12], 125 | 'class FieldResource():', 126 | 'root.FieldResource', 127 | 'CLASS_OTHER', 128 | False, 129 | 'module', 130 | ('__', '__class__'))]] 131 | 132 | def test_relevant_block(self): 133 | for i, code in enumerate(SOURCE_CODE): 134 | span = SPAN[i] 135 | generated_block = get_span_context('Signature mismatch in overriding method', 136 | code, tree_sitter_parser, '', 137 | '', span, None) 138 | 139 | for j, gen_block in enumerate(generated_block): 140 | self.assertEqual(self.desired_block[i][j], gen_block) 141 | 142 | 143 | if __name__ == "__main__": 144 | unittest.main() 145 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_useimplicitnonereturnvalue.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context, __columns__ 2 | from basecontexts import Block, MODULE_FUNCTION, MODULE_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import pandas as pd 6 | import unittest 7 | # to handle import while stand-alone test 8 | import sys 9 | sys.path.insert(0, '..') 10 | PY_LANGUAGE = Language("../my-languages.so", "python") 11 | 12 | # to handle bazel tests 13 | # PATH_PREFIX = "../code-cubert/data_preparation/" 14 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 15 | 16 | tree_sitter_parser = Parser() 17 | tree_sitter_parser.set_language(PY_LANGUAGE) 18 | 19 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 20 | 21 | # SOURCE_CODE = [ 22 | # '''import os 23 | # import numpy as np 24 | # 25 | # class Employee: 26 | # def __init__(self, id, name): 27 | # self.id = id 28 | # self.name = name 29 | # 30 | # class SalaryEmployee(Employee): 31 | # def __init__(self, id, name, weekly_salary): 32 | # super().__init__(id, name) 33 | # self.weekly_salary = weekly_salary 34 | # self.total = self.calculate_payroll() 35 | # 36 | # def calculate_payroll(self): 37 | # return self.weekly_salary 38 | # 39 | # class HourlyEmployee(Employee): 40 | # def __init__(self, id, name, hours_worked, hour_rate): 41 | # super().__init__(id, name) 42 | # self.hours_worked = hours_worked 43 | # self.hour_rate = hour_rate 44 | # self.total = self.calculate_payroll() 45 | # 46 | # def calculate_payroll(self): 47 | # return self.hours_worked * self.hour_rate 48 | # 49 | # class FulltimeEmployee(Employee): 50 | # def __init__(self, id, name, employee_type='F'): 51 | # super().__init__(id, name) 52 | # if(employee_type=='F'): 53 | # self.ratio = np.array([0.5, 0.3, 0.2]) 54 | # else: 55 | # self.ratio = np.array([1, 0, 0]) 56 | # 57 | # def work(self): 58 | # print(f'{self.name} gets {self.div} in compensation') 59 | # 60 | # class Intern(HourlyEmployee, FulltimeEmployee): 61 | # def __init__(self, id, name, hours_worked, hour_rate): 62 | # super(Intern, self).__init__(id, name, hours_worked, hour_rate) 63 | # self.bonus = 0 64 | # 65 | # def get_bonus(self, hours): 66 | # if(self.hours_worked < 10): 67 | # self.bonus = 2 68 | # else: 69 | # self.bonus = 10 70 | # 71 | # def get_work_desc(self, name, hours, bonus): 72 | # return self.name + " worked " + str(hours) + " hours | " + str(self.div) + " : added bonus - " + str(bonus) 73 | # 74 | # def work(self): 75 | # self.div = np.array(self.ratio * np.array([self.total for i in range(len(self.ratio))]), dtype=int) 76 | # self.bonus = self.get_bonus(self.hours_worked) 77 | # if(self.bonus > 0): 78 | # self.bonus = self.bonus + get_festive_bonus() 79 | # desc = self.get_work_desc(self.name, self.bonus) 80 | # print(desc) 81 | # 82 | # def get_festive_bonus(intern_obj): 83 | # intern_obj.bonus = intern_obj.bonus * 1.2 84 | # 85 | # intern = Intern(4, "II1", 20, 1.5) 86 | # intern.work() 87 | # print(get_festive_bonus(intern))'''] 88 | 89 | SOURCE_CODE = ['''import os\nimport numpy as np\n\nclass Employee:\n def __init__(self, id, name):\n self.id = id\n self.name = name\n\nclass SalaryEmployee(Employee):\n def __init__(self, id, name, weekly_salary):\n super().__init__(id, name)\n self.weekly_salary = weekly_salary\n self.total = self.calculate_payroll()\n\n def calculate_payroll(self):\n return self.weekly_salary\n\t\nclass HourlyEmployee(Employee):\n def __init__(self, id, name, hours_worked, hour_rate):\n super().__init__(id, name)\n self.hours_worked = hours_worked\n self.hour_rate = hour_rate\n self.total = self.calculate_payroll()\n\n def calculate_payroll(self):\n return self.hours_worked * self.hour_rate\n\nclass FulltimeEmployee(Employee):\n def __init__(self, id, name, employee_type=\'F\'):\n super().__init__(id, name)\n if(employee_type==\'F\'):\n self.ratio = np.array([0.5, 0.3, 0.2])\n else:\n self.ratio = np.array([1, 0, 0])\n\n def work(self):\n print(f\'{self.name} gets {self.div} in compensation\')\n\nclass Intern(HourlyEmployee, FulltimeEmployee):\n def __init__(self, id, name, hours_worked, hour_rate):\n super(Intern, self).__init__(id, name, hours_worked, hour_rate)\n self.bonus = 0\n\n def get_bonus(self, hours):\n if(self.hours_worked < 10):\n self.bonus = 2\n else:\n self.bonus = 10\n \n def get_work_desc(self, name, hours, bonus):\n return self.name + " worked " + str(hours) + " hours | " + str(self.div) + " : added bonus - " + str(bonus)\n \n def work(self):\n self.div = np.array(self.ratio * np.array([self.total for i in range(len(self.ratio))]), dtype=int)\n self.bonus = self.get_bonus(self.hours_worked)\n if(self.bonus > 0):\n self.bonus = self.bonus + get_festive_bonus()\n desc = self.get_work_desc(self.name, self.bonus)\n print(desc)\n\ndef get_festive_bonus(intern_obj):\n intern_obj.bonus = intern_obj.bonus * 1.2\n\nintern = Intern(4, "II1", 20, 1.5)\nintern.work()\nprint(get_festive_bonus(intern))'''] 90 | SPANS = [Span(65, 6, 65, 31)] 91 | 92 | 93 | class TestDistributableQueryContext(unittest.TestCase): 94 | desired_block = [[Block(0, 95 | 65, 96 | [0, 1, 2, 7, 16, 26, 37, 59, 62, 63, 64, 65], 97 | '''import os\nimport numpy as np\n\n\n\t\n\n\n\n\nintern = Intern(4, "II1", 20, 1.5)\nintern.work()\nprint(get_festive_bonus(intern))''', 98 | 'root', 99 | MODULE_OTHER, 100 | True, 101 | 'module', 102 | ('__', '__class__')), 103 | Block(60, 104 | 61, 105 | [], 106 | '''def get_festive_bonus(intern_obj):\n intern_obj.bonus = intern_obj.bonus * 1.2''', 107 | 'root.get_festive_bonus', 108 | MODULE_FUNCTION, 109 | True, 110 | 'module', 111 | ('__', '__class__'))]] 112 | 113 | def test_relevant_block(self): 114 | test_aux_result_df = pd.read_csv('./test_data/test__aux_res.csv', 115 | names=__columns__) 116 | for i, code in enumerate(SOURCE_CODE): 117 | span = SPANS[i] 118 | generated_block = get_span_context('Use of the return value of a procedure', 119 | code, tree_sitter_parser, 'test_unused_imp.py', 120 | '', span, test_aux_result_df) 121 | 122 | for j, gen_block in enumerate(generated_block): 123 | self.assertEqual(self.desired_block[i][j], gen_block) 124 | 125 | 126 | if __name__ == "__main__": 127 | unittest.main() 128 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/test_wrongnumberargumentsinclassinstantiation.py: -------------------------------------------------------------------------------- 1 | from get_context import get_span_context 2 | from basecontexts import Block, MODULE_FUNCTION, CLASS_FUNCTION, CLASS_OTHER 3 | from tree_sitter import Language, Parser 4 | from collections import namedtuple 5 | import unittest 6 | # to handle import while stand-alone test 7 | import sys 8 | sys.path.insert(0, '..') 9 | PY_LANGUAGE = Language("../my-languages.so", "python") 10 | 11 | # to handle bazel tests 12 | # PATH_PREFIX = "../code-cubert/data_preparation/" 13 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 14 | 15 | tree_sitter_parser = Parser() 16 | tree_sitter_parser.set_language(PY_LANGUAGE) 17 | 18 | Span = namedtuple('Span', 'start_line start_col end_line end_col') 19 | 20 | # class Point(object): 21 | # def __init__(self, x, y): 22 | # self.x = x 23 | # self.y = y 24 | 25 | # def sum(self, x, y): 26 | # return x + y 27 | 28 | # def get_obj(): 29 | # p = Point(1,2,3) 30 | 31 | # if __name__ == '__main__': 32 | # get_obj() 33 | 34 | SOURCE_CODE = ['''class Point(object):\n def __init__(self, x, y):\n self.x = x\n self.y = y\n\n def sum(self, x, y):\n return x + y\n\ndef get_obj():\n p = Point(1,2,3)\n\nif __name__ == '__main__':\n get_obj()'''] 35 | SPAN = [Span(9, 8, 9, 20)] 36 | 37 | 38 | class TestDistributableQueryContext(unittest.TestCase): 39 | desired_block = [[Block(8, 40 | 9, 41 | [], 42 | '''def get_obj():\n p = Point(1,2,3)''', 43 | 'root.get_obj', 44 | MODULE_FUNCTION, 45 | True, 46 | 'module', 47 | ('__', '__class__')), 48 | Block(1, 49 | 3, 50 | [], 51 | '''def __init__(self, x, y):\n self.x = x\n self.y = y''', 52 | 'root.Point.__init__', 53 | CLASS_FUNCTION, 54 | True, 55 | 'class Point(object):', 56 | ('__', '__class__', 'Point')), 57 | Block(5, 58 | 6, 59 | [], 60 | '''def sum(self, x, y):\n return x + y''', 61 | 'root.Point.sum', 62 | CLASS_FUNCTION, 63 | False, 64 | 'class Point(object):', 65 | ('__', '__class__', 'Point')), 66 | Block(0, 67 | 6, 68 | [0, 4], 69 | '''class Point(object):\n''', 70 | 'root.Point', 71 | CLASS_OTHER, 72 | False, 73 | 'module', 74 | ('__', '__class__', 'object'))]] 75 | 76 | def test_relevant_block(self): 77 | for i, code in enumerate(SOURCE_CODE): 78 | span = SPAN[i] 79 | generated_block = get_span_context('Wrong number of arguments in a class instantiation', 80 | code, tree_sitter_parser, '', 81 | '', span, None) 82 | 83 | for j, gen_block in enumerate(generated_block): 84 | self.assertEqual(self.desired_block[i][j], gen_block) 85 | 86 | 87 | if __name__ == "__main__": 88 | unittest.main() 89 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/unusedimport.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, MODULE_OTHER 2 | from get_context import __columns__ 3 | 4 | 5 | class UnusedImport(BaseContexts): 6 | """ 7 | This module extracts module level context and corrsponding metadata from 8 | program content. 9 | Metadata contains start/end line number informations for corresponding context. 10 | """ 11 | def __init__(self, parser): 12 | """ 13 | Args: 14 | parser: Tree sitter parser object 15 | """ 16 | super().__init__(parser) 17 | 18 | def get_import_use_lines(self, file_path, aux_result_df): 19 | """ 20 | This functions returns relevant Blocks as query specific context. 21 | Args: 22 | file_path : file of program_content 23 | aux_result_df: auxiliary query results dataframe 24 | Returns: 25 | A list consisting lines using import functionality 26 | """ 27 | used_import_df = aux_result_df.loc[(aux_result_df["Name"] == 'Used import') & (aux_result_df["Path"] == file_path), 28 | __columns__] 29 | if(used_import_df.shape[0] == 0): 30 | return set() 31 | 32 | # Get lines in start and end line, with 33 | # consideration that CodeQL has 1-based index 34 | used_import_df["Lines"] = (used_import_df.apply( 35 | lambda x: [i for i in range(int(x.Start_line) - 1, int(x.End_line))], 36 | axis=1)) 37 | 38 | lines_using_import = set([line for line_sublist in used_import_df["Lines"].tolist() 39 | for line in line_sublist]) 40 | 41 | return lines_using_import 42 | 43 | 44 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df): 45 | """ 46 | This functions returns relevant Blocks as query specific context. 47 | Args: 48 | program_content: Program in string format from which we need to 49 | extract classes 50 | parser : tree_sitter_parser 51 | file_path : file of program_content 52 | message : CodeQL message 53 | result_span: CodeQL-treesitter adjusted namedtuple of 54 | (start_line, start_col, end_line, end_col) 55 | aux_result_df: auxiliary query results dataframe 56 | Returns: 57 | A list consisting relevant Blocks 58 | """ 59 | start_line = result_span.start_line 60 | end_line = result_span.end_line 61 | 62 | context_object = UnusedImport(parser) 63 | all_blocks = context_object.get_all_blocks(program_content) 64 | 65 | file_path = '/' + file_path 66 | lines_using_import = context_object.get_import_use_lines(file_path, aux_result_df) 67 | 68 | relevant_blocks = [] 69 | local_block = context_object.get_local_block(program_content, start_line, end_line) 70 | for block in all_blocks: 71 | # MODULE_OTHER always relevant 72 | if(block.block_type == MODULE_OTHER): 73 | block.relevant = True 74 | elif(block == local_block): 75 | block.relevant = True 76 | else: 77 | block_lines = context_object.get_block_lines(block) 78 | for line in block_lines: 79 | if(line in lines_using_import): 80 | block.relevant = True 81 | 82 | relevant_blocks.append(block) 83 | 84 | return relevant_blocks 85 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/useimplicitnonereturnvalue.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts 2 | from basecontexts import MODULE_FUNCTION, CLASS_FUNCTION 3 | from get_context import __columns__ 4 | from collections import namedtuple 5 | import re 6 | 7 | Call = namedtuple('Call', 'func_name start_line start_col end_line end_col') 8 | 9 | 10 | class UseImplicitNoneReturnValue(BaseContexts): 11 | """ 12 | This module extracts module level context and corrsponding metadata from 13 | program content. 14 | Metadata contains start/end line number informations for corresponding context. 15 | """ 16 | def __init__(self, parser): 17 | """ 18 | Args: 19 | parser: Tree sitter parser object 20 | """ 21 | super().__init__(parser) 22 | 23 | def create_used_func_calls(self, message, start_line, start_column, end_line, end_column): 24 | """ 25 | Args: 26 | message: CodeQL result message 27 | start_line: start_line of auxillary span 28 | start_column: start_column of auxillary span 29 | end_line: end_line of auxillary span 30 | end_column: end_column of auxillary span 31 | Returns: 32 | Fully qualified function name 33 | """ 34 | matches = re.findall(r"\[\[""(.*)""\\|", message) 35 | qualified_func = matches[0].strip('"') 36 | 37 | # CodeQL has 1-based index 38 | return Call(qualified_func, int(start_line) - 1, int(start_column) - 1, 39 | int(end_line) - 1, int(end_column)) 40 | 41 | def get_used_functions(self, file_path, aux_result_df): 42 | """ 43 | This functions returns relevant Blocks as query specific context. 44 | Args: 45 | file_path : file of program_content 46 | aux_result_df: auxiliary query results dataframe 47 | Returns: 48 | A list consisting lines using import functionality 49 | """ 50 | used_import_df = aux_result_df.loc[(aux_result_df["Name"] == 'Used function call map') & (aux_result_df["Path"] == file_path), 51 | __columns__] 52 | if(used_import_df.shape[0] == 0): 53 | return set() 54 | 55 | # Get lines in start and end line, with 56 | # consideration that CodeQL has 1-based index 57 | used_import_df["call_map"] = (used_import_df.apply( 58 | lambda x: self.create_used_func_calls(x.Message, 59 | x.Start_line, x.Start_column, 60 | x.End_line, x.End_column), 61 | axis=1)) 62 | 63 | used_func_calls = used_import_df["call_map"].tolist() 64 | return used_func_calls 65 | 66 | 67 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df): 68 | """ 69 | This functions returns relevant Blocks as query specific context. 70 | Args: 71 | program_content: Program in string format from which we need to 72 | extract classes 73 | parser : tree_sitter_parser 74 | file_path : file of program_content 75 | message : CodeQL message 76 | result_span: CodeQL-treesitter adjusted namedtuple of 77 | (start_line, start_col, end_line, end_col) 78 | aux_result_df: auxiliary query results dataframe 79 | Returns: 80 | A list consisting relevant Blocks 81 | """ 82 | start_line = result_span.start_line 83 | start_col = result_span.start_col 84 | end_line = result_span.end_line 85 | end_col = result_span.end_col 86 | 87 | context_object = UseImplicitNoneReturnValue(parser) 88 | 89 | # local_block is relevant 90 | local_block = context_object.get_local_block(program_content, start_line, end_line) 91 | local_block.relevant = True 92 | required_blocks = [local_block] 93 | 94 | local_block_lines = local_block.other_lines 95 | if(not local_block_lines): 96 | local_block_lines = [i for i in range(local_block.start_line, local_block.end_line + 1)] 97 | 98 | file_path = '/' + file_path 99 | used_func_calls = context_object.get_used_functions(file_path, aux_result_df) 100 | relevant_call_maps = [] 101 | for call_map in used_func_calls: 102 | if((call_map.start_line, call_map.start_col, call_map.end_line, call_map.end_col) 103 | == (start_line, start_col, end_line, end_col)): 104 | relevant_call_maps.append(call_map) 105 | 106 | all_blocks = context_object.get_all_blocks(program_content) 107 | for block in all_blocks: 108 | add_block = False 109 | if(block.block_type == CLASS_FUNCTION): 110 | block_func_name = block.metadata.split('.')[-1] 111 | block_class_name = block.metadata.split('.')[-2] 112 | 113 | for call_map in used_func_calls: 114 | used_func_name = call_map.func_name.split('.')[-1] 115 | used_func_class = call_map.func_name.split('.')[0] # enclosing class in case of inner class 116 | if(block_func_name == used_func_name 117 | and block_class_name == used_func_class 118 | and (call_map.start_line in local_block_lines 119 | or call_map.end_line in local_block_lines)): 120 | add_block = True 121 | # if specific used call is relevant 122 | if(call_map in relevant_call_maps): 123 | block.relevant = True 124 | 125 | elif(block.block_type == MODULE_FUNCTION): 126 | block_func_name = block.metadata.split('.')[-1] 127 | 128 | for call_map in used_func_calls: 129 | used_func_name = call_map.func_name 130 | if(block_func_name == used_func_name 131 | and (call_map.start_line in local_block_lines 132 | or call_map.end_line in local_block_lines)): 133 | add_block = True 134 | # if specific used call is relevant 135 | if(call_map in relevant_call_maps): 136 | block.relevant = True 137 | 138 | # If already not in required_blocks and add_block == True 139 | # then add block to required_blocks. This condition is 140 | # required to avoid duplicate local_block 141 | if(add_block and (block not in required_blocks)): 142 | required_blocks.append(block) 143 | 144 | return required_blocks 145 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/wrongnumberargumentsincall.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts 2 | from basecontexts import MODULE_FUNCTION, CLASS_FUNCTION 3 | from get_context import __columns__ 4 | from collections import namedtuple 5 | import re 6 | 7 | Call = namedtuple('Call', 'func_name start_line start_col end_line end_col') 8 | 9 | 10 | class WrongNumberArgumentsInCall(BaseContexts): 11 | """ 12 | This module extracts module level context and corrsponding metadata from 13 | program content. 14 | Metadata contains start/end line number informations for corresponding context. 15 | """ 16 | def __init__(self, parser): 17 | """ 18 | Args: 19 | parser: Tree sitter parser object 20 | """ 21 | super().__init__(parser) 22 | 23 | def create_usable_func_map(self, message, start_line, start_column, end_line, end_column): 24 | """ 25 | Args: 26 | message: CodeQL result message 27 | start_line: start_line of auxillary span 28 | start_column: start_column of auxillary span 29 | end_line: end_line of auxillary span 30 | end_column: end_column of auxillary span 31 | Returns: 32 | Fully qualified function name 33 | """ 34 | matches = re.findall(r"\[\[""(.*)""\\|", message) 35 | qualified_func = matches[0].strip('"') 36 | 37 | # CodeQL has 1-based index 38 | return Call(qualified_func, int(start_line) - 1, int(start_column) - 1, 39 | int(end_line) - 1, int(end_column)) 40 | 41 | def get_used_functions(self, file_path, aux_result_df): 42 | """ 43 | This functions returns relevant Blocks as query specific context. 44 | Args: 45 | file_path : file of program_content 46 | aux_result_df: auxiliary query results dataframe 47 | Returns: 48 | A list consisting lines using import functionality 49 | """ 50 | used_import_df = aux_result_df.loc[(aux_result_df["Name"] == 'Function call map') & (aux_result_df["Path"] == file_path), 51 | __columns__] 52 | if(used_import_df.shape[0] == 0): 53 | return set() 54 | 55 | # Get lines in start and end line, with 56 | # consideration that CodeQL has 1-based index 57 | used_import_df["call_map"] = (used_import_df.apply( 58 | lambda x: self.create_usable_func_map(x.Message, 59 | x.Start_line, x.Start_column, 60 | x.End_line, x.End_column), 61 | axis=1)) 62 | 63 | used_func_calls = used_import_df["call_map"].tolist() 64 | return used_func_calls 65 | 66 | 67 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df): 68 | """ 69 | This functions returns relevant Blocks as query specific context. 70 | Args: 71 | program_content: Program in string format from which we need to 72 | extract classes 73 | parser : tree_sitter_parser 74 | file_path : file of program_content 75 | message : CodeQL message 76 | result_span: CodeQL-treesitter adjusted namedtuple of 77 | (start_line, start_col, end_line, end_col) 78 | aux_result_df: auxiliary query results dataframe 79 | Returns: 80 | A list consisting relevant Blocks 81 | """ 82 | start_line = result_span.start_line 83 | start_col = result_span.start_col 84 | end_line = result_span.end_line 85 | end_col = result_span.end_col 86 | 87 | context_object = WrongNumberArgumentsInCall(parser) 88 | 89 | # local_block is relevant 90 | local_block = context_object.get_local_block(program_content, start_line, end_line) 91 | local_block.relevant = True 92 | required_blocks = [local_block] 93 | 94 | local_block_lines = local_block.other_lines 95 | if(not local_block_lines): 96 | local_block_lines = [i for i in range(local_block.start_line, local_block.end_line + 1)] 97 | 98 | file_path = '/' + file_path 99 | used_func_calls = context_object.get_used_functions(file_path, aux_result_df) 100 | relevant_call_maps = [] 101 | for call_map in used_func_calls: 102 | if((call_map.start_line, call_map.start_col, call_map.end_line, call_map.end_col) 103 | == (start_line, start_col, end_line, end_col)): 104 | relevant_call_maps.append(call_map) 105 | 106 | all_blocks = context_object.get_all_blocks(program_content) 107 | for block in all_blocks: 108 | add_block = False 109 | if(block.block_type == CLASS_FUNCTION): 110 | block_func_name = block.metadata.split('.')[-1] 111 | block_class_name = block.metadata.split('.')[-2] 112 | 113 | for call_map in used_func_calls: 114 | used_func_name = call_map[0].split('.')[-1] 115 | used_func_class = call_map[0].split('.')[0] # enclosing class in case of inner class 116 | if(block_func_name == used_func_name 117 | and block_class_name == used_func_class 118 | and (call_map.start_line in local_block_lines 119 | or call_map.end_line in local_block_lines)): 120 | add_block = True 121 | # if specific used call is relevant 122 | if(call_map in relevant_call_maps): 123 | block.relevant = True 124 | 125 | elif(block.block_type == MODULE_FUNCTION): 126 | block_func_name = block.metadata.split('.')[-1] 127 | 128 | for call_map in used_func_calls: 129 | used_func_name = call_map[0] 130 | if(block_func_name == used_func_name 131 | and (call_map.start_line in local_block_lines 132 | or call_map.end_line in local_block_lines)): 133 | add_block = True 134 | # if specific used call is relevant 135 | if(call_map in relevant_call_maps): 136 | block.relevant = True 137 | 138 | # If already not in required_blocks and add_block == True 139 | # then add block to required_blocks. This condition is 140 | # required to avoid duplicate local_block 141 | if(add_block and (block not in required_blocks)): 142 | required_blocks.append(block) 143 | 144 | return required_blocks 145 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/contexts/wrongnumberargumentsinclassinstantiation.py: -------------------------------------------------------------------------------- 1 | from basecontexts import BaseContexts, CLASS_FUNCTION, CALL_NODE_TYPE 2 | from basecontexts import ContextRetrievalError 3 | import re 4 | 5 | 6 | class WrongNumberArgumentsInClassInstantiation(BaseContexts): 7 | """ 8 | This module extracts module level context and corrsponding metadata from 9 | program content. 10 | Metadata contains start/end line number informations for corresponding context. 11 | """ 12 | def __init__(self, parser): 13 | """ 14 | Args: 15 | parser: Tree sitter parser object 16 | """ 17 | super().__init__(parser) 18 | self.postordered_nodes = [] 19 | self.target_node = None 20 | 21 | def set_target_node(self, root_node, start_line, end_line, start_col, end_col): 22 | """ 23 | This functions returns the node which contains start_line to end_line 24 | in the tree_sitter tree. 25 | Args: 26 | root_node: root node of tree_sitter tree 27 | start_line : start line of a span 28 | end_line : end line of a span 29 | start_col : start column of answer span 30 | end_col : end column of answer span 31 | Returns: 32 | tree_sitter node containing start_line to end_line 33 | """ 34 | if(root_node.type == CALL_NODE_TYPE 35 | and root_node.start_point[0] == start_line 36 | and root_node.end_point[0] == end_line 37 | and root_node.start_point[1] == start_col 38 | and root_node.end_point[1] == end_col): 39 | self.target_node = root_node 40 | return 41 | else: 42 | for ch in root_node.children: 43 | self.set_target_node(ch, start_line, end_line, start_col, end_col) 44 | 45 | return 46 | 47 | def postorder_traverse(self, root_node, program_content): 48 | """ 49 | This functions returns postorder traversal of nodes and corresponding 50 | node literals. 51 | Args: 52 | root_node: root node of tree_sitter tree 53 | program_content: Program in string format from which we need to 54 | extract node literal. 55 | Returns: 56 | None 57 | """ 58 | if(len(root_node.children) == 0): 59 | literal = bytes(program_content, "utf8")[ 60 | root_node.start_byte:root_node.end_byte 61 | ].decode("utf8") 62 | self.postordered_nodes.append((literal, root_node)) 63 | else: 64 | for ch in root_node.children: 65 | self.postorder_traverse(ch, program_content) 66 | 67 | def get_target_class_name(self, root_node, program_content, start_line, end_line, start_col, end_col): 68 | """ 69 | This functions returns class name which is incorrectly instantiated. 70 | Args: 71 | root_node: root node of tree_sitter tree 72 | program_content: Program in string format from which we need to 73 | extract node literal 74 | start_line : start line of a span 75 | end_line : end line of a span 76 | start_col : start column of answer span 77 | end_col : end column of answer span 78 | Returns: 79 | None 80 | """ 81 | self.set_target_node(root_node, start_line, end_line, start_col, end_col) 82 | self.postorder_traverse(self.target_node, program_content) 83 | for i, node in enumerate(self.postordered_nodes): 84 | if(node[0] == '('): 85 | class_name = self.postordered_nodes[i - 1][0] 86 | break 87 | 88 | return class_name 89 | 90 | 91 | def get_query_specific_context(program_content, parser, file_path, message, result_span, aux_result_df=None): 92 | """ 93 | This functions returns relevant Blocks as query specific context. 94 | Args: 95 | program_content: Program in string format from which we need to 96 | extract classes 97 | parser : tree_sitter_parser 98 | file_path : file of program_content 99 | message : CodeQL message 100 | result_span: CodeQL-treesitter adjusted namedtuple of 101 | (start_line, start_col, end_line, end_col) 102 | aux_result_df: auxiliary query results dataframe 103 | Returns: 104 | A list consisting relevant Blocks 105 | """ 106 | start_line = result_span.start_line 107 | # start_col = result_span.start_col 108 | end_line = result_span.end_line 109 | # end_col = result_span.end_col 110 | 111 | context_object = WrongNumberArgumentsInClassInstantiation(parser) 112 | 113 | # tree = parser.parse(bytes(program_content, "utf8")) 114 | # root_node = tree.root_node 115 | 116 | # class instantiation Block is relevant 117 | local_block = context_object.get_local_block(program_content, start_line, end_line) 118 | local_block.relevant = True 119 | required_blocks = [local_block] 120 | 121 | # class_name = context_object.get_target_class_name(root_node, program_content, 122 | # start_line, end_line, 123 | # start_col, end_col) 124 | 125 | matches = re.findall(r"relative:\/\/\/[a-zA-Z0-9_.]*:(\d+):(\d+):(\d+):(\d+)", message) 126 | # get the class of class_name 127 | req_class_block = None 128 | if len(matches) == 1: 129 | # len shld be always one in this query 130 | # if len==0, then some inbuilt class call 131 | all_classes = context_object.get_all_classes(program_content) 132 | init_start_line = int(matches[0][0]) 133 | init_end_line = int(matches[0][2]) 134 | 135 | for block in all_classes: 136 | block_start_line = block.start_line 137 | block_end_line = block.end_line 138 | if (block_start_line <= init_start_line 139 | and block_end_line >= init_end_line): 140 | req_class_block = block 141 | break 142 | 143 | if(len(matches) == 1 and req_class_block is None): 144 | raise ContextRetrievalError({"message": "No __init__ in class", 145 | "type": "Wrong number of arg in class instantiation"}) 146 | 147 | # if class is not from current module (OR) an inner class 148 | # (OR) couldn't be found 149 | if(req_class_block is not None): 150 | # Blocks from corresponding class 151 | all_blocks = context_object.get_all_blocks(program_content) 152 | for block in all_blocks: 153 | if(block.start_line >= req_class_block.start_line 154 | and block.end_line <= req_class_block.end_line): 155 | if(block.block_type == CLASS_FUNCTION): 156 | func_name = block.metadata.split('.')[-1] 157 | if(func_name == '__init__'): 158 | block.relevant = True 159 | required_blocks.append(block) 160 | 161 | return required_blocks 162 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/create_groupwise_prediction_dataset.py: -------------------------------------------------------------------------------- 1 | from create_span_prediction_training_examples import PREPAREVOCAB 2 | import dataset_with_context_pb2 3 | from transformers import RobertaTokenizer 4 | from tqdm import tqdm 5 | 6 | 7 | def create_groupwise_prediction_dataset(block_query_subtokens_labels_dataset, vocab_file: str, model_type: str): 8 | """ 9 | This function creates examples for groupwise relevance prediction dataset to be 10 | eventually used for span prediction. 11 | Args: 12 | block_query_subtokens_labels_dataset: BlockQuerySubtokensLabelsDataset protobuf 13 | vocab_file: model vocab file 14 | model_type: cubert/codebert 15 | Returns: 16 | ExampleForGroupwisePredictionDataset protobuf 17 | """ 18 | prepare_vocab_object = PREPAREVOCAB(vocab_file) 19 | codebert_tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") 20 | 21 | if(model_type == 'cubert'): 22 | cls_token = '[CLS]_' 23 | sep_token = '[SEP]_' 24 | cls_token_id = prepare_vocab_object.convert_by_vocab([cls_token]) 25 | sep_token_id = prepare_vocab_object.convert_by_vocab([sep_token]) 26 | elif(model_type == 'codebert'): 27 | cls_token = codebert_tokenizer.cls_token 28 | sep_token = codebert_tokenizer.sep_token 29 | cls_token_id = codebert_tokenizer.convert_tokens_to_ids([cls_token]) 30 | sep_token_id = codebert_tokenizer.convert_tokens_to_ids([sep_token]) 31 | 32 | example_for_group_pred = dataset_with_context_pb2.ExampleForGroupwisePredictionDataset() 33 | 34 | for i in tqdm(range(len( 35 | block_query_subtokens_labels_dataset.block_query_subtokens_labels_item)), 36 | desc="Groupwise_prediction"): 37 | 38 | single_group = dataset_with_context_pb2.SingleGroupExample() 39 | single_group.example_type = (block_query_subtokens_labels_dataset. 40 | example_types[i]) 41 | single_group.distributable = (block_query_subtokens_labels_dataset. 42 | block_query_subtokens_labels_item[i]. 43 | block_query_subtokens_labels_group_item[0].distributable) 44 | 45 | for k in range(len( 46 | block_query_subtokens_labels_dataset.block_query_subtokens_labels_item[i]. 47 | block_query_subtokens_labels_group_item)): 48 | assert single_group.distributable == (block_query_subtokens_labels_dataset. 49 | block_query_subtokens_labels_item[i]. 50 | block_query_subtokens_labels_group_item[k].distributable) 51 | 52 | group_item = dataset_with_context_pb2.SingleGroupItem() 53 | 54 | group_item.query_id = (block_query_subtokens_labels_dataset. 55 | block_query_subtokens_labels_item[i]. 56 | block_query_subtokens_labels_group_item[k].query_id) 57 | 58 | group_item.block_id = (block_query_subtokens_labels_dataset. 59 | block_query_subtokens_labels_item[i]. 60 | block_query_subtokens_labels_group_item[k].block. 61 | unique_block_id) 62 | 63 | group_item.relevance = (block_query_subtokens_labels_dataset. 64 | block_query_subtokens_labels_item[i]. 65 | block_query_subtokens_labels_group_item[k]. 66 | block.relevance_label) 67 | 68 | query_subtokens = [] 69 | query_subtokens.extend(block_query_subtokens_labels_dataset. 70 | block_query_subtokens_labels_item[i]. 71 | block_query_subtokens_labels_group_item[k]. 72 | query_name_subtokens) 73 | 74 | program_subtokens = [] 75 | labels = [] 76 | 77 | for j in range(len(block_query_subtokens_labels_dataset. 78 | block_query_subtokens_labels_item[i]. 79 | block_query_subtokens_labels_group_item[k]. 80 | block_subtokens_labels)): 81 | program_subtokens.append(block_query_subtokens_labels_dataset. 82 | block_query_subtokens_labels_item[i]. 83 | block_query_subtokens_labels_group_item[k]. 84 | block_subtokens_labels[j].program_subtoken) 85 | labels.append(block_query_subtokens_labels_dataset. 86 | block_query_subtokens_labels_item[i]. 87 | block_query_subtokens_labels_group_item[k]. 88 | block_subtokens_labels[j].label) 89 | 90 | program_subtokens_ids = prepare_vocab_object.convert_by_vocab( 91 | program_subtokens) 92 | query_subtokens_ids = prepare_vocab_object.convert_by_vocab( 93 | query_subtokens) 94 | 95 | input_ids = [] 96 | input_mask = [] 97 | segment_ids = [] 98 | labels_ids = [] 99 | 100 | input_ids.extend(cls_token_id) 101 | segment_ids.append(0) 102 | input_mask.append(1) 103 | labels_ids.append(dataset_with_context_pb2.OutputLabels.Value("_")) 104 | 105 | for k in query_subtokens_ids: 106 | input_ids.append(k) 107 | segment_ids.append(0) 108 | input_mask.append(1) 109 | labels_ids.append( 110 | dataset_with_context_pb2.OutputLabels.Value("_")) 111 | 112 | input_ids.extend(sep_token_id) 113 | segment_ids.append(0) 114 | input_mask.append(1) 115 | labels_ids.append(dataset_with_context_pb2.OutputLabels.Value("_")) 116 | 117 | for h, k in enumerate(program_subtokens_ids): 118 | input_ids.append(k) 119 | segment_ids.append(1) 120 | input_mask.append(1) 121 | labels_ids.append(labels[h]) 122 | 123 | group_item.input_ids.extend(input_ids) 124 | group_item.input_mask.extend(input_mask) 125 | group_item.segment_ids.extend(segment_ids) 126 | group_item.label_ids.extend(labels_ids) 127 | group_item.program_ids.extend(program_subtokens_ids) 128 | group_item.program_label_ids.extend(labels) 129 | 130 | assert len(group_item.program_ids) == len( 131 | group_item.program_label_ids) 132 | 133 | single_group.group_items.append(group_item) 134 | 135 | single_group.query_name_token_ids.extend(query_subtokens_ids) 136 | 137 | example_for_group_pred.examples.append(single_group) 138 | 139 | return example_for_group_pred 140 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/create_relevance_prediction_examples.py: -------------------------------------------------------------------------------- 1 | from create_span_prediction_training_examples import PREPAREVOCAB 2 | import dataset_with_context_pb2 3 | from transformers import RobertaTokenizer 4 | from tqdm import tqdm 5 | 6 | 7 | def create_relevance_prediction_examples(block_query_subtokens_labels_dataset, vocab_file: str, 8 | include_single_hop_examples, model_type): 9 | """ 10 | This function creates examples for relevance prediction. 11 | Args: 12 | block_query_subtokens_labels_dataset: BlockQuerySubtokensLabelsDataset protobuf 13 | vocab_file: model vocab file 14 | include_single_hop_examples: True/False 15 | model_type: cubert/codebert 16 | Returns: 17 | ExampleforRelevancePredictionDataset protobuf 18 | """ 19 | prepare_vocab_object = PREPAREVOCAB(vocab_file) 20 | codebert_tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") 21 | 22 | if(model_type == 'cubert'): 23 | cls_token = '[CLS]_' 24 | sep_token = '[SEP]_' 25 | cls_token_id = prepare_vocab_object.convert_by_vocab([cls_token]) 26 | sep_token_id = prepare_vocab_object.convert_by_vocab([sep_token]) 27 | elif(model_type == 'codebert'): 28 | cls_token = codebert_tokenizer.cls_token 29 | sep_token = codebert_tokenizer.sep_token 30 | cls_token_id = codebert_tokenizer.convert_tokens_to_ids([cls_token]) 31 | sep_token_id = codebert_tokenizer.convert_tokens_to_ids([sep_token]) 32 | 33 | example_for_model_dataset = dataset_with_context_pb2.ExampleforRelevancePredictionDataset() 34 | 35 | for i in tqdm(range(len( 36 | block_query_subtokens_labels_dataset.block_query_subtokens_labels_item)), 37 | desc="Relevance_dataset"): 38 | 39 | if include_single_hop_examples is False: 40 | if(block_query_subtokens_labels_dataset.block_query_subtokens_labels_item[i]. 41 | block_query_subtokens_labels_group_item[0].distributable == 1): 42 | continue 43 | 44 | for k in range(len( 45 | block_query_subtokens_labels_dataset.block_query_subtokens_labels_item[i]. 46 | block_query_subtokens_labels_group_item)): 47 | 48 | example_for_model_dataset_item = dataset_with_context_pb2.ExampleforRelevancePrediction() 49 | example_for_model_dataset_item.example_type = (block_query_subtokens_labels_dataset. 50 | example_types[i]) 51 | 52 | query_subtokens = [] 53 | query_subtokens.extend(block_query_subtokens_labels_dataset. 54 | block_query_subtokens_labels_item[i]. 55 | block_query_subtokens_labels_group_item[k]. 56 | query_name_subtokens) 57 | 58 | example_for_model_dataset_item.query_id = (block_query_subtokens_labels_dataset. 59 | block_query_subtokens_labels_item[i]. 60 | block_query_subtokens_labels_group_item[k]. 61 | query_id) 62 | 63 | example_for_model_dataset_item.block_id = (block_query_subtokens_labels_dataset. 64 | block_query_subtokens_labels_item[i]. 65 | block_query_subtokens_labels_group_item[k]. 66 | block.unique_block_id) 67 | 68 | example_for_model_dataset_item.relevance = (block_query_subtokens_labels_dataset. 69 | block_query_subtokens_labels_item[i]. 70 | block_query_subtokens_labels_group_item[k]. 71 | block.relevance_label) 72 | 73 | example_for_model_dataset_item.program_path = (block_query_subtokens_labels_dataset. 74 | block_query_subtokens_labels_item[i]. 75 | block_query_subtokens_labels_group_item[k]. 76 | raw_file.file_path.dataset_file_path.unique_file_path) 77 | 78 | program_subtokens = [] 79 | 80 | for j in range(len(block_query_subtokens_labels_dataset. 81 | block_query_subtokens_labels_item[i]. 82 | block_query_subtokens_labels_group_item[k]. 83 | block_subtokens_labels)): 84 | program_subtokens.append(block_query_subtokens_labels_dataset. 85 | block_query_subtokens_labels_item[i]. 86 | block_query_subtokens_labels_group_item[k]. 87 | block_subtokens_labels[j].program_subtoken) 88 | 89 | program_subtokens_ids = prepare_vocab_object.convert_by_vocab( 90 | program_subtokens) 91 | query_subtokens_ids = prepare_vocab_object.convert_by_vocab( 92 | query_subtokens) 93 | 94 | input_ids = [] 95 | input_mask = [] 96 | segment_ids = [] 97 | 98 | input_ids.extend(cls_token_id) 99 | segment_ids.append(0) 100 | input_mask.append(1) 101 | 102 | for k in query_subtokens_ids: 103 | input_ids.append(k) 104 | segment_ids.append(0) 105 | input_mask.append(1) 106 | 107 | input_ids.extend(sep_token_id) 108 | segment_ids.append(0) 109 | input_mask.append(1) 110 | 111 | for k in program_subtokens_ids: 112 | input_ids.append(k) 113 | segment_ids.append(1) 114 | input_mask.append(1) 115 | 116 | example_for_model_dataset_item.input_ids.extend(input_ids) 117 | example_for_model_dataset_item.input_mask.extend(input_mask) 118 | example_for_model_dataset_item.segment_ids.extend(segment_ids) 119 | 120 | example_for_model_dataset.block_relevance_example.append( 121 | example_for_model_dataset_item) 122 | return example_for_model_dataset 123 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/create_single_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | from importlib import import_module 4 | from subprocess import call 5 | 6 | sys.path.insert(0, "../data_ingestion/") 7 | raw_codeql_queryset_pb2 = import_module('raw_codeql_queryset_pb2') 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | flags.DEFINE_string( 12 | "save_location", 13 | None, 14 | "Path to store intermediate data and example" 15 | ) 16 | 17 | flags.DEFINE_string( 18 | "file_with_path", 19 | None, 20 | "Provide the path of the file which stores the of\ 21 | paths of the program files" 22 | ) 23 | 24 | flags.DEFINE_string( 25 | "raw_queries_protobuf_file", 26 | None, 27 | "Path to serialized raw queries protobuf file." 28 | ) 29 | 30 | flags.DEFINE_string( 31 | "data_path", 32 | None, 33 | "Path to all examples" 34 | ) 35 | 36 | flags.DEFINE_string( 37 | "result_file", 38 | None, 39 | "Path to CodeQL result csv" 40 | ) 41 | 42 | flags.DEFINE_string( 43 | "example_type", 44 | None, 45 | "positive/negative" 46 | ) 47 | 48 | flags.DEFINE_string( 49 | "aux_result_path", 50 | None, 51 | "positive/negative" 52 | ) 53 | 54 | flags.DEFINE_string( 55 | "model_type", 56 | 'cubert', 57 | "cubert/codebert" 58 | ) 59 | 60 | flags.DEFINE_string( 61 | "block_ordering", 62 | 'line_number', 63 | "line_number/random/--" 64 | ) 65 | 66 | flags.DEFINE_string( 67 | "vocab_file", 68 | None, 69 | "Cubert vocab file path, not reqd when model_type is CodeBERT" 70 | ) 71 | 72 | if __name__ == '__main__': 73 | argv = FLAGS(sys.argv) 74 | 75 | save_location = FLAGS.save_location 76 | result_file = FLAGS.result_file 77 | eg_type = FLAGS.example_type 78 | aux_result_path = FLAGS.aux_result_path 79 | model_type = FLAGS.model_type 80 | block_ordering = FLAGS.block_ordering 81 | vocab_file = FLAGS.vocab_file 82 | data_path = FLAGS.data_path 83 | path_file = FLAGS.file_with_path 84 | serialized_queries_path = FLAGS.raw_queries_protobuf_file 85 | 86 | serialized_src_file_path = save_location + '/source_file_serialized' 87 | 88 | call("python ../data_ingestion/run_create_raw_programs_dataset.py \ 89 | --data_source=other --source_name=pyeth150_open --split_name=TEST \ 90 | --dataset_programming_language=Python --programs_file_path=" + path_file 91 | + " --downloaded_dataset_location=" + data_path 92 | + " --save_dataset_location=" + serialized_src_file_path, shell=True) 93 | 94 | merged_query_results = save_location + '/merged_query_result' 95 | call("python run_create_query_result.py --raw_programs_protobuf_file=" 96 | + serialized_src_file_path + " --raw_queries_protobuf_file=" + serialized_queries_path 97 | + " --results_file=" + result_file + " --save_dataset_location=" + merged_query_results, shell=True) 98 | 99 | tokenized_src_file = save_location + '/tokenized_src_file' 100 | call("python run_create_tokenized_files_labels.py --merged_query_result_protobuf_file=" 101 | + merged_query_results + " --save_dataset_location=" + tokenized_src_file 102 | + " --positive_or_negative_examples=" + eg_type 103 | + " --number_of_dataset_splits=1", shell=True) 104 | 105 | tokenized_block_label = save_location + '/tokenized_block_label' 106 | call("python run_create_blocks_labels_dataset.py --tokenized_file_protobuf_file=" 107 | + tokenized_src_file + " --save_dataset_location=" + tokenized_block_label 108 | + " --aux_result_path=" + aux_result_path 109 | + " --number_of_previous_dataset_splits=1 --number_of_dataset_splits=1", shell=True) 110 | 111 | block_subtoken_labels = save_location + '/block_subtoken_labels' 112 | call("python run_create_block_subtokens_labels.py --model_type=" 113 | + model_type + " --ordering_of_blocks=" + block_ordering + " --vocab_file=" + vocab_file 114 | + " --tokenized_block_label_protobuf_file=" + tokenized_block_label 115 | + " --save_dataset_location=" + block_subtoken_labels 116 | + " --number_of_previous_dataset_splits=1 --number_of_dataset_splits=1", shell=True) 117 | 118 | model_example = save_location + '/model_example' 119 | call("python run_create_single_model_baseline_examples.py --model_type=" 120 | + model_type + " --block_subtoken_label_protobuf_file=" + block_subtoken_labels 121 | + " --vocab_file=" + vocab_file + " --save_dataset_location=" + model_example 122 | + " --number_of_dataset_splits=1", shell=True) 123 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/merge_negative_positive_examples.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | import dataset_with_context_pb2 4 | 5 | 6 | FLAGS = flags.FLAGS 7 | 8 | 9 | flags.DEFINE_string( 10 | "positive_examples_files", 11 | None, 12 | "Path to positive examples." 13 | ) 14 | 15 | flags.DEFINE_string( 16 | "negative_examples_files", 17 | None, 18 | "Path to negative examples." 19 | ) 20 | 21 | flags.DEFINE_integer( 22 | "number_of_dataset_splits", 23 | 4, 24 | "In how many splits to save the dataset." 25 | ) 26 | 27 | flags.DEFINE_string( 28 | "save_dataset_location", 29 | None, 30 | "Path to store the final block subtokens files." 31 | ) 32 | 33 | if __name__ == "__main__": 34 | argv = FLAGS(sys.argv) 35 | 36 | data = dataset_with_context_pb2.ExampleforSpanPredictionDataset() 37 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 38 | split_dataset = (dataset_with_context_pb2.ExampleforSpanPredictionDataset()) 39 | with open(FLAGS.positive_examples_files + str(i), "rb") as fd: 40 | split_dataset.ParseFromString(fd.read()) 41 | # join 42 | data.examples.extend( 43 | split_dataset.examples 44 | ) 45 | for i in range(len(data.examples)): 46 | assert data.examples[i].example_type == 1 47 | print(len(data.examples)) 48 | 49 | negative_data = dataset_with_context_pb2.ExampleforSpanPredictionDataset() 50 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 51 | split_dataset = (dataset_with_context_pb2.ExampleforSpanPredictionDataset()) 52 | with open(FLAGS.negative_examples_files + str(i), "rb") as fd: 53 | split_dataset.ParseFromString(fd.read()) 54 | # join 55 | negative_data.examples.extend( 56 | split_dataset.examples 57 | ) 58 | for i in range(len(negative_data.examples)): 59 | assert negative_data.examples[i].example_type == 0 60 | print(len(negative_data.examples)) 61 | 62 | data.examples.extend(negative_data.examples) 63 | print(len(data.examples)) 64 | 65 | # split the data 66 | dataset_len = len(data.examples) 67 | split_len = (dataset_len / FLAGS.number_of_dataset_splits) 68 | 69 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 70 | temp = dataset_with_context_pb2.ExampleforSpanPredictionDataset() 71 | 72 | lower = (i - 1) * split_len 73 | upper = (i) * split_len 74 | 75 | for j in range(int(lower), int(upper)): 76 | temp.examples.append( 77 | data.examples[j] 78 | ) 79 | 80 | with open(FLAGS.save_dataset_location + str(i), "wb") as fd: 81 | fd.write(temp.SerializeToString()) 82 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/my-languages.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/CodeQueries_preparation/data_preparation/my-languages.so -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_block_subtokens_labels.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | from tree_sitter import Language, Parser 4 | import dataset_with_context_pb2 5 | from create_block_subtokens_labels import (create_cubert_subtokens_labels, 6 | create_codebert_subtokens_labels) 7 | 8 | from graphcodebertutils import create_graphcodebert_dataflow_subtokens_labels 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | flags.DEFINE_string( 13 | "model_type", 14 | None, 15 | "'cubert' / 'codebert' / 'graphcodebert'" 16 | ) 17 | 18 | flags.DEFINE_string( 19 | "ordering_of_blocks", 20 | None, 21 | "What ordering of blocks to follow?\ 22 | ('default_proto'/'line_number'/'random')" 23 | ) 24 | 25 | flags.DEFINE_string( 26 | "vocab_file", 27 | None, 28 | "Path to Cubert vocabulary file." 29 | ) 30 | 31 | flags.DEFINE_string( 32 | "tokenized_block_label_protobuf_file", 33 | None, 34 | "Path to tokenized block label protobuf file." 35 | ) 36 | 37 | flags.DEFINE_integer( 38 | "number_of_previous_dataset_splits", 39 | 20, 40 | "In how many splits previous dataset was stored." 41 | ) 42 | 43 | flags.DEFINE_integer( 44 | "number_of_dataset_splits", 45 | 20, 46 | "In how many splits to save the dataset." 47 | ) 48 | 49 | flags.DEFINE_string( 50 | "save_dataset_location", 51 | None, 52 | "Path to store the Cubert examples dataset." 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | argv = FLAGS(sys.argv) 58 | 59 | tokenized_block_query_labels_dataset = (dataset_with_context_pb2. 60 | TokenizedBlockQueryLabelsDataset()) 61 | 62 | for i in range(1, FLAGS.number_of_previous_dataset_splits + 1): 63 | split_dataset = (dataset_with_context_pb2.TokenizedBlockQueryLabelsDataset()) 64 | with open(FLAGS.tokenized_block_label_protobuf_file + str(i), "rb") as fd: 65 | split_dataset.ParseFromString(fd.read()) 66 | # join 67 | tokenized_block_query_labels_dataset.tokenized_block_query_labels_item.extend( 68 | split_dataset.tokenized_block_query_labels_item 69 | ) 70 | tokenized_block_query_labels_dataset.example_types.extend( 71 | split_dataset.example_types 72 | ) 73 | 74 | if(FLAGS.model_type == 'cubert'): 75 | dataset = create_cubert_subtokens_labels(FLAGS.ordering_of_blocks, 76 | tokenized_block_query_labels_dataset, 77 | FLAGS.vocab_file) 78 | elif(FLAGS.model_type == 'codebert'): 79 | dataset = create_codebert_subtokens_labels(FLAGS.ordering_of_blocks, 80 | tokenized_block_query_labels_dataset) 81 | elif(FLAGS.model_type == 'graphcodebert'): 82 | PY_LANGUAGE = Language("./my-languages.so", "python") 83 | tree_sitter_parser = Parser() 84 | tree_sitter_parser.set_language(PY_LANGUAGE) 85 | 86 | dataset = create_graphcodebert_dataflow_subtokens_labels(FLAGS.ordering_of_blocks, 87 | tokenized_block_query_labels_dataset, 88 | tree_sitter_parser) 89 | 90 | # split the data 91 | dataset_len = len(dataset.block_query_subtokens_labels_item) 92 | split_len = (dataset_len / FLAGS.number_of_dataset_splits) 93 | 94 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 95 | temp = dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset() 96 | 97 | lower = (i - 1) * split_len 98 | upper = (i) * split_len 99 | 100 | for j in range(int(lower), int(upper)): 101 | temp.block_query_subtokens_labels_item.append( 102 | dataset.block_query_subtokens_labels_item[j] 103 | ) 104 | temp.example_types.append( 105 | dataset.example_types[j] 106 | ) 107 | 108 | with open(FLAGS.save_dataset_location + str(i), "wb") as fd: 109 | fd.write(temp.SerializeToString()) 110 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_blocks_labels_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pandas as pd 3 | from absl import flags 4 | from tree_sitter import Language, Parser 5 | import dataset_with_context_pb2 6 | from create_blocks_labels_dataset import create_blocks_labels_dataset 7 | from create_blocks_relevance_labels_dataset import create_blocks_relevance_labels_dataset 8 | from contexts.get_context import __columns__ 9 | 10 | PATH_PREFIX = "../code-cubert/data_preparation/" 11 | FLAGS = flags.FLAGS 12 | 13 | flags.DEFINE_string( 14 | "tokenized_file_protobuf_file", 15 | None, 16 | "Path to tokenized source file protobuf file." 17 | ) 18 | 19 | flags.DEFINE_integer( 20 | "number_of_previous_dataset_splits", 21 | 20, 22 | "In how many splits previous dataset was stored." 23 | ) 24 | 25 | flags.DEFINE_integer( 26 | "number_of_dataset_splits", 27 | 20, 28 | "In how many splits to save the dataset." 29 | ) 30 | 31 | flags.DEFINE_string( 32 | "save_dataset_location", 33 | None, 34 | "Path to store the Cubert examples dataset." 35 | ) 36 | 37 | flags.DEFINE_string( 38 | "aux_result_path", 39 | None, 40 | "Path to store the Cubert examples dataset." 41 | ) 42 | 43 | flags.DEFINE_string( 44 | "with_header", 45 | 'yes', 46 | "yes/no" 47 | ) 48 | 49 | flags.DEFINE_string( 50 | "with_simplified_relevance", 51 | None, 52 | "yes/no" 53 | ) 54 | 55 | flags.DEFINE_string( 56 | "only_relevant_blocks", 57 | None, 58 | "yes/no - whether to get only relevant blocks or all blocks" 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | argv = FLAGS(sys.argv) 64 | PY_LANGUAGE = Language("./my-languages.so", "python") 65 | # PY_LANGUAGE = Language(PATH_PREFIX + "my-languages.so", "python") 66 | 67 | tree_sitter_parser = Parser() 68 | tree_sitter_parser.set_language(PY_LANGUAGE) 69 | 70 | tokenized_files_with_labels = dataset_with_context_pb2.TokenizedQueryProgramLabelsDataset() 71 | for i in range(1, FLAGS.number_of_previous_dataset_splits + 1): 72 | split_dataset = (dataset_with_context_pb2.TokenizedQueryProgramLabelsDataset()) 73 | with open(FLAGS.tokenized_file_protobuf_file + str(i), "rb") as fd: 74 | split_dataset.ParseFromString(fd.read()) 75 | # join 76 | tokenized_files_with_labels.tokens_and_labels.extend( 77 | split_dataset.tokens_and_labels 78 | ) 79 | if(i == 1): 80 | tokenized_files_with_labels.example_type = split_dataset.example_type 81 | else: 82 | assert tokenized_files_with_labels.example_type == split_dataset.example_type 83 | 84 | keep_header = (FLAGS.with_header == 'yes') 85 | aux_result_df = pd.read_csv(FLAGS.aux_result_path, 86 | names=__columns__) 87 | if(FLAGS.with_simplified_relevance == 'yes'): 88 | tokenized_block_query_labels_dataset = create_blocks_relevance_labels_dataset( 89 | tokenized_files_with_labels, tree_sitter_parser, 90 | keep_header, aux_result_df, FLAGS.only_relevant_blocks 91 | ) 92 | else: 93 | tokenized_block_query_labels_dataset = create_blocks_labels_dataset( 94 | tokenized_files_with_labels, tree_sitter_parser, 95 | keep_header, aux_result_df 96 | ) 97 | 98 | # split the data 99 | dataset_len = len(tokenized_block_query_labels_dataset.tokenized_block_query_labels_item) 100 | split_len = (dataset_len / FLAGS.number_of_dataset_splits) 101 | 102 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 103 | temp = dataset_with_context_pb2.TokenizedBlockQueryLabelsDataset() 104 | 105 | lower = (i - 1) * split_len 106 | upper = (i) * split_len 107 | 108 | for j in range(int(lower), int(upper)): 109 | temp.tokenized_block_query_labels_item.append( 110 | tokenized_block_query_labels_dataset.tokenized_block_query_labels_item[j] 111 | ) 112 | temp.example_types.append( 113 | tokenized_block_query_labels_dataset.example_types[j] 114 | ) 115 | 116 | with open(FLAGS.save_dataset_location + str(i), "wb") as fd: 117 | fd.write(temp.SerializeToString()) 118 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_groupwise_prediction_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | import dataset_with_context_pb2 4 | from create_groupwise_prediction_dataset import create_groupwise_prediction_dataset 5 | 6 | PATH_PREFIX = "../code-cubert/data_preparation/" 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_string( 10 | "model_type", 11 | None, 12 | "'cubert' / 'codebert'" 13 | ) 14 | 15 | flags.DEFINE_string( 16 | "block_subtoken_label_protobuf_file", 17 | None, 18 | "Path to tokenized block subtoken label protobuf file." 19 | ) 20 | 21 | flags.DEFINE_integer( 22 | "number_of_previous_dataset_splits", 23 | 20, 24 | "In how many splits previous dataset was stored." 25 | ) 26 | 27 | flags.DEFINE_integer( 28 | "number_of_dataset_splits", 29 | 4, 30 | "In how many splits to save the dataset." 31 | ) 32 | 33 | flags.DEFINE_string( 34 | "vocab_file", 35 | None, 36 | "Path to cubert vocabulary file." 37 | ) 38 | 39 | flags.DEFINE_string( 40 | "save_dataset_location", 41 | None, 42 | "Path to store the cubert/codebert examples dataset." 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | argv = FLAGS(sys.argv) 48 | 49 | dataset = (dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset()) 50 | 51 | for i in range(1, FLAGS.number_of_previous_dataset_splits + 1): 52 | split_dataset = (dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset()) 53 | with open(FLAGS.block_subtoken_label_protobuf_file + str(i), "rb") as fd: 54 | split_dataset.ParseFromString(fd.read()) 55 | # join 56 | dataset.block_query_subtokens_labels_item.extend( 57 | split_dataset.block_query_subtokens_labels_item 58 | ) 59 | dataset.example_types.extend( 60 | split_dataset.example_types 61 | ) 62 | 63 | examples_for_group_pred = create_groupwise_prediction_dataset( 64 | dataset, FLAGS.vocab_file, FLAGS.model_type 65 | ) 66 | 67 | # split the data 68 | dataset_len = len(examples_for_group_pred.examples) 69 | split_len = (dataset_len / FLAGS.number_of_dataset_splits) 70 | 71 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 72 | temp = dataset_with_context_pb2.ExampleForGroupwisePredictionDataset() 73 | 74 | lower = (i - 1) * split_len 75 | upper = (i) * split_len 76 | 77 | for j in range(int(lower), int(upper)): 78 | temp.examples.append( 79 | examples_for_group_pred.examples[j] 80 | ) 81 | 82 | with open(FLAGS.save_dataset_location + str(i), "wb") as fd: 83 | fd.write(temp.SerializeToString()) 84 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_query_result.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | import dataset_with_context_pb2 4 | from create_query_result import create_query_result 5 | from create_query_result import create_query_result_merged 6 | 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_string( 10 | "raw_programs_protobuf_file", 11 | None, 12 | "Path to serialized raw programs protobuf file." 13 | ) 14 | 15 | flags.DEFINE_string( 16 | "raw_queries_protobuf_file", 17 | None, 18 | "Path to serialized raw queries protobuf file." 19 | ) 20 | 21 | flags.DEFINE_string( 22 | "results_file", 23 | None, 24 | "Path to CodeQL analysis results file." 25 | ) 26 | 27 | flags.DEFINE_string( 28 | "save_dataset_location", 29 | None, 30 | "Path to store the merged query results dataset." 31 | ) 32 | 33 | flags.DEFINE_string( 34 | "positive_or_negative_examples", 35 | None, 36 | "Are the examples positive or negative? ('positive'/'negative')" 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | argv = FLAGS(sys.argv) 42 | 43 | raw_programs = dataset_with_context_pb2.RawProgramDataset() 44 | raw_queries = dataset_with_context_pb2.RawQueryList() 45 | 46 | with open(FLAGS.raw_programs_protobuf_file, "rb") as fd: 47 | raw_programs.ParseFromString(fd.read()) 48 | 49 | with open(FLAGS.raw_queries_protobuf_file, "rb") as fd: 50 | raw_queries.ParseFromString(fd.read()) 51 | 52 | dataset = create_query_result( 53 | raw_programs, raw_queries, 54 | FLAGS.results_file, FLAGS.positive_or_negative_examples) 55 | 56 | dataset_merged = create_query_result_merged(dataset) 57 | 58 | with open(FLAGS.save_dataset_location, "wb") as fd: 59 | fd.write(dataset_merged.SerializeToString()) 60 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_relevance_prediction_examples.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | import dataset_with_context_pb2 4 | from create_relevance_prediction_examples import create_relevance_prediction_examples 5 | 6 | PATH_PREFIX = "../code-cubert/data_preparation/" 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_string( 10 | "model_type", 11 | None, 12 | "'cubert' / 'codebert'" 13 | ) 14 | 15 | flags.DEFINE_string( 16 | "block_subtoken_label_protobuf_file", 17 | None, 18 | "Path to tokenized block subtoken label protobuf file." 19 | ) 20 | 21 | flags.DEFINE_integer( 22 | "number_of_dataset_splits", 23 | 20, 24 | "In how many splits to save the dataset." 25 | ) 26 | 27 | flags.DEFINE_string( 28 | "vocab_file", 29 | None, 30 | "Path to Cubert vocabulary file." 31 | ) 32 | 33 | flags.DEFINE_string( 34 | "include_single_hop_examples", 35 | None, 36 | "If single hop examples should be included (yes/no)." 37 | ) 38 | 39 | flags.DEFINE_string( 40 | "save_dataset_location", 41 | None, 42 | "Path to store the Cubert examples dataset." 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | argv = FLAGS(sys.argv) 48 | 49 | dataset = (dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset()) 50 | 51 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 52 | split_dataset = (dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset()) 53 | with open(FLAGS.block_subtoken_label_protobuf_file + str(i), "rb") as fd: 54 | split_dataset.ParseFromString(fd.read()) 55 | # join 56 | dataset.block_query_subtokens_labels_item.extend( 57 | split_dataset.block_query_subtokens_labels_item 58 | ) 59 | dataset.example_types.extend( 60 | split_dataset.example_types 61 | ) 62 | 63 | if(FLAGS.include_single_hop_examples == "yes"): 64 | hop = True 65 | else: 66 | hop = False 67 | 68 | examples_for_relevance_detection = create_relevance_prediction_examples( 69 | dataset, FLAGS.vocab_file, hop, FLAGS.model_type 70 | ) 71 | 72 | with open(FLAGS.save_dataset_location, "wb") as f: 73 | f.write(examples_for_relevance_detection.SerializeToString()) 74 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_span_prediction_training_examples.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | import dataset_with_context_pb2 4 | from create_span_prediction_training_examples import create_span_prediction_training_examples 5 | 6 | PATH_PREFIX = "../code-cubert/data_preparation/" 7 | FLAGS = flags.FLAGS 8 | 9 | flags.DEFINE_string( 10 | "model_type", 11 | None, 12 | "'cubert' / 'codebert'" 13 | ) 14 | 15 | flags.DEFINE_string( 16 | "block_subtoken_label_protobuf_file", 17 | None, 18 | "Path to tokenized block subtoken label protobuf file." 19 | ) 20 | 21 | flags.DEFINE_integer( 22 | "number_of_dataset_splits", 23 | 20, 24 | "Into how many splits previous dataset was split." 25 | ) 26 | 27 | flags.DEFINE_string( 28 | "vocab_file", 29 | None, 30 | "Path to Cubert vocabulary file." 31 | ) 32 | 33 | flags.DEFINE_string( 34 | "save_dataset_location", 35 | None, 36 | "Path to store the Cubert examples dataset." 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | argv = FLAGS(sys.argv) 42 | 43 | dataset = (dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset()) 44 | 45 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 46 | split_dataset = (dataset_with_context_pb2.BlockQuerySubtokensLabelsDataset()) 47 | with open(FLAGS.block_subtoken_label_protobuf_file + str(i), "rb") as fd: 48 | split_dataset.ParseFromString(fd.read()) 49 | # join 50 | dataset.block_query_subtokens_labels_item.extend( 51 | split_dataset.block_query_subtokens_labels_item 52 | ) 53 | dataset.example_types.extend( 54 | split_dataset.example_types 55 | ) 56 | 57 | examples_for_span_pred = create_span_prediction_training_examples( 58 | dataset, FLAGS.vocab_file, FLAGS.model_type 59 | ) 60 | 61 | with open(FLAGS.save_dataset_location, "wb") as f: 62 | f.write(examples_for_span_pred.SerializeToString()) 63 | -------------------------------------------------------------------------------- /CodeQueries_preparation/data_preparation/run_create_tokenized_files_labels.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from absl import flags 3 | import dataset_with_context_pb2 4 | from create_tokenized_files_labels import create_tokenized_files_labels 5 | 6 | FLAGS = flags.FLAGS 7 | 8 | flags.DEFINE_string( 9 | "merged_query_result_protobuf_file", 10 | None, 11 | "Path to serialized query results protobuf file." 12 | ) 13 | 14 | flags.DEFINE_integer( 15 | "number_of_dataset_splits", 16 | 20, 17 | "In how many splits to save the dataset." 18 | ) 19 | 20 | flags.DEFINE_string( 21 | "save_dataset_location", 22 | None, 23 | "Path to store the Cubert examples dataset." 24 | ) 25 | 26 | flags.DEFINE_string( 27 | "positive_or_negative_examples", 28 | None, 29 | "Are the examples positive or negative? ('positive'/'negative')" 30 | ) 31 | 32 | 33 | if __name__ == "__main__": 34 | argv = FLAGS(sys.argv) 35 | 36 | dataset_merged = dataset_with_context_pb2.RawMergedResultDataset() 37 | 38 | with open(FLAGS.merged_query_result_protobuf_file, "rb") as fd: 39 | dataset_merged.ParseFromString(fd.read()) 40 | 41 | tokenized_files_with_labels = create_tokenized_files_labels( 42 | dataset_merged, FLAGS.positive_or_negative_examples) 43 | 44 | # split the data 45 | dataset_len = len(tokenized_files_with_labels.tokens_and_labels) 46 | split_len = (dataset_len / FLAGS.number_of_dataset_splits) 47 | 48 | for i in range(1, FLAGS.number_of_dataset_splits + 1): 49 | temp = dataset_with_context_pb2.TokenizedQueryProgramLabelsDataset() 50 | 51 | lower = (i - 1) * split_len 52 | upper = (i) * split_len 53 | 54 | for j in range(int(lower), int(upper)): 55 | temp.tokens_and_labels.append( 56 | tokenized_files_with_labels.tokens_and_labels[j] 57 | ) 58 | temp.example_type = (dataset_with_context_pb2. 59 | ExampleType.Value(FLAGS.positive_or_negative_examples)) 60 | 61 | with open(FLAGS.save_dataset_location + str(i), "wb") as fd: 62 | fd.write(temp.SerializeToString()) 63 | -------------------------------------------------------------------------------- /Codequeries_Statistics.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/Codequeries_Statistics.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### The work is accepted in `ISEC 2024` and is now available in the [ACM Digital Library](https://dl.acm.org/doi/10.1145/3641399.3641408). 2 | 3 | # CodeQueries Benchmark 4 | 5 | CodeQueries is a dataset to evaluate various methodologies on answering semantic queries over code. Existing datasets for question-answering in the context of programming languages target comparatively simpler tasks of predicting binary yes/no answers to a question or range over a localized context (e.g., a source-code method). In contrast, in CodeQueries, a source-code file is annotated with the required spans for a code analysis query about semantic aspects of code. Given a query and code, a `Span Predictor` system is expected to identify answer and supporting-fact spans in the code for the query. 6 | 7 |

8 | CodeQueries task definition 9 |

10 | 11 | 12 | The dataset statistics and some additional results are provided in the [Codequeries_Statistics](https://github.com/thepurpleowl/codequeries-benchmark/blob/main/Codequeries_Statistics.pdf) file. [Here](https://github.com/thepurpleowl/codequeries-benchmark/blob/main/using_CodeQueries.ipynb) is the starter code to get started with the dataset. 13 | 14 | More details on the curated dataset for this benchmark are available on [HuggingFace](https://huggingface.co/datasets/thepurpleowl/codequeries). 15 | 16 | ### Steps 17 | ----------- 18 | The repo provides scripts to evaluate the dataset for LLM generations and in a two-step setup. Follow the steps to use the scripts - 19 | 1. Clone the repo in a virtual environment. 20 | 2. Run `setup.sh` to setup the workspace. 21 | 3. Run the following commands to get performance metric values. 22 | 23 | 24 | #### LLM experiment evaluation 25 | ----------- 26 | We have used the GPT3.5-Turbo model from OpenAI with different prompt templates (provided at `/prompt_templates`) to generate required answer and supporting-fact spans for a query. We generate 10 samples for each input and the generated results downloaded as a part of setup. Following scripts can be used to evaluate the LLM results with diffrerent prompts. 27 | To evaluate zero-shot prompt, 28 |     `python evaluate_generated_spans.py --g=test_dir_file_0shot/logs` 29 | To evaluate few-shot prompt with BM25 retrieval, 30 |     `python evaluate_generated_spans.py --g=test_dir_file_fewshot/logs` 31 | To evaluate few-shot prompt with supporting facts, 32 |     `python evaluate_generated_spans.py --g=test_dir_file_fewshot_sf/logs --with_sf=True` 33 | 34 | #### Two-step setup evaluation 35 | ----------- 36 | In many cases, the entire file contents do not fit in the input to the model. However, not all code is relevant for answering a given query. We identify the relevant code blocks using the CodeQL results during data preparation and implement a two-step procedure to deal with the problem of scaling to large-size code: 37 |     Step 1: We first apply a relevance classifier to every block in the given code and select code blocks that are likely to be relevant for answering a given query. 38 |     Step 2: We then apply the span prediction model to the set of selected code blocks to predict answer and supporting-fact spans. 39 | 40 | To evaluate the two-step setup, run 41 | `python3 evaluate_spanprediction.py --example_types_to_evaluate= --setting=twostep --span_type= --span_model_checkpoint_path= --relevance_model_checkpoint_path=` 42 | 43 | 44 | ### Experiment results on sampled test data 45 | ----------- 46 | #### LLM experiment 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 |
Zero-shot prompting
(Answer span prediction)
Few-shot prompting with BM25 retrieval
(Answer span prediction)
Few-shot prompting with supporting fact
(Answer & supporting-fact span prediction)
Pass@kPositiveNegativePositiveNegativePositive
19.8212.8316.4544.2521.88
213.0617.4221.1455.5328.06
517.4722.8527.6965.4334.94
1020.8426.7732.6670.039.08
99 | 100 | #### Two-step setup 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 |
Answer span predictionAnswer & supporting-fact span prediction
VariantPositiveNegativePositive
Two-step(20, 20)9.4292.138.42
Two-step(all, 20)15.03 94.4913.27
Two-step(20, all)32.8796.2630.66
Two-step(all, all)51.9095.6749.30
142 | 143 | 144 | ### Experiment results on complete test data 145 | ----------- 146 | | Variants | Positive | Negative | 147 | |-------------------------------|--------------|--------------| 148 | | Two-step(20, 20) | 3.74 | 95.54 | 149 | | Two-step(all, 20) | 7.81 | 97.87 | 150 | | Two-step(20, all) | 33.41 | 96.23 | 151 | | Two-step(all, all) | 52.61 | 96.73 | 152 | | Prefix | 36.60 | 93.80 | 153 | | Sliding window | 51.91 | 85.75 | 154 | -------------------------------------------------------------------------------- /evaluate_relevance.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import datasets 4 | from absl import flags 5 | import pickle 6 | import os 7 | from pathlib import Path 8 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 9 | from utils import Relevance_Classification_Model, DEVICE 10 | from utils import get_relevance_dataloader_input, get_relevance_dataloader, eval_fn_relevance_prediction 11 | 12 | FLAGS = flags.FLAGS 13 | 14 | 15 | flags.DEFINE_string( 16 | "vocab_file", 17 | "pretrained_model_configs/vocab.txt", 18 | "Path to Cubert vocabulary file." 19 | ) 20 | 21 | flags.DEFINE_string( 22 | "model_checkpoint_path", 23 | "finetuned_ckpts/Twostep_Relevance-512", 24 | "Path to relevance model checkpoint." 25 | ) 26 | 27 | flags.DEFINE_string( 28 | "data", 29 | "all", 30 | "all/sampled" 31 | ) 32 | 33 | 34 | def get_relevance_model_performance(vocab_file, data): 35 | if data == all: 36 | examples_data = datasets.load_dataset("thepurpleowl/codequeries", "twostep", 37 | split=datasets.Split.TEST, use_auth_token=True) 38 | else: 39 | with open('resources/twostep_TEST.pkl', 'rb') as f: 40 | examples_data = pickle.load(f) 41 | # evaluation 42 | model = Relevance_Classification_Model() 43 | model.to(DEVICE) 44 | model.load_state_dict(torch.load(FLAGS.model_checkpoint_path, map_location=DEVICE)) 45 | 46 | model_input_ids, model_segment_ids, model_input_mask, model_labels_ids = get_relevance_dataloader_input(examples_data, 47 | vocab_file, False) 48 | eval_relevance_data_loader, _ = get_relevance_dataloader( 49 | model_input_ids, 50 | model_input_mask, 51 | model_segment_ids, 52 | model_labels_ids 53 | ) 54 | 55 | eval_relevance_out, eval_relevance_targets, _ = eval_fn_relevance_prediction( 56 | eval_relevance_data_loader, model, DEVICE, False) 57 | 58 | assert len(eval_relevance_targets) == len(eval_relevance_out) 59 | 60 | store_path = "analyses/relevance" 61 | if not Path(store_path).exists(): 62 | os.makedirs(store_path) 63 | with open(f'{store_path}/{data}_eval_relevance_targets.pkl', 'wb') as f: 64 | pickle.dump(eval_relevance_targets, f) 65 | with open(f'{store_path}/{data}_eval_relevance_out.pkl', 'wb') as f: 66 | pickle.dump(eval_relevance_out, f) 67 | scores = precision_recall_fscore_support(eval_relevance_targets, eval_relevance_out) 68 | relevance_accuracy = accuracy_score(eval_relevance_targets, eval_relevance_out) 69 | print(scores) 70 | print("Accuracy: ", relevance_accuracy, "Precision: ", scores[0][1], ", Recall: ", scores[1][1]) 71 | 72 | 73 | if __name__ == "__main__": 74 | argv = FLAGS(sys.argv) 75 | get_relevance_model_performance(FLAGS.vocab_file, FLAGS.data) 76 | -------------------------------------------------------------------------------- /figures/QA_Task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/figures/QA_Task.png -------------------------------------------------------------------------------- /get_sampled_data.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import pickle 3 | import json 4 | from tqdm import tqdm 5 | import datasets 6 | import tiktoken 7 | import random 8 | random.seed(42) 9 | 10 | TOKENIZER_MODEL_ALIAS_MAP = { 11 | "gpt-4": "gpt-3.5-turbo", 12 | "gpt-35-turbo": "gpt-3.5-turbo", 13 | } 14 | query_data: dict = json.load(open("resources/codequeries_meta.json", "r")) 15 | reccos_dict: dict = {q["name"]: q["reccomendation"] for q in query_data} 16 | all_queries = reccos_dict.keys() 17 | 18 | 19 | def count_file_tokens(file_path: str, model_name: str = "gpt-35-turbo"): 20 | with open('/home/t-susahu/CodeQueries/data' + f'/{file_path}', 'r') as f: 21 | input_str = f.read() 22 | if model_name in TOKENIZER_MODEL_ALIAS_MAP: 23 | model_name = TOKENIZER_MODEL_ALIAS_MAP[model_name] 24 | encoding = tiktoken.encoding_for_model(model_name) 25 | num_tokens = len(encoding.encode(input_str)) 26 | return {'file_tokens': num_tokens} 27 | 28 | 29 | # To get file which can be fit into prompt, data from which `sampled test data` is created 30 | def generate_querywise_test_samples(): 31 | dataset = datasets.load_dataset("thepurpleowl/codequeries", "ideal", split=datasets.Split.TEST) 32 | partitioned_data_all = {query['name']: dataset.map(lambda x: count_file_tokens(x['code_file_path'])).filter(lambda x: x["query_name"] == query['name'] and x["file_tokens"] < 3000) for query in query_data} 33 | 34 | with open('resources/partitioned_data_all.pkl', 'wb') as f: 35 | pickle.dump(partitioned_data_all, f) 36 | 37 | 38 | # To get sampled 20 files for twostep train 39 | def generate_querywise_train_all_samples(): 40 | dataset = datasets.load_dataset("thepurpleowl/codequeries", "ideal", split=datasets.Split.TRAIN) 41 | partitioned_data_train_all = {query['name']: dataset.filter(lambda x: x["query_name"] == query['name']) for query in query_data} 42 | 43 | with open('resources/partitioned_data_train_all.pkl', 'wb') as f: 44 | pickle.dump(partitioned_data_train_all, f) 45 | 46 | 47 | # To get train files to be used as examples with LLM prompting. In case you want run LLM experiment with supporting fact prompting, you need to run this. 48 | def generate_querywise_train_samples(): 49 | dataset = datasets.load_dataset("thepurpleowl/codequeries", "ideal", split=datasets.Split.TRAIN) 50 | partitioned_data_train_1000 = {query['name']: dataset.map(lambda x: count_file_tokens(x['code_file_path'])).filter(lambda x: x["query_name"] == query['name'] and x["file_tokens"] < 1000) for query in query_data} 51 | 52 | with open('resources/partitioned_data_train_1000.pkl', 'wb') as f: 53 | pickle.dump(partitioned_data_train_1000, f) 54 | 55 | 56 | # To get how many files after sampling 57 | def get_querywise_test_stat(partitioned_data_path): 58 | dataset = datasets.load_dataset("thepurpleowl/codequeries", "ideal", split=datasets.Split.TEST) 59 | with open(partitioned_data_path, 'rb') as f: 60 | partitioned_data_all = pickle.load(f) 61 | 62 | all_files = dataset['code_file_path'] 63 | total = 0 64 | for query in all_queries: 65 | total += partitioned_data_all[query].shape[0] 66 | for ff in partitioned_data_all[query]['code_file_path']: 67 | assert ff in all_files 68 | 69 | print(dataset.shape[0], total) 70 | 71 | 72 | def sample_data(query_data, split, s=10): 73 | if split == 'train': 74 | pos_files = set(query_data.filter(lambda x: x["example_type"] == 1)['code_file_path']) 75 | neg_files = set(query_data.filter(lambda x: x["example_type"] == 0)['code_file_path']) 76 | else: 77 | pos_files = set(query_data.filter(lambda x: x["example_type"] == 1 and x['file_tokens'] <= 2000)['code_file_path']) 78 | neg_files = set(query_data.filter(lambda x: x["example_type"] == 0 and x['file_tokens'] <= 2000)['code_file_path']) 79 | assert len(pos_files.intersection(neg_files)) == 0 80 | 81 | all_files = random.sample(pos_files, min(s, len(pos_files))) + random.sample(neg_files, min(s, len(neg_files))) 82 | 83 | # get answer and sf spans 84 | metadata_with_spans = {} 85 | for ff in all_files: 86 | ans_spans = [] 87 | sf_spans = [] 88 | file_data = query_data.filter(lambda x: x['code_file_path'] == ff and x["example_type"] == 1) 89 | for row in file_data: 90 | ans_spans += row['answer_spans'] 91 | sf_spans += row['supporting_fact_spans'] 92 | metadata_with_spans[ff] = {'ans_spans': ans_spans, 'sf_spans': sf_spans} 93 | if ff in neg_files: 94 | assert not ans_spans and not sf_spans 95 | 96 | return metadata_with_spans 97 | 98 | 99 | # From partitioned data, get metadata for `test sampled data` and sampled max 20 train files for two-step. 100 | # Output for this function is provided, which can be directly used with `get_instance_foe_files`. 101 | def get_sampled_file_metadata(partitioned_data_path, output_path, split='train'): 102 | with open(partitioned_data_path, 'rb') as f: 103 | partitioned_data = pickle.load(f) 104 | 105 | sampled_partitioned_data = {} 106 | for query in tqdm(all_queries): 107 | sampled_partitioned_data[query] = sample_data(partitioned_data[query], split) 108 | 109 | with open(output_path, 'wb') as f: 110 | pickle.dump(sampled_partitioned_data, f) 111 | 112 | 113 | # To get the twostep data for `sampled test data` or sampled max 20 train files 114 | def get_instances_for_files(sampled_data_path, dataset_setting, split): 115 | if split == 'TEST': 116 | dataset = datasets.load_dataset("thepurpleowl/codequeries", dataset_setting, split=datasets.Split.TEST) 117 | else: 118 | dataset = datasets.load_dataset("thepurpleowl/codequeries", dataset_setting, split=datasets.Split.TRAIN) 119 | with open(sampled_data_path, 'rb') as f: 120 | sampled_data = pickle.load(f) 121 | 122 | filtered_data = dataset.filter(lambda x: x['code_file_path'] in list(sampled_data[x["query_name"]].keys())) 123 | with open(f'resources/{dataset_setting}_{split}.pkl', 'wb') as f: 124 | pickle.dump(filtered_data, f) 125 | 126 | 127 | if __name__ == '__main__': 128 | # get_sampled_file_metadata('resources/partitioned_data_all.pkl', 'resources/sampled_test_data.pkl', 'test') 129 | get_instances_for_files('resources/sampled_test_data.pkl', 'twostep', 'TEST') 130 | -------------------------------------------------------------------------------- /pretrained_model_configs/README.md: -------------------------------------------------------------------------------- 1 | This directory contains CuBERT vocab and CuBERT checkpoint configs for input limit length of 512 and 1024. -------------------------------------------------------------------------------- /pretrained_model_configs/config_1024.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 4096, 8 | "num_attention_heads": 16, 9 | "num_hidden_layers": 24, 10 | "type_vocab_size": 2, 11 | "vocab_size": 49558, 12 | "max_position_embeddings": 1024 13 | } -------------------------------------------------------------------------------- /pretrained_model_configs/config_512.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "hidden_act": "gelu", 4 | "hidden_dropout_prob": 0.1, 5 | "hidden_size": 1024, 6 | "initializer_range": 0.02, 7 | "intermediate_size": 4096, 8 | "num_attention_heads": 16, 9 | "num_hidden_layers": 24, 10 | "type_vocab_size": 2, 11 | "vocab_size": 49558, 12 | "max_position_embeddings": 512 13 | } -------------------------------------------------------------------------------- /prompt_templates/ex_with_sf.j2: -------------------------------------------------------------------------------- 1 | Example code snippet with answer span(s) matching the query description with supporting fact span(s) 2 | ```python 3 | {{ positive_context }} 4 | ``` 5 | 6 | 7 | Answer span(s) 8 | ```python 9 | {% for span in positive_spans %} 10 | {{span}} 11 | {% endfor %} 12 | ``` 13 | 14 | 15 | Supporting fact span(s) 16 | ```python 17 | {% for span in supporting_fact_spans %} 18 | {{span}} 19 | {% endfor %} 20 | ```END -------------------------------------------------------------------------------- /prompt_templates/ex_wo_sf.j2: -------------------------------------------------------------------------------- 1 | Example code snippet with answer span(s) matching the query description but without supporting fact span(s) 2 | ```python 3 | {{ positive_context }} 4 | ``` 5 | 6 | 7 | Answer span(s) 8 | ```python 9 | {% for span in positive_spans %} 10 | {{span}} 11 | {% endfor %} 12 | ``` 13 | 14 | 15 | Supporting fact span(s) 16 | ```python 17 | N/A 18 | ```END -------------------------------------------------------------------------------- /prompt_templates/span_highlight_0shot.j2: -------------------------------------------------------------------------------- 1 | You are an expert software developer. Please help identify the results of evaluating the CodeQL query titled "{{ query_name }}" on a code snippet. The results should be given as code spans or fragments (if any) from the code snippet. The description of the CodeQL query "{{ query_name }}" is - {{ description }} 2 | 3 | 4 | If there are spans that match the query description, print them out one per line. If no spans matching the query description are present, say N/A. 5 | 6 | 7 | Code snippet 8 | ```python 9 | {{ input_code }} 10 | ``` 11 | 12 | 13 | Code span(s) 14 | ```python 15 | -------------------------------------------------------------------------------- /prompt_templates/span_highlight_fewshot.j2: -------------------------------------------------------------------------------- 1 | You are an expert software developer. Please help identify the results of evaluating the CodeQL query titled "{{ query_name }}" on a code snippet. The results should be given as code spans or fragments (if any) from the code snippet. The description of the CodeQL query "{{ query_name }}" is - {{ description }} 2 | 3 | 4 | If there are spans that match the query description, print them out one per line. If no spans matching the query description are present, say N/A. 5 | 6 | 7 | The following are some examples of code snippets with and without spans matching the query description. 8 | Example code snippet with span(s) matching the query description 9 | ```python 10 | {{ positive_context }} 11 | ``` 12 | 13 | 14 | Code span(s) 15 | ```python 16 | {% for span in positive_spans %} 17 | {{span}} 18 | {% endfor %} 19 | ``` 20 | 21 | 22 | Example code snippet with no span(s) matching the query description 23 | ```python 24 | {{ negative_context }} 25 | ``` 26 | 27 | 28 | Code span(s) 29 | ```python 30 | N/A 31 | ``` 32 | 33 | 34 | Code snippet 35 | ```python 36 | {{ input_code }} 37 | ``` 38 | 39 | 40 | Code span(s) 41 | ```python 42 | -------------------------------------------------------------------------------- /prompt_templates/span_highlight_fewshot_sf.j2: -------------------------------------------------------------------------------- 1 | You are an expert software developer. Please help identify the results of evaluating the CodeQL query titled "{{ query_name }}" on a code snippet. The results should be given as code spans or fragments (if any) from the code snippet. The description of the CodeQL query "{{ query_name }}" is - {{ description }} 2 | 3 | 4 | The results should consist of two parts: answer spans and supporting fact spans. If there are spans that match the query description, print them out as answer spans. Supporting fact spans are spans that provide additional evidence about the correctness of the answer spans. Always print one span per line. If no such spans exist, print N/A. 5 | 6 | 7 | The following are some examples of code snippets with spans matching the query description, along with supporting facts if any. 8 | {{example_sf_description}} 9 | 10 | {{example_a}} 11 | 12 | 13 | {{example_b}} 14 | 15 | 16 | Code snippet 17 | ```python 18 | {{ input_code }} 19 | ``` 20 | 21 | 22 | Answer span(s) 23 | ```python 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.10 2 | tqdm 3 | absl-py 4 | sklearn 5 | datasets 6 | rank_bm25 -------------------------------------------------------------------------------- /resources/query_folderName_map.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/resources/query_folderName_map.pkl -------------------------------------------------------------------------------- /resources/sampled_test_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/resources/sampled_test_data.pkl -------------------------------------------------------------------------------- /resources/sampled_train_all_data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thepurpleowl/codequeries-benchmark/d07408316bf7bb00936901fae8fb013bfc20abdb/resources/sampled_train_all_data.pkl -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # https://zenodo.org/api/records/7065336 4 | echo "Downloading Model checkpoints..." 5 | wget https://zenodo.org/api/files/f8ac69c4-6f2f-4da8-a29b-870fd4dbdc84/model_ckpts.zip 6 | mkdir finetuned_ckpts 7 | mv model_ckpts.zip finetuned_ckpts 8 | cd finetuned_ckpts 9 | unzip model_ckpts.zip 10 | cd .. 11 | 12 | # https://zenodo.org/record/8002087 13 | echo "Downloading LLM experiment prompt and generations..." 14 | wget https://zenodo.org/record/8002087/files/llm-exp.zip 15 | unzip -q llm-exp.zip 16 | 17 | echo "Downloading Model checkpoints trained with low amount of data..." 18 | # wget https://zenodo.org/record/8002087/files/models-ckpt-low-data.zip 19 | mkdir model-ckpt-with-low-data 20 | mv models-ckpt-low-data.zip model-ckpt-with-low-data 21 | cd model-ckpt-with-low-data 22 | unzip models-ckpt-low-data.zip 23 | cd .. 24 | 25 | # # To download pretrained CuBERT checkpoint, uncomment 26 | # echo "Downloading pretrained checkpoints for training..." 27 | # wget https://zenodo.org/record/8002087/files/pretrained_models.zip 28 | # unzip -q pretrained_models.zip 29 | 30 | 31 | echo "Installing requirements." 32 | pip3 install -r requirements.txt 33 | pip3 install torch==1.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 34 | 35 | echo "Preparing twostep sampled test data" 36 | python3 get_sampled_data.py -------------------------------------------------------------------------------- /train_spanprediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datasets 3 | import csv 4 | import pickle 5 | import os 6 | from pathlib import Path 7 | from tqdm import tqdm 8 | from transformers import AdamW, get_linear_schedule_with_warmup 9 | from utils import Cubert_Model, LR_SP, NUM_WARMUP_STEPS, BATCH, CSV_FIELDS, EPOCHS 10 | from utils import get_dataloader_input, get_dataloader 11 | from utils import eval_fn, train_fn, DEVICE 12 | 13 | if not Path('model_op').exists(): 14 | os.makedirs('model_op') 15 | RESULTS_CSV_PATH = 'model_op/sp_results.csv' 16 | MODEL_PATH = "model_op/model_sp_latest" 17 | 18 | 19 | def train_cubert(train_data, dev_data): 20 | (model_input_ids, model_segment_ids, 21 | model_input_mask, model_labels_ids) = get_dataloader_input(train_data, 22 | example_types_to_evaluate="all", 23 | setting='ideal', 24 | vocab_file="pretrained_model_configs/vocab.txt") 25 | 26 | train_data_loader, train_file_length = get_dataloader( 27 | model_input_ids, 28 | model_input_mask, 29 | model_segment_ids, 30 | model_labels_ids, 31 | True 32 | ) 33 | 34 | (model_input_ids, model_segment_ids, 35 | model_input_mask, model_labels_ids) = get_dataloader_input(dev_data, 36 | example_types_to_evaluate="all", 37 | setting='ideal', 38 | vocab_file="pretrained_model_configs/vocab.txt") 39 | 40 | dev_data_loader, dev_file_length = get_dataloader( 41 | model_input_ids, 42 | model_input_mask, 43 | model_segment_ids, 44 | model_labels_ids 45 | ) 46 | 47 | device = torch.device(DEVICE) 48 | model = Cubert_Model(mode='train') 49 | model.to(device) 50 | 51 | optimizer = AdamW(model.parameters(), lr=LR_SP) 52 | num_train_steps = int( 53 | (train_file_length / BATCH) * EPOCHS 54 | ) 55 | scheduler = get_linear_schedule_with_warmup( 56 | optimizer, num_warmup_steps=NUM_WARMUP_STEPS, 57 | num_training_steps=num_train_steps 58 | ) 59 | 60 | lowest_loss = float("inf") 61 | 62 | with open(RESULTS_CSV_PATH, "a") as f: 63 | csvwriter = csv.writer(f) 64 | csvwriter.writerow(CSV_FIELDS) 65 | 66 | for epoch in tqdm(range(EPOCHS)): 67 | train_fn(train_data_loader, model, 68 | optimizer, device, scheduler) 69 | 70 | _, _, train_loss = eval_fn( 71 | train_data_loader, model, device) 72 | 73 | _, _, dev_loss = eval_fn( 74 | dev_data_loader, model, device) 75 | 76 | if(dev_loss < lowest_loss): 77 | torch.save(model.state_dict(), MODEL_PATH + '_best') 78 | lowest_loss = dev_loss 79 | 80 | # torch.save(model.state_dict(), MODEL_PATH + '_' + str(epoch)) 81 | 82 | with open(RESULTS_CSV_PATH, "a") as f: 83 | results_row = [epoch, train_loss.item(), dev_loss.item()] 84 | 85 | csvwriter = csv.writer(f) 86 | csvwriter.writerow(results_row) 87 | 88 | 89 | if __name__ == '__main__': 90 | # # evaluation 91 | # span_model = Cubert_Model() 92 | # span_model.to(DEVICE) 93 | # span_model.load_state_dict(torch.load("finetuned_ckpts/Cubert-1K", map_location=DEVICE)) 94 | with open('resources/ideal_TRAIN.pkl', 'rb') as f: 95 | train_data = pickle.load(f) 96 | dev_data = datasets.load_dataset("thepurpleowl/codequeries", "ideal", split=datasets.Split.VALIDATION) 97 | 98 | train_cubert(train_data, dev_data) 99 | --------------------------------------------------------------------------------