├── src └── lexecutor │ ├── __init__.py │ ├── predictors │ ├── DLUtil.py │ ├── AsIs.py │ ├── ValuePredictor.py │ ├── codet5 │ │ ├── CodeT5.py │ │ ├── CodeT5ValuePredictor.py │ │ ├── ModelServer.py │ │ ├── PrepareData.py │ │ ├── InputFactory.py │ │ └── FineTune.py │ ├── NaiveValuePredictor.py │ ├── codebert │ │ ├── CodeBERT.py │ │ ├── CodeBERTValuePredictor.py │ │ ├── PrepareData.py │ │ ├── FineTune.py │ │ └── InputFactory.py │ ├── PrepareFrequencyValueData.py │ ├── RandomPredictor.py │ ├── FrequencyValuePredictor.py │ └── Type4PyValuePredictor.py │ ├── Logging.py │ ├── evaluation │ ├── findSemanticsChangingCommit.sh │ ├── RemoveLastLine.py │ ├── RunPynguin.py │ ├── RemoveDecorators.py │ ├── GetWrappInfo.py │ ├── FindSingleHunkCommits.py │ ├── FindRefactoringCommits.py │ ├── EvaluateModels.py │ ├── CountTotalLines.py │ ├── RunExperiments.py │ ├── FunctionExtractor.py │ ├── FunctionBodyExtractor.py │ ├── CombineData.py │ ├── AddFunctionInvocation.py │ ├── CompareDatasetSizes.py │ └── FunctionPairExtractor.py │ ├── Util.py │ ├── TraceEntries.py │ ├── Hyperparams.py │ ├── IIDs.py │ ├── TraceWriter.py │ ├── Instrument.py │ ├── RuntimeStats.py │ ├── Runtime.py │ ├── ValueAbstraction.py │ └── CodeRewriter.py ├── tests ├── playground_mini.py ├── search.py ├── example_bug.py ├── broken_commit_new.py ├── broken_commit_old.py ├── test_complete_code.py ├── term_makeTerms.py ├── pass.py ├── furl_is_common_hostname.py ├── keras_densenet.py └── django_forms.py ├── .vscode └── settings.json ├── requirements.txt ├── REQUIREMENTS.md ├── STATUS.md ├── .gitignore ├── get_traces.sh ├── setup.py ├── LICENSE ├── INSTALL.md ├── get_function_bodies_dataset.sh ├── get_stackoverflow_snippets_dataset.py └── README.md /src/lexecutor/__init__.py: -------------------------------------------------------------------------------- 1 | # empty 2 | -------------------------------------------------------------------------------- /tests/playground_mini.py: -------------------------------------------------------------------------------- 1 | a = 23 2 | b = 42 3 | x = a + b -------------------------------------------------------------------------------- /tests/search.py: -------------------------------------------------------------------------------- 1 | print(f"{' Node.js Output ':=^80}") 2 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.exclude": { 3 | "data": true 4 | } 5 | } -------------------------------------------------------------------------------- /src/lexecutor/predictors/DLUtil.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | 3 | 4 | dtype = t.float 5 | device = "cuda" if t.cuda.is_available() else "cpu" 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | libcst 2 | pandas 3 | tables 4 | pytest 5 | pytest-xdist 6 | gensim 7 | flask 8 | requests 9 | torch 10 | transformers 11 | GitPython 12 | beautifulsoup4 -------------------------------------------------------------------------------- /tests/example_bug.py: -------------------------------------------------------------------------------- 1 | if (not has_min_size(all_data)): 2 | raise RuntimeError("not enough data") 3 | 4 | train_len = 0.8 * len(all_data) 5 | 6 | logger.info(f"Extracting training data with config {config_str}") 7 | 8 | train_data = all_data[0:train_len] 9 | print(train_data) -------------------------------------------------------------------------------- /tests/broken_commit_new.py: -------------------------------------------------------------------------------- 1 | if (not has_min_size(all_data)): 2 | raise RuntimeError("not enough data") 3 | 4 | train_len = 0.8 * len(all_data) 5 | 6 | logger.info(f"Extracting training data with config {config_str}") 7 | 8 | train_data = all_data[0:train_len] 9 | print(train_data) -------------------------------------------------------------------------------- /tests/broken_commit_old.py: -------------------------------------------------------------------------------- 1 | if (not has_min_size(all_data)): 2 | raise RuntimeError("not enough data") 3 | 4 | train_len = round(0.8 * len(all_data)) 5 | 6 | logger.info(f"Extracting training data with config {config_str}") 7 | 8 | train_data = all_data[0:train_len] 9 | print(train_data) -------------------------------------------------------------------------------- /tests/test_complete_code.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | class SimpleCodeTest(TestCase): 4 | def test_variable_read(self): 5 | x = 23 6 | y = x 7 | self.assertEqual(x, 23) # don't instrument 8 | self.assertEqual(y, 23) # don't instrument 9 | -------------------------------------------------------------------------------- /REQUIREMENTS.md: -------------------------------------------------------------------------------- 1 | # Requirements 2 | 3 | The following requirements are needed to run the code in this repository: 4 | 5 | * OS: Ubuntu 18.04.6 LTS 6 | * Python 3.8 7 | * screen (to run the experiments in the background) 8 | * NVIDIA Tesla GPU (P100, T100, or T4) with memory >= 16GB 9 | 10 | -------------------------------------------------------------------------------- /src/lexecutor/Logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | logging.basicConfig(format='%(asctime)s %(message)s', 5 | datefmt='%m/%d/%Y %I:%M:%S %p') 6 | 7 | logger = logging.getLogger("LExecutor logger") 8 | logger.setLevel(logging.INFO) 9 | 10 | logger.info("Logging starts") -------------------------------------------------------------------------------- /src/lexecutor/predictors/AsIs.py: -------------------------------------------------------------------------------- 1 | class AsIs(): 2 | def name(self, iid, name): 3 | raise 4 | 5 | def call(self, iid, fct, fct_name, *args, **kwargs): 6 | raise 7 | 8 | def attribute(self, iid, base, attr_name): 9 | raise 10 | 11 | def binary_operation(self, iid, left, operator, right): 12 | raise 13 | -------------------------------------------------------------------------------- /STATUS.md: -------------------------------------------------------------------------------- 1 | # Badges 2 | 3 | * Evaluated - Reusable 4 | 5 | Our artifacts are carefully documented and well-structured to the extend that reuse is facilitated. 6 | 7 | * Available 8 | 9 | Our artifacts are publically available and a [link](https://zenodo.org/record/8270900) to this repository along with a unique identifier for the object is provided. 10 | 11 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/ValuePredictor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class ValuePredictor(ABC): 5 | def name(self, iid, name): 6 | pass 7 | 8 | def call(self, iid, fct, fct_name, *args, **kwargs): 9 | pass 10 | 11 | def attribute(self, iid, base, attr_name): 12 | pass 13 | 14 | def binary_operation(self, iid, left, operator, right): 15 | pass 16 | -------------------------------------------------------------------------------- /tests/term_makeTerms.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/mininet/mininet/blob/master/mininet/term.py 2 | 3 | def fut(): 4 | """Create terminals. 5 | nodes: list of Node objects 6 | title: base title for each 7 | returns: list of created tunnel/terminal processes""" 8 | terms = [] 9 | for node in nodes: 10 | terms += makeTerm( node, title, term ) 11 | return terms 12 | 13 | 14 | if __name__ == "__main__": 15 | fut() 16 | -------------------------------------------------------------------------------- /tests/pass.py: -------------------------------------------------------------------------------- 1 | # LExecutor: DO NOT INSTRUMENT 2 | 3 | from lexecutor.Runtime import _n_ 4 | from lexecutor.Runtime import _a_ 5 | from lexecutor.Runtime import _c_ 6 | from lexecutor.Runtime import _l_ 7 | class DummySocketManager(_n_(771458, "object", lambda: object)): 8 | _l_(771462) 9 | 10 | x=y 11 | _l_(771459) 12 | def __init__(self, config, logger): 13 | _l_(771460) 14 | 15 | pass def get_socket(self): 16 | _l_(771461) 17 | 18 | pass -------------------------------------------------------------------------------- /src/lexecutor/evaluation/findSemanticsChangingCommit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Arguments 4 | if [ "$#" -ne 1 ]; then 5 | echo "Must pass one argument: name of project to analyze (e.g., 'flask')" 6 | exit 7 | fi 8 | project=$1 9 | 10 | # Step 1 11 | # Step 2 12 | # Step 3: lexecute 13 | 14 | for f in `find data/function_pairs/${project} -name compare.py | xargs` 15 | do 16 | for i in {1..5} 17 | do 18 | timeout 30 python $f 19 | done 20 | done > out_${project}_randomized -------------------------------------------------------------------------------- /tests/furl_is_common_hostname.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/gruns/furl/blob/master/furl/furl.py 2 | 3 | 4 | def function_under_test(): 5 | toks = hostname.split('.') 6 | if toks[-1] == '': # Trailing '.' in a fully qualified domain name. 7 | toks.pop() 8 | 9 | for tok in toks: 10 | if is_valid_host.regex.search(tok) is not None: 11 | return False 12 | 13 | return '' not in toks # Adjacent periods aren't allowed. 14 | 15 | 16 | if __name__ == '__main__': 17 | function_under_test() 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | myenv* 2 | *.swp 3 | *~ 4 | iids*.json 5 | trace_*.h5 6 | __pycache__ 7 | workspace.code-workspace 8 | *_instr.py 9 | /data 10 | /dist 11 | *.egg-info 12 | .vscode/launch.json 13 | *.orig 14 | training_loss.csv 15 | train*.pt 16 | validate*.pt 17 | all_training_traces.txt 18 | checkpoint-last 19 | functions_under_test 20 | bodies_under_test 21 | metrics*.csv 22 | trace*.txt 23 | build 24 | eval_examples.pkl 25 | validation_acc.csv 26 | tests/test.py 27 | *.out 28 | *.log 29 | popular_projects_snippets_dataset 30 | tmp 31 | pynguin-report 32 | flask_files.txt 33 | pyrightconfig.json 34 | out* 35 | so_snippets_dataset/* 36 | *_dataset.txt -------------------------------------------------------------------------------- /src/lexecutor/predictors/codet5/CodeT5.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, T5ForConditionalGeneration, AdamW 2 | from ...Logging import logger 3 | from ..DLUtil import device 4 | 5 | 6 | def load_CodeT5(): 7 | logger.info("Loading pre-trained codet5-small") 8 | 9 | tokenizer = AutoTokenizer.from_pretrained('Salesforce/codet5-small') 10 | 11 | # logger.info(f"Special tokens: {tokenizer.all_special_tokens=}") 12 | # logger.info(f"Input ids of special tokens: {tokenizer.all_special_ids=}") 13 | 14 | model = T5ForConditionalGeneration.from_pretrained( 15 | 'Salesforce/codet5-small') 16 | model.to(device) 17 | 18 | return tokenizer, model 19 | -------------------------------------------------------------------------------- /src/lexecutor/Util.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | 4 | def gather_files(files_arg, suffix=".py"): 5 | if all([f.endswith(".txt") for f in files_arg]): 6 | files = [] 7 | for f in files_arg: 8 | with open(f) as fp: 9 | for line in fp.readlines(): 10 | files.append(line.rstrip()) 11 | else: 12 | for f in files_arg: 13 | if not f.endswith(suffix): 14 | raise Exception(f"Incorrect argument, expected {suffix} file: {f}") 15 | files = files_arg 16 | return files 17 | 18 | 19 | def timestamp(): 20 | epoch = datetime.utcfromtimestamp(0) 21 | now = datetime.now() 22 | return round((now-epoch).total_seconds()*1000000.0) 23 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/RemoveLastLine.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ..Util import gather_files 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument( 6 | "--files", help="Python files or .txt file with all file paths", nargs="+") 7 | 8 | def remove_last_line(file_path): 9 | # Read the file 10 | with open(file_path, 'r') as file: 11 | lines = file.readlines() 12 | 13 | # Remove the last line 14 | if lines: 15 | lines = lines[:-1] 16 | 17 | # Save the modified content back to the file 18 | with open(file_path, 'w') as file: 19 | file.writelines(lines) 20 | 21 | print(f"The last line has been removed from {file_path}") 22 | 23 | if __name__ == "__main__": 24 | args = parser.parse_args() 25 | files = gather_files(args.files) 26 | 27 | for file in files: 28 | remove_last_line(file) -------------------------------------------------------------------------------- /tests/keras_densenet.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/keras-team/keras/blob/master/keras/applications/densenet.py 2 | 3 | # def transition_block(x, reduction, name): 4 | # """A transition block. 5 | # Args: 6 | # x: input tensor. 7 | # reduction: float, compression rate at transition layers. 8 | # name: string, block label. 9 | # Returns: 10 | # output tensor for the block. 11 | # """ 12 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 13 | x = layers.BatchNormalization( 14 | axis=bn_axis, epsilon=1.001e-5, name=name + '_bn')( 15 | x) 16 | x = layers.Activation('relu', name=name + '_relu')(x) 17 | x = layers.Conv2D( 18 | int(backend.int_shape(x)[bn_axis] * reduction), 19 | 1, 20 | use_bias=False, 21 | name=name + '_conv')( 22 | x) 23 | x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x) 24 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/NaiveValuePredictor.py: -------------------------------------------------------------------------------- 1 | from .ValuePredictor import ValuePredictor 2 | from ..Logging import logger 3 | 4 | class Toy: 5 | pass 6 | 7 | 8 | class NaiveValuePredictor(ValuePredictor): 9 | def name(self, iid, name): 10 | v = Toy() 11 | logger.info(f"{iid}: Predicting for name {name}: {v}") 12 | return v 13 | 14 | def call(self, iid, fct, fct_name, *args, **kwargs): 15 | v = Toy() 16 | logger.info(f"{iid}: Predicting for call: {v}") 17 | return v 18 | 19 | def attribute(self, iid, base, attr_name): 20 | v = Toy() 21 | logger.info(f"{iid}: Predicting for attribute {attr_name}: {v}") 22 | return v 23 | 24 | def binary_operation(self, iid, left, operator, right): 25 | v = 3 26 | logger.info(f"{iid}: Predicting result of {operator} operation: {v}") 27 | return v 28 | -------------------------------------------------------------------------------- /get_traces.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # $1: repo URL 3 | # $2: directory to instrument in repository 4 | # $3: test directory in repository 5 | 6 | REPO_NAME=$(echo $1 | grep -o '[^/]*$') 7 | 8 | mkdir ./data 9 | mkdir ./data/repos 10 | 11 | # Download repo 12 | git -C ./data/repos clone $1 13 | 14 | # Install requirements 15 | cd ./data/repos/$REPO_NAME 16 | python3 setup.py install 17 | 18 | # Install additional requirements 19 | # Rich 20 | pip install commonmark 21 | pip install pygments 22 | pip install attr 23 | # Requests 24 | pip install trustme 25 | 26 | # Instrument 27 | cd ../../../ 28 | 29 | FILES_TO_INSTRUMENT=$(find ./data/repos/$REPO_NAME/$2 -type f -name "*.py") 30 | 31 | python3 -m lexecutor.Instrument --files $FILES_TO_INSTRUMENT --iids iids.json 32 | 33 | # Discard tests that cannot be executed 34 | # Request 35 | # rm ./data/repos/$REPO_NAME/$3/conftest.py 36 | # Run tests 37 | cd ./data/repos/$REPO_NAME 38 | pytest ./$3 -------------------------------------------------------------------------------- /src/lexecutor/TraceEntries.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from typing import List 3 | 4 | 5 | class NameEntry(object): 6 | def __init__(self, iid, name, value): 7 | self.iid = iid 8 | self.name = name 9 | self.value = value 10 | 11 | 12 | class CallEntry(object): 13 | def __init__(self, iid, fct_name, args: List[str], value): 14 | self.iid = iid 15 | self.fct_name = fct_name 16 | self.args = args 17 | self.value = value 18 | 19 | 20 | class AttributeEntry(object): 21 | def __init__(self, iid, base, attr_name, value): 22 | self.iid = iid 23 | self.base = base 24 | self.attr_name = attr_name 25 | self.value = value 26 | 27 | 28 | class BinOpEntry(object): 29 | def __init__(self, iid, left, operator, right, value): 30 | self.iid = iid 31 | self.left = left 32 | self.operator = operator 33 | self.right = right 34 | self.value = value -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="lexecutor", 8 | version="0.0.1", 9 | author="Michael Pradel", 10 | author_email="michael@binaervarianz.de", 11 | description="Learning-guided execution", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/michaelpradel/LExecutor", 15 | project_urls={ 16 | "Bug Tracker": "https://github.com/michaelpradel/LExecutor/issues", 17 | }, 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: MIT License", 21 | "Operating System :: OS Independent", 22 | ], 23 | package_dir={"": "src"}, 24 | packages=setuptools.find_packages(where="src"), 25 | python_requires=">=3.6", 26 | install_requires=[ 27 | 'libcst', 28 | 'pandas', 29 | 'torch', 30 | 'tables', 31 | ], 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2021 Michael Pradel 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /tests/django_forms.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/django/django/blob/main/django/forms/forms.py 2 | 3 | 4 | def fut(): 5 | # Collect fields from current class and remove them from attrs. 6 | attrs['declared_fields'] = { 7 | key: attrs.pop(key) for key, value in list(attrs.items()) 8 | if isinstance(value, Field) 9 | } 10 | 11 | new_class = super().__new__(mcs, name, bases, attrs) 12 | 13 | # Walk through the MRO. 14 | declared_fields = {} 15 | for base in reversed(new_class.__mro__): 16 | # Collect fields from base class. 17 | if hasattr(base, 'declared_fields'): 18 | declared_fields.update(base.declared_fields) 19 | 20 | # Field shadowing. 21 | for attr, value in base.__dict__.items(): 22 | if value is None and attr in declared_fields: 23 | declared_fields.pop(attr) 24 | 25 | new_class.base_fields = declared_fields 26 | new_class.declared_fields = declared_fields 27 | 28 | return new_class 29 | 30 | 31 | if __name__ == '__main__': 32 | fut() 33 | -------------------------------------------------------------------------------- /src/lexecutor/Hyperparams.py: -------------------------------------------------------------------------------- 1 | class Hyperparams(object): 2 | iids_file = "iids.json" 3 | verbose = False 4 | 5 | # data deduplication 6 | # dedup = "name-value" 7 | dedup = "name-value-iid" 8 | 9 | # data splitting 10 | # split = "project" 11 | # split = "file" 12 | split = "mixed" 13 | 14 | value_abstraction = "fine-grained" 15 | # value_abstraction = "coarse-grained-deterministic" 16 | # value_abstraction = "coarse-grained-randomized" 17 | 18 | perc_train = 0.95 19 | 20 | # CodeT5 model 21 | max_output_length = 8 22 | 23 | # feedforward model 24 | token_emb_len = 100 25 | value_emb_len = 20 26 | max_call_args = 3 27 | joined_layer_len = 200 28 | intermediate_layer_len = 200 29 | 30 | # training 31 | epochs = 10 32 | # CodeT5 33 | batch_size_CodeT5 = 50 34 | # CodeBERT 35 | batch_size_CodeBERT = 13 36 | 37 | # experiments 38 | # dataset = "so_snippets" 39 | # dataset = "random_functions" 40 | dataset = "other" 41 | number_executions = 10 42 | 43 | 44 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation Guide 2 | 3 | Create and enter a virtual environment: 4 | 5 | ``` 6 | virtualenv -p /usr/bin/python3.8 myenv 7 | source myenv/bin/activate 8 | ``` 9 | 10 | Install requirements: 11 | 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | Locally install the package in development/editable mode: 17 | 18 | ``` 19 | pip install -e ./ 20 | ``` 21 | 22 | # Usage Guide 23 | 24 | 1. Instrument the Python files that will be LExecuted 25 | 26 | 2. Run the Python files instrumented in step 1 27 | 28 | As a simple example, consider that the following code is in `./files/file.py`. 29 | 30 | ```python 31 | if (not has_min_size(all_data)): 32 | raise RuntimeError("not enough data") 33 | 34 | train_len = round(0.8 * len(all_data)) 35 | 36 | logger.info(f"Extracting training data with {config_str}") 37 | 38 | train_data = all_data[0:train_len] 39 | ``` 40 | Then, to *LExecute* the code, do as follows: 41 | 42 | 1. Instrument the code: 43 | ``` 44 | python -m lexecutor.Instrument --files ./files/file.py 45 | ``` 46 | 47 | 2. Run the instrumented code: 48 | ``` 49 | python ./files/file.py 50 | ``` 51 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/RunPynguin.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | from ..Util import gather_files 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--files", help="Python files with extracted functions or .txt file with all file paths", nargs="+") 9 | parser.add_argument( 10 | "--dest", help="Destination directory for the tests", required=True) 11 | 12 | 13 | if __name__ == "__main__": 14 | args = parser.parse_args() 15 | files = gather_files(args.files) 16 | os.environ['PYNGUIN_DANGER_AWARE'] = 'x' 17 | 18 | pynguin_parameters = '--maximum_search_time 30 --seed 42 --max-attempts 10 --maximum_test_execution_timeout 10 --maximum_slicing_time 10 --test_execution_time_per_statement 1 --assertion-generation SIMPLE -v' 19 | 20 | for file in files: 21 | if file.startswith(os.getcwd()): 22 | file = file[len(os.getcwd())+1:] 23 | module_pynguin_path = file.replace("/", ".")[2:-3] 24 | print(f"Running Pynguin on {module_pynguin_path} with parameters {pynguin_parameters}") 25 | log_pynguin = subprocess.run(f"pynguin --project-path . --output-path {args.dest} --module-name {module_pynguin_path} {pynguin_parameters}".split( 26 | " "), capture_output=True, text=True, shell=False, timeout=60) 27 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/RemoveDecorators.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import argparse 4 | import libcst as cst 5 | import pandas as pd 6 | from ..Util import gather_files 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--files", help="Python files or .txt file with all file paths", nargs="+") 11 | 12 | class DecoratorRemover(cst.CSTTransformer): 13 | def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: 14 | # Remove decorators by replacing them with an empty list 15 | updated_node = updated_node.with_changes(decorators=[]) 16 | return updated_node 17 | 18 | 19 | if __name__ == "__main__": 20 | args = parser.parse_args() 21 | files = gather_files(args.files) 22 | 23 | for file in files: 24 | with open(file + ".orig", "r") as fp: 25 | src = fp.read() 26 | ast = cst.parse_module(src) 27 | 28 | transformer = DecoratorRemover() 29 | updated_module = ast.visit(transformer) 30 | 31 | # Get the updated code 32 | updated_code = updated_module.code 33 | 34 | base_dir = file.split('functions_with_invocation')[0] 35 | file_name = file.split('functions_with_invocation')[1] 36 | outfile = base_dir + "functions_without_decorator" + file_name 37 | 38 | with open(outfile, "w") as f: 39 | f.write(updated_code) -------------------------------------------------------------------------------- /src/lexecutor/IIDs.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import os 3 | from os import path 4 | import json 5 | from .Logging import logger 6 | 7 | 8 | Location = namedtuple( 9 | "Location", ["file", "line", "column_start", "column_end"]) 10 | 11 | 12 | class IIDs: 13 | def __init__(self, file_path): 14 | if not path.exists(file_path): 15 | logger.info(f"Creating new iid file at {file_path}") 16 | self.next_iid = 1 17 | self._iid_to_location = {} 18 | else: 19 | with open(file_path, "r") as file: 20 | json_object = json.load(file) 21 | self.next_iid = json_object["next_iid"] 22 | self._iid_to_location = json_object["iid_to_location"] 23 | self.file_path = file_path 24 | 25 | def new(self, file, line, column_start, column_end): 26 | self._iid_to_location[self.next_iid] = Location( 27 | file, line, column_start, column_end) 28 | self.next_iid += 1 29 | return self.next_iid - 1 30 | 31 | def store(self): 32 | all_data = { 33 | "next_iid": self.next_iid, 34 | "iid_to_location": self._iid_to_location, 35 | } 36 | json_object = json.dumps(all_data, indent=2) 37 | with open(self.file_path, "w") as file: 38 | file.write(json_object) 39 | 40 | def line(self, iid): 41 | return self._iid_to_location[str(iid)][1] 42 | 43 | def location(self, iid): 44 | return Location(*self._iid_to_location[str(iid)]) 45 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/GetWrappInfo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import argparse 4 | import pandas as pd 5 | from ..Util import gather_files 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--files", help="Python files to get wrapp info or .txt file with all file paths", nargs="+") 10 | 11 | def get_wrapp_info(file_path): 12 | wrapped = 0 13 | with open(file_path, "r") as file: 14 | lines = file.readlines() 15 | for line in lines: 16 | line = line.strip() 17 | if line.startswith("class Wrapper:"): 18 | wrapped = 1 19 | break 20 | return wrapped 21 | 22 | def save_wrapp_info(files, wrapp_info): 23 | # Create CSV file and add header if it doesn't exist 24 | if not os.path.isfile('./wrapp_info.csv'): 25 | columns = ["file", "wrapped"] 26 | 27 | with open('./wrapp_info.csv', 'a') as csvFile: 28 | writer = csv.writer(csvFile) 29 | writer.writerow(columns) 30 | 31 | df = pd.read_csv('./wrapp_info.csv') 32 | df_new_data = pd.DataFrame({ 33 | 'file': files, 34 | 'wrapped': wrapp_info 35 | }) 36 | df = pd.concat([df, df_new_data]) 37 | df.to_csv('./wrapp_info.csv', index=False) 38 | 39 | if __name__ == "__main__": 40 | args = parser.parse_args() 41 | files = gather_files(args.files) 42 | 43 | wrapp_info = [] 44 | for file_path in files: 45 | wrapp_info.append(get_wrapp_info(f"{file_path}")) 46 | save_wrapp_info(files, wrapp_info) 47 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codebert/CodeBERT.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizer, RobertaForMaskedLM 2 | from ...Logging import logger 3 | from ..DLUtil import device 4 | 5 | def load_CodeBERT(): 6 | logger.info("Loading pre-trained codebert-base-mlm") 7 | 8 | tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base-mlm") 9 | 10 | value_tokens = [ 11 | 'None', 12 | 'True', 13 | 'False', 14 | 'bool', 15 | 'str_empty', 16 | 'str_nonempty', 17 | 'str', 18 | 'int_neg', 19 | 'int_zero', 20 | 'int_pos', 21 | 'int', 22 | 'float_neg', 23 | 'float_zero', 24 | 'float_pos', 25 | 'float', 26 | 'list_empty', 27 | 'list_nonempty', 28 | 'list', 29 | 'tuple_empty', 30 | 'tuple_nonempty', 31 | 'tuple', 32 | 'set_empty', 33 | 'set_nonempty', 34 | 'set', 35 | 'dict_empty', 36 | 'dict_nonempty', 37 | 'dict', 38 | 'resource', 39 | 'callable', 40 | 'object' 41 | ] 42 | 43 | additional_tokens = ['', '', '', ''] 44 | 45 | special_tokens_dict = {'additional_special_tokens': value_tokens + additional_tokens} 46 | tokenizer.add_special_tokens(special_tokens_dict) 47 | 48 | model = RobertaForMaskedLM.from_pretrained("microsoft/codebert-base-mlm") 49 | model.resize_token_embeddings(len(tokenizer)) 50 | model.to(device) 51 | 52 | return tokenizer, model -------------------------------------------------------------------------------- /src/lexecutor/evaluation/FindSingleHunkCommits.py: -------------------------------------------------------------------------------- 1 | from git import Repo 2 | import re 3 | 4 | # Helper script to find commits that are likely single-function refactorings. 5 | # 6 | # Dump output into a file "out" and then open the commit links in a browser with: 7 | # for l in `cat out | xargs`; do firefox $l; done 8 | 9 | 10 | # url_prefix = "https://github.com/scrapy/scrapy/commit/" 11 | # repo = Repo("data/repos/scrapy") 12 | 13 | # url_prefix = "https://github.com/nvbn/thefuck/commit/" 14 | # repo = Repo("data/repos/thefuck") 15 | 16 | # url_prefix = "https://github.com/scikit-learn/scikit-learn/commit/" 17 | # repo = Repo("data/repos/scikit-learn") 18 | 19 | url_prefix = "https://github.com/psf/black/commit/" 20 | repo = Repo("data/repos/black") 21 | 22 | # url_prefix = "https://github.com/pallets/flask/commit/" 23 | # repo = Repo("data/repos/flask") 24 | 25 | # url_prefix = "https://github.com/pandas-dev/pandas/commit/" 26 | # repo = Repo("data/repos/pandas") 27 | 28 | commits = list(repo.iter_commits("main")) 29 | nb_commits_match = 0 30 | for c in commits: 31 | if len(c.parents) == 0: 32 | continue 33 | diff = c.parents[0].diff(c, create_patch=True) 34 | if len(diff) == 1 and diff[0].a_path and diff[0].a_path.endswith(".py"): 35 | diff_str = str(diff[0]) 36 | matches = re.findall(r"@@", diff_str) 37 | if len(matches) == 2: 38 | # print(diff_str) 39 | # print("\n---------------------------\n") 40 | print(f"{url_prefix}{c.hexsha}") 41 | nb_commits_match += 1 42 | print(f"{len(commits)} total commits, {nb_commits_match} matches") 43 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/FindRefactoringCommits.py: -------------------------------------------------------------------------------- 1 | from git import Repo 2 | 3 | # Helper script to find commits that are likely single-function refactorings. 4 | # 5 | # Dump output into a file "out" and then open the commit links in a browser with: 6 | # for l in `cat out | xargs`; do firefox $l; done 7 | 8 | 9 | # url_prefix = "https://github.com/scrapy/scrapy/commit/" 10 | # repo = Repo("data/repos/scrapy") 11 | 12 | # url_prefix = "https://github.com/nvbn/thefuck/commit/" 13 | # repo = Repo("data/repos/thefuck") 14 | 15 | # url_prefix = "https://github.com/scikit-learn/scikit-learn/commit/" 16 | # repo = Repo("data/repos/scikit-learn") 17 | 18 | # url_prefix = "https://github.com/psf/black/commit/" 19 | # repo = Repo("data/repos/black") 20 | 21 | # url_prefix = "https://github.com/pallets/flask/commit/" 22 | # repo = Repo("data/repos/flask") 23 | 24 | url_prefix = "https://github.com/pandas-dev/pandas/commit/" 25 | repo = Repo("data/repos/pandas") 26 | 27 | commits = list(repo.iter_commits("main")) 28 | nb_commits_refactor = 0 29 | nb_commits_match = 0 30 | for c in commits: 31 | if "refactor" in c.message: 32 | nb_commits_refactor += 1 33 | diff = c.parents[0].diff(c, create_patch=True) 34 | if len(diff) == 1 and diff[0].a_path and diff[0].a_path.endswith(".py"): 35 | diff_str = str(diff[0]) 36 | # heuristic check for single-function edits 37 | if diff_str.count("def ") <= 1: 38 | print(f"{url_prefix}{c.hexsha}") 39 | nb_commits_match += 1 40 | print(f"{len(commits)} total commits, {nb_commits_refactor} refactorings, {nb_commits_match} matches") 41 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/PrepareFrequencyValueData.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from collections import Counter 4 | from .codet5.PrepareData import read_traces, clean_entries 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--traces", help="Trace file or .txt file(s) with all trace file paths to use", 9 | nargs="+", required=True) 10 | 11 | def get_values_frequencies(trace_files): 12 | name_to_values = {} 13 | call_to_values = {} 14 | attribute_to_values = {} 15 | 16 | entries = read_traces(trace_files) 17 | clean_entries(entries) 18 | for index, entry in entries.iterrows(): 19 | key = entry["name"] 20 | if entry["kind"] == "name": 21 | name_to_values.setdefault(key, Counter())[ 22 | entry.value] += 1 23 | elif entry["kind"] == "call": 24 | call_to_values.setdefault(key, Counter())[ 25 | entry.value] += 1 26 | elif entry["kind"] == "attribute": 27 | attribute_to_values.setdefault(key, Counter())[ 28 | entry.value] += 1 29 | 30 | return { 31 | "name_to_values": name_to_values, 32 | "call_to_values": call_to_values, 33 | "attribute_to_values": attribute_to_values 34 | } 35 | 36 | def store_values_frequencies(values_frequencies): 37 | with open("values_frequencies.json", "w") as outfile: 38 | json.dump(values_frequencies, outfile) 39 | 40 | if __name__ == "__main__": 41 | args = parser.parse_args() 42 | values_frequencies = get_values_frequencies(args.traces) 43 | store_values_frequencies(values_frequencies) -------------------------------------------------------------------------------- /src/lexecutor/predictors/RandomPredictor.py: -------------------------------------------------------------------------------- 1 | from .ValuePredictor import ValuePredictor 2 | from ..ValueAbstraction import DummyObject, DummyResource 3 | from ..Logging import logger 4 | import random 5 | 6 | class RandomPredictor(ValuePredictor): 7 | def __init__(self): 8 | super().__init__() 9 | self.values = [ 10 | None, 11 | True, 12 | False, 13 | "", 14 | "a", 15 | -1, 16 | 0, 17 | 1, 18 | -1.0, 19 | 0.0, 20 | 1.0, 21 | [], 22 | [DummyObject()], 23 | (), 24 | (DummyObject(),), 25 | set(), 26 | {DummyObject()}, 27 | {}, 28 | {"a": DummyObject()}, 29 | DummyResource(), 30 | DummyObject, 31 | DummyObject() 32 | ] 33 | 34 | def get_random_value(self): 35 | random_index = random.randint(0, len(self.values) - 1) 36 | return self.values[random_index] 37 | 38 | def name(self, iid, name): 39 | v = self.get_random_value() 40 | logger.info(f"{iid}: Predicting with RandomPredictor for name {name}: {v}") 41 | return v 42 | 43 | def call(self, iid, fct, fct_name, *args, **kwargs): 44 | v = self.get_random_value() 45 | logger.info(f"{iid}: Predicting with RandomPredictor for call: {v}") 46 | return v 47 | 48 | def attribute(self, iid, base, attr_name): 49 | v = self.get_random_value() 50 | logger.info(f"{iid}: Predicting with RandomPredictor for attribute {attr_name}: {v}") 51 | return v -------------------------------------------------------------------------------- /src/lexecutor/evaluation/EvaluateModels.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch as t 3 | from ..predictors.DLUtil import device 4 | from ..predictors.codet5.FineTune import evaluate as evaluate_CodeT5, load_CodeT5 5 | from ..predictors.codebert.FineTune import evaluate as evaluate_CodeBERT, load_CodeBERT 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--model_type", help="CodeT5 or CodeBERT", required=True) 10 | parser.add_argument( 11 | "--value_abstraction", help="fine-grained or coarse-grained", required=True) 12 | parser.add_argument( 13 | "--state_file", help=".bin file for validation", default="model.bin") 14 | parser.add_argument( 15 | "--validate_tensors", help=".pt file for validation", default="validate.pt") 16 | 17 | def evaluate_model(model_type, value_abstraction, state_file, validation_data_file): 18 | if model_type == "CodeT5": 19 | tokenizer, model = load_CodeT5() 20 | model.load_state_dict(t.load(state_file, map_location=device)) 21 | topk_accuracies = evaluate_CodeT5(validation_data_file, model, tokenizer) 22 | else: 23 | tokenizer, model = load_CodeBERT() 24 | model.load_state_dict(t.load(state_file, map_location=device)) 25 | topk_accuracies = evaluate_CodeBERT(validation_data_file, model, tokenizer) 26 | 27 | print("="*30) 28 | print(f"{model_type}-{value_abstraction}\n{topk_accuracies}") 29 | print("="*30) 30 | 31 | 32 | if __name__ == "__main__": 33 | args = parser.parse_args() 34 | 35 | evaluate_model(args.model_type, 36 | args.value_abstraction, 37 | args.state_file, 38 | args.validate_tensors) 39 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/CountTotalLines.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import argparse 4 | import pandas as pd 5 | from ..Util import gather_files 6 | from ..Hyperparams import Hyperparams as params 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--files", help="Python files to count lines or .txt file with all file paths", nargs="+") 11 | 12 | def count_lines(file_path): 13 | total_lines = 0 14 | with open(file_path, "r") as file: 15 | lines = file.readlines() 16 | for line in lines: 17 | line = line.strip() 18 | if line.startswith("_l_("): 19 | total_lines += 1 20 | return total_lines 21 | 22 | def save_total_lines(files, total_lines): 23 | # Create CSV file and add header if it doesn't exist 24 | if not os.path.isfile(f'./total_lines_{params.dataset}_dataset.csv'): 25 | columns = ["file", "total_lines"] 26 | 27 | with open(f'./total_lines_{params.dataset}_dataset.csv', 'a') as csvFile: 28 | writer = csv.writer(csvFile) 29 | writer.writerow(columns) 30 | 31 | df = pd.read_csv(f'./total_lines_{params.dataset}_dataset.csv') 32 | df_new_data = pd.DataFrame({ 33 | 'file': files, 34 | 'total_lines': total_lines 35 | }) 36 | df = pd.concat([df, df_new_data]) 37 | df.to_csv(f'./total_lines_{params.dataset}_dataset.csv', index=False) 38 | 39 | if __name__ == "__main__": 40 | args = parser.parse_args() 41 | files = gather_files(args.files) 42 | 43 | total_lines = [] 44 | for file_path in files: 45 | total_lines.append(count_lines(f"{file_path}")) 46 | save_total_lines(files, total_lines) -------------------------------------------------------------------------------- /src/lexecutor/evaluation/RunExperiments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import signal 4 | import subprocess 5 | from ..Util import gather_files 6 | from ..Hyperparams import Hyperparams as params 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument( 10 | "--files", help="Python files or .txt file with all file paths", nargs="+") 11 | parser.add_argument( 12 | "--tests", help="Run pytest tests", action="store_true") 13 | parser.add_argument( 14 | "--log_dest_dir", help="Destination directory for the log files", required=True) 15 | 16 | 17 | if __name__ == "__main__": 18 | args = parser.parse_args() 19 | 20 | files = gather_files(args.files) 21 | 22 | if args.tests: 23 | command = "pytest" 24 | else: 25 | command = "python3" 26 | 27 | # run the files (with a timeout) 28 | for file in files: 29 | for execution in range(1, params.number_executions+1): 30 | if params.dataset == "random_functions": 31 | project_name = file.split("/")[2] 32 | file_name = file.split("/")[4].split('.')[0] 33 | log_file = open(f"{args.log_dest_dir}/{project_name}_{file_name}_{str(execution)}.txt", "w") 34 | else: 35 | file_name = file.split("/")[2].split('.')[0] 36 | log_file = open(f"{args.log_dest_dir}/{file_name}_{str(execution)}.txt", "w") 37 | try: 38 | process = subprocess.Popen( 39 | f"time {command} {file} {execution}", shell=True, start_new_session=True, stdout=log_file, stderr=log_file) 40 | process.wait(timeout=30) # seconds 41 | except subprocess.TimeoutExpired: 42 | log_file.write("TimeLimit!!!!") 43 | os.killpg(os.getpgid(process.pid), signal.SIGTERM) 44 | -------------------------------------------------------------------------------- /src/lexecutor/TraceWriter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from .ValueAbstraction import abstract_value 3 | from .Logging import logger 4 | from .Util import timestamp 5 | 6 | 7 | column_names = ["iid", "name", "value", "kind", "info"] 8 | 9 | 10 | class TraceWriter: 11 | def __init__(self): 12 | self.buffer = [] 13 | 14 | def _append(self, iid, name, raw_value, kind): 15 | value, info = abstract_value(raw_value) 16 | self.buffer.append([iid, name, value, kind, info]) 17 | 18 | if len(self.buffer) % 10000000 == 0: 19 | self.write_to_file() 20 | 21 | def append_name(self, iid, name, raw_value): 22 | self._append(iid, name, raw_value, "name") 23 | 24 | def append_call(self, iid, fct, raw_args, raw_kwargs, raw_value): 25 | fct_name = fct.__name__ if hasattr(fct, "__name__") else str(fct) 26 | if " " in fct_name: # some fcts don't have a proper name 27 | fct_name = fct_name.split(" ")[0] 28 | 29 | self._append(iid, fct_name, raw_value, "call") 30 | 31 | def append_attribute(self, iid, raw_base, attr_name, raw_value): 32 | self._append(iid, attr_name, raw_value, "attribute") 33 | 34 | def write_to_file(self): 35 | file_name = f"trace_{timestamp()}.h5" 36 | logger.info(f"Writing to {file_name}, and flushing buffer") 37 | 38 | df = pd.DataFrame(data=self.buffer, columns=column_names) 39 | df["iid"] = df["iid"].astype("int") 40 | df["name"] = df["name"].astype("str") 41 | df["value"] = df["value"].astype("str") 42 | df["kind"] = df["kind"].astype("str") 43 | df["info"] = df["info"].astype("str") 44 | 45 | logger.info(f"Deduplicating {len(df)} trace entries") 46 | df.drop_duplicates( 47 | subset=["iid", "name", "value", "kind"], inplace=True) 48 | logger.info(f"After deduplicating: {len(df)} trace entries") 49 | 50 | df.to_hdf(file_name, key="entries", complevel=9, complib="bzip2") 51 | 52 | self.buffer = [] 53 | -------------------------------------------------------------------------------- /get_function_bodies_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FUNCTION_BODIES="" 4 | 5 | declare -a PROJECTS=( 6 | "https://github.com/psf/black" 7 | "https://github.com/pallets/flask" 8 | "https://github.com/pandas-dev/pandas" 9 | "https://github.com/scrapy/scrapy" 10 | "https://github.com/tensorflow/tensorflow" 11 | ) 12 | 13 | mkdir popular_projects_snippets_dataset 14 | 15 | # extract function bodies from projects 16 | for project in ${PROJECTS[@]}; do 17 | REPO_NAME=$(echo $project | grep -o '[^/]*$') 18 | 19 | # delete repo in case it already exists 20 | rm -rf data/repos/$REPO_NAME 21 | # download repo 22 | git -C ./data/repos clone $project 23 | # create destination dir 24 | mkdir popular_projects_snippets_dataset/$REPO_NAME 25 | mkdir popular_projects_snippets_dataset/$REPO_NAME/functions 26 | mkdir popular_projects_snippets_dataset/$REPO_NAME/functions_with_invocation 27 | mkdir popular_projects_snippets_dataset/$REPO_NAME/bodies 28 | # extract function bodies 29 | if [ "$REPO_NAME" == "flask" ] || [ "$REPO_NAME" == "black" ] 30 | then 31 | FILES=$(find ./data/repos/$REPO_NAME/src/$REPO_NAME -type f -name "*.py") 32 | else 33 | FILES=$(find ./data/repos/$REPO_NAME/$REPO_NAME -type f -name "*.py") 34 | fi 35 | python -m lexecutor.evaluation.FunctionBodyExtractor --files $FILES --dest ./popular_projects_snippets_dataset/$REPO_NAME 36 | python -m lexecutor.evaluation.FunctionExtractor --files $FILES --dest ./popular_projects_snippets_dataset/$REPO_NAME 37 | # randomly select 200 function bodies 38 | FUNCTION_BODIES+="$(find ./popular_projects_snippets_dataset/$REPO_NAME/bodies -type f -name "*.py" | shuf -n 200) " 39 | done 40 | 41 | # save file paths to .txt files 42 | echo $FUNCTION_BODIES | tr ' ' '\n' > popular_projects_function_bodies_dataset.txt 43 | sed -e 's/bodies/functions/g' -e 's/body/function/g' popular_projects_function_bodies_dataset.txt > popular_projects_functions_dataset.txt 44 | sed -e 's/bodies/functions_with_invocation/g' popular_projects_function_bodies_dataset.txt > popular_projects_functions_with_invocation_dataset.txt 45 | 46 | python -m lexecutor.evaluation.AddFunctionInvocation --files popular_projects_functions_dataset.txt 47 | python -m lexecutor.evaluation.GetWrappInfo --files popular_projects_functions_dataset.txt -------------------------------------------------------------------------------- /src/lexecutor/evaluation/FunctionExtractor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ..Util import gather_files 3 | import libcst as cst 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--files", help="Python files to extract from or .txt file with all file paths", nargs="+") 9 | parser.add_argument( 10 | "--dest", help="Destination directory", required=True) 11 | 12 | 13 | class ExtractorVisitor(cst.CSTTransformer): 14 | def __init__(self, dest_dir): 15 | self.dest_dir = dest_dir 16 | 17 | existing_files = [f for f in os.listdir(dest_dir)] 18 | self.next_id = 0 19 | while f"body_{self.next_id}.py" in existing_files: 20 | self.next_id += 1 21 | 22 | def set_source_file(self, file): 23 | self.file = file 24 | 25 | def leave_Param(self, node, updated_node): 26 | # remove parameter type annotation 27 | return updated_node.with_changes(annotation=None) 28 | 29 | def leave_FunctionDef(self, node, updated_node): 30 | info = f"# Extracted from {self.file}" 31 | 32 | if len(updated_node.params.params) and updated_node.params.params[0].name.value == "self": 33 | # wrap function into a class 34 | code = cst.Module([]).code_for_node( 35 | cst.ClassDef( 36 | name=cst.Name( 37 | value='Wrapper' 38 | ), 39 | body=cst.IndentedBlock( 40 | body=[updated_node.with_changes(returns=None)] 41 | ) 42 | ) 43 | ) 44 | else: 45 | # remove return type annotation and save full function 46 | code = cst.Module([]).code_for_node( 47 | updated_node.with_changes(returns=None)) 48 | 49 | outfile = os.path.join( 50 | f"{self.dest_dir}/functions", f"function_{self.next_id}.py") 51 | with open(outfile, "w") as f: 52 | f.write(info+"\n") 53 | f.write(code) 54 | 55 | self.next_id += 1 56 | 57 | return updated_node 58 | 59 | 60 | if __name__ == "__main__": 61 | args = parser.parse_args() 62 | files = gather_files(args.files) 63 | extractor = ExtractorVisitor(args.dest) 64 | for file in files: 65 | with open(file, "r") as fp: 66 | src = fp.read() 67 | ast = cst.parse_module(src) 68 | extractor.set_source_file(file) 69 | ast.visit(extractor) 70 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/FrequencyValuePredictor.py: -------------------------------------------------------------------------------- 1 | from .ValuePredictor import ValuePredictor 2 | from .NaiveValuePredictor import NaiveValuePredictor 3 | from ..Logging import logger 4 | from ..ValueAbstraction import restore_value 5 | from random import choices 6 | import json 7 | 8 | class FrequencyValuePredictor(ValuePredictor): 9 | def __init__(self, values_frequencies_file): 10 | with open(f'{values_frequencies_file}', 'r') as openfile: 11 | values_frequencies = json.load(openfile) 12 | 13 | self.name_to_values = values_frequencies["name_to_values"] 14 | self.call_to_values = values_frequencies["call_to_values"] 15 | self.attribute_to_values = values_frequencies["attribute_to_values"] 16 | 17 | self.naive_predictor = NaiveValuePredictor() # as a fallback 18 | 19 | self.total_predictions = 0 20 | self.frequency_based_predictions = 0 21 | 22 | def name(self, iid, name): 23 | counter = self.name_to_values.get(name) 24 | self.total_predictions += 1 25 | if counter is None: 26 | return self.naive_predictor.name(iid, name) 27 | else: 28 | self.frequency_based_predictions += 1 29 | v = choices(list(counter.keys()), list(counter.values()))[0] 30 | logger.info(f"{iid}: Predicting for name {name}: {v}") 31 | return restore_value(v) 32 | 33 | def call(self, iid, fct, fct_name, *args, **kwargs): 34 | counter = self.call_to_values.get(fct_name) 35 | self.total_predictions += 1 36 | if counter is None: 37 | return self.naive_predictor.call(iid, fct, *args, **kwargs) 38 | else: 39 | self.frequency_based_predictions += 1 40 | v = choices(list(counter.keys()), list(counter.values()))[0] 41 | logger.info(f"{iid}: Predicting for call: {v}") 42 | return restore_value(v) 43 | 44 | def attribute(self, iid, base, attr_name): 45 | counter = self.attribute_to_values.get(attr_name) 46 | self.total_predictions += 1 47 | if counter is None: 48 | return self.naive_predictor.attribute(iid, base, attr_name) 49 | else: 50 | self.frequency_based_predictions += 1 51 | v = choices(list(counter.keys()), list(counter.values()))[0] 52 | logger.info(f"{iid}: Predicting for attribute {attr_name}: {v}") 53 | return restore_value(v) 54 | 55 | def print_stats(self): 56 | print(f"{self.frequency_based_predictions}/{self.total_predictions} ({self.frequency_based_predictions/self.total_predictions if self.total_predictions > 0 else 0}) predictions were frequency based") 57 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/FunctionBodyExtractor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ..Util import gather_files 3 | import libcst as cst 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--files", help="Python files to extract from or .txt file with all file paths", nargs="+") 9 | parser.add_argument( 10 | "--dest", help="Destination directory", required=True) 11 | 12 | 13 | class ExtractorVisitor(cst.CSTTransformer): 14 | def __init__(self, dest_dir): 15 | self.dest_dir = dest_dir 16 | 17 | existing_files = [f for f in os.listdir(dest_dir)] 18 | self.next_id = 0 19 | while f"body_{self.next_id}.py" in existing_files: 20 | self.next_id += 1 21 | 22 | def set_source_file(self, file): 23 | self.file = file 24 | 25 | def leave_Return(self, node, updated_node): 26 | args = [] 27 | if node.value is not None: 28 | if type(node.value) is cst.Tuple: 29 | args = [cst.Arg(value=cst.Tuple(elements=node.value.elements))] 30 | else: 31 | args = [cst.Arg(value=node.value)] 32 | expr = cst.Expr( 33 | value=cst.Call( 34 | func=cst.Name("exit"), 35 | args=args 36 | ) 37 | ) 38 | return expr 39 | 40 | def leave_Yield(self, node, updated_node): 41 | args = [] 42 | if node.value is not None: 43 | if type(node.value) is cst.Tuple: 44 | args = [cst.Arg(value=cst.Tuple(elements=node.value.elements))] 45 | elif type(node.value) is cst.From: 46 | args = [cst.Arg(value=node.value.item)] 47 | else: 48 | args = [cst.Arg(value=node.value)] 49 | expr = cst.Expr( 50 | value=cst.Call( 51 | func=cst.Name("exit"), 52 | args=args 53 | ) 54 | ) 55 | return expr 56 | 57 | def leave_FunctionDef(self, node, updated_node): 58 | info = f"# Extracted from {self.file}" 59 | 60 | # save function body 61 | body = [s for s in updated_node.body.body] 62 | body_code = cst.Module(body=body).code 63 | outfile = os.path.join( 64 | f"{self.dest_dir}/bodies", f"body_{self.next_id}.py") 65 | with open(outfile, "w") as f: 66 | f.write(info+"\n") 67 | f.write(body_code) 68 | 69 | self.next_id += 1 70 | 71 | return updated_node 72 | 73 | 74 | if __name__ == "__main__": 75 | args = parser.parse_args() 76 | files = gather_files(args.files) 77 | extractor = ExtractorVisitor(args.dest) 78 | for file in files: 79 | with open(file, "r") as fp: 80 | src = fp.read() 81 | ast = cst.parse_module(src) 82 | extractor.set_source_file(file) 83 | ast.visit(extractor) 84 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codet5/CodeT5ValuePredictor.py: -------------------------------------------------------------------------------- 1 | from ..ValuePredictor import ValuePredictor 2 | from ..DLUtil import device 3 | from .ModelServer import ModelServer 4 | from ...Logging import logger 5 | import time 6 | import requests 7 | from requests.exceptions import ConnectionError 8 | import subprocess 9 | from ...ValueAbstraction import restore_value 10 | 11 | 12 | class CodeT5ValuePredictor(ValuePredictor): 13 | def __init__(self, stats): 14 | self.stats = stats 15 | 16 | def _query_model(self, entry): 17 | def get(entry): 18 | raw_response = requests.get( 19 | "http://localhost:5000/query", params=entry) 20 | if raw_response.status_code != 200: 21 | raise RuntimeError( 22 | f"Model server returned error code {raw_response.status_code}") 23 | return raw_response.json() 24 | 25 | response = None 26 | try: 27 | response = get(entry) 28 | except ConnectionError: 29 | # model server not yet running; start it 30 | logger.info("No model server running. Starting it now") 31 | server_log = open("model_server.log", "w") 32 | subprocess.Popen( 33 | "python -m lexecutor.predictors.codet5.ModelServer".split(" "), 34 | stderr=server_log, stdout=server_log) 35 | 36 | # try to connect until it's responding (or we give up) 37 | attempts = 0 38 | while attempts < 5: 39 | try: 40 | response = get(entry) 41 | logger.info("Model server is up and running") 42 | break 43 | except ConnectionError: 44 | time.sleep(5) # seconds 45 | attempts += 1 46 | 47 | if response is None: 48 | raise RuntimeError("Could not connect to model server") 49 | 50 | val_as_string = response["v"] 51 | val = restore_value(val_as_string) 52 | 53 | return val_as_string, val 54 | 55 | def name(self, iid, name): 56 | entry = {"iid": iid, "name": name, "kind": "name"} 57 | abstract_v, v = self._query_model(entry) 58 | logger.info(f"{iid}: Predicting for name {name}: {v}") 59 | self.stats.inject_value( 60 | iid, f"Inject {abstract_v} for variable {name}") 61 | return v 62 | 63 | def call(self, iid, fct, fct_name, *args, **kwargs): 64 | entry = {"iid": iid, "name": fct_name, "kind": "call"} 65 | abstract_v, v = self._query_model(entry) 66 | logger.info(f"{iid}: Predicting for call: {v}") 67 | self.stats.inject_value( 68 | iid, f"Inject {abstract_v} as return value of {fct_name}") 69 | return v 70 | 71 | def attribute(self, iid, base, attr_name): 72 | entry = {"iid": iid, "name": attr_name, "kind": "attribute"} 73 | abstract_v, v = self._query_model(entry) 74 | logger.info(f"{iid}: Predicting for attribute {attr_name}: {v}") 75 | self.stats.inject_value( 76 | iid, f"Inject {abstract_v} for attribute {attr_name}") 77 | return v 78 | -------------------------------------------------------------------------------- /get_stackoverflow_snippets_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import requests 4 | import os 5 | from bs4 import BeautifulSoup 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument( 9 | "--dest_dir", help="Destination directory", required=True) 10 | 11 | def get_hrefs(soup): 12 | # get all href links 13 | href=[] 14 | for i in soup.find_all("a",class_="s-link",href=True): 15 | href.append(i['href']) 16 | return href 17 | 18 | def add_prefix(herfs_list): 19 | new_href=[] 20 | prefix='https://stackoverflow.com' 21 | for h in herfs_list: 22 | new_href.append(prefix+h) 23 | return new_href 24 | 25 | def get_popular_python_questions(start_page, end_page, page_size): 26 | soups=[] 27 | for page in range(start_page, end_page + 1): 28 | request = requests.get( 29 | url = f'https://stackoverflow.com/questions/tagged/python?tab=votes&page={page}&pagesize={page_size}') 30 | soup = BeautifulSoup(request.text, "html.parser") 31 | soups.append(soup.find("div", id="questions")) 32 | 33 | hrefs=[] 34 | for soup in soups: 35 | hrefs.extend(get_hrefs(soup)) 36 | hrefs = add_prefix(hrefs) 37 | 38 | return hrefs 39 | 40 | def get_random_answer(question_url): 41 | request = requests.get(url = question_url) 42 | soup = BeautifulSoup(request.text, "html.parser") 43 | answers = soup.find_all("div", class_="answercell post-layout--right") 44 | random_index = random.randint(0, len(answers) - 1) 45 | return answers[random_index] 46 | 47 | def get_python_code(answer): 48 | code = "" 49 | code_block = answer.find_all("pre") 50 | for code_block in code_block: 51 | raw_code = code_block.find_all("code") 52 | for snippet in raw_code: 53 | for line in snippet.get_text().split('\n'): 54 | if not (line.startswith("...") or line.startswith("*") or line.startswith("/") or line.startswith("<") or line.startswith("-->")): 55 | if line.startswith(">>> "): 56 | code += line[4:] + "\n" 57 | elif line.startswith(">>>"): 58 | code += line[3:] + "\n" 59 | elif line.startswith("$"): 60 | code += line[2:] + "\n" 61 | else: 62 | code += line + "\n" 63 | return code 64 | 65 | if __name__ == "__main__": 66 | args = parser.parse_args() 67 | 68 | popular_python_questions = get_popular_python_questions(1, 20, 50) 69 | 70 | next_id = 1 71 | for question in popular_python_questions: 72 | found_snippet = False 73 | while not found_snippet: 74 | try: 75 | random_answer = get_random_answer(question) 76 | except ValueError: 77 | break 78 | 79 | code = get_python_code(random_answer) 80 | 81 | if code: 82 | found_snippet = True 83 | 84 | if found_snippet: 85 | outfile = os.path.join(args.dest_dir, f"snippet_{next_id}.py") 86 | info = f"# Extracted from {question}" 87 | with open(outfile, "w") as f: 88 | f.write(info+"\n") 89 | f.write(code) 90 | next_id += 1 -------------------------------------------------------------------------------- /src/lexecutor/predictors/codet5/ModelServer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch as t 3 | import numpy as np 4 | from flask import Flask, json, request 5 | import requests 6 | from ..DLUtil import device 7 | from ...Hyperparams import Hyperparams as params 8 | from ...IIDs import IIDs 9 | from .FineTune import load_CodeT5 10 | from .InputFactory import InputFactory 11 | from ...Logging import logger 12 | import logging 13 | 14 | # TODO auto-kill the server after some time of inactivity 15 | 16 | 17 | class ModelServer: 18 | def __init__(self): 19 | self._initialize_model() 20 | self._initialize_http_server() 21 | 22 | def _fetch_model(self, model_path): 23 | path_to_url = { 24 | "data/released_models/codet5_model_20230105_fine-grained.bin": "https://github.com/michaelpradel/LExecutor/releases/download/Models_20230105/codet5_model_20230105_fine-grained.bin", 25 | "data/released_models/codet5_model_20230105_coarse-grained.bin": "https://github.com/michaelpradel/LExecutor/releases/download/Models_20230105/codet5_model_20230105_coarse-grained.bin" 26 | } 27 | if Path(model_path).exists(): 28 | return 29 | Path.mkdir(Path(model_path).parent, parents=True, exist_ok=True) 30 | logger.info(f"Downloading model from {path_to_url[model_path]}") 31 | request = requests.get(path_to_url[model_path], allow_redirects=True) 32 | open(model_path, 'wb').write(request.content) 33 | 34 | def _initialize_model(self): 35 | logger.info("Loading CodeT5 model") 36 | self.tokenizer, self.model = load_CodeT5() 37 | 38 | if params.value_abstraction == "fine-grained": 39 | model_path = "data/released_models/codet5_model_20230105_fine-grained.bin" 40 | elif params.value_abstraction == "coarse-grained-deterministic" or params.value_abstraction == "coarse-grained-randomized": 41 | model_path = "data/released_models/codet5_model_20230105_coarse-grained.bin" 42 | self._fetch_model(model_path) 43 | self.model.load_state_dict(t.load(model_path, map_location=device)) 44 | 45 | iids = IIDs(params.iids_file) 46 | self.input_factory = InputFactory(iids, self.tokenizer) 47 | logger.info("CodeT5 model loaded") 48 | 49 | def _initialize_http_server(self): 50 | logger.info("Starting HTTP server") 51 | api = Flask(__name__) 52 | flask_log = logging.getLogger('werkzeug') 53 | flask_log.setLevel(logging.ERROR) 54 | 55 | @api.route('/query', methods=['GET']) 56 | def handle_query(): 57 | # reconstruct entry from REST API request 58 | entry = {"iid": request.args.get("iid"), 59 | "name": request.args.get("name"), 60 | "kind": request.args.get("kind")} 61 | 62 | # turn query into vectors 63 | input_ids, _ = self.input_factory.entry_to_inputs(entry) 64 | input_ids = [tensor.cpu() for tensor in input_ids] 65 | 66 | # query the model and decode the result 67 | with t.no_grad(): 68 | self.model.eval() 69 | generated_ids = self.model.generate( 70 | t.tensor(np.array([input_ids]), device=device), max_length=params.max_output_length) 71 | 72 | predicted_value = self.tokenizer.decode( 73 | generated_ids[0], skip_special_tokens=True) 74 | 75 | if params.verbose: 76 | if self.tokenizer.bos_token_id not in generated_ids or self.tokenizer.eos_token_id not in generated_ids[0]: 77 | print( 78 | f"Warning: CodeT5 likely produced a garbage value: {predicted_value}") 79 | 80 | # respond with a JSON object 81 | result = {"v": predicted_value} 82 | return json.dumps(result) 83 | 84 | api.run() 85 | 86 | 87 | 88 | if __name__ == "__main__": 89 | ModelServer() 90 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/Type4PyValuePredictor.py: -------------------------------------------------------------------------------- 1 | from .RandomPredictor import RandomPredictor 2 | from ..Logging import logger 3 | from ..ValueAbstraction import restore_value 4 | from ..IIDs import IIDs 5 | import requests 6 | 7 | class Type4PyValuePredictor(RandomPredictor): 8 | def __init__(self, code_snippet_file, stats): 9 | super().__init__() 10 | self.type_predictions = self._query_model(code_snippet_file) 11 | print(self.type_predictions) 12 | self.stats = stats 13 | 14 | def _query_model(self, code_snippet_file): 15 | with open(code_snippet_file + '.orig') as file: 16 | raw_response = requests.post( 17 | "http://localhost:5001/api/predict?tc=0", file.read()) 18 | if raw_response.status_code != 200: 19 | raise RuntimeError( 20 | f"Model server returned error code {raw_response.status_code}") 21 | return raw_response.json() 22 | 23 | def _get_abstract_value(self, name): 24 | abstract_value = None 25 | predicted_type = False # boolean aux var 26 | 27 | if self.type_predictions["response"]: 28 | # global var 29 | if name in self.type_predictions["response"]["variables"]: 30 | abstract_value = self.type_predictions["response"]["variables_p"][name][0][0].split('[')[0].lower() 31 | predicted_type = True 32 | else: 33 | # in function 34 | functions = self.type_predictions["response"]["funcs"] 35 | if self.type_predictions["response"]["classes"]: 36 | for class_ in self.type_predictions["response"]["classes"]: 37 | functions += class_["funcs"] 38 | 39 | for fct in functions: 40 | # variables 41 | if name in fct["variables"]: 42 | if fct["variables_p"][name] and fct["variables_p"][name][0]: 43 | abstract_value = fct["variables_p"][name][0][0].split('[')[0].lower() 44 | predicted_type = True 45 | break 46 | # parameters 47 | elif name in fct["params"]: 48 | if fct["params_p"][name] and fct["params_p"][name][0]: 49 | abstract_value = fct["params_p"][name][0][0].split('[')[0].lower() 50 | predicted_type = True 51 | break 52 | # return 53 | elif "ret_type_p" in fct and name == fct["name"]: 54 | if fct["ret_type_p"] and fct["ret_type_p"][0]: 55 | abstract_value = fct["ret_type_p"][0][0].split('[')[0].lower() 56 | predicted_type = True 57 | break 58 | 59 | return abstract_value, predicted_type 60 | 61 | def name(self, iid, name): 62 | abstract_v, predicted_type = self._get_abstract_value(name) 63 | if predicted_type: 64 | self.stats.type4py_predictions += 1 65 | v = restore_value(abstract_v) 66 | logger.info(f"{iid}: Predicting with Type4Py for name {name}: {v}") 67 | return v 68 | else: 69 | self.stats.random_predictions += 1 70 | super().name(iid, name) 71 | 72 | def call(self, iid, fct, fct_name, *args, **kwargs): 73 | abstract_v, predicted_type = self._get_abstract_value(fct_name) 74 | if predicted_type: 75 | self.stats.type4py_predictions += 1 76 | v = restore_value(abstract_v) 77 | logger.info(f"{iid}: Predicting with Type4Py for call {fct_name}: {v}") 78 | return v 79 | else: 80 | self.stats.random_predictions += 1 81 | super().call(iid, fct, fct_name, *args, **kwargs) 82 | 83 | def attribute(self, iid, base, attr_name): 84 | abstract_v, predicted_type = self._get_abstract_value(attr_name) 85 | if predicted_type: 86 | self.stats.type4py_predictions += 1 87 | v = restore_value(abstract_v) 88 | logger.info(f"{iid}: Predicting with Type4Py for attribute {attr_name}: {v}") 89 | return v 90 | else: 91 | self.stats.random_predictions += 1 92 | super().attribute(iid, base, attr_name) 93 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codebert/CodeBERTValuePredictor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch as t 3 | import numpy as np 4 | from ..ValuePredictor import ValuePredictor 5 | from ..DLUtil import device 6 | from .CodeBERT import load_CodeBERT 7 | from .InputFactory import InputFactory 8 | from ...Logging import logger 9 | from transformers import pipeline 10 | import time 11 | import requests 12 | from requests.exceptions import ConnectionError 13 | import subprocess 14 | from ...ValueAbstraction import restore_value 15 | from ...Hyperparams import Hyperparams as params 16 | from ...IIDs import IIDs 17 | 18 | 19 | class CodeBERTValuePredictor(ValuePredictor): 20 | def __init__(self, stats): 21 | self.stats = stats 22 | 23 | # load model 24 | self.tokenizer, self.model = load_CodeBERT() 25 | if params.value_abstraction == "fine-grained": 26 | model_path = "data/released_models/codebert_model_20232906_fine-grained.bin" 27 | elif params.value_abstraction == "coarse-grained-deterministic" or params.value_abstraction == "coarse-grained-randomized": 28 | model_path = "data/released_models/codebert_model_20232906_coarse-grained.bin" 29 | self._fetch_model(model_path) 30 | self.model.load_state_dict(t.load( 31 | model_path, map_location=device)) 32 | self.model.to(device) 33 | logger.info("CodeBERT model loaded") 34 | 35 | self.iids = IIDs(params.iids_file) 36 | self.stats = stats 37 | self.input_factory = InputFactory(self.iids, self.tokenizer) 38 | 39 | def _fetch_model(self, model_path): 40 | path_to_url = { 41 | "data/released_models/codebert_model_20232906_fine-grained.bin": "https://github.com/michaelpradel/LExecutor/releases/download/Models_20230105/codebert_model_20232906_fine-grained.bin", 42 | "data/released_models/codebert_model_20232906_coarse-grained.bin": "https://github.com/michaelpradel/LExecutor/releases/download/Models_20230105/codebert_model_20232906_coarse-grained.bin" 43 | } 44 | if Path(model_path).exists(): 45 | return 46 | Path.mkdir(Path(model_path).parent, parents=True, exist_ok=True) 47 | logger.info(f"Downloading model from {path_to_url[model_path]}") 48 | request = requests.get(path_to_url[model_path], allow_redirects=True) 49 | open(model_path, 'wb').write(request.content) 50 | 51 | def _query_model(self, entry): 52 | # turn entry into vectors 53 | input_ids, _ = self.input_factory.entry_to_inputs(entry) 54 | input_ids = [tensor.to(device) for tensor in input_ids] 55 | 56 | # query the model and decode the result 57 | with t.no_grad(): 58 | self.model.eval() 59 | 60 | fill_mask = pipeline('fill-mask', model=self.model, tokenizer=self.tokenizer, device=0, framework="pt") 61 | 62 | # This is required because the fill-mask pipeline adds special tokens during encoding. 63 | # If use skip_special_tokens=True, is discarded as well 64 | INPUT = self.tokenizer.decode(input_ids) 65 | INPUT = INPUT.replace("", "") 66 | INPUT = INPUT.replace("", "") 67 | INPUT = INPUT.replace("", "") 68 | 69 | predictions = fill_mask(INPUT) 70 | 71 | val_as_string = predictions[0]['token_str'] 72 | val = restore_value(val_as_string) 73 | 74 | return val_as_string, val 75 | 76 | def name(self, iid, name): 77 | entry = {"iid": iid, "name": name, "kind": "name"} 78 | abstract_v, v = self._query_model(entry) 79 | logger.info(f"{iid}: Predicting for name {name}: {v}") 80 | self.stats.inject_value( 81 | iid, f"Inject {abstract_v} for variable {name}") 82 | return v 83 | 84 | def call(self, iid, fct, fct_name, *args, **kwargs): 85 | entry = {"iid": iid, "name": fct_name, "kind": "call"} 86 | abstract_v, v = self._query_model(entry) 87 | logger.info(f"{iid}: Predicting for call: {v}") 88 | self.stats.inject_value( 89 | iid, f"Inject {abstract_v} as return value of {fct_name}") 90 | return v 91 | 92 | def attribute(self, iid, base, attr_name): 93 | entry = {"iid": iid, "name": attr_name, "kind": "attribute"} 94 | abstract_v, v = self._query_model(entry) 95 | logger.info(f"{iid}: Predicting for attribute {attr_name}: {v}") 96 | self.stats.inject_value( 97 | iid, f"Inject {abstract_v} for attribute {attr_name}") 98 | return v 99 | -------------------------------------------------------------------------------- /src/lexecutor/Instrument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path 3 | import libcst as cst 4 | from .CodeRewriter import CodeRewriter 5 | from .IIDs import IIDs 6 | from .Util import gather_files 7 | import re 8 | from shutil import copyfile, move 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument( 13 | "--files", help="Python files to instrument or .txt file with all file paths", nargs="+") 14 | parser.add_argument( 15 | "--iids", help="JSON file with instruction IDs", default="iids.json") 16 | parser.add_argument( 17 | "--restore", help="Restores uninstrumented files from .py.orig files", action="store_true") 18 | parser.add_argument( 19 | "--line_coverage_instrumentation", help="Instruments files to calculate line coverage", action="store_true") 20 | parser.add_argument( 21 | "--validate", help="Validate syntactic correctness of the instrumented code (and skip a file if syntactically incorrect)", action="store_true") 22 | parser.add_argument( 23 | "--verbose", help="Print details, e.g., about exceptions during instrumentation", action="store_true") 24 | 25 | 26 | ignored_file_suffixes = [ 27 | "ansible/utils/collection_loader/_collection_finder.py", 28 | "ansible/constants.py", 29 | "django/db/models/expressions.py" 30 | ] 31 | 32 | 33 | def gather_accessed_names(ast_wrapper): 34 | scopes = set(ast_wrapper.resolve(cst.metadata.ScopeProvider).values()) 35 | ranges = ast_wrapper.resolve(cst.metadata.PositionProvider) 36 | used_names = set() 37 | for scope in scopes: 38 | for access in scope.accesses: 39 | name = access.node 40 | 41 | # check for reads of class variables defined in the same class 42 | # (we cannot wrap them into a lambda) 43 | if isinstance(scope, cst.metadata.ClassScope) and (all(ref.scope == scope for ref in access.referents)): 44 | continue 45 | 46 | used_names.add(name) 47 | 48 | return used_names 49 | 50 | 51 | def instrument_file(file_path, iids, line_coverage_instrumentation, validate): 52 | for suffix in ignored_file_suffixes: 53 | if file_path.endswith(suffix): 54 | print(f"{file_path} is on blacklist -- skipping it") 55 | return 56 | 57 | with open(file_path, "r") as file: 58 | src = file.read() 59 | 60 | if "LExecutor: DO NOT INSTRUMENT" in src: 61 | print(f"{file_path} is already instrumented -- skipping it") 62 | return 63 | 64 | ast = cst.parse_module(src) 65 | ast_wrapper = cst.metadata.MetadataWrapper(ast) 66 | accessed_names = gather_accessed_names(ast_wrapper) 67 | 68 | code_rewriter = CodeRewriter(file_path, iids, line_coverage_instrumentation, accessed_names) 69 | rewritten_ast = ast_wrapper.visit(code_rewriter) 70 | rewritten_code = "# LExecutor: DO NOT INSTRUMENT\n\n" + rewritten_ast.code 71 | 72 | if validate: 73 | try: 74 | cst.parse_module(rewritten_code) 75 | except Exception as e: 76 | print(f"Error while validating {file_path}. Ignoring this file.") 77 | if args.verbose: 78 | print(e) 79 | return 80 | 81 | copied_file_path = re.sub(r"\.py$", ".py.orig", file_path) 82 | copyfile(file_path, copied_file_path) 83 | 84 | with open(file_path, "w") as file: 85 | file.write(rewritten_code) 86 | 87 | 88 | def restore_file(file_path): 89 | orig_file_path = re.sub(r"\.py$", ".py.orig", file_path) 90 | if path.isfile(orig_file_path): 91 | move(orig_file_path, file_path) 92 | return True 93 | else: 94 | return False 95 | 96 | 97 | if __name__ == "__main__": 98 | args = parser.parse_args() 99 | files = gather_files(args.files) 100 | if not args.restore: 101 | print(f"Found {len(files)} file(s) to instrument") 102 | iids = IIDs(args.iids) 103 | for file_path in files: 104 | try: 105 | print(f"Instrumenting {file_path}") 106 | instrument_file(file_path, iids, args.line_coverage_instrumentation, args.validate) 107 | except Exception as e: 108 | print(f"Error while instrumenting {file_path}. Ignoring this file.") 109 | if args.verbose: 110 | print(e) 111 | iids.store() 112 | else: 113 | nb_restored = 0 114 | for file_path in files: 115 | if restore_file(file_path): 116 | nb_restored += 1 117 | print(f"Have restored {nb_restored} out of {len(files)} file(s)") 118 | -------------------------------------------------------------------------------- /src/lexecutor/RuntimeStats.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | from os import path 4 | import csv 5 | import time 6 | from .Logging import logger 7 | from .IIDs import IIDs 8 | from .Hyperparams import Hyperparams as param 9 | 10 | write_event_trace = True 11 | write_metrics = True 12 | 13 | 14 | class RuntimeStats: 15 | def __init__(self, execution): 16 | self.total_uses = 0 17 | self.guided_uses = 0 18 | 19 | self.covered_iids = set() 20 | self.executed_lines = [] 21 | 22 | if write_event_trace: 23 | self.event_trace = [] 24 | self.iids = IIDs(param.iids_file) 25 | 26 | self.random_predictions = 0 27 | self.type4py_predictions = 0 28 | 29 | self.execution = execution 30 | 31 | def cover_iid(self, iid): 32 | self.covered_iids.add(iid) 33 | if write_event_trace: 34 | self.event_trace.append(f"Line {self.iids.line(iid)}: Executed") 35 | 36 | def cover_line(self, iid): 37 | self.executed_lines.append(iid) 38 | logger.info(f"Line {self.iids.line(iid)}: Executed") 39 | 40 | def inject_value(self, iid, msg): 41 | if write_event_trace: 42 | self.event_trace.append( 43 | f"Line {self.iids.line(iid)}: {msg}") 44 | 45 | def uncaught_exception(self, iid, e): 46 | if write_event_trace: 47 | self.event_trace.append( 48 | f"Line {self.iids.line(iid)}: Uncaught exception {type(e)}\n{e}") 49 | 50 | def print(self): 51 | logger.info(f"Covered iids: {len(self.covered_iids)}") 52 | logger.info(f"Total uses: {self.total_uses}") 53 | logger.info(f"Guided uses : {self.guided_uses}/{self.total_uses}") 54 | 55 | def _save_summary_metrics(self, file, predictor_name, execution_time): 56 | if write_metrics: 57 | if param.dataset == "so_snippets": 58 | project_name = "" 59 | file_name = file.split("/")[2].split('.')[0] 60 | elif param.dataset == "random_functions": 61 | project_name = file.split("/")[2] 62 | file_name = file.split("/")[4].split('.')[0] 63 | else: 64 | project_name = "" 65 | file_name = file.split("/")[-1].split(".")[0] 66 | 67 | if predictor_name == 'CodeT5ValuePredictor' or predictor_name == 'CodeBERTValuePredictor': 68 | predictor_name = f'{predictor_name}_{param.value_abstraction}' 69 | 70 | # Create destination dir if it doesn't exist 71 | if not os.path.exists('./metrics'): 72 | os.makedirs('./metrics') 73 | if not os.path.exists(f'./metrics/{param.dataset}'): 74 | os.makedirs(f'./metrics/{param.dataset}') 75 | if not os.path.exists(f'./metrics/{param.dataset}/{predictor_name}'): 76 | os.makedirs(f'./metrics/{param.dataset}/{predictor_name}') 77 | if not os.path.exists(f'./metrics/{param.dataset}/{predictor_name}/raw'): 78 | os.makedirs(f'./metrics/{param.dataset}/{predictor_name}/raw') 79 | 80 | # Create CSV file and add header if it doesn't exist 81 | csv_file = f'./metrics/{param.dataset}/{predictor_name}/raw/metrics_{project_name}_{file_name}_{self.execution}.csv' 82 | if not os.path.isfile(csv_file): 83 | columns = ['file', 'predictor', 'covered_iids', 84 | 'total_uses', 'guided_uses', 'executed_lines', 85 | 'covered_lines', 'execution_time', 'random_predictions', 86 | 'type4py_predictions', 'execution'] 87 | 88 | with open(csv_file, 'a') as fp: 89 | writer = csv.writer(fp) 90 | writer.writerow(columns) 91 | print(f"Wrote metrics to {csv_file}") 92 | 93 | df = pd.read_csv(csv_file) 94 | df_new_data = pd.DataFrame({ 95 | 'file': [file], 96 | 'predictor': [predictor_name], 97 | 'covered_iids': [len(self.covered_iids)], 98 | 'total_uses': [self.total_uses], 99 | 'guided_uses': [self.guided_uses], 100 | 'executed_lines': [len(self.executed_lines)], 101 | 'covered_lines': [len(set(self.executed_lines))], 102 | 'execution_time': [execution_time], 103 | 'random_predictions': [self.random_predictions], 104 | 'type4py_predictions': [self.type4py_predictions], 105 | 'execution': [self.execution] 106 | }) 107 | df = pd.concat([df, df_new_data]) 108 | df.to_csv(csv_file, index=False) 109 | 110 | def _save_event_trace(self): 111 | with open("trace.txt", "w") as fp: 112 | fp.write("\n".join(self.event_trace)) 113 | 114 | def save(self, file, predictor_name, start_time): 115 | self._save_summary_metrics(file, predictor_name, time.time() - start_time) 116 | if write_event_trace: 117 | self._save_event_trace() 118 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/CombineData.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | 4 | folder_path = "./metrics/" 5 | 6 | # Get a list of used datasets 7 | datasets = os.listdir(folder_path) 8 | 9 | for dataset in datasets: 10 | # Get a list of used predictors 11 | predictors = os.listdir(folder_path + dataset) 12 | 13 | combined_df_for_dataset = pd.DataFrame() 14 | 15 | for predictor in predictors: 16 | # Get a list of raw metric files 17 | files = os.listdir(f'{folder_path}{dataset}/{predictor}/raw') 18 | 19 | all_executions_df = pd.DataFrame() 20 | 21 | for execution in range(1, 11): 22 | 23 | # Filter the files to include only the ones that match the pattern "metrics_x.csv" 24 | matching_files = [file for file in files if file.startswith("metrics_") and file.endswith(f"_{execution}.csv")] 25 | 26 | combined_df_for_predictor = pd.DataFrame() 27 | 28 | for file in matching_files: 29 | try: 30 | df = pd.read_csv(f'{folder_path}{dataset}/{predictor}/raw/{file}') 31 | combined_df_for_predictor = pd.concat([combined_df_for_predictor, df], ignore_index=True) 32 | except pd.errors.EmptyDataError: 33 | print(file) 34 | 35 | files = combined_df_for_predictor.file.unique() 36 | 37 | for file in files: 38 | indexes = combined_df_for_predictor.index[combined_df_for_predictor['file'] == file].tolist() 39 | for index in indexes[:-1]: 40 | combined_df_for_predictor = combined_df_for_predictor.drop(index=index) 41 | 42 | all_executions_df = pd.concat([all_executions_df, combined_df_for_predictor], ignore_index=True) 43 | 44 | combined_df_for_predictor = all_executions_df.groupby('file', as_index=False)["covered_iids","total_uses","guided_uses","covered_lines","executed_lines", "execution_time", "random_predictions","type4py_predictions"].mean() 45 | combined_df_for_predictor['predictor'] = [predictor] * len(combined_df_for_predictor) 46 | 47 | if predictor == 'PynguinTests': 48 | aux_df = pd.read_csv("wrapp_info.csv") 49 | 50 | combined_df_for_predictor = combined_df_for_predictor.merge(aux_df, on='file', how='left') 51 | combined_df_for_predictor['covered_lines'] = combined_df_for_predictor['covered_lines'] - combined_df_for_predictor['wrapped'] - 1 52 | combined_df_for_predictor['covered_lines'] = combined_df_for_predictor.apply(lambda x: x['covered_lines'] if x['covered_lines']>=0 else 0, axis=1) 53 | 54 | combined_df_for_predictor['executed_lines'] = combined_df_for_predictor['executed_lines'] - combined_df_for_predictor['wrapped'] - 1 55 | combined_df_for_predictor['executed_lines'] = combined_df_for_predictor.apply(lambda x: x['executed_lines'] if x['executed_lines']>=0 else 0, axis=1) 56 | combined_df_for_predictor.drop(['wrapped'], inplace=True, axis=1) 57 | 58 | combined_df_for_predictor['file'] = combined_df_for_predictor['file'].str.replace("pynguin_tests/test_", "", regex=True) 59 | combined_df_for_predictor['file'] = combined_df_for_predictor['file'].str.replace("[_]", "/", regex=True) 60 | combined_df_for_predictor['file'] = combined_df_for_predictor['file'].str.replace("popular/projects/snippets/dataset", "popular_projects_snippets_dataset", regex=True) 61 | combined_df_for_predictor['file'] = combined_df_for_predictor['file'].str.replace("functions", "bodies", regex=True) 62 | combined_df_for_predictor['file'] = combined_df_for_predictor['file'].str.replace("function/", "body_", regex=True) 63 | 64 | elif predictor == 'Type4PyValuePredictor': 65 | aux_df = pd.read_csv("aux_data_functions_with_invocation_dataset.csv") 66 | 67 | combined_df_for_predictor = combined_df_for_predictor.merge(aux_df, on='file', how='left') 68 | combined_df_for_predictor['covered_lines'] = combined_df_for_predictor['covered_lines'] - combined_df_for_predictor['lines_to_discard'] 69 | combined_df_for_predictor['covered_lines'] = combined_df_for_predictor.apply(lambda x: x['covered_lines'] if x['covered_lines']>=0 else 0, axis=1) 70 | 71 | combined_df_for_predictor['executed_lines'] = combined_df_for_predictor['executed_lines'] - combined_df_for_predictor['lines_to_discard'] 72 | combined_df_for_predictor['executed_lines'] = combined_df_for_predictor.apply(lambda x: x['executed_lines'] if x['executed_lines']>=0 else 0, axis=1) 73 | combined_df_for_predictor.drop(['lines_to_discard'], inplace=True, axis=1) 74 | 75 | combined_df_for_predictor.to_csv(f'{folder_path}{dataset}/{predictor}/metrics.csv', index=False) 76 | 77 | combined_df_for_dataset = pd.concat([combined_df_for_dataset, combined_df_for_predictor], ignore_index=True) 78 | 79 | combined_df_for_dataset['file'] = combined_df_for_dataset['file'].str.replace("functions", "bodies", regex=True) 80 | combined_df_for_dataset['file'] = combined_df_for_dataset['file'].str.replace("function", "body", regex=True) 81 | combined_df_for_dataset.to_csv(f'{folder_path}metrics_{dataset}_dataset.csv', index=False) -------------------------------------------------------------------------------- /src/lexecutor/predictors/codebert/PrepareData.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import pandas as pd 4 | import numpy as np 5 | import torch as t 6 | from ...Logging import logger 7 | from ...Util import gather_files 8 | from .CodeBERT import load_CodeBERT 9 | from ...Hyperparams import Hyperparams as params 10 | from ...IIDs import IIDs 11 | from .InputFactory import InputFactory 12 | from ...ValueAbstraction import fine_to_coarse_grained 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--iids", help="JSON file with instruction IDs", required=True) 18 | parser.add_argument( 19 | "--traces", help="Trace file or .txt file(s) with all trace file paths to use", 20 | nargs="+", required=True) 21 | parser.add_argument( 22 | "--output_dir", help="directory to store tensors", required=True) 23 | parser.add_argument( 24 | "--output_suffix", help="Suffix to append to output file names (if nothing given: train.pt, validate.pt)") 25 | 26 | 27 | def read_traces(trace_files): 28 | logger.info("Loading trace files") 29 | df = pd.DataFrame(data=None) 30 | trace_files = gather_files(trace_files, suffix=".h5") 31 | for trace_file in trace_files: 32 | current_df = pd.read_hdf(trace_file, key="entries") 33 | df = pd.concat([df, current_df]) 34 | return df 35 | 36 | 37 | def abstract_trace_entries(entries): 38 | if params.value_abstraction.startswith("coarse-grained"): 39 | logger.info("Abstracting trace entries to use coarse-grained values") 40 | entries.replace({"value": fine_to_coarse_grained}, inplace=True) 41 | 42 | 43 | def dedup_trace_entries(entries): 44 | logger.info(f"Deduplicating {len(entries)} trace entries") 45 | if params.dedup == "name-value-iid": 46 | entries.drop_duplicates( 47 | subset=["iid", "name", "value", "kind"], inplace=True) 48 | elif params.dedup == "name-value": 49 | entries.drop_duplicates(subset=["name", "value"], inplace=True) 50 | else: 51 | raise ValueError(f"Unknown dedup mode: {params.dedup}") 52 | 53 | # TODO the following handles some bug in trace gathering, which should be fixed there 54 | entries.drop(entries[entries.name.astype(str).str.startswith( 55 | "MarkDecorator")].index, inplace=True) 56 | 57 | logger.info(f"After deduplicating: {len(entries)} trace entries") 58 | 59 | 60 | def clean_entries(entries): 61 | before_len = len(entries) 62 | # remove entries with invalid names (e.g. "functools.partial( --out_dir 14 | 2) Train LExecutor on increasingly large datasets: 15 | CompareDatasetSizes --train --in_dir [--size ] 16 | 3) Produce raw results (to be used for plotting, etc.): 17 | CompareDatasetSizes --stats --in_dir 18 | """ 19 | parser = argparse.ArgumentParser(description=description) 20 | parser.add_argument("--prepare", action="store_true") 21 | parser.add_argument( 22 | "--tensors", help=".pt files with data to use for training and validation (pass when using --prepare)", nargs="+") 23 | parser.add_argument( 24 | "--out_dir", help="output folder for re-arranged training and validation data (pass when using --prepare)") 25 | parser.add_argument("--train", action="store_true") 26 | parser.add_argument( 27 | "--in_dir", help="folder with training and validation produced with --prepare (pass when using --train)") 28 | parser.add_argument( 29 | "--size", help="fix the index of the run (optional; one index or comma-separated indices; pass when using --train)") 30 | parser.add_argument( 31 | "--stats", help="summarize results of the experiment", action="store_true") 32 | 33 | 34 | def load_data(tensor_files): 35 | print(f"Reading {len(tensor_files)} tensor files") 36 | tensors = [] 37 | for tensor_file in tensor_files: 38 | tensors.append(t.load(tensor_file)) 39 | combined = t.cat(tensors, 0) 40 | print(f"Combined tensors into one dataset of size {len(combined)}") 41 | return combined 42 | 43 | 44 | def rearrange_data(all_data): 45 | # compute indices for training (w/ increasing datasets) and validation 46 | shuffled_indices = t.randperm(len(all_data)) 47 | train_indices = shuffled_indices[:int(params.perc_train * len(all_data))] 48 | validate_indices = shuffled_indices[int(params.perc_train * len(all_data)):] 49 | one_portion_length = int(len(train_indices) / 10) 50 | increasing_train_data_subset_indices = [] 51 | for idx in range(0, 10): 52 | subset = train_indices[:one_portion_length * (idx + 1)] 53 | increasing_train_data_subset_indices.append(subset) 54 | 55 | # create tensors based on indices 56 | train_data_subsets = [] 57 | for subset_idx, subset in enumerate(increasing_train_data_subset_indices): 58 | next_subset = all_data.index_select(0, subset) 59 | train_data_subsets.append(next_subset) 60 | print(f"Subset {subset_idx} has {len(next_subset)} data points") 61 | validate_data = all_data.index_select(0, validate_indices) 62 | 63 | return train_data_subsets, validate_data 64 | 65 | 66 | def store_data(increasing_train_data_subsets, validate_data, out_dir): 67 | for idx, train_data_subset in enumerate(increasing_train_data_subsets): 68 | out_file = f"{out_dir}/train{idx}.pt" 69 | print(f"Storing training data subset with {len(train_data_subset)} data points into {out_file}") 70 | t.save(train_data_subset, out_file) 71 | t.save(validate_data, f"{out_dir}/validate.pt") 72 | print(f"Stored training and validation data into {out_dir}") 73 | 74 | 75 | def run_training_with_size(in_dir, idx): 76 | print(f"Training on subset {idx}") 77 | stats_dir = f"{in_dir}/stats{idx}" 78 | if not os.path.exists(stats_dir): 79 | os.makedirs(stats_dir) 80 | run(["python", "-m", "lexecutor.predictors.codet5.FineTune", 81 | "--train_tensors", f"{in_dir}/train{idx}.pt", 82 | "--validate_tensors", f"{in_dir}/validate.pt", 83 | "--output_dir", f"{in_dir}/model{idx}", 84 | "--stats_dir", stats_dir]) 85 | 86 | 87 | def print_stats(in_dir): 88 | total_value_use_pairs = 0 89 | 90 | validation_file = f"{in_dir}/validate.pt" 91 | nb_validation_pairs = len(load_data([validation_file])) 92 | print(f"Validation pairs: {nb_validation_pairs}") 93 | total_value_use_pairs = nb_validation_pairs 94 | 95 | train_sizes = [] 96 | accuracies = [] 97 | for i in range(10): 98 | training_file = f"{in_dir}/train{i}.pt" 99 | nb_training_pairs = len(load_data([training_file])) 100 | stats_dir = f"{in_dir}/stats{i}" 101 | accuracy_file = join(stats_dir, "validation_acc.csv") 102 | if exists(accuracy_file): # otherwise, experiment hasn't finished yet 103 | with open(accuracy_file, "r") as fp: 104 | accuracy_reader = csv.reader(fp) 105 | rows = list(accuracy_reader) 106 | if len(rows) == 6: 107 | # experiment has finished for this dataset size 108 | acc = rows[5][1] 109 | train_sizes.append(str(nb_training_pairs)) 110 | accuracies.append(acc) 111 | total_value_use_pairs += nb_training_pairs 112 | 113 | print(f"train_sizes = [{', '.join(train_sizes)}]") 114 | print(f"accuracies = [{', '.join(accuracies)}]") 115 | 116 | 117 | if __name__ == "__main__": 118 | args = parser.parse_args() 119 | if args.prepare: 120 | all_data = load_data(args.tensors) 121 | increasing_train_data_subsets, validate_data = rearrange_data(all_data) 122 | store_data(increasing_train_data_subsets, validate_data, args.out_dir) 123 | elif args.train: 124 | if args.size: 125 | if "," in args.size: 126 | sizes = [int(s) for s in args.size.split(",")] 127 | else: 128 | sizes = [int(args.size)] 129 | else: 130 | sizes = list(range(0, 10)) 131 | 132 | for idx in sizes: 133 | run_training_with_size(args.in_dir, idx) 134 | elif args.stats: 135 | print_stats(args.in_dir) 136 | 137 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codet5/InputFactory.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import re 4 | import torch as t 5 | from ..DLUtil import dtype, device 6 | from ...Logging import logger 7 | from ...Hyperparams import Hyperparams as params 8 | 9 | 10 | # special tokens already provided by the tokenizer 11 | mask_token = "" 12 | kind_name_token = "" 13 | kind_call_token = "" 14 | kind_attribute_token = "" 15 | sep_token = "" 16 | 17 | 18 | class InputFactory(object): 19 | 20 | def __init__(self, iids, tokenizer): 21 | self.iids = iids 22 | self.tokenizer = tokenizer 23 | self.file_to_lines = {} 24 | self.file_to_tokenized_lines = {} 25 | 26 | self.kind_name_token_id = self.tokenizer( 27 | kind_name_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 28 | self.kind_call_token_id = self.tokenizer( 29 | kind_call_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 30 | self.kind_attribute_token_id = self.tokenizer( 31 | kind_attribute_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 32 | self.sep_token_id = self.tokenizer( 33 | sep_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 34 | 35 | def __tokenize_lines(self, file_name): 36 | if file_name in self.file_to_tokenized_lines: 37 | return self.file_to_lines[file_name], self.file_to_tokenized_lines[file_name] 38 | else: 39 | with open(file_name, "r") as f: 40 | lines = f.readlines() 41 | self.file_to_lines[file_name] = lines 42 | tokenized_lines = self.tokenizer( 43 | lines, return_attention_mask=False, add_special_tokens=False).input_ids 44 | self.file_to_tokenized_lines[file_name] = tokenized_lines 45 | 46 | if len(self.file_to_tokenized_lines) > 10000: # prevent OOM 47 | self.file_to_lines = {} 48 | self.file_to_tokenized_lines = {} 49 | 50 | return lines, tokenized_lines 51 | 52 | def _extract_context_window(self, token_ids, marker_token): 53 | # Get at most 512 tokens around the target token 54 | id_of_target_begin = self.tokenizer.encode(marker_token)[1] 55 | target_index = token_ids.index(id_of_target_begin) 56 | 57 | # fewer context before target 58 | if target_index < 255: 59 | previous_target_tokens = token_ids[0:target_index] 60 | after_target_tokens = token_ids[target_index:target_index + 61 | (512 - len(previous_target_tokens))] 62 | # fewer context after target 63 | elif target_index + 255 > len(token_ids): 64 | after_target_tokens = token_ids[target_index:] 65 | previous_target_tokens = token_ids[target_index - 66 | (512 - len(after_target_tokens)):target_index] 67 | # equal context before and after target 68 | else: 69 | previous_target_tokens = token_ids[target_index-255:target_index] 70 | after_target_tokens = token_ids[target_index:target_index+255] 71 | 72 | return previous_target_tokens, after_target_tokens 73 | 74 | 75 | def _encode_input(self, entry, location, lines, tokenized_lines): 76 | # format of input: 77 | # name kind pre-context post-context 78 | 79 | target_line = lines[location.line-1] 80 | 81 | start_index = location.column_start 82 | end_index = location.column_end 83 | 84 | modified_line = target_line[:start_index] + \ 85 | mask_token + target_line[end_index:] 86 | 87 | tokenized_target_line = self.tokenizer( 88 | modified_line, return_attention_mask=False, add_special_tokens=False).input_ids 89 | 90 | # store and later restore the original tokenized line 91 | # (because we use the same tokenized_lines for all entries) 92 | original_tokenized_target_line = tokenized_lines[location.line-1] 93 | tokenized_lines[location.line-1] = tokenized_target_line 94 | token_ids = list(itertools.chain(*tokenized_lines)) 95 | tokenized_lines[location.line-1] = original_tokenized_target_line 96 | 97 | name = entry["name"] 98 | name_ids = self.tokenizer(name, return_attention_mask=False, 99 | add_special_tokens=False).input_ids 100 | 101 | previous_target_tokens, after_target_tokens = self._extract_context_window( 102 | token_ids, mask_token) 103 | context_ids = previous_target_tokens + after_target_tokens 104 | 105 | # shrink context to fit everything (incl. the variable-sized name_ids) into 512 tokens 106 | while len(name_ids) + len(context_ids) + 5 > 512: 107 | context_ids = context_ids[1:-1] 108 | 109 | if entry["kind"] == "name": 110 | kind_token = self.kind_name_token_id 111 | elif entry["kind"] == "call": 112 | kind_token = self.kind_call_token_id 113 | elif entry["kind"] == "attribute": 114 | kind_token = self.kind_attribute_token_id 115 | 116 | input_ids = [self.tokenizer.bos_token_id] + \ 117 | name_ids + \ 118 | [self.sep_token_id, kind_token, self.sep_token_id] + \ 119 | context_ids + \ 120 | [self.tokenizer.eos_token_id] 121 | 122 | # Add padding 123 | if len(input_ids) < 512: 124 | input_ids = input_ids + \ 125 | (512 - len(input_ids)) * [self.tokenizer.pad_token_id] 126 | 127 | return input_ids 128 | 129 | def _encode_output(self, entry): 130 | # Create labels 131 | if not hasattr(entry, "value"): 132 | # during prediction 133 | value = "unknown" 134 | else: 135 | # during training 136 | assert entry["value"].startswith("@"), entry["value"] 137 | value = entry["value"][1:] 138 | 139 | label_ids = self.tokenizer( 140 | value, padding="max_length", max_length=params.max_output_length).input_ids 141 | return label_ids 142 | 143 | def entry_to_inputs(self, entry): 144 | location = self.iids.location(str(entry["iid"])) 145 | 146 | lines, tokenized_lines = self.__tokenize_lines(location.file+'.orig') 147 | 148 | input_ids = self._encode_input(entry, location, lines, tokenized_lines) 149 | label_ids = self._encode_output(entry) 150 | 151 | input_ids = t.tensor(input_ids, device='cpu') 152 | label_ids = t.tensor(label_ids, device='cpu') 153 | 154 | assert len(input_ids) == 512, len(input_ids) 155 | assert len(label_ids) == params.max_output_length, len(label_ids) 156 | return input_ids, label_ids 157 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codebert/FineTune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import torch as t 5 | import csv 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import DataLoader, TensorDataset 9 | from transformers import AdamW, pipeline 10 | from .CodeBERT import load_CodeBERT 11 | from ...Hyperparams import Hyperparams as params 12 | from ..DLUtil import device 13 | from ...Logging import logger 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--train_tensors", help=".pt files for training", default="train.pt") 19 | parser.add_argument( 20 | "--validate_tensors", help=".pt files for validation", default="validate.pt") 21 | parser.add_argument( 22 | "--output_dir", help="directory to store models", required=True) 23 | 24 | 25 | print_examples = True 26 | 27 | 28 | def evaluate(validate_tensors_path, model, tokenizer): 29 | validate_dataset = TensorDataset(t.load(validate_tensors_path)) 30 | validate_loader = DataLoader( 31 | validate_dataset, batch_size=params.batch_size_CodeBERT, drop_last=True) 32 | 33 | logger.info("Starting evaluation") 34 | logger.info(" Num examples = {}".format(len(validate_dataset))) 35 | logger.info(" Num batches = {}".format(len(validate_loader))) 36 | logger.info(" Batch size = {}".format(params.batch_size_CodeBERT)) 37 | 38 | k_max = 5 39 | k_to_all_accuracies = {k: [] for k in range(1, k_max+1)} 40 | all_inputs = [] 41 | all_labels = [] 42 | all_predictions = [] 43 | 44 | with t.no_grad(): 45 | model.eval() 46 | 47 | fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0, framework="pt") 48 | 49 | for batch_idx, batch in enumerate(validate_loader): 50 | print(f"Batch: {batch_idx}") 51 | batch = t.cat(batch) 52 | input_ids = batch[:, 0:512] 53 | input_ids = input_ids.to(device) 54 | label_ids = batch[:, 512:] 55 | label_ids = label_ids.to(device) 56 | 57 | labels = tokenizer.batch_decode( 58 | label_ids, skip_special_tokens=True) 59 | 60 | # combine top-most prediction obtained via normal sampling and top-2, top-3, etc. predictions obtained via top-p nucleus sampling 61 | # 1) top-most prediction obtained via normal sampling 62 | 63 | masked_index = t.nonzero(input_ids == tokenizer.mask_token_id, as_tuple=False) 64 | 65 | # This is required because the fill-mask pipeline adds special tokens during encoding. 66 | # If use skip_special_tokens=True, is discarded as well 67 | INPUT = tokenizer.batch_decode(input_ids) 68 | for i in range(len(INPUT)): 69 | input = INPUT[i].replace("", "") 70 | input = input.replace("", "") 71 | input = input.replace("", "") 72 | INPUT[i] = input 73 | 74 | predictions = fill_mask(INPUT) 75 | 76 | corrects = [1 for i in range( 77 | len(labels)) if label_ids[i][int(masked_index[i][1])] == predictions[i][0]['token']] 78 | top1_accuracy = float(len(corrects)) / len(labels) 79 | k_to_all_accuracies[1].append(top1_accuracy) 80 | 81 | # for debugging/eye-balling the results 82 | all_inputs.extend(INPUT) 83 | all_labels.extend(labels) 84 | all_predictions.extend(predictions) 85 | if print_examples: 86 | for label_idx, label in enumerate(labels): 87 | if random.uniform(0, 100) < 0.1: 88 | prediction = predictions[label_idx] 89 | logger.info( 90 | f"Label: {label}, Prediction: {prediction}") 91 | 92 | # 2) count correct predictions among different top-k (top-2, top-3, etc.) 93 | k_to_corrects = {k: 0 for k in range(1, k_max+1)} 94 | # for top-1, use regular predictions from above 95 | k_to_corrects[1] = len(corrects) 96 | i = 0 97 | while i < len(labels): 98 | topk_predictions_for_example = [prediction['token'] for prediction in predictions[i][:k_max]] 99 | label_for_example = label_ids[i][int(masked_index[i][1])] 100 | for k in range(2, k_max+1): 101 | if label_for_example in topk_predictions_for_example[:k]: 102 | k_to_corrects[k] += 1 103 | i += 1 104 | 105 | # compute top-k accuracies 106 | for k, corrects in k_to_corrects.items(): 107 | accuracy = float(corrects) / len(labels) 108 | k_to_all_accuracies[k].append(accuracy) 109 | 110 | k_to_accuracy = {k: round( 111 | np.array(k_to_all_accuracies[k]).mean().item(), 4) for k in range(1, k_max+1)} 112 | logger.info( 113 | f"validation accuracy: {k_to_accuracy}") 114 | 115 | # for debugging 116 | logger.info("Storing examples in human-readable format") 117 | examples_df = pd.DataFrame( 118 | {"input": all_inputs, "label": all_labels, "prediction": all_predictions}) 119 | examples_df.to_pickle("./eval_examples.pkl") 120 | 121 | logger.info("Done with evaluation") 122 | return k_to_accuracy 123 | 124 | 125 | def save_model(model, output_dir, epoch): 126 | model_to_save = model.module if hasattr(model, "module") else model 127 | output_model_file = os.path.join( 128 | output_dir, f"pytorch_model_epoch{epoch}.bin") 129 | t.save(model_to_save.state_dict(), output_model_file) 130 | logger.info("Saved the last model into %s", output_model_file) 131 | 132 | 133 | if __name__ == "__main__": 134 | args = parser.parse_args() 135 | 136 | tokenizer, model = load_CodeBERT() 137 | 138 | train_dataset = TensorDataset(t.load(args.train_tensors)) 139 | train_loader = DataLoader( 140 | train_dataset, batch_size=params.batch_size_CodeBERT, drop_last=True) 141 | 142 | optim = AdamW(model.parameters(), lr=1e-5) 143 | 144 | logger.info(f"Starting training on {device}") 145 | logger.info(" Num examples = {}".format(len(train_dataset))) 146 | logger.info(" Batch size = {}".format(params.batch_size_CodeBERT)) 147 | logger.info(" Batch num = {}".format( 148 | len(train_dataset) / params.batch_size_CodeBERT)) 149 | logger.info(" Num epoch = {}".format(params.epochs)) 150 | 151 | if not os.path.exists(args.output_dir): 152 | os.makedirs(args.output_dir) 153 | 154 | df_training_loss = pd.DataFrame(columns=['batch', 'loss', 'epoch']) 155 | df_validation_acc = pd.DataFrame(columns=['epoch', 'val_accuracy']) 156 | 157 | for epoch in range(params.epochs): 158 | logger.info(f"Epoch {epoch}") 159 | 160 | for batch_idx, batch in enumerate(train_loader): 161 | batch = t.cat(batch) 162 | input_ids = batch[:, :512] 163 | input_ids = input_ids.to(device) 164 | labels = batch[:, 512:] 165 | labels = labels.to(device) 166 | 167 | model.train() 168 | optim.zero_grad() 169 | 170 | outputs = model(input_ids, labels=labels) 171 | 172 | loss = outputs.loss 173 | loss.backward() 174 | optim.step() 175 | 176 | logger.info( 177 | f" Training loss of batch {batch_idx}: {round(loss.item(), 4)}") 178 | 179 | # save training losses to file 180 | df_training_loss = pd.concat([df_training_loss, pd.DataFrame({ 181 | 'batch': [batch_idx], 182 | 'loss': [round(loss.item(), 4)], 183 | 'epoch': [epoch] 184 | })]) 185 | df_training_loss.to_csv('./training_loss.csv', index=False) 186 | 187 | accuracy = evaluate(args.validate_tensors, model, tokenizer) 188 | 189 | # save validation accuracies to file 190 | df_validation_acc = pd.concat([df_validation_acc, pd.DataFrame({ 191 | "epoch": [epoch], 192 | "val_accuracy": [accuracy] 193 | })]) 194 | df_validation_acc.to_csv('./validation_acc.csv', index=False) 195 | 196 | save_model(model, args.output_dir, epoch) 197 | 198 | logger.info('Terminating training') 199 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codet5/FineTune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import random 4 | import torch as t 5 | import csv 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import DataLoader, TensorDataset 9 | from transformers import AdamW 10 | from .CodeT5 import load_CodeT5 11 | from ...Hyperparams import Hyperparams as params 12 | from ..DLUtil import device 13 | from ...Logging import logger 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | "--train_tensors", help=".pt files for training", default="train.pt") 19 | parser.add_argument( 20 | "--validate_tensors", help=".pt files for validation", default="validate.pt") 21 | parser.add_argument( 22 | "--output_dir", help="directory to store models", required=True) 23 | parser.add_argument( 24 | "--stats_dir", help="directory to store loss and accuracy results (default=current directory)", default=".") 25 | 26 | 27 | print_examples = True 28 | 29 | 30 | def evaluate(validate_tensors_path, model, tokenizer): 31 | validate_dataset = TensorDataset(t.load(validate_tensors_path)) 32 | validate_loader = DataLoader( 33 | validate_dataset, batch_size=params.batch_size_CodeT5, drop_last=True) 34 | 35 | logger.info("Starting evaluation") 36 | logger.info(" Num examples = {}".format(len(validate_dataset))) 37 | logger.info(" Num batches = {}".format(len(validate_loader))) 38 | logger.info(" Batch size = {}".format(params.batch_size_CodeT5)) 39 | 40 | k_max = 5 41 | k_to_all_accuracies = {k: [] for k in range(1, k_max+1)} 42 | all_inputs = [] 43 | all_labels = [] 44 | all_predictions = [] 45 | 46 | with t.no_grad(): 47 | model.eval() 48 | 49 | for batch_idx, batch in enumerate(validate_loader): 50 | batch = t.cat(batch) 51 | input_ids = batch[:, 0:512] 52 | input_ids = input_ids.to(device) 53 | label_ids = batch[:, 512:518] 54 | label_ids = label_ids.to(device) 55 | 56 | labels = tokenizer.batch_decode( 57 | label_ids, skip_special_tokens=True) 58 | 59 | # combine top-most prediction obtained via normal sampling and top-2, top-3, etc. predictions obtained via top-p nucleus sampling 60 | # 1) top-most prediction obtained via normal sampling 61 | generated_ids = model.generate( 62 | input_ids, max_length=params.max_output_length) 63 | predictions = tokenizer.batch_decode( 64 | generated_ids, skip_special_tokens=True) 65 | 66 | corrects = [1 for i in range( 67 | len(labels)) if labels[i] == predictions[i]] 68 | top1_accuracy = float(len(corrects)) / len(labels) 69 | k_to_all_accuracies[1].append(top1_accuracy) 70 | 71 | # for debugging/eye-balling the results 72 | all_inputs.extend(tokenizer.batch_decode( 73 | input_ids, skip_special_tokens=False)) 74 | all_labels.extend(labels) 75 | all_predictions.extend(predictions) 76 | if print_examples: 77 | for label_idx, label in enumerate(labels): 78 | if random.uniform(0, 100) < 0.1: 79 | prediction = predictions[label_idx] 80 | logger.info( 81 | f"Label: {label}, Prediction: {prediction}") 82 | 83 | # 2) top-2, top-3, etc. predictions obtained via top-p nucleus sampling (see https://huggingface.co/blog/how-to-generate) 84 | topk_generated_ids = model.generate( 85 | input_ids, max_length=params.max_output_length, 86 | do_sample=True, top_k=k_max, top_p=0.95, num_return_sequences=k_max) 87 | topk_predictions = tokenizer.batch_decode( 88 | topk_generated_ids, skip_special_tokens=True) 89 | 90 | # count correct predictions among different top-k 91 | k_to_corrects = {k: 0 for k in range(1, k_max+1)} 92 | # for top-1, use regular predictions from above 93 | k_to_corrects[1] = len(corrects) 94 | i = 0 95 | while i < len(topk_predictions): 96 | topk_predictions_for_example = topk_predictions[i:i+k_max] 97 | example_idx = int(i / k_max) 98 | label_for_example = labels[example_idx] 99 | for k in range(2, k_max+1): 100 | if label_for_example in topk_predictions_for_example[:k]: 101 | k_to_corrects[k] += 1 102 | i += k_max 103 | 104 | # compute top-k accuracies 105 | for k, corrects in k_to_corrects.items(): 106 | accuracy = float(corrects) / len(labels) 107 | k_to_all_accuracies[k].append(accuracy) 108 | 109 | k_to_accuracy = {k: round( 110 | np.array(k_to_all_accuracies[k]).mean().item(), 4) for k in range(1, k_max+1)} 111 | logger.info( 112 | f"validation accuracy: {k_to_accuracy}") 113 | 114 | # for debugging 115 | logger.info("Storing examples in human-readable format") 116 | examples_df = pd.DataFrame( 117 | {"input": all_inputs, "label": all_labels, "prediction": all_predictions}) 118 | examples_df.to_pickle("./eval_examples.pkl") 119 | 120 | logger.info("Done with evaluation") 121 | return k_to_accuracy 122 | 123 | 124 | def save_model(model, output_dir, epoch): 125 | model_to_save = model.module if hasattr(model, "module") else model 126 | output_model_file = os.path.join( 127 | output_dir, f"pytorch_model_epoch{epoch}.bin") 128 | t.save(model_to_save.state_dict(), output_model_file) 129 | logger.info("Saved the last model into %s", output_model_file) 130 | 131 | 132 | if __name__ == "__main__": 133 | args = parser.parse_args() 134 | 135 | tokenizer, model = load_CodeT5() 136 | 137 | train_dataset = TensorDataset(t.load(args.train_tensors)) 138 | train_loader = DataLoader( 139 | train_dataset, batch_size=params.batch_size_CodeT5, drop_last=True) 140 | 141 | optim = AdamW(model.parameters(), lr=1e-5) 142 | 143 | logger.info(f"Starting training on {device}") 144 | logger.info(" Num examples = {}".format(len(train_dataset))) 145 | logger.info(" Batch size = {}".format(params.batch_size_CodeT5)) 146 | logger.info(" Batch num = {}".format( 147 | len(train_dataset) / params.batch_size_CodeT5)) 148 | logger.info(" Num epoch = {}".format(params.epochs)) 149 | 150 | if not os.path.exists(args.output_dir): 151 | os.makedirs(args.output_dir) 152 | 153 | df_training_loss = pd.DataFrame(columns=['batch', 'loss', 'epoch']) 154 | df_validation_acc = pd.DataFrame(columns=['epoch', 'val_accuracy']) 155 | 156 | for epoch in range(params.epochs): 157 | logger.info(f"Epoch {epoch}") 158 | 159 | for batch_idx, batch in enumerate(train_loader): 160 | batch = t.cat(batch) 161 | input_ids = batch[:, 0:512] 162 | input_ids = input_ids.to(device) 163 | labels = batch[:, 512:518] 164 | labels = labels.to(device) 165 | 166 | model.train() 167 | optim.zero_grad() 168 | 169 | outputs = model(input_ids, labels=labels) 170 | 171 | loss = outputs.loss 172 | loss.backward() 173 | optim.step() 174 | 175 | logger.info( 176 | f" Training loss of batch {batch_idx}: {round(loss.item(), 4)}") 177 | 178 | # save training losses to file 179 | df_training_loss = pd.concat([df_training_loss, pd.DataFrame({ 180 | 'batch': [batch_idx], 181 | 'loss': [round(loss.item(), 4)], 182 | 'epoch': [epoch] 183 | })]) 184 | df_training_loss.to_csv(f"{args.stats_dir}/training_loss.csv", index=False) 185 | 186 | accuracy = evaluate(args.validate_tensors, model, tokenizer)[1] 187 | 188 | # save validation accuracies to file 189 | df_validation_acc = pd.concat([df_validation_acc, pd.DataFrame({ 190 | "epoch": [epoch], 191 | "val_accuracy": [accuracy] 192 | })]) 193 | df_validation_acc.to_csv(f"{args.stats_dir}/validation_acc.csv", index=False) 194 | 195 | save_model(model, args.output_dir, epoch) 196 | 197 | logger.info('Terminating training') 198 | -------------------------------------------------------------------------------- /src/lexecutor/predictors/codebert/InputFactory.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import json 3 | import re 4 | import torch as t 5 | from ..DLUtil import dtype, device 6 | from ...Logging import logger 7 | from ...Hyperparams import Hyperparams as params 8 | 9 | 10 | # special tokens 11 | mask_token = "" 12 | kind_name_token = "" 13 | kind_call_token = "" 14 | kind_attribute_token = "" 15 | sep_token = "" 16 | 17 | 18 | class InputFactory(object): 19 | 20 | def __init__(self, iids, tokenizer): 21 | self.iids = iids 22 | self.tokenizer = tokenizer 23 | self.file_to_lines = {} 24 | self.file_to_tokenized_lines = {} 25 | 26 | self.kind_name_token_id = self.tokenizer( 27 | kind_name_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 28 | self.kind_call_token_id = self.tokenizer( 29 | kind_call_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 30 | self.kind_attribute_token_id = self.tokenizer( 31 | kind_attribute_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 32 | self.sep_token_id = self.tokenizer( 33 | sep_token, add_special_tokens=False, return_attention_mask=False).input_ids[0] 34 | 35 | def __tokenize_lines(self, file_name): 36 | if file_name in self.file_to_tokenized_lines: 37 | return self.file_to_lines[file_name], self.file_to_tokenized_lines[file_name] 38 | else: 39 | with open(file_name, "r") as f: 40 | lines = f.readlines() 41 | self.file_to_lines[file_name] = lines 42 | tokenized_lines = self.tokenizer( 43 | lines, return_attention_mask=False, add_special_tokens=False).input_ids 44 | self.file_to_tokenized_lines[file_name] = tokenized_lines 45 | 46 | if len(self.file_to_tokenized_lines) > 10000: # prevent OOM 47 | self.file_to_lines = {} 48 | self.file_to_tokenized_lines = {} 49 | 50 | return lines, tokenized_lines 51 | 52 | def _extract_context_window(self, token_ids, marker_token): 53 | # Get at most 512 tokens around the target token 54 | id_of_target_begin = self.tokenizer.encode(marker_token)[1] 55 | target_index = token_ids.index(id_of_target_begin) 56 | 57 | # fewer context before target 58 | if target_index < 255: 59 | previous_target_tokens = token_ids[0:target_index] 60 | after_target_tokens = token_ids[target_index:target_index + 61 | (512 - len(previous_target_tokens))] 62 | # fewer context after target 63 | elif target_index + 255 > len(token_ids): 64 | after_target_tokens = token_ids[target_index:] 65 | previous_target_tokens = token_ids[target_index - 66 | (512 - len(after_target_tokens)):target_index] 67 | # equal context before and after target 68 | else: 69 | previous_target_tokens = token_ids[target_index-255:target_index] 70 | after_target_tokens = token_ids[target_index:target_index+255] 71 | 72 | return previous_target_tokens, after_target_tokens 73 | 74 | def _encode_output(self, entry): 75 | # Create labels 76 | if not hasattr(entry, "value"): 77 | # during prediction 78 | value = "unknown" 79 | else: 80 | # during training 81 | assert entry["value"].startswith("@"), entry["value"] 82 | value = entry["value"][1:] 83 | 84 | label_ids = self.tokenizer( 85 | value, max_length=1, 86 | return_attention_mask=False, add_special_tokens=False).input_ids 87 | return label_ids 88 | 89 | def _encode_input_output(self, entry, location, lines, tokenized_lines): 90 | # format of input: 91 | # name kind pre-context post-context 92 | 93 | target_line = lines[location.line-1] 94 | 95 | start_index = location.column_start 96 | end_index = location.column_end 97 | 98 | modified_line = target_line[:start_index] + \ 99 | mask_token + target_line[end_index:] 100 | 101 | tokenized_target_line = self.tokenizer( 102 | modified_line, return_attention_mask=False, add_special_tokens=False).input_ids 103 | 104 | # store and later restore the original tokenized line 105 | # (because we use the same tokenized_lines for all entries) 106 | original_tokenized_target_line = tokenized_lines[location.line-1] 107 | tokenized_lines[location.line-1] = tokenized_target_line 108 | token_ids = list(itertools.chain(*tokenized_lines)) 109 | tokenized_lines[location.line-1] = original_tokenized_target_line 110 | 111 | name = entry["name"] 112 | name_ids = self.tokenizer(name, return_attention_mask=False, 113 | add_special_tokens=False).input_ids 114 | 115 | mask_value_ids = self._encode_output(entry) 116 | previous_target_tokens, after_target_tokens = self._extract_context_window( 117 | token_ids, mask_token) 118 | 119 | # shrink context to fit everything (incl. the variable-sized name_ids) into 512 tokens 120 | while len(name_ids) + len(mask_value_ids) + len(previous_target_tokens) + len(after_target_tokens[1:]) + 5 > 512: 121 | previous_target_tokens = previous_target_tokens[1:] 122 | after_target_tokens = after_target_tokens[:-1] 123 | context_ids = previous_target_tokens + [self.tokenizer.encode(mask_token)[1]] + after_target_tokens[1:] 124 | 125 | if entry["kind"] == "name": 126 | kind_token = self.kind_name_token_id 127 | elif entry["kind"] == "call": 128 | kind_token = self.kind_call_token_id 129 | elif entry["kind"] == "attribute": 130 | kind_token = self.kind_attribute_token_id 131 | 132 | input_ids = [self.tokenizer.bos_token_id] + \ 133 | name_ids + \ 134 | [self.sep_token_id, kind_token, self.sep_token_id] + \ 135 | context_ids + \ 136 | [self.tokenizer.eos_token_id] 137 | 138 | # guarantee that the input fits 512 tokens even after decoding and encoding again 139 | while True: 140 | decoded_input = self.tokenizer.decode(input_ids) 141 | decoded_input = decoded_input.replace("", "") 142 | decoded_input = decoded_input.replace("", "") 143 | decoded_input = decoded_input.replace("", "") 144 | encoded_input = self.tokenizer.encode(decoded_input) 145 | if len(encoded_input[1:-1]) > 510: 146 | input_ids = [self.tokenizer.bos_token_id] + \ 147 | name_ids + \ 148 | [self.sep_token_id, kind_token, self.sep_token_id] + \ 149 | previous_target_tokens[1:0] + [self.tokenizer.encode(mask_token)[1]] + after_target_tokens[1:-1] + \ 150 | [self.tokenizer.eos_token_id] 151 | else: 152 | break 153 | 154 | # Add padding 155 | if len(input_ids) < 512: 156 | input_ids = input_ids + \ 157 | (512 - len(input_ids)) * [self.tokenizer.pad_token_id] 158 | 159 | output_ids = [self.tokenizer.bos_token_id] + \ 160 | name_ids + \ 161 | [self.sep_token_id, kind_token, self.sep_token_id] + \ 162 | previous_target_tokens + mask_value_ids + after_target_tokens[1:] + \ 163 | [self.tokenizer.eos_token_id] 164 | 165 | # Add padding 166 | if len(output_ids) < 512: 167 | output_ids = output_ids + \ 168 | (512 - len(output_ids)) * [self.tokenizer.pad_token_id] 169 | 170 | return input_ids, output_ids 171 | 172 | 173 | def entry_to_inputs(self, entry): 174 | location = self.iids.location(str(entry["iid"])) 175 | 176 | lines, tokenized_lines = self.__tokenize_lines(location.file+'.orig') 177 | 178 | input_ids, label_ids = self._encode_input_output(entry, location, lines, tokenized_lines) 179 | 180 | input_ids = t.tensor(input_ids, device='cpu') 181 | label_ids = t.tensor(label_ids, device='cpu') 182 | 183 | assert len(input_ids) == 512, len(input_ids) 184 | assert len(label_ids) == 512, len(label_ids) 185 | 186 | return input_ids, label_ids 187 | -------------------------------------------------------------------------------- /src/lexecutor/Runtime.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import sys 3 | import time 4 | from os import path 5 | from .Hyperparams import Hyperparams as params 6 | from .TraceWriter import TraceWriter 7 | from .ValueAbstraction import restore_value, DummyObject 8 | from .RuntimeStats import RuntimeStats 9 | from .Logging import logger 10 | 11 | 12 | logger.info("Runtime starting") 13 | 14 | # ------- begin: select mode ----- 15 | # mode = "RECORD" # record values and write into a trace file 16 | mode = "PREDICT" # predict and inject values if missing in exeuction 17 | # mode = "REPLAY" # replay a previously recorded trace (mostly for testing) 18 | # ------- end: select mode ------- 19 | 20 | file_type = "SOURCE" 21 | # file_type = "TESTE" 22 | 23 | if mode == "RECORD": 24 | trace = TraceWriter() 25 | atexit.register(lambda: trace.write_to_file()) 26 | runtime_stats = None 27 | elif mode == "PREDICT": 28 | # for running experiments 29 | if file_type == "SOURCE": 30 | file = sys.argv[0] 31 | execution = sys.argv[1] if len(sys.argv) > 1 else "" 32 | elif file_type == "TESTE": 33 | file = sys.argv[1] 34 | execution = sys.argv[2] if len(sys.argv) > 1 else "" 35 | 36 | runtime_stats = RuntimeStats(execution) 37 | atexit.register(runtime_stats.print) 38 | 39 | # from .predictors.AsIs import AsIs 40 | # predictor = AsIs() 41 | 42 | # from .predictors.NaiveValuePredictor import NaiveValuePredictor 43 | # predictor = NaiveValuePredictor() 44 | 45 | # from .predictors.RandomPredictor import RandomPredictor 46 | # predictor = RandomPredictor() 47 | 48 | # from .predictors.FrequencyValuePredictor import FrequencyValuePredictor 49 | # predictor = FrequencyValuePredictor("values_frequencies.json") 50 | 51 | from .predictors.codet5.CodeT5ValuePredictor import CodeT5ValuePredictor 52 | predictor = CodeT5ValuePredictor(runtime_stats) 53 | 54 | # from .predictors.codebert.CodeBERTValuePredictor import CodeBERTValuePredictor 55 | # predictor = CodeBERTValuePredictor(runtime_stats) 56 | 57 | # from .predictors.Type4PyValuePredictor import Type4PyValuePredictor 58 | # predictor = Type4PyValuePredictor(file, runtime_stats) 59 | 60 | start_time = time.time() 61 | if file_type == "TESTE": 62 | predictor_name = "PynguinTests" 63 | else: 64 | predictor_name = predictor.__class__.__name__ 65 | 66 | atexit.register(runtime_stats.save, file, predictor_name, start_time) 67 | elif mode == "REPLAY": 68 | with open("trace.out", "r") as file: 69 | trace = file.readlines() 70 | next_trace_idx = 0 71 | runtime_stats = None 72 | 73 | logger.info(f"### LExecutor running in {mode} mode ###") 74 | 75 | # map kind+name to predicted value to ensure consistent predictions for the same name 76 | kind_and_name_to_value = {} 77 | 78 | 79 | def _n_(iid, name, lambada): 80 | if params.verbose: 81 | logger.info(f"\nAt iid={iid}, looking up name '{name}'") 82 | 83 | if runtime_stats is not None: 84 | runtime_stats.total_uses += 1 85 | runtime_stats.cover_iid(iid) 86 | 87 | perform_fct = lambada 88 | 89 | def record_fct(v): 90 | trace.append_name(iid, name, v) 91 | 92 | def predict_fct(): 93 | key = f"name#{name}" 94 | if key in kind_and_name_to_value: 95 | return kind_and_name_to_value[key] 96 | else: 97 | v = predictor.name(iid, name) 98 | kind_and_name_to_value[key] = v 99 | return v 100 | 101 | return mode_branch(iid, perform_fct, record_fct, predict_fct, kind="name") 102 | 103 | 104 | def _c_(iid, fct, *args, **kwargs): 105 | if params.verbose: 106 | logger.info(f"\nAt iid={iid}, calling function {fct}") 107 | 108 | if runtime_stats is not None: 109 | runtime_stats.total_uses += 1 110 | runtime_stats.cover_iid(iid) 111 | 112 | def perform_fct(): 113 | return fct(*args, **kwargs) 114 | 115 | def record_fct(v): 116 | trace.append_call(iid, fct, args, kwargs, v) 117 | 118 | def predict_fct(): 119 | fct_name = fct.__name__ if hasattr(fct, "__name__") else str(fct) 120 | if " " in fct_name: # some fcts that don't have a proper name 121 | fct_name = fct_name.split(" ")[0] 122 | 123 | key = f"call#{fct_name}" 124 | if key in kind_and_name_to_value: 125 | return kind_and_name_to_value[key] 126 | else: 127 | v = predictor.call(iid, fct, fct_name, args, kwargs) 128 | kind_and_name_to_value[key] = v 129 | return v 130 | 131 | kind = "call_dummy" if fct is DummyObject else "call" 132 | return mode_branch(iid, perform_fct, record_fct, predict_fct, kind=kind) 133 | 134 | 135 | def _a_(iid, base, attr_name): 136 | if params.verbose: 137 | logger.info(f"\nAt iid={iid}, looking up attribute '{attr_name}'") 138 | 139 | if runtime_stats is not None: 140 | runtime_stats.total_uses += 1 141 | runtime_stats.cover_iid(iid) 142 | 143 | def perform_fct(): 144 | # return getattr(base, attr_name) 145 | # unmangle private attributes (code copied from DynaPyt) 146 | if (attr_name.startswith('__')) and (not attr_name.endswith('__')): 147 | if type(base).__name__ == 'type': 148 | parents = [base] 149 | else: 150 | parents = [type(base)] 151 | found = True 152 | while len(parents) > 0: 153 | found = True 154 | cur_par = parents.pop() 155 | try: 156 | cur_name = cur_par.__name__ 157 | cur_name = cur_name.lstrip('_') 158 | return getattr(base, '_'+cur_name+attr_name) 159 | except AttributeError: 160 | found = False 161 | parents.extend(list(cur_par.__bases__)) 162 | continue 163 | break 164 | if not found: 165 | raise AttributeError() 166 | else: 167 | return getattr(base, attr_name) 168 | 169 | def record_fct(v): 170 | trace.append_attribute(iid, base, attr_name, v) 171 | 172 | def predict_fct(): 173 | key = f"attribute#{attr_name}" 174 | if key in kind_and_name_to_value: 175 | return kind_and_name_to_value[key] 176 | else: 177 | v = predictor.attribute(iid, base, attr_name) 178 | kind_and_name_to_value[key] = v 179 | return v 180 | 181 | return mode_branch(iid, perform_fct, record_fct, predict_fct, kind="attribute") 182 | 183 | def _l_(iid): 184 | if runtime_stats is not None: 185 | runtime_stats.cover_line(iid) 186 | runtime_stats.save(file, predictor_name, start_time) 187 | 188 | def mode_branch(iid, perform_fct, record_fct, predict_fct, kind): 189 | if mode == "RECORD": 190 | v = perform_fct() 191 | record_fct(v) 192 | return v 193 | elif mode == "PREDICT": 194 | if kind == "call_dummy": 195 | # predict and inject a return value 196 | v = predict_fct() 197 | return v 198 | else: 199 | # try to perform the regular behavior and intervene in case of exceptions caused by missing values 200 | try: 201 | v = perform_fct() 202 | if params.verbose: 203 | logger.info("Found/computed/returned regular value") 204 | return v 205 | except Exception as e: 206 | if (type(e) == NameError and kind == "name") \ 207 | or (type(e) == AttributeError and kind == "attribute"): 208 | if params.verbose: 209 | logger.info( 210 | f"Catching '{type(e)}' during {kind} and calling predictor instead") 211 | v = predict_fct() 212 | runtime_stats.guided_uses += 1 213 | return v 214 | else: 215 | if params.verbose: 216 | logger.info( 217 | f"Exception '{type(e)}' not caught, re-raising") 218 | runtime_stats.uncaught_exception(iid, e) 219 | raise e 220 | elif mode == "REPLAY": 221 | # replay mode 222 | global next_trace_idx 223 | trace_line = trace[next_trace_idx].rstrip() 224 | next_trace_idx += 1 225 | segments = trace_line.split(" ") 226 | trace_iid = int(segments[0]) 227 | abstract_value = segments[-1] 228 | if iid != trace_iid: 229 | raise Exception( 230 | f"trace_iid={trace_iid} doesn't match execution iid={iid}") 231 | v = restore_value(abstract_value) 232 | return v 233 | else: 234 | raise Exception(f"Unexpected mode {mode}") 235 | -------------------------------------------------------------------------------- /src/lexecutor/ValueAbstraction.py: -------------------------------------------------------------------------------- 1 | from .Logging import logger 2 | from .Hyperparams import Hyperparams as params 3 | import random 4 | 5 | 6 | def abstract_value(value): 7 | t = type(value) 8 | # common primitive values 9 | if value is None: 10 | abtract_value = "@None" 11 | elif value is True: 12 | abtract_value = "@True" 13 | elif value is False: 14 | abtract_value = "@False" 15 | # strings 16 | elif t is str: 17 | if len(value) == 0: 18 | abtract_value = "@str_empty" 19 | else: 20 | abtract_value = "@str_nonempty" 21 | # built-in numeric types 22 | elif t is int: 23 | if value < 0: 24 | abtract_value = "@int_neg" 25 | elif value == 0: 26 | abtract_value = "@int_zero" 27 | else: 28 | abtract_value = "@int_pos" 29 | elif t is float: 30 | if value < 0: 31 | abtract_value = "@float_neg" 32 | elif value == 0: 33 | abtract_value = "@float_zero" 34 | else: 35 | abtract_value = "@float_pos" 36 | # built-in sequence types 37 | elif t is list: 38 | if len(value) == 0: 39 | abtract_value = "@list_empty" 40 | else: 41 | abtract_value = "@list_nonempty" 42 | elif t is tuple: 43 | if len(value) == 0: 44 | abtract_value = "@tuple_empty" 45 | else: 46 | abtract_value = "@tuple_nonempty" 47 | # built-in set and dict types 48 | elif t is set: 49 | if len(value) == 0: 50 | abtract_value = "@set_empty" 51 | else: 52 | abtract_value = "@set_nonempty" 53 | elif t is dict: 54 | if len(value) == 0: 55 | abtract_value = "@dict_empty" 56 | else: 57 | abtract_value = "@dict_nonempty" 58 | # functions and methods 59 | elif callable(value): 60 | if hasattr(value, "__enter__") and hasattr(value, "__exit__"): 61 | abtract_value = "@resource" 62 | else: 63 | abtract_value = "@callable" 64 | # all other types 65 | else: 66 | abtract_value = "@object" 67 | 68 | return abtract_value, str(t)[:20] 69 | 70 | 71 | class DummyResource(object): 72 | def __enter__(self): 73 | return self 74 | 75 | def __exit__(self, exc_type, exc_value, trace): 76 | return True 77 | 78 | 79 | class DummyObject(): 80 | def __init__(self, *a, **b): 81 | pass 82 | 83 | 84 | fine_to_coarse_grained = { 85 | "@None": "@None", 86 | "@True": "@bool", 87 | "@False": "@bool", 88 | "@str_empty": "@str", 89 | "@str_nonempty": "@str", 90 | "@int_neg": "@int", 91 | "@int_zero": "@int", 92 | "@int_pos": "@int", 93 | "@float_neg": "@float", 94 | "@float_zero": "@float", 95 | "@float_pos": "@float", 96 | "@list_empty": "@list", 97 | "@list_nonempty": "@list", 98 | "@tuple_empty": "@tuple", 99 | "@tuple_nonempty": "@tuple", 100 | "@set_empty": "@set", 101 | "@set_nonempty": "@set", 102 | "@dict_empty": "@dict", 103 | "@dict_nonempty": "@dict", 104 | "@resource": "@resource", 105 | "@callable": "@callable", 106 | "@object": "@object", 107 | } 108 | 109 | 110 | if params.value_abstraction.startswith("coarse-grained"): 111 | if params.value_abstraction == "coarse-grained-deterministic": 112 | def restore_value(abstract_value): 113 | # common primitive values 114 | if abstract_value == "None": 115 | return None 116 | elif abstract_value == "bool": 117 | return True 118 | # strings 119 | elif abstract_value == "str": 120 | return "a" 121 | # built-in numeric types 122 | elif abstract_value == "int": 123 | return 1 124 | elif abstract_value == "float": 125 | return 1.0 126 | # built-in sequence types 127 | elif abstract_value == "list": 128 | return [DummyObject()] 129 | elif abstract_value == "tuple": 130 | return (DummyObject(),) 131 | # built-in set and dict types 132 | elif abstract_value == "set": 133 | return {DummyObject()} 134 | elif abstract_value == "dict": 135 | return {"a": DummyObject()} 136 | # functions and methods 137 | elif abstract_value == "resource": 138 | return DummyResource() 139 | elif abstract_value == "callable": 140 | return DummyObject 141 | elif abstract_value == "object": 142 | return DummyObject() 143 | # all other types 144 | else: 145 | logger.info("Unknown abstract value: %s", abstract_value) 146 | return DummyObject() 147 | elif params.value_abstraction == "coarse-grained-randomized": 148 | def restore_value(abstract_value): 149 | # common primitive values 150 | if abstract_value == "None": 151 | return None 152 | elif abstract_value == "bool": 153 | return random.choice([True, False]) 154 | # strings 155 | elif abstract_value == "str": 156 | return random.choice(["", "a"]) 157 | # built-in numeric types 158 | elif abstract_value == "int": 159 | return random.choice([-1, 0, 1]) 160 | elif abstract_value == "float": 161 | return random.choice([-1.0, 0.0, 1.0]) 162 | # built-in sequence types 163 | elif abstract_value == "list": 164 | return random.choice([[], [DummyObject()]]) 165 | elif abstract_value == "tuple": 166 | return random.choice([(), (DummyObject(),)]) 167 | # built-in set and dict types 168 | elif abstract_value == "set": 169 | return random.choice([{}, {DummyObject()}]) 170 | elif abstract_value == "dict": 171 | return random.choice([{}, {"a": DummyObject()}]) 172 | # functions and methods 173 | elif abstract_value == "resource": 174 | return DummyResource() 175 | elif abstract_value == "callable": 176 | return DummyObject 177 | elif abstract_value == "object": 178 | return DummyObject() 179 | # all other types 180 | else: 181 | logger.info("Unknown abstract value: %s", abstract_value) 182 | return DummyObject() 183 | 184 | elif params.value_abstraction == "fine-grained": 185 | def restore_value(abstract_value): 186 | # common primitive values 187 | if abstract_value == "None": 188 | return None 189 | elif abstract_value == "True": 190 | return True 191 | elif abstract_value == "False": 192 | return False 193 | # strings 194 | elif abstract_value == "str_empty": 195 | return "" 196 | elif abstract_value == "str_nonempty": 197 | return "a" 198 | # built-in numeric types 199 | elif abstract_value == "int_neg": 200 | return -1 201 | elif abstract_value == "int_zero": 202 | return 0 203 | elif abstract_value == "int_pos": 204 | return 1 205 | elif abstract_value == "float_neg": 206 | return -1.0 207 | elif abstract_value == "float_zero": 208 | return 0.0 209 | elif abstract_value == "float_pos": 210 | return 1.0 211 | # built-in sequence types 212 | elif abstract_value == "list_empty": 213 | return [] 214 | elif abstract_value == "list_nonempty": 215 | return [DummyObject()] 216 | elif abstract_value == "tuple_empty": 217 | return () 218 | elif abstract_value == "tuple_nonempty": 219 | return (DummyObject(),) 220 | # built-in set and dict types 221 | elif abstract_value == "set_empty": 222 | return set() 223 | elif abstract_value == "set_nonempty": 224 | return {DummyObject()} 225 | elif abstract_value == "dict_empty": 226 | return {} 227 | elif abstract_value == "dict_nonempty": 228 | return {"a": DummyObject()} 229 | # functions and methods 230 | elif abstract_value == "resource": 231 | return DummyResource() 232 | elif abstract_value == "callable": 233 | return DummyObject 234 | elif abstract_value == "object": 235 | return DummyObject() 236 | # all other types 237 | else: 238 | logger.info("Unknown abstract value: %s", abstract_value) 239 | return DummyObject() 240 | 241 | else: 242 | raise ValueError( 243 | f"Unknown setting for value_abstraction: {params.value_abstraction}") 244 | -------------------------------------------------------------------------------- /src/lexecutor/evaluation/FunctionPairExtractor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from git import Repo 3 | import re 4 | from collections import namedtuple 5 | from os.path import exists, isdir, join 6 | from os import mkdir 7 | import libcst as cst 8 | from typing import Optional 9 | 10 | # Helper script to find commits that modify a single function, 11 | # and to extract the pair of old+new function into separate files. 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--repo", help="Directory with a git repository", required=True) 16 | parser.add_argument( 17 | "--dest", help="Destination directory", required=True) 18 | 19 | 20 | CodeChange = namedtuple("CodeChange", ["old_commit", "new_commit", "file", "line"]) 21 | 22 | 23 | def find_code_changes(repo): 24 | print(f"Finding code changes in {repo}") 25 | try: 26 | commits = list(repo.iter_commits("main")) 27 | except: 28 | commits = list(repo.iter_commits("master")) 29 | 30 | print(f"Constructed list of {len(commits)} commits") 31 | code_changes = [] 32 | for c_idx, c in enumerate(commits): 33 | if len(c.parents) == 0: 34 | continue 35 | diff = c.parents[0].diff(c, create_patch=True) 36 | if len(diff) == 1 and diff[0].a_path and diff[0].a_path.endswith(".py") and diff[0].b_path and diff[0].b_path.endswith(".py"): 37 | diff_str = str(diff[0]) 38 | matches = re.findall(r"@@", diff_str) 39 | if len(matches) == 2: 40 | try: 41 | line_info = diff_str.split("@@")[1] 42 | line = int(line_info[line_info.find("-")+1:line_info.find(",")]) 43 | code_changes.append(CodeChange(c.parents[0].hexsha, c.hexsha, diff[0].a_path, line)) 44 | except: 45 | print(f"Error parsing diff for commit {c.hexsha} -- ignoring") 46 | if len(code_changes) == 1000: 47 | break 48 | if c_idx % 100 == 0: 49 | print(f"Processed {c_idx+1}/{len(commits)} commits. Found {len(code_changes)} relevant commits so far.") 50 | return code_changes 51 | 52 | 53 | class FunctionExtractor(cst.CSTTransformer): 54 | METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,) 55 | 56 | def __init__(self, line): 57 | self.line = line 58 | self.function = None 59 | self.is_method = False 60 | self.param_names = [] 61 | 62 | def leave_Param(self, node, updated_node): 63 | # remove parameter type annotation 64 | return updated_node.with_changes(annotation=None) 65 | 66 | def leave_FunctionDef(self, node, updated_node): 67 | if node.name.value == "__init__": 68 | # ignore constructors, because we want to compare return values 69 | return updated_node 70 | 71 | start = self.get_metadata(cst.metadata.PositionProvider, node).start 72 | end = self.get_metadata(cst.metadata.PositionProvider, node).end 73 | if start.line <= self.line <= end.line: 74 | self.function = updated_node.with_changes(returns=None) 75 | 76 | if len(node.params.params) > 0 and node.params.params[0].name.value == "self": 77 | self.is_method = True 78 | 79 | for param in node.params.params: 80 | self.param_names.append(param.name.value) 81 | 82 | return updated_node 83 | 84 | 85 | def extract_function(file, line) -> Optional[cst.FunctionDef]: 86 | tree = cst.parse_module(open(file).read()) 87 | tree = cst.MetadataWrapper(tree) 88 | extractor = FunctionExtractor(line) 89 | tree.visit(extractor) 90 | return extractor 91 | 92 | 93 | def write_function_to_file(fct, dest_dir, name_prefix, code_change): 94 | fct_code = cst.Module([]).code_for_node(fct) 95 | file_name = join(dest_dir, f"{name_prefix}.py") 96 | comment = f"# {code_change.old_commit} -- {code_change.new_commit} -- {code_change.file} -- {code_change.line}\n\n" 97 | all_code = comment + fct_code 98 | with open(file_name, "w") as f: 99 | f.write(fct_code) 100 | 101 | 102 | def create_class_wrapper(fct_node, wrapper_name): 103 | fct_def_code = cst.Module([]).code_for_node( 104 | cst.ClassDef( 105 | name=cst.Name( 106 | value=wrapper_name 107 | ), 108 | body = cst.IndentedBlock([fct_node]) 109 | # body=cst.IndentedBlock( 110 | # body=[node.with_changes(=None)] 111 | # ) 112 | ) 113 | ) 114 | return fct_def_code 115 | 116 | 117 | def write_function_comparison_script(old_fct_extractor, new_fct_extractor, dest_dir, code_change): 118 | # create code that defines the functions/methods 119 | assert old_fct_extractor.is_method == new_fct_extractor.is_method 120 | if old_fct_extractor.is_method: 121 | # wrap function into a class 122 | old_fct_def_code = create_class_wrapper(old_fct_extractor.function, "Wrapper1") 123 | new_fct_def_code = create_class_wrapper(new_fct_extractor.function, "Wrapper2") 124 | fct_def_code = old_fct_def_code + "\n\n" + new_fct_def_code 125 | else: 126 | # change name of functions to distinguish old and new 127 | renamed_old_fct = old_fct_extractor.function.with_changes(name=cst.Name(value=old_fct_extractor.function.name.value + "_1")) 128 | renamed_new_fct = new_fct_extractor.function.with_changes(name=cst.Name(value=new_fct_extractor.function.name.value + "_2")) 129 | fct_def_code = cst.Module([]).code_for_node(renamed_old_fct) + "\n\n" + cst.Module([]).code_for_node(renamed_new_fct) 130 | 131 | # create code that calls and compares the two functions/methods 132 | main_code_template = """ 133 | def different(val1, val2): 134 | if type(val1) == Wrapper1 and type(val2) == Wrapper2: 135 | return False 136 | if type(val1) != type(val2): 137 | return True 138 | if type(val1) == list and type(val2) == list and len(val1) != len(val2): 139 | return True 140 | if type(val1) == dict and type(val2) == dict and len(val1) != len(val2): 141 | return True 142 | if type(val1) == set and type(val2) == set and len(val1) != len(val2): 143 | return True 144 | if type(val1) == tuple and type(val2) == tuple and len(val1) != len(val2): 145 | return True 146 | if type(val1) in [int, float, str, bool, type(None)] and type(val2) in [int, float, str, bool, type(None)]: 147 | return val1 != val2 148 | return False 149 | 150 | 151 | if __name__ == "__main__": 152 | import pathlib 153 | p = str(pathlib.Path(__file__).parent.resolve()) 154 | 155 | try: 156 | val1 = INVOCATION1 157 | val2 = INVOCATION2 158 | except Exception as e: 159 | print(p + ": Function(s) raised an exception: " + str(type(e)) + " -- " + str(e)) 160 | else: 161 | if different(val1, val2): 162 | print(p + ": Functions returned different values: " + str(val1) + " vs. " + str(val2)) 163 | else: 164 | print(p + ": Both functions returned the same value" + str(val1)) 165 | 166 | """ 167 | 168 | if old_fct_extractor.is_method: 169 | main_code_template = main_code_template.replace("INVOCATION1", "Wrapper1()." + old_fct_extractor.function.name.value + "(" + ", ".join(old_fct_extractor.param_names[1:]) + ")") 170 | main_code_template = main_code_template.replace("INVOCATION2", "Wrapper2()." + new_fct_extractor.function.name.value + "(" + ", ".join(new_fct_extractor.param_names[1:]) + ")") 171 | else: 172 | main_code_template = main_code_template.replace("INVOCATION1", old_fct_extractor.function.name.value + "_1(" + ", ".join(old_fct_extractor.param_names) + ")") 173 | main_code_template = main_code_template.replace("INVOCATION2", new_fct_extractor.function.name.value + "_2(" + ", ".join(new_fct_extractor.param_names) + ")") 174 | 175 | comment = f"# {code_change.old_commit} -- {code_change.new_commit} -- {code_change.file} -- {code_change.line}\n\n" 176 | 177 | all_code = comment + fct_def_code + "\n\n" + main_code_template 178 | file_name = join(dest_dir, "compare.py") 179 | with open(file_name, "w") as f: 180 | f.write(all_code) 181 | 182 | 183 | def extract_function_pair(repo, code_change, dest_dir): 184 | # get old function 185 | repo.git.checkout(code_change.old_commit) 186 | file_path = join(repo.working_tree_dir, code_change.file) 187 | old_function_extractor = extract_function(file_path, code_change.line) 188 | 189 | # get new function 190 | repo.git.checkout(code_change.new_commit) 191 | file_path = join(repo.working_tree_dir, code_change.file) 192 | new_function_extractor = extract_function(file_path, code_change.line) 193 | 194 | if old_function_extractor.function is None or new_function_extractor.function is None: 195 | return 196 | 197 | if old_function_extractor.is_method != new_function_extractor.is_method: 198 | return 199 | 200 | # write original functions into files 201 | write_function_to_file(old_function_extractor.function, dest_dir, "old", code_change) 202 | write_function_to_file(new_function_extractor.function, dest_dir, "new", code_change) 203 | 204 | # write both functions to a single file that invokes and compares them 205 | write_function_comparison_script(old_function_extractor, new_function_extractor, dest_dir, code_change) 206 | 207 | print(f"Extracted function pair to {dest_dir}") 208 | 209 | 210 | if __name__ == "__main__": 211 | args = parser.parse_args() 212 | if not exists(args.repo) or not isdir(args.repo): 213 | print(f"Invalid repo directory: {args.repo}") 214 | exit(1) 215 | if exists(args.dest) and not isdir(args.dest): 216 | print(f"Destination must be a directory: {args.dest}") 217 | exit(1) 218 | if not exists(args.dest): 219 | mkdir(args.dest) 220 | 221 | repo = Repo(args.repo) 222 | code_changes = find_code_changes(repo) 223 | print(f"{len(code_changes)} code changes found") 224 | for code_change_idx, code_change in enumerate(code_changes): 225 | dest_dir = join(args.dest, f"code_change_{code_change_idx}") 226 | if not exists(dest_dir): 227 | mkdir(dest_dir) 228 | 229 | try: 230 | extract_function_pair(repo, code_change, dest_dir) 231 | except: 232 | print(f"Something went wrong when extracting from code change {code_change_idx} -- ignoring") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LExecutor: Learning-Guided Execution 2 | 3 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.8270900.svg)](https://doi.org/10.5281/zenodo.8270900) 4 | 5 | This repository contains the implementation of LExecutor and supplementary material for the paper "LExecutor: Learning-Guided Execution" (FSE'23). 6 | 7 | Paper: https://arxiv.org/abs/2302.02343 8 | 9 | ## Getting Started Guide 10 | 11 | 1. Check that your setup meets the [REQUIREMENTS.md](REQUIREMENTS.md). 12 | 2. Follow the installation instructions in [INSTALL.md](INSTALL.md). 13 | 14 | ## Replication Guide 15 | 16 | To reproduce the results from the paper, follow these instructions. The results of the following instructions are provided in the [artifact](https://zenodo.org/record/8270900), i.e., you can also inspect them there to skip some of the below steps. 17 | 18 | First, install LExecutor using the instructions above. 19 | 20 | ### Accuracy of the Neural Model (RQ1) 21 | 22 | #### Value-use events dataset 23 | 24 | To gather a corpus of value-use events for training and evaluating the neural model, we proceed as follows: 25 | 26 | 1. Set the LExecutor mode to RECORD at `./src/lexecutor/Runtime.py` 27 | 28 | 2. Make `get_traces.sh` executable: 29 | ``` 30 | chmod +x get_traces.sh 31 | ``` 32 | 33 | 3. For every considered project, execute `get_traces.sh` giving the required arguments, e.g.: 34 | ``` 35 | ./get_traces.sh https://github.com/Textualize/rich rich tests 36 | ``` 37 | 38 | 4. Get the path of all the generated traces: 39 | ``` 40 | find ./data/repos/ -type f -name "trace_*.h5" > traces.txt 41 | ``` 42 | 43 | The output is stored as follows: the repositories with instrumented files and trace files are stored in `./data/repos`; the instruction ids is stored in `./iids.json`; the trace paths are stored in `./traces.txt`. 44 | 45 | #### Model training and validation 46 | 47 | Our current implementation integrates two pre-trained models, CodeT5 and CodeBERT, which we fine-tune for our prediction task as follows. 48 | 49 | ##### CodeT5 50 | 51 | 1. Create a folder to store the output: 52 | ``` 53 | mkdir ./data/codeT5_models_fine-grained 54 | ``` 55 | 56 | 2. Prepare the dataset: 57 | ``` 58 | python -m lexecutor.predictors.codet5.PrepareData \ 59 | --iids iids.json \ 60 | --traces traces.txt \ 61 | --output_dir ./data/codeT5_models_fine-grained 62 | ``` 63 | 64 | 3. Fine-tune the model: 65 | ``` 66 | python -m lexecutor.predictors.codet5.FineTune \ 67 | --train_tensors ./data/codeT5_models_fine-grained/train.pt \ 68 | --validate_tensors ./data/codeT5_models_fine-grained/validate.pt \ 69 | --output_dir ./data/codeT5_models_fine-grained \ 70 | --stats_dir ./data/codeT5_models_fine-grained 71 | ``` 72 | 73 | The output, i.e. the tensors, models for every epoch, training loss and validation accuracy, is stored in `./data/codeT5_models_fine-grained`. 74 | 75 | ##### CodeBERT 76 | 77 | 1. Create a folder to store the output: 78 | ``` 79 | mkdir ./data/codeBERT_models_fine-grained 80 | ``` 81 | 82 | 2. Prepare the dataset: 83 | ``` 84 | python -m lexecutor.predictors.codebert.PrepareData \ 85 | --iids iids.json \ 86 | --traces traces.txt \ 87 | --output_dir ./data/codeBERT_models_fine-grained 88 | ``` 89 | 90 | 3. Fine-tune the model: 91 | ``` 92 | python -m lexecutor.predictors.codeBERT.FineTune \ 93 | --train_tensors ./data/codeBERT_models_fine-grained/train.pt \ 94 | --validate_tensors ./data/codeBERT_models_fine-grained/validate.pt \ 95 | --output_dir ./data/codeBERT_models_fine-grained \ 96 | --stats_dir ./data/codeBERT_models_fine-grained 97 | ``` 98 | 99 | The output, i.e. the tensors, the models for every epoch, training loss and validation accuracy, is stored in `./data/codeBERT_models_fine-grained`. 100 | 101 | By default, we train and use the models based on the fine-grained abstraction of values. To fine-tune the models based on the coarse-grained abstraction of values, set `value_abstraction` to `coarse-grained-deterministic` or `coarse-grained-randomized` in `./src/lexecutor/Hyperparams.py`. Then, replace `fine-grained` by `coarse-grained` in the steps 1-3 above. 102 | 103 | ### Effectiveness at Covering Code and Efficiency at Guiding Executions (RQ2 and RQ3) 104 | 105 | #### Datasets 106 | 107 | ##### Open-source functions 108 | 109 | To gather a dataset of functions extracted from open-source Python projects, we proceed as follows: 110 | 111 | 1. Make `get_function_bodies_dataset.sh` executable: 112 | ``` 113 | chmod +x get_function_bodies_dataset.sh 114 | ``` 115 | 116 | 2. Execute `get_function_bodies_dataset.sh`: 117 | ``` 118 | ./get_function_bodies_dataset.sh 119 | ``` 120 | 121 | The output contains two extra versions of each function to fit the considered baseline approaches: 1) for functions that are methods, we wrapp them in a `Wrapper` class, otherwise we would not be able run Pynguin on them; 2) we add a function invocation to each function for them to be executed. This is required to run the code inside each function when running the baseline predictor based on Type4Py. 122 | 123 | The output is stored as follows: the repositories are stored in `./data/repos`; the randomly selected functions are stored in `./popular_projects_snippets_dataset`; the paths to the files in each version of the dataset are stored in `popular_projects_function_bodies_dataset.txt`, `popular_projects_functions_dataset.txt` and `popular_projects_functions_with_invocation_dataset.txt`. Finally, auxiliary information useful to calculate line coverage afterwards are stored in `wrapp_info.csv` and `aux_data_functions_with_invocation_dataset.csv`. 124 | 125 | ##### Stack Overflow snippets 126 | 127 | To gather a dataset of code snippets from Stack Overflow, we proceed as follows: 128 | 129 | 1. Create a folder to store the code snippets: 130 | ``` 131 | mkdir so_snippets_dataset 132 | ``` 133 | 134 | 2. Get the code snippets: 135 | ``` 136 | python get_stackoverflow_snippets_dataset.py --dest_dir so_snippets_dataset 137 | ``` 138 | 139 | 3. Get the path of all the collected snippets: 140 | ``` 141 | find ./so_snippets_dataset -type f -name "*.py" > so_snippets_dataset.txt 142 | ``` 143 | 144 | The output is stored as follows: the code snippets from Stack Overflow are stored in `./so_snippets_dataset` and their paths are stored in `so_snippets_dataset.txt`. 145 | 146 | #### Data generation 147 | 148 | 1. Set the dataset under evaluation at `./src/lexecutor/Hyperparams.py` 149 | 150 | 2. Calculate the total lines in each file on the dataset under evaluation, e.g.: 151 | ``` 152 | python -m lexecutor.evaluation.CountTotalLines --files popular_projects_function_bodies_dataset.txt 153 | ``` 154 | 155 | 3. Instrument the files in the dataset under evaluation, e.g.: 156 | ``` 157 | python -m lexecutor.Instrument --files popular_projects_function_bodies_dataset.txt --iids iids.json 158 | ``` 159 | 160 | 4. Execute each predictor/baseline on the dataset under evaluation as follows: 161 | 162 | 1. Set `./src/lexecutor/Runtime.py` to use the desired predictor. Some predictors/baselines require additional steps: 163 | - For the predictors based on CodeT5 and CodeBERT, the value abstraction must also be set at `./src/lexecutor/Hyperparams.py` 164 | - For the predictor based on Type4Py, make sure that the docker image containing Type4Py's pre-trained model is running according to [this tutorial](https://github.com/saltudelft/type4py/wiki/Type4Py's-Local-Model) 165 | - For the Pynguin baseline, execute the following steps: 166 | 1. Create and enter a virtual environment for Python 3.10 (required by the newest Pynguin version): 167 | ``` 168 | python3.10 -m venv myenv_py3.10 169 | source myenv_3.10/bin/activate 170 | ``` 171 | 172 | 2. Generate tests with Pynguin for the extracted functions: 173 | ``` 174 | mkdir pynguin_tests 175 | python -m lexecutor.evaluation.RunPynguin --files popular_projects_functions_dataset.txt --dest pynguin_tests 176 | ``` 177 | 178 | 3. Get the path of all the generated tests: 179 | ``` 180 | find ./pynguin_tests -type f -name "test_*.py" > pynguin_tests.txt 181 | ``` 182 | 183 | 4. Set the predictor to `AsIs` and the file_type to `TESTE` in `./src/lexecutor/Runtime.py` 184 | 185 | 2. Create a folder to store the log files, e.g.: 186 | ``` 187 | mkdir logs 188 | mkdir logs/popular_projects_functions_dataset 189 | mkdir logs/popular_projects_functions_dataset/RandomPredictor 190 | ``` 191 | 192 | 3. Execute `RunExperiments.py` with the required arguments, e.g.: 193 | ``` 194 | python -m lexecutor.evaluation.RunExperiments \ 195 | --files popular_projects_functions_dataset.txt \ 196 | --log_dest_dir logs/popular_projects_functions_dataset/RandomPredictor 197 | ``` 198 | 199 | For the Pynguin baseline, make sure to include `--tests` and give the path to the generated tests, i.e. `pynguin_tests.txt`, to `--files` when executing `RunExperiments.py` 200 | 201 | 5. Process and combine the raw data generated: 202 | ``` 203 | python -m lexecutor.evaluation.CombineData 204 | ``` 205 | 206 | #### Data analysis and plots generation 207 | 208 | The code to get the plots for RQ2 and table content for RQ3 is available at `./src/notebooks/analyze_code_coverage_effectiveness_and_efficiency.ipynb` 209 | 210 | ### Using LExecutor to Find Semantics-Changing Commits (RQ4) 211 | 212 | #### Pairs of old + new function from commits dataset 213 | 214 | To gather a corpus of pairs of old + new function from commits, we proceed as follows: 215 | 216 | 1. Create a folder to store the function pairs for every considered project, e.g.: 217 | ``` 218 | mkdir data/function_pairs && mkdir data/function_pairs/flask 219 | ``` 220 | 221 | 2. For every considered project, execute `FunctionPairExtractor.py` providing the required arguments, e.g.: 222 | ``` 223 | python -m lexecutor.evaluation.FunctionPairExtractor \ 224 | --repo data/repos_with_commit_history/flask/ \ 225 | --dest data/function_pairs/flask/ 226 | ``` 227 | 228 | The output, i.e. the function pairs with code that invokes both functions and compares their return values, is stored in `compare.py` files under `data/function_pairs/` 229 | 230 | #### Finding semantics-changing commits 231 | 232 | 1. Instrument the code in the `compare.py` files, e.g.: 233 | ``` 234 | python -m lexecutor.Instrument --files `find data/function_pairs/flask -name compare.py | xargs` 235 | ``` 236 | 237 | 2. Run the instrumented code to compare its runtime behavior, e.g.: 238 | ``` 239 | for f in `find data/function_pairs/flask -name compare.py | xargs`; do timeout 30 python $f; done > out_flask 240 | ``` 241 | -------------------------------------------------------------------------------- /src/lexecutor/CodeRewriter.py: -------------------------------------------------------------------------------- 1 | import libcst as cst 2 | from libcst.metadata import ParentNodeProvider, PositionProvider 3 | 4 | 5 | class CodeRewriter(cst.CSTTransformer): 6 | 7 | METADATA_DEPENDENCIES = (ParentNodeProvider, PositionProvider,) 8 | 9 | ignored_names = ["True", "False", "None"] 10 | ignored_calls = ["super"] # special function names to not instrument 11 | 12 | def __init__(self, file_path, iids, line_coverage_instrumentation, used_names): 13 | self.file_path = file_path 14 | self.used_names = used_names 15 | self.iids = iids 16 | self.line_coverage_instrumentation = line_coverage_instrumentation 17 | 18 | self.instrument = True # turned off in special cases, e.g., inside nested f-strings 19 | 20 | self.quotation_char = '"' # flipped to "'" when inside an f-string with double quotes 21 | self.fstring_stack = [] 22 | 23 | def __create_iid(self, node): 24 | location = self.get_metadata(PositionProvider, node) 25 | line = location.start.line 26 | column_start = location.start.column 27 | column_end = location.end.column 28 | iid = self.iids.new(self.file_path, line, column_start, column_end) 29 | return iid 30 | 31 | def __create_name_call(self, node, updated_node): 32 | callee_name = cst.Name(value="_n_") 33 | iid = self.__create_iid(node) 34 | iid_arg = cst.Arg(value=cst.Integer(value=str(iid))) 35 | name_arg = cst.Arg(cst.SimpleString( 36 | value=f"{self.quotation_char}{node.value}{self.quotation_char}")) 37 | lambada = cst.Lambda(params=cst.Parameters( 38 | params=[]), body=updated_node) 39 | value_arg = cst.Arg(value=lambada) 40 | call = cst.Call(func=callee_name, args=[iid_arg, name_arg, value_arg]) 41 | return call 42 | 43 | def __ensure_generator_expr_have_parens(self, args): 44 | # make sure that generator expressions have parentheses if not the only argument 45 | updated_args = [] 46 | for arg in args: 47 | if (isinstance(arg.value, cst.GeneratorExp) 48 | and len(arg.value.lpar) == 0 49 | and len(arg.value.rpar) == 0): 50 | g = arg.value 51 | g_new = cst.GeneratorExp(elt=g.elt, 52 | for_in=g.for_in, 53 | lpar=[cst.LeftParen()], 54 | rpar=[cst.RightParen()]) 55 | updated_args.append(cst.Arg(value=g_new)) 56 | else: 57 | updated_args.append(arg) 58 | return updated_args 59 | 60 | def __get_callee_name_node(self, call_node): 61 | if isinstance(call_node.func, cst.Name): 62 | return call_node.func 63 | elif isinstance(call_node.func, cst.Attribute): 64 | return call_node.func.attr 65 | else: # everything else, e.g., cst.Subscript 66 | return call_node.func 67 | 68 | def __create_call_call(self, node, updated_node): 69 | callee_name = cst.Name(value="_c_") 70 | node_of_callee_name = self.__get_callee_name_node(node) 71 | iid = self.__create_iid(node_of_callee_name) 72 | iid_arg = cst.Arg(value=cst.Integer(value=str(iid))) 73 | fct_arg = cst.Arg(value=updated_node.func) 74 | all_args = [iid_arg, fct_arg] + \ 75 | self.__ensure_generator_expr_have_parens(updated_node.args) 76 | call = cst.Call(func=callee_name, args=all_args) 77 | return call 78 | 79 | def __create_attribute_call(self, node, updated_node): 80 | callee_name = cst.Name(value="_a_") 81 | assert type(node.attr) == cst.Name, type(node.attr) 82 | iid = self.__create_iid(node.attr) 83 | iid_arg = cst.Arg(value=cst.Integer(value=str(iid))) 84 | value_arg = cst.Arg(updated_node.value) 85 | attr_arg = cst.Arg(cst.SimpleString( 86 | value=f"{self.quotation_char}{node.attr.value}{self.quotation_char}")) 87 | call = cst.Call(func=callee_name, args=[iid_arg, value_arg, attr_arg]) 88 | return call 89 | 90 | def __create_line_call(self, node, updated_node): 91 | callee_name = cst.Name(value="_l_") 92 | iid = self.__create_iid(node) 93 | iid_arg = cst.Arg(value=cst.Integer(value=str(iid))) 94 | call = cst.Call(func=callee_name, args=[iid_arg]) 95 | return call 96 | 97 | def __create_line_call_stmt(self, node, updated_node): 98 | statement_call = self.__create_line_call(node, updated_node) 99 | stmt = cst.SimpleStatementLine(body=[cst.Expr(value=statement_call)], 100 | trailing_whitespace=cst.TrailingWhitespace( 101 | whitespace=cst.SimpleWhitespace(value='',) 102 | ), 103 | ) 104 | return stmt 105 | 106 | def __create_aux_stmt(self, updated_node, value): 107 | aux_stmt = cst.SimpleStatementLine( 108 | body=[ 109 | cst.Assign( 110 | targets=[ 111 | cst.AssignTarget( 112 | target=cst.Name(value='aux', lpar=[], rpar=[],), 113 | whitespace_before_equal=cst.SimpleWhitespace(value=' ',), 114 | whitespace_after_equal=cst.SimpleWhitespace(value=' ',),), 115 | ], 116 | value=value 117 | ) 118 | ], 119 | trailing_whitespace=updated_node.trailing_whitespace 120 | ) 121 | return aux_stmt 122 | 123 | 124 | def __update_indented_block(self, node, updated_node): 125 | stmt = self.__create_line_call_stmt(node, updated_node) 126 | body_content = [stmt, cst.Expr(cst.Newline())] 127 | body_content.extend(updated_node.body.body) 128 | new_body = cst.IndentedBlock(body=body_content) 129 | return updated_node.with_changes(body=new_body) 130 | 131 | def __create_import(self, name): 132 | module_name = cst.Attribute(value=cst.Name( 133 | value="lexecutor"), attr=cst.Name(value="Runtime")) 134 | fct_name = cst.Name(value=name) 135 | imp_alias = cst.ImportAlias(name=fct_name) 136 | imp = cst.ImportFrom(module=module_name, names=[imp_alias]) 137 | stmt = cst.SimpleStatementLine(body=[imp]) 138 | return stmt 139 | 140 | def __wrap_import(self, node, updated_node): 141 | statement_call = self.__create_line_call(node, updated_node) 142 | stmt = cst.SimpleStatementLine(body=[cst.Expr(value=statement_call)], 143 | trailing_whitespace=cst.TrailingWhitespace( 144 | whitespace=cst.SimpleWhitespace(value='',) 145 | ), 146 | ) 147 | body_content = [cst.SimpleStatementLine(body=[updated_node])] 148 | body_content.extend([stmt, cst.Expr(cst.Newline())]) 149 | 150 | try_stmt = cst.Try(body=cst.IndentedBlock( 151 | body=body_content), 152 | handlers=[cst.ExceptHandler(body=cst.IndentedBlock( 153 | body=[cst.SimpleStatementLine(body=[cst.Pass()])]), 154 | type=cst.Name(value="ImportError"))]) 155 | return try_stmt 156 | 157 | def __is_l_value(self, node): 158 | parent = self.get_metadata(ParentNodeProvider, node) 159 | 160 | # assignments to a single value 161 | if (type(parent) == cst.AssignTarget or 162 | type(parent) == cst.AnnAssign or 163 | type(parent) == cst.AugAssign): 164 | return True 165 | 166 | # multi-assignments 167 | if type(parent) == cst.Element: 168 | grand_parent = self.get_metadata(ParentNodeProvider, parent) 169 | if type(grand_parent) == cst.Tuple: 170 | grand_grand_parent = self.get_metadata( 171 | ParentNodeProvider, grand_parent) 172 | if (type(grand_grand_parent) == cst.AssignTarget or 173 | type(grand_grand_parent) == cst.AnnAssign or 174 | type(grand_grand_parent) == cst.AugAssign): 175 | return True 176 | 177 | return False 178 | 179 | def __is_ignored_call(self, call_node): 180 | if type(call_node.func) == cst.Name: 181 | return call_node.func.value in self.ignored_calls 182 | else: 183 | return False 184 | 185 | def visit_SimpleStatementLine(self, node): 186 | # don't visit lines marked with special comment 187 | c = node.trailing_whitespace.comment 188 | if c is not None and c.value == "# don't instrument": 189 | return False 190 | return True 191 | 192 | def visit_Import(self, node): 193 | # don't instrument imports, as we'll wrap them in try-except 194 | return False 195 | 196 | def visit_ImportFrom(self, node): 197 | # don't instrument imports, as we'll wrap them in try-except 198 | return False 199 | 200 | def visit_Del(self, node): 201 | # don't instrument delete statements, as "del" on call on allowed 202 | return False 203 | 204 | def visit_FormattedString(self, node): 205 | if node.start == 'f"' or node.start == 'fr"' or node.start == 'rf"': 206 | self.quotation_char = "'" 207 | self.fstring_stack.append(node) 208 | elif node.start == "f'" or node.start == "fr'" or node.start == 'rf"': 209 | self.quotation_char = '"' 210 | self.fstring_stack.append(node) 211 | if len(self.fstring_stack) > 1: 212 | self.instrument = False 213 | return True 214 | 215 | def leave_FormattedString(self, node, updated_node): 216 | if self.fstring_stack and node == self.fstring_stack[-1]: 217 | # flip quotation character back 218 | if self.quotation_char == "'": 219 | self.quotation_char = '"' 220 | elif self.quotation_char == '"': 221 | self.quotation_char = "'" 222 | self.fstring_stack.pop() 223 | if len(self.fstring_stack) < 2: 224 | self.instrument = True 225 | return updated_node 226 | 227 | def leave_Call(self, node, updated_node): 228 | # rewrite Call nodes to intercept function calls 229 | if not self.__is_ignored_call(node) and not self.line_coverage_instrumentation: 230 | wrapped_call = self.__create_call_call(node, updated_node) 231 | return wrapped_call 232 | else: 233 | return updated_node 234 | 235 | def leave_Name(self, node, updated_node): 236 | if not self.instrument: 237 | return updated_node 238 | 239 | # rewrite Name nodes to intercept values they resolve to 240 | if node in self.used_names and node.value not in self.ignored_names and not self.line_coverage_instrumentation: 241 | wrapped_name = self.__create_name_call(node, updated_node) 242 | return wrapped_name 243 | else: 244 | return updated_node 245 | 246 | def leave_Attribute(self, node, updated_node): 247 | if not self.instrument: 248 | return updated_node 249 | 250 | if not self.__is_l_value(node) and not self.line_coverage_instrumentation: 251 | wrapped_attribute = self.__create_attribute_call(node, updated_node) 252 | return wrapped_attribute 253 | else: 254 | return updated_node 255 | 256 | def leave_SimpleStatementLine(self, node, updated_node): 257 | if isinstance(node.body[0], cst.Expr): 258 | if isinstance(node.body[0].value, cst.SimpleString): 259 | if node.body[0].value.value.startswith('"""'): 260 | return updated_node 261 | 262 | statement_call = self.__create_line_call(node, updated_node) 263 | stmt = cst.SimpleStatementLine(body=[cst.Expr(value=statement_call)], 264 | trailing_whitespace=updated_node.trailing_whitespace) 265 | 266 | if isinstance(node.body[0], cst.Pass): 267 | return cst.FlattenSentinel([updated_node, stmt]) 268 | if isinstance(node.body[0], cst.Return): 269 | if node.body[0].value: 270 | value = updated_node.body[0].value 271 | else: 272 | value = cst.SimpleString(value='""',lpar=[],rpar=[],) 273 | aux_stmt = self.__create_aux_stmt(updated_node, value) 274 | new_return_content = [cst.Return(value=cst.Name(value='aux',lpar=[],rpar=[],), 275 | whitespace_after_return=cst.SimpleWhitespace(value=' ',), 276 | semicolon=cst.MaybeSentinel.DEFAULT,)] 277 | return cst.FlattenSentinel([aux_stmt, stmt, updated_node.with_changes(body=new_return_content)]) 278 | try: 279 | if isinstance(node.body[0], cst.Expr) and isinstance(node.body[0].value, cst.Call) and node.body[0].value.func.value == 'exit': 280 | if len(updated_node.body[0].value.args) < 3: 281 | value = cst.SimpleString(value='""',lpar=[],rpar=[],) 282 | else: 283 | value = updated_node.body[0].value.args[2] 284 | aux_stmt = self.__create_aux_stmt(updated_node, value) 285 | new_exit_content = [cst.Expr( 286 | value=cst.Call( 287 | func=cst.Name(value='exit',lpar=[],rpar=[],), 288 | args=[cst.Arg( 289 | value=cst.Name(value='aux',lpar=[],rpar=[],), 290 | keyword=None, 291 | equal=cst.MaybeSentinel.DEFAULT, 292 | comma=cst.MaybeSentinel.DEFAULT, 293 | star='', 294 | whitespace_after_star=cst.SimpleWhitespace(value='',), 295 | whitespace_after_arg=cst.SimpleWhitespace(value='',), 296 | ),],lpar=[],rpar=[], 297 | whitespace_after_func=cst.SimpleWhitespace(value='',), 298 | whitespace_before_args=cst.SimpleWhitespace(value='',),), 299 | semicolon=cst.MaybeSentinel.DEFAULT,)] 300 | return cst.FlattenSentinel([aux_stmt, stmt, updated_node.with_changes(body=new_exit_content)]) 301 | except Exception as e: 302 | print(e) 303 | if not self.instrument: 304 | return cst.FlattenSentinel([updated_node, stmt]) 305 | 306 | # surround imports with try-except; 307 | # cannot do this in leave_Import because we need to replace the import's parent node 308 | if isinstance(node.body[0], cst.Import) or isinstance(node.body[0], cst.ImportFrom): 309 | # don't wrap __future__ imports 310 | if not (isinstance(node.body[0], cst.ImportFrom) and 311 | node.body[0].module is not None and 312 | node.body[0].module.value == "__future__"): 313 | # don't try-except-pass wrap imports that are already surrounded by try-except (as they should sometimes fail) 314 | skip = False 315 | parent = self.get_metadata(ParentNodeProvider, node) 316 | if isinstance(parent, cst.IndentedBlock): 317 | grand_parent = self.get_metadata( 318 | ParentNodeProvider, parent) 319 | if isinstance(grand_parent, cst.Try): 320 | skip = True 321 | if not skip: 322 | wrapped_import = self.__wrap_import( 323 | node.body[0], updated_node.body[0]) 324 | return wrapped_import 325 | return cst.FlattenSentinel([updated_node, stmt]) 326 | 327 | def leave_For(self, node, updated_node): 328 | return self.__update_indented_block(node, updated_node) 329 | 330 | def leave_While(self, node, updated_node): 331 | return self.__update_indented_block(node, updated_node) 332 | 333 | def leave_FunctionDef(self, node, updated_node): 334 | return self.__update_indented_block(node, updated_node) 335 | 336 | def leave_ClassDef(self, node, updated_node): 337 | return self.__update_indented_block(node, updated_node) 338 | 339 | def leave_With(self, node, updated_node): 340 | return self.__update_indented_block(node, updated_node) 341 | 342 | def leave_If(self, node, updated_node): 343 | return self.__update_indented_block(node, updated_node) 344 | 345 | def leave_Elif(self, node, updated_node): 346 | return self.__update_indented_block(node, updated_node) 347 | 348 | def leave_Try(self, node, updated_node): 349 | return self.__update_indented_block(node, updated_node) 350 | 351 | def leave_ExceptHandler(self, node, updated_node): 352 | return self.__update_indented_block(node, updated_node) 353 | 354 | def leave_Finally(self, node, updated_node): 355 | return self.__update_indented_block(node, updated_node) 356 | 357 | def leave_Module(self, node, updated_node): 358 | if not self.instrument: 359 | return updated_node 360 | 361 | # check for "__future__" imports; they must remain at beginning of file 362 | target_idx = 0 # index to add our imports at 363 | new_body = [] 364 | for i in range(len(updated_node.body)): 365 | stmt = updated_node.body[i] 366 | new_body.append(stmt) 367 | 368 | if (isinstance(stmt, cst.SimpleStatementLine) 369 | and isinstance(stmt.body[0], cst.ImportFrom) 370 | and stmt.body[0].module.value == "__future__"): 371 | target_idx = i + 1 372 | 373 | # add our imports 374 | import_n = self.__create_import("_n_") 375 | import_a = self.__create_import("_a_") 376 | import_c = self.__create_import("_c_") 377 | import_l = self.__create_import("_l_") 378 | 379 | new_body = (list(new_body[:target_idx]) 380 | + [import_n, import_a, import_c, import_l] 381 | + list(new_body[target_idx:])) 382 | 383 | return updated_node.with_changes(body=new_body) 384 | --------------------------------------------------------------------------------