├── .gitattributes ├── .gitignore ├── CodeSyntax ├── CodeSyntax_java.zip ├── CodeSyntax_java_skip_semicolon.zip ├── CodeSyntax_java_with_new_lines.zip ├── CodeSyntax_python.zip └── CodeSyntax_python_with_new_lines.zip ├── LICENSE ├── README.md ├── data └── figures │ ├── English.pdf │ ├── Error Case Head 18-4 for body.pdf │ ├── Error Case Head 18-4 for body.png │ ├── German.pdf │ ├── Head 15-10 Assign target value.pdf │ ├── Head 15-10 Assign target value.png │ ├── Head 15-11 Call func args.pdf │ ├── Head 15-11 Call func args.png │ ├── Head 17-2 if else.pdf │ ├── Head 17-2 if else.png │ ├── Head 18-4 for body.pdf │ ├── Head 18-4 for body.png │ ├── Java Head 16-5 for body.pdf │ ├── Java Head 16-5 for body.png │ ├── Java Head 19-9 func args.pdf │ ├── Java Head 19-9 func args.png │ ├── Java Head 20-10 Assign target value.pdf │ ├── Java Head 20-10 Assign target value.png │ ├── Java Head 9-10 if else.pdf │ ├── Java Head 9-10 if else.png │ ├── NL PL dependency graph.pdf │ ├── Python Head 15-10 Assign target value.pdf │ ├── Python Head 15-10 Assign target value.png │ ├── Python Head 15-11 Call func args.pdf │ ├── Python Head 15-11 Call func args.png │ ├── Python Head 17-2 if else.pdf │ ├── Python Head 17-2 if else.png │ ├── Python Head 18-4 for body.pdf │ ├── Python Head 18-4 for body.png │ ├── code_if_else.PNG │ ├── error case Head 15-10 Assign target value.pdf │ ├── error case Head 15-10 Assign target value.png │ ├── error case Head 15-11 Call func args.pdf │ ├── error case Head 15-11 Call func args.png │ ├── error case Head 17-2 if else.pdf │ ├── error case Head 17-2 if else.png │ ├── error case Java Head 16-5 for body.pdf │ ├── error case Java Head 16-5 for body.png │ ├── error case Java Head 19-9 func args.pdf │ ├── error case Java Head 19-9 func args.png │ ├── error case Java Head 20-10 Assign target value.pdf │ ├── error case Java Head 20-10 Assign target value.png │ ├── error case Java Head 9-10 if else.pdf │ ├── error case Java Head 9-10 if else.png │ ├── java_test_any_semicolon_in_keywords.pdf │ ├── java_test_any_skip_semicolon.pdf │ ├── java_test_any_with_new_lines_semicolon_in_keywords.pdf │ ├── java_test_first_semicolon_in_keywords.pdf │ ├── java_test_last_semicolon_in_keywords.pdf │ ├── java_test_last_skip_semicolon.pdf │ ├── java_test_last_with_new_lines_semicolon_in_keywords.pdf │ ├── java_valid_any_semicolon_in_keywords.pdf │ ├── java_valid_any_skip_semicolon.pdf │ ├── java_valid_first_semicolon_in_keywords.pdf │ ├── java_valid_last_semicolon_in_keywords.pdf │ ├── java_valid_last_skip_semicolon.pdf │ ├── offset_distribution_NL.pdf │ ├── offset_distribution_end.pdf │ ├── offset_distribution_start.pdf │ ├── preview_NL.pdf │ ├── preview_NL_PL.pdf │ ├── preview_PL.pdf │ ├── python_test_any.pdf │ ├── python_test_any_with_new_line.pdf │ ├── python_test_first.pdf │ ├── python_test_last.pdf │ ├── python_test_last_with_new_line.pdf │ ├── python_valid_any.pdf │ ├── python_valid_first.pdf │ └── python_valid_last.pdf ├── evaluating_models ├── NL │ ├── analysis.ipynb │ ├── attention-analysis │ │ ├── LICENSE │ │ ├── README.md │ │ ├── bert │ │ │ ├── __init__.py │ │ │ ├── modeling.py │ │ │ └── tokenization.py │ │ ├── bpe_utils.py │ │ ├── extract_attention.py │ │ ├── preprocess_depparse.py │ │ └── utils.py │ ├── convert_dependency_English.py │ ├── convert_dependency_German.py │ ├── data │ │ ├── depparse_english │ │ │ ├── .gitignore │ │ │ ├── NL_codebert_topk_scores.pkl │ │ │ ├── bert_base_topk_scores.pkl │ │ │ ├── bert_base_topk_scores_cased.pkl │ │ │ ├── bert_base_topk_scores_offset_greedy2.pkl │ │ │ ├── bert_base_topk_scores_offset_greedy2_cased.pkl │ │ │ ├── bert_large_topk_scores.pkl │ │ │ ├── bert_large_topk_scores_cased.pkl │ │ │ ├── roberta_base_topk_scores.pkl │ │ │ ├── roberta_base_topk_scores_offset_greedy2.pkl │ │ │ └── roberta_large_topk_scores.pkl │ │ ├── depparse_german │ │ │ ├── .gitignore │ │ │ ├── dev.json │ │ │ ├── german_bert_topk_scores.pkl │ │ │ ├── german_topk_scores_offset_greedy2.pkl │ │ │ ├── german_xlmr_base_topk_scores.pkl │ │ │ └── german_xlmr_large_topk_scores.pkl │ │ ├── eng_news_txt_tbnk-ptb_revised │ │ │ └── .gitignore │ │ ├── ud-treebanks-v2.8_UD_German-HDT │ │ │ └── .gitignore │ │ └── wsj_dependency │ │ │ └── .gitignore │ ├── merge.py │ ├── preprocess_attn_NL_word_level_sorted.py │ ├── remove_uncommon_datapoints_NL.py │ ├── run_exp_roberta.py │ ├── stanford-parser │ │ └── .gitignore │ ├── topk_scores_attention_NL.py │ ├── topk_scores_baselines_NL.py │ └── treebank_to_dependency.sh └── PL │ ├── analysis.ipynb │ ├── attention-analysis │ ├── LICENSE │ ├── README.md │ ├── bert │ │ ├── __init__.py │ │ ├── modeling.py │ │ └── tokenization.py │ ├── bpe_utils.py │ ├── extract_attention.py │ ├── preprocess_depparse.py │ └── utils.py │ ├── case_study.ipynb │ ├── cubert │ ├── vocab.txt │ └── vocab_java.txt │ ├── data │ ├── CodeBERT_tokenized │ │ └── .gitignore │ ├── CuBERT_tokenized │ │ └── .gitignore │ ├── attention │ │ └── .gitignore │ ├── cubert_model_java │ │ ├── bert_config.json │ │ ├── readme.txt │ │ └── vocab.txt │ ├── cubert_model_python │ │ ├── bert_config.json │ │ ├── readme.txt │ │ └── vocab.txt │ └── scores │ │ └── scores.zip │ ├── preprocess_attn_java.py │ ├── preprocess_attn_python.py │ ├── remove_uncommon_datapoints.py │ ├── run_exp_codebert_java.py │ ├── run_exp_codebert_python.py │ ├── run_exp_java.sh │ ├── run_exp_python.sh │ ├── tokenize_and_align_codebert.py │ ├── tokenize_and_align_cubert_java.py │ ├── tokenize_and_align_cubert_python.py │ ├── topk_scores_baselines_java.py │ ├── topk_scores_baselines_python.py │ ├── topk_scores_codebert_attention_java.py │ ├── topk_scores_codebert_attention_python.py │ ├── topk_scores_cubert_attention_java.py │ └── topk_scores_cubert_attention_python.py └── generating_CodeSyntax ├── Java AST Parser ├── .classpath ├── .project ├── .settings │ └── org.eclipse.jdt.core.prefs ├── bin │ └── .gitignore ├── org.eclipse.core.contenttype_3.7.1000.v20210409-1722.jar ├── org.eclipse.core.jobs_3.11.0.v20210420-1453.jar ├── org.eclipse.core.resources_3.15.0.v20210521-0722.jar ├── org.eclipse.core.runtime_3.22.0.v20210506-1025.jar ├── org.eclipse.equinox.common_3.15.0.v20210518-0604.jar ├── org.eclipse.equinox.preferences_3.8.200.v20210212-1143.jar ├── org.eclipse.jdt.core_3.26.0.v20210609-0549.jar ├── org.eclipse.osgi_3.16.300.v20210525-1715.jar ├── org.eclipse.text_3.12.0.v20210512-1644.jar └── src │ └── generate_labels_java │ └── main.java ├── dataset.ipynb ├── deduplicated_java_code.pickle ├── deduplicated_python_code.pickle ├── generate_labels_java.py └── generate_labels_python.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # data files 3 | evaluating_models/PL/cubert/* 4 | !evaluating_models/PL/cubert/vocab.txt 5 | !evaluating_models/PL/cubert/vocab_java.txt 6 | evaluating_models/PL/data/CuBERT_tokenized/*.json 7 | evaluating_models/PL/data/CodeBERT_tokenized/*.json 8 | evaluating_models/PL/data/cubert_model_java/*ckpt* 9 | evaluating_models/PL/data/cubert_model_python/*ckpt* 10 | evaluating_models/PL/data/attention/*.pkl 11 | evaluating_models/PL/data/scores/*.pkl 12 | evaluating_models/NL/data/depparse_*/dev.json 13 | 14 | 15 | # large dataset files 16 | CodeSyntax/*.json 17 | generating_CodeSyntax/deduplicated*full* 18 | generating_CodeSyntax/Java AST Parser/java_node_start_end_position.txt 19 | generating_CodeSyntax/java/* 20 | generating_CodeSyntax/python/* 21 | 22 | 23 | .ipynb_checkpoints 24 | ignore/ 25 | 26 | 27 | # Byte-compiled / optimized / DLL files 28 | __pycache__/ 29 | *.py[cod] 30 | *$py.class 31 | 32 | # C extensions 33 | *.so 34 | 35 | # Distribution / packaging 36 | .Python 37 | build/ 38 | develop-eggs/ 39 | dist/ 40 | downloads/ 41 | eggs/ 42 | .eggs/ 43 | lib/ 44 | lib64/ 45 | parts/ 46 | sdist/ 47 | var/ 48 | wheels/ 49 | *.egg-info/ 50 | .installed.cfg 51 | *.egg 52 | MANIFEST 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .nox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | .python-version 108 | 109 | # celery beat schedule file 110 | celerybeat-schedule 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | -------------------------------------------------------------------------------- /CodeSyntax/CodeSyntax_java.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/CodeSyntax/CodeSyntax_java.zip -------------------------------------------------------------------------------- /CodeSyntax/CodeSyntax_java_skip_semicolon.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/CodeSyntax/CodeSyntax_java_skip_semicolon.zip -------------------------------------------------------------------------------- /CodeSyntax/CodeSyntax_java_with_new_lines.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/CodeSyntax/CodeSyntax_java_with_new_lines.zip -------------------------------------------------------------------------------- /CodeSyntax/CodeSyntax_python.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/CodeSyntax/CodeSyntax_python.zip -------------------------------------------------------------------------------- /CodeSyntax/CodeSyntax_python_with_new_lines.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/CodeSyntax/CodeSyntax_python_with_new_lines.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking Language Models for Code Syntax Understanding 2 | 3 | This is the repository for the EMNLP 2022 Findings paper "Benchmarking Language Models for Code Syntax Understanding." It contains: 4 | 1. The CodeSyntax dataset, a large-scale dataset of programs annotated with the syntactic relationships represented as edges in their corresponding abstract syntax trees (ASTs) (see folder CodeSyntax). 5 | 2. Code for building the CodeSyntax dataset (see folder generating_CodeSyntax). 6 | 3. Code for evaluating pre-trained language models on code and natural language syntax understanding tasks (see folder evaluating_models). 7 | 8 | 9 | ## The CodeSyntax dataset 10 | The folder CodeSyntax contains the following 5 dataset files in .zip compressed format. 11 | * The default python and java dataset, which we reported main results on: 12 | * CodeSyntax_python.json 13 | * CodeSyntax_java.json 14 | * The modified versions of python and java dataset that are used to conduct ablation study in section 4.2: 15 | * CodeSyntax_python_with_new_lines.json and CodeSyntax_java_with_new_lines.json, where new line tokens are included in ground truth dependent nodes. 16 | * CodeSyntax_java_skip_semicolon.json, where semicolon tokens are removed from ground truth dependent nodes. 17 | 18 | Each dataset file is a list of maps (one map for each code sample) in .json format. The maps have the following keys: 19 | * "code": original source code 20 | * "tokens": a list of code tokens (source code tokenized by python tokenize or javalang module) 21 | * "id": id of the sample 22 | * "relns": a map from relation names to a list of relation edges. Each relation edge is represented by a list of the form [head_index, dependent_start_index, dependent_end_index] where head_index is the index of the head token and the dependent block starts at dependent_start_index and ends at dependent_end_index (inclusively). Note that the indices start at 0. 23 | * For example, consider the following program that contains only one assignment statement and a relation edge from the 1-st token to the 3-rd token: 24 |   {"code": "q = queue" 25 |   "tokens": ["q", "=", "queue"] 26 |   "id": 1 27 |   "relns": { 28 |    "Assign:target->value": [ 29 |     [0,2,2] 30 |    ] 31 |   }} 32 | 33 | ## Building the CodeSyntax dataset 34 | The folder generating_CodeSyntax contains our code to generate the CodeSyntax dataset labeled with relations. 35 | 1. Install requirements: 36 | * python 3.9 37 | * python packages: pip install pandas, ast, javalang, seaborn, matplotlib 38 | * jupyter notebook 39 | * Java SE 16 and Eclipse IDE 40 | 2. Get and process source code from code search net: 41 | * Run the dataset.ipynb notebook. 42 | 3. Deduplicate (remove the code samples used in CuBERT and CodeBERT pre-training). 43 | * Download CuBERT's pre-training dataset information [manifest.json](https://github.com/google-research/google-research/tree/master/cubert), place them in folder "Cubert Python" and "Cubert Java", 44 | * and then run the dataset.ipynb notebook. 45 | 4. Generate labels through AST parser: 46 | * For python, we use the ast package as AST parser and the tokenize package as our tokenizer: 47 | * python generate_labels_python.py 48 | * For java, we use org.eclipse.jdt.core.dom's AST parser to get ast node's start and end positions and then feed it to python to convert position to token index using javalang tokenizer (which is the tokenizer used in CuBERT): 49 | * Open Eclipse IDE. 50 | * Click on import projects->general->import existing projects into workspace and choose the root folder generating_CodeSyntax\Java AST Parser. 51 | * Build and run main.java. 52 | * python generate_labels_java.py 53 | * The generated dataset will be in the CodeSyntax folder. 54 | 5. Generate dataset statistics 55 | * Run the last section of dataset.ipynb notebook. 56 | 57 | 58 | ## To extend CodeSyntax to another language 59 | Follow the workflow discussed in the "building the CodeSyntax dataset" section. You need to find: a source code dataset, a tokenizer, and an AST parser for the target language, and substitube them into the existing framework. 60 | 61 | 62 | ## Evaluating language models 63 | The folder evaluating_models contains our code for evaluating models and plotting results. 64 | 65 | ### To reproduce our results on programming languages 66 | Go to the folder evaluating_models/PL/ and then: 67 | (Note that running pre-trained language models and storing attention weights requires a significant amount of memory and disk space. If you would like to download our extracted attention weights needed to run the notebooks, the weights for the first 1000 samples are available here: https://drive.google.com/file/d/169yaIMSrCnzGQuBSc5wYMJ0F0ScqBScs/view?usp=sharing. You can download and unzip them into the folder PL/data/attention and then skip to step 8 or 9.) 68 | 1. Install requirements: 69 | * To extract attention and evaluate models: 70 | * Download and install [CuBERT](https://github.com/google-research/google-research/tree/master/cubert), [CodeBERT](https://github.com/microsoft/CodeBERT) and corresponding dependencies following their instructions. Save pre-trained CuBERT models in the folder data/cubert_model_python and data/cubert_model_java. 71 | * Python packages: pip install transformers, tensorflow==1.15, torch, tensor2tensor, javalang, numpy, matplotlib 72 | * To plot results: 73 | * Jupyter Notebook 74 | * Python packages: pip install numpy, matplotlib 75 | 2. Unzip dataset in the CodeSyntax folder. 76 | 3. Tokenize source code (code -> CuBERT/CodeBERT subtokens) and generate subtoken-token alignment: 77 | * python tokenize_and_align_cubert_java.py 78 | * python tokenize_and_align_cubert_python.py 79 | * python tokenize_and_align_codebert.py 80 | 4. Run CuBERT/CodeBERT to extract attention and convert attention to word-level: 81 | * bash run_exp_cubert_python.sh 82 | If you save the model checkpoints at a different directory, you need to modify the --bert-dir argument in this .sh script. 83 | * bash run_exp_cubert_python.sh 84 | * python run_exp_codebert_python.py 85 | * python run_exp_codebert_java.py 86 | 5. Preprocess attention (get predictions in the format of token index, sorted by weights): 87 | * python preprocess_attn_python.py 88 | * python preprocess_attn_java.py 89 | * For CodeBERT, this step is already included in the previous step, i.e., run_exp_codebert_python.py. 90 | 6. Remove uncommon data points. CodeBERT tends to generate more subtokens. Sometimes CuBERT is able to process a sample (length <=512>), but CodeBERT can't (length>512). 91 | * python remove_uncommon_datapoints.py 92 | 7. Generate top k scores for pre-trained models and baselines by running the following files: 93 | * For CuBERT on Python dataset: topk_scores_cubert_attention.py 94 | * For CuBERT on Java dataset: topk_scores_cubert_attention_java.py 95 | * For CodeBERT on Python dataset: topk_scores_codebert_attention.py 96 | * For CodeBERT on Java dataset: topk_scores_codebert_attention_java.py 97 | * For baselines on Python dataset: topk_scores_baselines_python.py 98 | * For baselines on Java dataset: topk_scores_baselines_java.py 99 | 8. Plot results and create tables: 100 | * If you did not evaluate models by yourself, you need to unzip the file data/scores/scores.zip. Please place these .pkl data files in the folder data/scores. 101 | * Run the notebook analysis.ipynb 102 | 9. Case study: 103 | * Run the notebook case_study.ipynb 104 | 105 | ### To reproduce our results on natural languages 106 | Go to the folder evaluating_models/NL/ and then: 107 | (Note that we follow the paper "What Does BERT Look At? An Analysis of BERT's Attention" and utilize their code posted at https://github.com/clarkkev/attention-analysis. For more information about the attention-analysis subfolder, please refer to their repository.) 108 | 1. Download: 109 | * [The English News Text Treebank: Penn Treebank Revised](https://catalog.ldc.upenn.edu/LDC2015T13) (not freely available). Unzip and place it in the folder data/eng_news_txt_tbnk-ptb_revised. 110 | * [UD_German-HDT Hamburg Dependency Treebank](https://universaldependencies.org/#download). Unzip and place it in the folder data/ud-treebanks-v2.8_UD_German-HDT. 111 | * [stanford parser](https://nlp.stanford.edu/software/lex-parser.shtml#Download) 112 | * [BERT](https://github.com/google-research/bert) 113 | * [RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/roberta) 114 | * [CodeBERT](https://github.com/microsoft/CodeBERT) 115 | * [Multilingual BERT](https://github.com/google-research/bert/blob/master/multilingual.md) 116 | * [XMR-RoBERTa](https://github.com/pytorch/fairseq/tree/main/examples/xlmr) 117 | * pip install transformers, torch, tensorflow==1.15, matplotlib 118 | 3. Run stanford parser to convert the treebank into dependency labels. The results will be in the folder data/wsj_dependency. 119 | * bash treebank_to_dependency.sh 120 | 4. Convert dependency labels to the format that attention-analysis/preprocess_depparse.py requires. 121 | * python convert_dependency_English.py 122 | * python convert_dependency_German.py 123 | * The results will be in the folder data/deparse_english and data/deparse_german 124 | * sample results: 125 |    Pierre 2-nn 126 |    Vinken 9-nsubj 127 |    , 2-punct 128 |    61 5-num 129 |    years 6-npadvmod 130 |    old 2-amod 131 |    , 2-punct 132 |    ... 133 | 4. Preprocess input data. 134 | * python ./attention-analysis/preprocess_depparse.py --data-dir data/depparse_english 135 | * python ./attention-analysis/preprocess_depparse.py --data-dir data/depparse_german 136 | * sample result of one sentence: 137 |    {"words": ["Pierre", "Vinken", ",", "61", "years", "old", ",", "will", "join", "the", "board", "as", "a", "nonexecutive", "director", "Nov.", "29", "."], 138 |    "relns": ["nn", "nsubj", "punct", "num", "npadvmod", "amod", "punct", "aux", "root", "det", "dobj", "prep", "det", "amod", "pobj", "tmod", "num", "punct"], 139 |    "heads": [2, 9, 2, 5, 6, 2, 2, 9, 0, 11, 9, 9, 15, 15, 12, 9, 16, 9]} 140 | 5. extract attention 141 | * python ./attention-analysis/extract_attention.py --preprocessed-data-file data/depparse_english/dev.json --bert-dir attention-analysis/cased_L-24_H-1024_A-16 --batch_size 4 --word_level --cased 142 | * python ./attention-analysis/extract_attention.py --preprocessed-data-file data/depparse_german/dev.json --bert-dir attention-analysis/multi_cased_L-12_H-768_A-12 --batch_size 4 --word_level --cased 143 | * If you placed pre-trained model checkpoint at a different location, please change the --bert-dir argument. For more information about extract_attention.py, please refer to https://github.com/clarkkev/attention-analysis. 144 | * For RoBERTa and XLM-RoBERTa: python run_exp_roberta.py 145 | 6. Preprocess attention (get predictions in the format of token index, sorted by weights): 146 | * preprocess_attn_NL_word_level_sorted.py for top k 147 | 7. Generate top k scores for attention and baseline: 148 | * topk_scores_attention_NL.py, topk_scores_baselines_NL.py 149 | 8. Plot results and create tables: 150 | * Run the notebook analysis.ipynb 151 | Note that the English treebank is not freely available, so we can not release the English dataset file dev.json. 152 | 153 | ## How to cite 154 | 155 | ```bibtex 156 | @inproceedings{shen-etal-2022-codesyntax, 157 | title = "Benchmarking Language Models for Code Syntax Understanding", 158 | author = "Da Shen and Xinyun Chen and Chenguang Wang and Koushik Sen and Dawn Song", 159 | booktitle = "Findings of the Association for Computational Linguistics: {EMNLP} 2022", 160 | year = "2022", 161 | publisher = "Association for Computational Linguistics" 162 | } 163 | ``` 164 | 165 | 166 | ## Acknowledgements 167 | * [What Does BERT Look At? An Analysis of BERT's Attention](https://github.com/clarkkev/attention-analysis) 168 | * [Code Search Net](https://github.com/github/CodeSearchNet) 169 | * [CuBERT](https://github.com/google-research/google-research/tree/master/cubert) 170 | * [CodeBERT](https://github.com/microsoft/CodeBERT) -------------------------------------------------------------------------------- /data/figures/English.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/English.pdf -------------------------------------------------------------------------------- /data/figures/Error Case Head 18-4 for body.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Error Case Head 18-4 for body.pdf -------------------------------------------------------------------------------- /data/figures/Error Case Head 18-4 for body.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Error Case Head 18-4 for body.png -------------------------------------------------------------------------------- /data/figures/German.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/German.pdf -------------------------------------------------------------------------------- /data/figures/Head 15-10 Assign target value.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 15-10 Assign target value.pdf -------------------------------------------------------------------------------- /data/figures/Head 15-10 Assign target value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 15-10 Assign target value.png -------------------------------------------------------------------------------- /data/figures/Head 15-11 Call func args.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 15-11 Call func args.pdf -------------------------------------------------------------------------------- /data/figures/Head 15-11 Call func args.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 15-11 Call func args.png -------------------------------------------------------------------------------- /data/figures/Head 17-2 if else.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 17-2 if else.pdf -------------------------------------------------------------------------------- /data/figures/Head 17-2 if else.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 17-2 if else.png -------------------------------------------------------------------------------- /data/figures/Head 18-4 for body.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 18-4 for body.pdf -------------------------------------------------------------------------------- /data/figures/Head 18-4 for body.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Head 18-4 for body.png -------------------------------------------------------------------------------- /data/figures/Java Head 16-5 for body.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 16-5 for body.pdf -------------------------------------------------------------------------------- /data/figures/Java Head 16-5 for body.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 16-5 for body.png -------------------------------------------------------------------------------- /data/figures/Java Head 19-9 func args.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 19-9 func args.pdf -------------------------------------------------------------------------------- /data/figures/Java Head 19-9 func args.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 19-9 func args.png -------------------------------------------------------------------------------- /data/figures/Java Head 20-10 Assign target value.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 20-10 Assign target value.pdf -------------------------------------------------------------------------------- /data/figures/Java Head 20-10 Assign target value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 20-10 Assign target value.png -------------------------------------------------------------------------------- /data/figures/Java Head 9-10 if else.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 9-10 if else.pdf -------------------------------------------------------------------------------- /data/figures/Java Head 9-10 if else.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Java Head 9-10 if else.png -------------------------------------------------------------------------------- /data/figures/NL PL dependency graph.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/NL PL dependency graph.pdf -------------------------------------------------------------------------------- /data/figures/Python Head 15-10 Assign target value.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 15-10 Assign target value.pdf -------------------------------------------------------------------------------- /data/figures/Python Head 15-10 Assign target value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 15-10 Assign target value.png -------------------------------------------------------------------------------- /data/figures/Python Head 15-11 Call func args.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 15-11 Call func args.pdf -------------------------------------------------------------------------------- /data/figures/Python Head 15-11 Call func args.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 15-11 Call func args.png -------------------------------------------------------------------------------- /data/figures/Python Head 17-2 if else.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 17-2 if else.pdf -------------------------------------------------------------------------------- /data/figures/Python Head 17-2 if else.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 17-2 if else.png -------------------------------------------------------------------------------- /data/figures/Python Head 18-4 for body.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 18-4 for body.pdf -------------------------------------------------------------------------------- /data/figures/Python Head 18-4 for body.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/Python Head 18-4 for body.png -------------------------------------------------------------------------------- /data/figures/code_if_else.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/code_if_else.PNG -------------------------------------------------------------------------------- /data/figures/error case Head 15-10 Assign target value.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Head 15-10 Assign target value.pdf -------------------------------------------------------------------------------- /data/figures/error case Head 15-10 Assign target value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Head 15-10 Assign target value.png -------------------------------------------------------------------------------- /data/figures/error case Head 15-11 Call func args.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Head 15-11 Call func args.pdf -------------------------------------------------------------------------------- /data/figures/error case Head 15-11 Call func args.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Head 15-11 Call func args.png -------------------------------------------------------------------------------- /data/figures/error case Head 17-2 if else.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Head 17-2 if else.pdf -------------------------------------------------------------------------------- /data/figures/error case Head 17-2 if else.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Head 17-2 if else.png -------------------------------------------------------------------------------- /data/figures/error case Java Head 16-5 for body.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 16-5 for body.pdf -------------------------------------------------------------------------------- /data/figures/error case Java Head 16-5 for body.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 16-5 for body.png -------------------------------------------------------------------------------- /data/figures/error case Java Head 19-9 func args.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 19-9 func args.pdf -------------------------------------------------------------------------------- /data/figures/error case Java Head 19-9 func args.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 19-9 func args.png -------------------------------------------------------------------------------- /data/figures/error case Java Head 20-10 Assign target value.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 20-10 Assign target value.pdf -------------------------------------------------------------------------------- /data/figures/error case Java Head 20-10 Assign target value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 20-10 Assign target value.png -------------------------------------------------------------------------------- /data/figures/error case Java Head 9-10 if else.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 9-10 if else.pdf -------------------------------------------------------------------------------- /data/figures/error case Java Head 9-10 if else.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/error case Java Head 9-10 if else.png -------------------------------------------------------------------------------- /data/figures/java_test_any_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_any_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_test_any_skip_semicolon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_any_skip_semicolon.pdf -------------------------------------------------------------------------------- /data/figures/java_test_any_with_new_lines_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_any_with_new_lines_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_test_first_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_first_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_test_last_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_last_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_test_last_skip_semicolon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_last_skip_semicolon.pdf -------------------------------------------------------------------------------- /data/figures/java_test_last_with_new_lines_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_test_last_with_new_lines_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_valid_any_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_valid_any_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_valid_any_skip_semicolon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_valid_any_skip_semicolon.pdf -------------------------------------------------------------------------------- /data/figures/java_valid_first_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_valid_first_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_valid_last_semicolon_in_keywords.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_valid_last_semicolon_in_keywords.pdf -------------------------------------------------------------------------------- /data/figures/java_valid_last_skip_semicolon.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/java_valid_last_skip_semicolon.pdf -------------------------------------------------------------------------------- /data/figures/offset_distribution_NL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/offset_distribution_NL.pdf -------------------------------------------------------------------------------- /data/figures/offset_distribution_end.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/offset_distribution_end.pdf -------------------------------------------------------------------------------- /data/figures/offset_distribution_start.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/offset_distribution_start.pdf -------------------------------------------------------------------------------- /data/figures/preview_NL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/preview_NL.pdf -------------------------------------------------------------------------------- /data/figures/preview_NL_PL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/preview_NL_PL.pdf -------------------------------------------------------------------------------- /data/figures/preview_PL.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/preview_PL.pdf -------------------------------------------------------------------------------- /data/figures/python_test_any.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_test_any.pdf -------------------------------------------------------------------------------- /data/figures/python_test_any_with_new_line.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_test_any_with_new_line.pdf -------------------------------------------------------------------------------- /data/figures/python_test_first.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_test_first.pdf -------------------------------------------------------------------------------- /data/figures/python_test_last.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_test_last.pdf -------------------------------------------------------------------------------- /data/figures/python_test_last_with_new_line.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_test_last_with_new_line.pdf -------------------------------------------------------------------------------- /data/figures/python_valid_any.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_valid_any.pdf -------------------------------------------------------------------------------- /data/figures/python_valid_first.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_valid_first.pdf -------------------------------------------------------------------------------- /data/figures/python_valid_last.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/data/figures/python_valid_last.pdf -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kevin Clark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/README.md: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | # BERT Attention Analysis 4 | 5 | This repository contains code for [What Does BERT Look At? An Analysis of BERT's Attention](https://arxiv.org/abs/1906.04341). 6 | It includes code for getting attention maps from BERT and writing them to disk, analyzing BERT's attention in general (sections 3 and 6 of the paper), and comparing its attention to dependency syntax (sections 4.2 and 5). 7 | We will add the code for the coreference resolution analysis (section 4.3 of the paper) soon! 8 | 9 | ## Requirements 10 | For extracting attention maps from text: 11 | * [Tensorflow](https://www.tensorflow.org/) 12 | * [NumPy](http://www.numpy.org/) 13 | 14 | Additional requirements for the attention analysis: 15 | * [Jupyter](https://jupyter.org/https://jupyter.org/) 16 | * [MatplotLib](https://matplotlib.org/) 17 | * [seaborn](https://seaborn.pydata.org/index.html) 18 | * [scikit-learn](https://scikit-learn.org/) 19 | 20 | ## Attention Analysis 21 | `Syntax_Analysis.ipynb` and `General_Analysis.ipynb` 22 | contain code for analyzing BERT's attention, including reproducing the figures and tables in the paper. 23 | 24 | You can download the data needed to run the notebooks (including BERT attention maps on Wikipedia 25 | and the Penn Treebank) from [here](https://drive.google.com/open?id=1DEIBQIl0Q0az5ZuLoy4_lYabIfLSKBg-). However, note that the Penn Treebank annotations are not 26 | freely available, so the Penn Treebank data only includes dummy labels. 27 | If you want to run the analysis on your own data, you can use the scripts described below to extract BERT attention maps. 28 | 29 | ## Extracting BERT Attention Maps 30 | We provide a script for running BERT over text and writing the resulting 31 | attention maps to disk. 32 | The input data should be a [JSON](https://www.json.org/) file containing a 33 | list of dicts, each one corresponding to a single example to be passed in 34 | to BERT. Each dict must contain exactly one of the following fields: 35 | * `"text"`: A string. 36 | * `"words"`: A list of strings. Needed if you want word-level rather than 37 | token-level attention. 38 | * `"tokens"`: A list of strings corresponding to BERT wordpiece tokenization. 39 | 40 | If the present field is "tokens," the script expects [CLS]/[SEP] tokens 41 | to be already added; otherwise it adds these tokens to the 42 | beginning/end of the text automatically. 43 | Note that if an example is longer than `max_sequence_length` tokens 44 | after BERT wordpiece tokenization, attention maps will not be extracted for it. 45 | Attention extraction adds two additional fields to each dict: 46 | * `"attns"`: A numpy array of size [num_layers, heads_per_layer, sequence_length, 47 | sequence_length] containing attention weights. 48 | * `"tokens"`: If `"tokens"` was not already provided for the example, the 49 | BERT-wordpiece-tokenized text (list of strings). 50 | 51 | Other fields already in the feature dicts will be preserved. For example 52 | if each dict has a `tags` key containing POS tags, they will stay 53 | in the data after attention extraction so they can be used when 54 | analyzing the data. 55 | 56 | Attention extraction is run with 57 | ``` 58 | python extract_attention.py --preprocessed_data_file --bert_dir 59 | ``` 60 | The following optional arguments can also be added: 61 | * `--max_sequence_length`: Maximum input sequence length after tokenization (default is 128). 62 | * `--batch_size`: Batch size when running BERT over examples (default is 16). 63 | * `--debug`: Use a tiny BERT model for fast debugging. 64 | * `--cased`: Do not lowercase the input text. 65 | * `--word_level`: Compute word-level instead of token-level attention (see Section 4.1 of the paper). 66 | 67 | The feature dicts with added attention maps (numpy arrays with shape [n_layers, n_heads_per_layer, n_tokens, n_tokens]) are written to `_attn.pkl` 68 | 69 | 70 | ## Pre-processing Scripts 71 | We include two pre-processing scripts for going from a raw data file to 72 | JSON that can be supplied to ``attention_extractor.py``. 73 | 74 | `preprocess_unlabeled.py` does BERT-pre-training-style preprocessing for unlabeled text 75 | (i.e, taking two consecutive text spans, truncating them so they are at most 76 | `max_sequence_length` tokens, and adding [CLS]/[SEP] tokens). 77 | Each line of the input data file 78 | should be one sentence. Documents should be separated by empty lines. 79 | Example usage: 80 | ``` 81 | python preprocess_unlabeled.py --data-file $ATTN_DATA_DIR/unlabeled.txt --bert-dir $ATTN_DATA_DIR/uncased_L-12_H-768_A-12 82 | ``` 83 | will create the file `$ATTN_DATA_DIR/unlabeled.json` containing pre-processed data. 84 | After pre-processing, you can run `extract_attention.py` to get attention maps, e.g., 85 | ``` 86 | python extract_attention.py --preprocessed-data-file $ATTN_DATA_DIR/unlabeled.json --bert-dir $ATTN_DATA_DIR/uncased_L-12_H-768_A-12 87 | ``` 88 | 89 | 90 | `preprocess_depparse.py` pre-processes dependency parsing data. 91 | Dependency parsing data should consist of two files `train.txt` and `dev.txt` under a common directory. 92 | Each line in the files should contain a word followed by a space followed by - 93 | (e.g., 0-root). Examples should be separated by empty lines. Example usage: 94 | ``` 95 | python preprocess_depparse.py --data-dir $ATTN_DATA_DIR/depparse 96 | ``` 97 | 98 | After pre-processing, you can run `extract_attention.py` to get attention maps, e.g., 99 | ``` 100 | python extract_attention.py --preprocessed-data-file $ATTN_DATA_DIR/depparse/dev.json --bert-dir $ATTN_DATA_DIR/uncased_L-12_H-768_A-12 --word_level 101 | ``` 102 | ## Computing Distances Between Attention Heads 103 | `head_distances.py` computes the average Jenson-Shannon divergence between the attention weights of all pairs of attention heads and writes the results to disk as a numpy array of shape [n_heads, n_heads]. These distances can be used to cluster BERT's attention heads (see Section 6 and Figure 6 of the paper; code for doing this clustering is in `General_Analysis.ipynb`). Example usage (requires that attention maps have already been extracted): 104 | ``` 105 | python head_distances.py --attn-data-file $ATTN_DATA_DIR/unlabeled_attn.pkl --outfile $ATTN_DATA_DIR/head_distances.pkl 106 | ``` 107 | 108 | ## Citation 109 | If you find the code or data helpful, please cite the original paper: 110 | 111 | ``` 112 | @inproceedings{clark2019what, 113 | title = {What Does BERT Look At? An Analysis of BERT's Attention}, 114 | author = {Kevin Clark and Urvashi Khandelwal and Omer Levy and Christopher D. Manning}, 115 | booktitle = {BlackBoxNLP@ACL}, 116 | year = {2019} 117 | } 118 | ``` 119 | 120 | ## Contact 121 | [Kevin Clark](https://cs.stanford.edu/~kevclark/) (@clarkkev). 122 | -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/bpe_utils.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | """Going from BERT's bpe tokenization to word-level tokenization.""" 4 | 5 | import utils 6 | from bert import tokenization 7 | 8 | import numpy as np 9 | 10 | 11 | def tokenize_and_align(tokenizer, words, cased): 12 | """Given already-tokenized text (as a list of strings), returns a list of 13 | lists where each sub-list contains BERT-tokenized tokens for the 14 | correponding word.""" 15 | 16 | words = ["[CLS]"] + words + ["[SEP]"] 17 | basic_tokenizer = tokenizer.basic_tokenizer 18 | tokenized_words = [] 19 | for word in words: 20 | word = tokenization.convert_to_unicode(word) 21 | word = basic_tokenizer._clean_text(word) 22 | if word == "[CLS]" or word == "[SEP]": 23 | word_toks = [word] 24 | else: 25 | if not cased: 26 | word = word.lower() 27 | word = basic_tokenizer._run_strip_accents(word) 28 | word_toks = basic_tokenizer._run_split_on_punc(word) 29 | 30 | tokenized_word = [] 31 | for word_tok in word_toks: 32 | tokenized_word += tokenizer.wordpiece_tokenizer.tokenize(word_tok) 33 | tokenized_words.append(tokenized_word) 34 | 35 | i = 0 36 | word_to_tokens = [] 37 | for word in tokenized_words: 38 | tokens = [] 39 | for _ in word: 40 | tokens.append(i) 41 | i += 1 42 | word_to_tokens.append(tokens) 43 | assert len(word_to_tokens) == len(words) 44 | 45 | return word_to_tokens 46 | 47 | 48 | def get_word_word_attention(token_token_attention, words_to_tokens, 49 | mode="first"): 50 | """Convert token-token attention to word-word attention (when tokens are 51 | derived from words using something like byte-pair encodings).""" 52 | 53 | word_word_attention = np.array(token_token_attention) 54 | not_word_starts = [] 55 | for word in words_to_tokens: 56 | not_word_starts += word[1:] 57 | 58 | # sum up the attentions for all tokens in a word that has been split 59 | for word in words_to_tokens: 60 | word_word_attention[:, word[0]] = word_word_attention[:, word].sum(axis=-1) 61 | word_word_attention = np.delete(word_word_attention, not_word_starts, -1) 62 | 63 | # several options for combining attention maps for words that have been split 64 | # we use "mean" in the paper 65 | for word in words_to_tokens: 66 | if mode == "first": 67 | pass 68 | elif mode == "mean": 69 | word_word_attention[word[0]] = np.mean(word_word_attention[word], axis=0) 70 | elif mode == "max": 71 | word_word_attention[word[0]] = np.max(word_word_attention[word], axis=0) 72 | word_word_attention[word[0]] /= word_word_attention[word[0]].sum() 73 | else: 74 | raise ValueError("Unknown aggregation mode", mode) 75 | word_word_attention = np.delete(word_word_attention, not_word_starts, 0) 76 | 77 | return word_word_attention 78 | 79 | 80 | def make_attn_word_level(data, tokenizer, cased): 81 | for features in utils.logged_loop(data): 82 | words_to_tokens = tokenize_and_align(tokenizer, features["words"], cased) 83 | assert sum(len(word) for word in words_to_tokens) == len(features["tokens"]) 84 | features["attns"] = np.stack([[ 85 | get_word_word_attention(attn_head, words_to_tokens) 86 | for attn_head in layer_attns] for layer_attns in features["attns"]]) 87 | -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/extract_attention.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | """Runs BERT over input data and writes out its attention maps to disk.""" 4 | 5 | import argparse 6 | import os 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from bert import modeling 11 | from bert import tokenization 12 | import bpe_utils 13 | import utils 14 | 15 | 16 | class Example(object): 17 | """Represents a single input sequence to be passed into BERT.""" 18 | 19 | def __init__(self, features, tokenizer, max_sequence_length,): 20 | self.features = features 21 | 22 | if "tokens" in features: 23 | self.tokens = features["tokens"] 24 | else: 25 | if "text" in features: 26 | text = features["text"] 27 | else: 28 | text = " ".join(features["words"]) 29 | self.tokens = ["[CLS]"] + tokenizer.tokenize(text) + ["[SEP]"] 30 | 31 | self.input_ids = tokenizer.convert_tokens_to_ids(self.tokens) 32 | self.segment_ids = [0] * len(self.tokens) 33 | self.input_mask = [1] * len(self.tokens) 34 | while len(self.input_ids) < max_sequence_length: 35 | self.input_ids.append(0) 36 | self.input_mask.append(0) 37 | self.segment_ids.append(0) 38 | 39 | 40 | def examples_in_batches(examples, batch_size): 41 | for i in utils.logged_loop(range(1 + ((len(examples) - 1) // batch_size))): 42 | yield examples[i * batch_size:(i + 1) * batch_size] 43 | 44 | 45 | class AttnMapExtractor(object): 46 | """Runs BERT over examples to get its attention maps.""" 47 | 48 | def __init__(self, bert_config_file, init_checkpoint, 49 | max_sequence_length=128, debug=False): 50 | make_placeholder = lambda name: tf.placeholder( 51 | tf.int32, shape=[None, max_sequence_length], name=name) 52 | self._input_ids = make_placeholder("input_ids") 53 | self._segment_ids = make_placeholder("segment_ids") 54 | self._input_mask = make_placeholder("input_mask") 55 | 56 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 57 | if debug: 58 | bert_config.num_hidden_layers = 3 59 | bert_config.hidden_size = 144 60 | self._attn_maps = modeling.BertModel( 61 | config=bert_config, 62 | is_training=False, 63 | input_ids=self._input_ids, 64 | input_mask=self._input_mask, 65 | token_type_ids=self._segment_ids, 66 | use_one_hot_embeddings=True).attn_maps 67 | 68 | if not debug: 69 | print("Loading BERT from checkpoint...") 70 | assignment_map, _ = modeling.get_assignment_map_from_checkpoint( 71 | tf.trainable_variables(), init_checkpoint) 72 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 73 | 74 | def get_attn_maps(self, sess, examples): 75 | feed = { 76 | self._input_ids: np.vstack([e.input_ids for e in examples]), 77 | self._segment_ids: np.vstack([e.segment_ids for e in examples]), 78 | self._input_mask: np.vstack([e.input_mask for e in examples]) 79 | } 80 | return sess.run(self._attn_maps, feed_dict=feed) 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser(description=__doc__) 85 | parser.add_argument( 86 | "--preprocessed-data-file", required=True, 87 | help="Location of preprocessed data (JSON file); see the README for " 88 | "expected data format.") 89 | parser.add_argument("--bert-dir", required=True, 90 | help="Location of the pre-trained BERT model.") 91 | parser.add_argument("--cased", default=False, action='store_true', 92 | help="Don't lowercase the input.") 93 | parser.add_argument("--max_sequence_length", default=128, type=int, 94 | help="Maximum input sequence length after tokenization " 95 | "(default=128).") 96 | parser.add_argument("--batch_size", default=16, type=int, 97 | help="Batch size when running BERT (default=16).") 98 | parser.add_argument("--debug", default=False, action='store_true', 99 | help="Use tiny model for fast debugging.") 100 | parser.add_argument("--word_level", default=False, action='store_true', 101 | help="Get word-level rather than token-level attention.") 102 | args = parser.parse_args() 103 | 104 | print("Creating examples...") 105 | tokenizer = tokenization.FullTokenizer( 106 | vocab_file=os.path.join(args.bert_dir, "vocab.txt"), 107 | do_lower_case=not args.cased) 108 | examples = [] 109 | for features in utils.load_json(args.preprocessed_data_file): 110 | example = Example(features, tokenizer, args.max_sequence_length) 111 | if len(example.input_ids) <= args.max_sequence_length: 112 | examples.append(example) 113 | 114 | print("Building BERT model...") 115 | extractor = AttnMapExtractor( 116 | os.path.join(args.bert_dir, "bert_config.json"), 117 | os.path.join(args.bert_dir, "bert_model.ckpt"), 118 | args.max_sequence_length, args.debug 119 | ) 120 | 121 | print("Extracting attention maps...") 122 | feature_dicts_with_attn = [] 123 | with tf.Session() as sess: 124 | sess.run(tf.global_variables_initializer()) 125 | for batch_of_examples in examples_in_batches(examples, args.batch_size): 126 | attns = extractor.get_attn_maps(sess, batch_of_examples) 127 | for e, e_attn in zip(batch_of_examples, attns): 128 | seq_len = len(e.tokens) 129 | e.features["attns"] = e_attn[:, :, :seq_len, :seq_len].astype("float16") 130 | e.features["tokens"] = e.tokens 131 | feature_dicts_with_attn.append(e.features) 132 | 133 | if args.word_level: 134 | print("Converting to word-level attention...") 135 | bpe_utils.make_attn_word_level( 136 | feature_dicts_with_attn, tokenizer, args.cased) 137 | 138 | outpath = args.preprocessed_data_file.replace(".json", "") 139 | outpath += "_attn.pkl" 140 | print("Writing attention maps to {:}...".format(outpath)) 141 | utils.write_pickle(feature_dicts_with_attn, outpath) 142 | print("Done!") 143 | 144 | 145 | if __name__ == "__main__": 146 | main() -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/preprocess_depparse.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | """Preprocesses dependency parsing data and writes the result as JSON.""" 4 | 5 | import argparse 6 | import os 7 | 8 | import utils 9 | 10 | 11 | def preprocess_depparse_data(raw_data_file): 12 | examples = [] 13 | with open(raw_data_file, encoding='utf-8') as f: 14 | current_example = {"words": [], "relns": [], "heads": []} 15 | for line in f: 16 | line = line.strip() 17 | if line: 18 | word, label = line.split() 19 | head, reln = label.split("-") 20 | head = int(head) 21 | current_example["words"].append(word) 22 | current_example["relns"].append(reln) 23 | current_example["heads"].append(head) 24 | else: 25 | examples.append(current_example) 26 | current_example = {"words": [], "relns": [], "heads": []} 27 | utils.write_json(examples, raw_data_file.replace(".txt", ".json")) 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser(description=__doc__) 32 | parser.add_argument( 33 | "--data-dir", required=True, 34 | help="The location of dependency parsing data. Should contain files " 35 | "train.txt and dev.txt. See the README for expected data format.") 36 | args = parser.parse_args() 37 | for split in ["train", "dev"]: 38 | print("Preprocessing {:} data...".format(split)) 39 | preprocess_depparse_data(os.path.join(args.data_dir, split + ".txt")) 40 | print("Done!") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /evaluating_models/NL/attention-analysis/utils.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import pickle 9 | import time 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def load_json(path): 15 | with tf.gfile.GFile(path, 'r') as f: 16 | return json.load(f) 17 | 18 | 19 | def write_json(o, path): 20 | tf.gfile.MakeDirs(path.rsplit('/', 1)[0]) 21 | with tf.gfile.GFile(path, 'w') as f: 22 | json.dump(o, f, ensure_ascii=False) 23 | 24 | 25 | def load_pickle(path): 26 | with tf.gfile.GFile(path, 'rb') as f: 27 | return pickle.load(f) 28 | 29 | 30 | def write_pickle(o, path): 31 | if '/' in path: 32 | tf.gfile.MakeDirs(path.rsplit('/', 1)[0]) 33 | with tf.gfile.GFile(path, 'wb') as f: 34 | pickle.dump(o, f, -1) 35 | 36 | 37 | def logged_loop(iterable, n=None, **kwargs): 38 | if n is None: 39 | n = len(iterable) 40 | ll = LoopLogger(n, **kwargs) 41 | for i, elem in enumerate(iterable): 42 | ll.update(i + 1) 43 | yield elem 44 | 45 | 46 | class LoopLogger(object): 47 | """Class for printing out progress/ETA for a loop.""" 48 | 49 | def __init__(self, max_value=None, step_size=1, n_steps=25, print_time=True): 50 | self.max_value = max_value 51 | if n_steps is not None: 52 | self.step_size = max(1, max_value // n_steps) 53 | else: 54 | self.step_size = step_size 55 | self.print_time = print_time 56 | self.n = 0 57 | self.start_time = time.time() 58 | 59 | def step(self, values=None): 60 | self.update(self.n + 1, values) 61 | 62 | def update(self, i, values=None): 63 | self.n = i 64 | if self.n % self.step_size == 0 or self.n == self.max_value: 65 | if self.max_value is None: 66 | msg = 'On item ' + str(self.n) 67 | else: 68 | msg = '{:}/{:} = {:.1f}%'.format(self.n, self.max_value, 69 | 100.0 * self.n / self.max_value) 70 | if self.print_time: 71 | time_elapsed = time.time() - self.start_time 72 | time_per_step = time_elapsed / self.n 73 | msg += ', ELAPSED: {:.1f}s'.format(time_elapsed) 74 | msg += ', ETA: {:.1f}s'.format((self.max_value - self.n) 75 | * time_per_step) 76 | if values is not None: 77 | for k, v in values: 78 | msg += ' - ' + str(k) + ': ' + ('{:.4f}'.format(v) 79 | if isinstance(v, float) else str(v)) 80 | print(msg) 81 | -------------------------------------------------------------------------------- /evaluating_models/NL/convert_dependency_English.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | # script to convert stanford dependency to the format that preprocess_depparse.py requires 4 | examples=[] 5 | for i in range(1, 2455): 6 | try: 7 | with open("data/wsj_dependency/wsj_"+f'{i:04}'+".sd", "r") as f: 8 | lines = f.read().split('\n') 9 | examples.extend(lines) 10 | examples.append("") 11 | except Exception as e: 12 | print(e) 13 | 14 | 15 | regex = re.compile(r'([a-z]+)\(.+-(\d+), (.+)-\d+\)') 16 | formatted_examples = [] # a word followed by a space followed by index_of_head-dependency_label (e.g., 0-root) 17 | 18 | for line in examples: 19 | if line != "": 20 | mo = regex.search(line) 21 | label, index, word = mo.groups() 22 | formatted_examples.append(word+" "+index+"-"+label) 23 | else: 24 | formatted_examples.append("") 25 | 26 | 27 | with open("data/deparse_english/dev.txt", "w") as f: 28 | f.write("\n".join(formatted_examples)) -------------------------------------------------------------------------------- /evaluating_models/NL/convert_dependency_German.py: -------------------------------------------------------------------------------- 1 | # script to convert universal dependency to the format that preprocess_depparse.py requires 2 | examples=[] 3 | for partition in ["test", "dev", "train"]: 4 | with open("data/ud-treebanks-v2.8_UD_German-HDT/de_hdt-ud-"+partition+".conllu", "r", encoding='utf-8') as f: 5 | lines = f.read().split('\n') 6 | examples.extend(lines) 7 | 8 | 9 | formatted_examples = [] # a word followed by a space followed by index_of_head-dependency_label (e.g., 0-root) 10 | count = 0 11 | 12 | for line in examples: 13 | if line != "" and not line.startswith("#"): 14 | words = line.split("\t") 15 | word, index, label = words[1], words[6], words[7] 16 | formatted_examples.append(word+" "+index+"-"+label) 17 | elif line == "": 18 | count += 1 19 | formatted_examples.append("") 20 | 21 | print("number of examples:", count) 22 | 23 | 24 | with open("data/depparse_german/dev.txt", "w", encoding='utf-8') as f: 25 | f.write("\n".join(formatted_examples)) -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/.gitignore -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/NL_codebert_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/NL_codebert_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/bert_base_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/bert_base_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/bert_base_topk_scores_cased.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/bert_base_topk_scores_cased.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/bert_base_topk_scores_offset_greedy2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/bert_base_topk_scores_offset_greedy2.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/bert_base_topk_scores_offset_greedy2_cased.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/bert_base_topk_scores_offset_greedy2_cased.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/bert_large_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/bert_large_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/bert_large_topk_scores_cased.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/bert_large_topk_scores_cased.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/roberta_base_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/roberta_base_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/roberta_base_topk_scores_offset_greedy2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/roberta_base_topk_scores_offset_greedy2.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_english/roberta_large_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_english/roberta_large_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_german/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_german/.gitignore -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_german/german_bert_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_german/german_bert_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_german/german_topk_scores_offset_greedy2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_german/german_topk_scores_offset_greedy2.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_german/german_xlmr_base_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_german/german_xlmr_base_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/depparse_german/german_xlmr_large_topk_scores.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/depparse_german/german_xlmr_large_topk_scores.pkl -------------------------------------------------------------------------------- /evaluating_models/NL/data/eng_news_txt_tbnk-ptb_revised/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/eng_news_txt_tbnk-ptb_revised/.gitignore -------------------------------------------------------------------------------- /evaluating_models/NL/data/ud-treebanks-v2.8_UD_German-HDT/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/ud-treebanks-v2.8_UD_German-HDT/.gitignore -------------------------------------------------------------------------------- /evaluating_models/NL/data/wsj_dependency/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/data/wsj_dependency/.gitignore -------------------------------------------------------------------------------- /evaluating_models/NL/merge.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | max_attn_data = [] 4 | for k in range(6): 5 | with open("data/depparse/dev_attn_sorted_roberta_large"+str(k*10000)+"_"+str((k+1)*10000)+".pkl", "rb") as f: 6 | max_attn_data.extend(pickle.load(f)) 7 | 8 | with open("data/depparse/dev_attn_sorted_roberta_large.pkl", "wb") as f: 9 | pickle.dump(max_attn_data,f) -------------------------------------------------------------------------------- /evaluating_models/NL/preprocess_attn_NL_word_level_sorted.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | 5 | # preprocess attention 6 | # find the token that is predicted to be head for each single attention head 7 | def load_pickle(fname): 8 | with open(fname, "rb") as f: 9 | return pickle.load(f) 10 | 11 | 12 | max_attn_data = [] 13 | # attn_data = load_pickle("data/depparse/dev_attn_bert_base.pkl") 14 | attn_data = load_pickle("data/depparse_german/dev_attn.pkl") 15 | for i, data in enumerate(attn_data): 16 | if i % 1000 == 0: 17 | print("processing example", i) 18 | # cls and sep have not been removed from word-level attention 19 | attn = data["attns"] 20 | attn[:, :, range(attn.shape[2]), range(attn.shape[2])] = 0 21 | attn = attn[:,:, 1:-1, 1:-1] 22 | max_attn = np.flip(np.argsort(attn, axis=3),axis=3).astype(np.int16)[:,:,:,:20] 23 | max_attn_data.append({"words": data["words"], "max_attn": max_attn, "relns": data["relns"], "heads": data["heads"], "id": i}) 24 | 25 | # with open("data/depparse/dev_attn_sorted_bert_base.pkl", "wb") as f: 26 | with open("data/depparse_german/dev_attn_sorted_bert.pkl", "wb") as f: 27 | pickle.dump(max_attn_data,f) 28 | 29 | with open("data/depparse_german/dev_attn_sorted_bert_short.pkl", "wb") as f: 30 | pickle.dump(max_attn_data[0:1000],f) -------------------------------------------------------------------------------- /evaluating_models/NL/remove_uncommon_datapoints_NL.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | # remove uncommon data points 4 | 5 | with open("data/depparse_german/dev_attn_sorted_bert.pkl", "rb") as f: 6 | bert = pickle.load(f) 7 | with open("data/depparse_german/dev_attn_sorted_xlmr_base.pkl", "rb") as f: 8 | roberta = pickle.load(f) 9 | 10 | 11 | bert_ids = set() 12 | roberta_ids = set() 13 | for data in bert: 14 | bert_ids.add(data["id"]) 15 | for data in roberta: 16 | roberta_ids.add(data["id"]) 17 | 18 | # uncommon_ids = bert_ids.symmetric_difference(roberta_ids) 19 | # print("we have", len(uncommon_ids), "uncommon datapoints") 20 | 21 | for filename, all_data in [("dev_attn_sorted_bert_base_cased", bert), 22 | ("dev_attn_sorted_bert_base", None), 23 | ("dev_attn_sorted_bert_large_cased", None), 24 | ("dev_attn_sorted_bert_large", None), 25 | ("dev_attn_sorted_roberta_base", None), 26 | ("dev_attn_sorted_roberta_large", roberta)]: 27 | if all_data == None: 28 | with open("data/depparse/"+filename+".pkl", "rb") as f: 29 | all_data = pickle.load(f) 30 | common_data = [] 31 | for data in all_data: 32 | if data["id"] not in uncommon_ids: 33 | common_data.append(data) 34 | print(filename, "has", len(common_data), "datapoints") 35 | with open("data/depparse/"+filename+".pkl", "wb") as f: 36 | pickle.dump(common_data, f) -------------------------------------------------------------------------------- /evaluating_models/NL/run_exp_roberta.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaTokenizerFast, XLMRobertaTokenizerFast, XLMRobertaModel 6 | 7 | 8 | 9 | # extract attention and convert attention to word-level 10 | 11 | def align_codebert_tokens(codebert_encodings, words, id=-1): 12 | start, end = 0, 0 13 | result = [] 14 | # produce alignment by using start and end 15 | for word in words: 16 | end = start+len(word) 17 | codebert_token_index_first = codebert_encodings.char_to_token(start) 18 | codebert_token_index_last = codebert_encodings.char_to_token(end-1) 19 | result.append([*range(codebert_token_index_first, codebert_token_index_last+1)]) 20 | # print(repr(word), codebert_encodings.tokens()[codebert_token_index_first: codebert_token_index_last+1]) 21 | start = end + 1 22 | assert len(result) == len(words) # assert that every word is mapped to some codebert tokens 23 | for tokens in result: 24 | assert len(tokens) > 0 25 | tokens = [item for sublist in result for item in sublist] 26 | assert len(tokens) == len(set(tokens)) # assert that no codebert token is mapped twice 27 | return result 28 | 29 | def get_word_word_attention(token_token_attention, words_to_tokens, length, 30 | mode="mean"): 31 | """This function is adopted from paper "What Does BERT Look At? An Analysis of BERT's Attention" 32 | Convert token-token attention to word-word attention (when tokens are 33 | derived from words using something like byte-pair encodings).""" 34 | 35 | word_starts = set() 36 | for word in words_to_tokens: 37 | word_starts.add(word[0]) 38 | not_word_starts = [i for i in range(length) if i not in word_starts] 39 | 40 | # sum up the attentions for all tokens in a word that has been split 41 | for word in words_to_tokens: 42 | token_token_attention[:, word[0]] = token_token_attention[:, word].sum(axis=-1) 43 | token_token_attention = np.delete(token_token_attention, not_word_starts, -1) 44 | 45 | # several options for combining attention maps for words that have been split 46 | # we use "mean" in the paper 47 | for word in words_to_tokens: 48 | if mode == "first": 49 | pass 50 | elif mode == "mean": 51 | token_token_attention[word[0]] = np.mean(token_token_attention[word], axis=0) 52 | elif mode == "max": 53 | token_token_attention[word[0]] = np.max(token_token_attention[word], axis=0) 54 | token_token_attention[word[0]] /= token_token_attention[word[0]].sum() 55 | else: 56 | raise ValueError("Unknown aggregation mode", mode) 57 | token_token_attention = np.delete(token_token_attention, not_word_starts, 0) 58 | 59 | 60 | return token_token_attention 61 | 62 | 63 | def make_attn_word_level(alignment, attn, length): 64 | """This function is adopted from paper What Does BERT Look At? An Analysis of BERT's Attention""" 65 | return np.stack([[ 66 | get_word_word_attention(attn_head, alignment, length) 67 | for attn_head in layer_attns] for layer_attns in attn]) 68 | 69 | 70 | 71 | # run CodeBERT to extract attention and convert attention to word-level 72 | config = "roberta-large" # "xlm-roberta-large" "roberta-large" 73 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 74 | # tokenizer = RobertaTokenizerFast.from_pretrained(config) 75 | # model = RobertaModel.from_pretrained(config) 76 | tokenizer = XLMRobertaTokenizerFast.from_pretrained(config) 77 | model = XLMRobertaModel.from_pretrained(config) 78 | model.to(device) 79 | model.eval() 80 | 81 | with open("data/depparse_german/dev.json", 'r') as f: 82 | # with open("data/depparse_english/dev.json", 'r') as f: 83 | examples = json.load(f) 84 | 85 | max_attn_data = [] 86 | for i, example in enumerate(examples): 87 | if i%100 == 0: 88 | print(i) 89 | words = example["words"] 90 | encodings = tokenizer(" ".join(words)) 91 | if len(encodings.tokens()) < 512: 92 | input_tensor = torch.tensor(encodings['input_ids'], device=device).unsqueeze(0) 93 | outputs = model(input_tensor, output_attentions=True) 94 | attn = outputs.attentions # list of tensors of shape 1*12*num_of_tokens*num_of_tokens 95 | attn = np.vstack([layer.cpu().detach().numpy() for layer in attn]) # shape of 12*12*num_of_tokens*num_of_tokens 12 layers 12 heads 96 | attn = make_attn_word_level(align_codebert_tokens(encodings, words, id = i), attn, len(encodings.tokens())) 97 | # example['attns'] = attn 98 | 99 | # preprocess attention by sorting predictions based upon weights 100 | attn[:, :, range(attn.shape[2]), range(attn.shape[2])] = 0 101 | max_attn = np.flip(np.argsort(attn, axis=3),axis=3).astype(np.int16)[:,:,:,:20] 102 | max_attn_data.append({"words": words, "max_attn": max_attn, "relns": example["relns"], "heads": example["heads"], "id": i}) 103 | 104 | 105 | 106 | 107 | # with open("data/depparse_english/dev_attn_sorted_roberta_base.pkl", "wb") as f: 108 | with open("data/depparse_german/dev_attn_sorted_xlmr_base.pkl", "wb") as f: 109 | pickle.dump(max_attn_data,f) 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /evaluating_models/NL/stanford-parser/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/NL/stanford-parser/.gitignore -------------------------------------------------------------------------------- /evaluating_models/NL/topk_scores_attention_NL.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | from collections import defaultdict 5 | 6 | 7 | 8 | def evaluate_single_head_topk_NL(max_attn_data, relns, attn_layer, attn_head, max_k = 20): 9 | scores = {} 10 | n_correct = {} 11 | n_incorrect = {} 12 | for reln in relns: 13 | scores[reln] = np.zeros((max_k+1), dtype = 'float') 14 | n_correct[reln] = np.zeros((max_k+1), dtype = 'int') 15 | n_incorrect[reln] = np.zeros((max_k+1), dtype = 'int') 16 | 17 | for example in max_attn_data: 18 | for dep_idx, reln, head_idx in zip(range(len(example["relns"])), example["relns"], example["heads"]): 19 | if reln in relns: 20 | for k in range(1, max_k+1): 21 | if k-1 < len(example["max_attn"][attn_layer][attn_head][dep_idx]): 22 | k_th_prediction = example["max_attn"][attn_layer][attn_head][dep_idx][k-1] 23 | if k_th_prediction == head_idx - 1: 24 | n_correct[reln][k:] = [c+1 for c in n_correct[reln][k:]] 25 | break 26 | else: 27 | n_incorrect[reln][k] += 1 28 | else: 29 | n_incorrect[reln][k:] += 1 30 | break 31 | 32 | for reln in relns: 33 | for k in range(1, max_k+1): 34 | if (n_correct[reln][k] + n_incorrect[reln][k]) == 0: 35 | scores[reln][k] = None 36 | else: 37 | scores[reln][k] = n_correct[reln][k] / float(n_correct[reln][k] + n_incorrect[reln][k]) 38 | return scores 39 | 40 | def get_relns_NL(dataset): 41 | relns = set() 42 | for example in dataset: 43 | for reln in example["relns"]: 44 | relns.add(reln) 45 | relns = list(relns) 46 | relns.sort() 47 | return relns 48 | 49 | # scores[reln][layer][head] 50 | def get_scores_NL(max_attn_data, relns, max_k=20): 51 | scores = {} 52 | n_correct = {} 53 | n_total = {} 54 | num_layer = max_attn_data[0]["max_attn"].shape[0] 55 | num_head = max_attn_data[0]["max_attn"].shape[1] 56 | for reln in relns: 57 | scores[reln] = np.zeros((num_layer, num_head, max_k+1), dtype = 'float') 58 | n_correct[reln] = np.zeros((num_layer, num_head, max_k+1), dtype = 'int') 59 | n_total[reln] = 0 60 | 61 | 62 | for i, example in enumerate(max_attn_data): 63 | if i % 1000 == 0: 64 | print("processing example", i) 65 | n_words = example["max_attn"].shape[3] 66 | for dep_idx, reln, head_idx in zip(range(len(example["relns"])), example["relns"], example["heads"]): 67 | n_total[reln] += 1 68 | for k in range(1, min(max_k+1, n_words+1)): 69 | k_th_prediction = example["max_attn"][:,:, dep_idx, k-1] 70 | n_correct[reln][:,:,k:][np.where(k_th_prediction == head_idx - 1)] += 1 71 | 72 | for reln in relns: 73 | if (n_total[reln]) == 0: 74 | scores[reln][:,:,:] = -100000 75 | else: 76 | for k in range(1, max_k+1): 77 | scores[reln][:,:,k] = n_correct[reln][:,:,k] / n_total[reln] 78 | 79 | return scores 80 | 81 | 82 | 83 | 84 | 85 | # average topk scores for each relationship and categories (word level) 86 | def get_avg_NL(scores, relns, max_k=20): 87 | reln_avg = [None]*(max_k+1) 88 | 89 | for k in range(1, (max_k+1)): 90 | sum, count = 0, 0 91 | for reln in relns: 92 | flatten_idx = np.argmax(scores[reln][:,:,k]) 93 | num_head = scores[reln].shape[1] 94 | # print(num_head) 95 | row = int(flatten_idx/num_head) 96 | col = flatten_idx % num_head 97 | sum += scores[reln][row][col][k] 98 | count += 1 99 | reln_avg[k] = sum/count 100 | return reln_avg 101 | 102 | def print_attn_table_NL(k, relns, scores): 103 | print("relationship\t\t accuracy\tlayer\thead") 104 | sum, count = 0, 0 105 | table = "" 106 | for reln in relns: 107 | flatten_idx = np.argmax(scores[reln][:,:,k]) 108 | num_head = scores[reln].shape[1] 109 | row = int(flatten_idx/num_head) 110 | col = flatten_idx % num_head 111 | table += reln.ljust(30) + str(round(scores[reln][row][col][k],3)).ljust(5) + "\t" + str(row) + "\t" + str(col) + '\n' 112 | sum += scores[reln][row][col][k] 113 | count += 1 114 | print(table) 115 | print("average of",count,"relations:", sum/count) 116 | 117 | 118 | 119 | # with open("data/depparse_english/dev_attn_sorted_bert_large.pkl", "rb") as f: 120 | # with open("data/depparse_german/dev_attn_sorted_xlmr_base.pkl", "rb") as f: 121 | with open("data/depparse_english/dev_attn_sorted_codebert.pkl", "rb") as f: 122 | max_attn_data = pickle.load(f) 123 | relns = get_relns_NL(max_attn_data) 124 | # relns = ["pobj"] 125 | print("relations", relns) 126 | 127 | 128 | 129 | scores = get_scores_NL(max_attn_data, relns, max_k=20) 130 | reln_avg = get_avg_NL(scores, relns, max_k=20) 131 | 132 | print_attn_table_NL(1, relns, scores) 133 | 134 | # with open("data/depparse_english/bert_large_topk_scores.pkl", "wb") as f: 135 | # with open("data/depparse_english/roberta_large_topk_scores.pkl", "wb") as f: 136 | # with open("data/depparse_german/german_xlmr_base_topk_scores.pkl", "wb") as f: 137 | with open("data/depparse_english/NL_codebert_topk_scores.pkl", "wb") as f: 138 | pickle.dump(scores, f) -------------------------------------------------------------------------------- /evaluating_models/NL/topk_scores_baselines_NL.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | from collections import defaultdict 5 | # find offset baseline top k score 6 | # similar to attenion heads, we consider each offset predictor that gives k predictions based upon fixed offset 7 | # e.g. an offset head can predict [index+1, index+6, index+2, index+12] for each token for each example 8 | # then we can calculate top k score for this offset predictor. 9 | 10 | import itertools 11 | 12 | 13 | 14 | def get_baseline_correctness_NL(max_attn_data, relns, min_offset=-10, max_offset=19): 15 | print("getting correctness for single baselines.") 16 | # Each row is a flatten array of 0/1 that represents correctness for all labels 17 | # correctness for offset i is stored in row i. 18 | correctness = {} 19 | num_predictors = max_offset - min_offset +1 20 | for reln in relns: 21 | correctness[reln] = [[] for i in range(num_predictors)] 22 | 23 | for index in range (min_offset, max_offset+1): 24 | for example in max_attn_data: 25 | for dep_idx, reln, head_idx in zip(range(len(example["relns"])), example["relns"], example["heads"]): 26 | prediction = dep_idx + index 27 | if prediction == head_idx-1: 28 | correctness[reln][index].append(1) 29 | else: 30 | correctness[reln][index].append(0) 31 | 32 | return correctness 33 | 34 | 35 | 36 | def get_relns_NL(dataset): 37 | relns = set() 38 | for example in dataset: 39 | for reln in example["relns"]: 40 | relns.add(reln) 41 | relns = list(relns) 42 | relns.sort() 43 | return relns 44 | 45 | 46 | 47 | def get_baseline_topk_scores_NL(correctness, relns, max_k=20, min_offset=-10, max_offset=19): 48 | # This function selects next best baseline by picking the one that gives highest increse in score 49 | print("getting top k scores for each relation") 50 | 51 | num_predictors = max_offset - min_offset +1 52 | reln_scores_topk = {} # reln -> list of top k scores (index is k) 53 | for reln in relns: 54 | reln_scores_topk[reln] = [(0, [0])]*(max_k+1) 55 | 56 | 57 | 58 | for reln in relns: 59 | reln_correctness = np.array(correctness[reln], dtype=bool) 60 | topk = np.zeros((reln_correctness.shape[1]), dtype=bool) 61 | combination=[] 62 | # selects next baseline by picking the one that gives highest increse in score 63 | for k in range (1, (max_k+1)): 64 | # calculate single baseline score 65 | single_baseline_scores = [-1]*num_predictors 66 | for i in range(min_offset, max_offset+1): 67 | # find the score for the relation 68 | offset_correctness = reln_correctness[i] & (~topk) # we only care about labels that we have not gotten correct 69 | score = np.count_nonzero(offset_correctness)/len(offset_correctness) 70 | single_baseline_scores[i] = score 71 | 72 | # sort baselines 73 | single_baseline_scores=np.array(single_baseline_scores) 74 | best_baseline=single_baseline_scores.argsort()[-1] 75 | best_baseline = -(num_predictors -best_baseline) if best_baseline > max_offset else best_baseline 76 | 77 | # add the best baseline to the combination 78 | combination.append(best_baseline) 79 | topk = topk | reln_correctness[best_baseline] 80 | score = np.count_nonzero(topk)/len(topk) 81 | reln_scores_topk[reln][k] = (score, combination.copy()) 82 | 83 | return reln_scores_topk 84 | 85 | def get_scores_topk(reln_scores_topk, max_k=20): 86 | scores_topk = [0]*(max_k+1) 87 | for k in range(1, max_k+1): 88 | sum, count = 0, 0 89 | for key, value in reln_scores_topk.items(): 90 | sum += value[k][0] 91 | count += 1 92 | scores_topk[k] = sum/count 93 | last = 0 94 | for i in range(0, max_k+1): 95 | if scores_topk[i] == 0: 96 | scores_topk[i] = last 97 | else: 98 | last = scores_topk[i] 99 | return scores_topk 100 | 101 | 102 | 103 | baseline_type="offset" 104 | max_offset=512 105 | min_offset=-512 106 | 107 | # offset 108 | with open("data/depparse_english/dev_attn_sorted_roberta_base.pkl", "rb") as f: 109 | # with open("data/depparse_german/dev_attn_sorted_bert.pkl", "rb") as f: 110 | max_attn_data = pickle.load(f) 111 | 112 | relns = get_relns_NL(max_attn_data) 113 | print(relns) 114 | 115 | correctness = get_baseline_correctness_NL(max_attn_data, relns, min_offset=min_offset, max_offset=max_offset) 116 | reln_scores_topk_offset = get_baseline_topk_scores_NL(correctness, relns, max_k=20, min_offset=min_offset, max_offset=max_offset) 117 | scores_topk_offset = get_scores_topk(reln_scores_topk_offset, max_k=20) 118 | print(scores_topk_offset) 119 | 120 | with open("data/depparse_english/roberta_base_topk_scores_"+baseline_type+".pkl", "wb") as f: 121 | # with open("data/depparse_german/german_topk_scores_"+baseline_type+".pkl", "wb") as f: 122 | pickle.dump(reln_scores_topk_offset, f) -------------------------------------------------------------------------------- /evaluating_models/NL/treebank_to_dependency.sh: -------------------------------------------------------------------------------- 1 | for file in ../data/eng_news_txt_tbnk-ptb_revised/data/penntree/**/*.tree; do 2 | echo ${file:55:8} 3 | java -mx1g edu.stanford.nlp.trees.EnglishGrammaticalStructure -basic -treeFile ${file} > ../data/wsj_dependency/${file:55:8}.sd 4 | done -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Kevin Clark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/README.md: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | # BERT Attention Analysis 4 | 5 | This repository contains code for [What Does BERT Look At? An Analysis of BERT's Attention](https://arxiv.org/abs/1906.04341). 6 | It includes code for getting attention maps from BERT and writing them to disk, analyzing BERT's attention in general (sections 3 and 6 of the paper), and comparing its attention to dependency syntax (sections 4.2 and 5). 7 | We will add the code for the coreference resolution analysis (section 4.3 of the paper) soon! 8 | 9 | ## Requirements 10 | For extracting attention maps from text: 11 | * [Tensorflow](https://www.tensorflow.org/) 12 | * [NumPy](http://www.numpy.org/) 13 | 14 | Additional requirements for the attention analysis: 15 | * [Jupyter](https://jupyter.org/https://jupyter.org/) 16 | * [MatplotLib](https://matplotlib.org/) 17 | * [seaborn](https://seaborn.pydata.org/index.html) 18 | * [scikit-learn](https://scikit-learn.org/) 19 | 20 | ## Attention Analysis 21 | `Syntax_Analysis.ipynb` and `General_Analysis.ipynb` 22 | contain code for analyzing BERT's attention, including reproducing the figures and tables in the paper. 23 | 24 | You can download the data needed to run the notebooks (including BERT attention maps on Wikipedia 25 | and the Penn Treebank) from [here](https://drive.google.com/open?id=1DEIBQIl0Q0az5ZuLoy4_lYabIfLSKBg-). However, note that the Penn Treebank annotations are not 26 | freely available, so the Penn Treebank data only includes dummy labels. 27 | If you want to run the analysis on your own data, you can use the scripts described below to extract BERT attention maps. 28 | 29 | ## Extracting BERT Attention Maps 30 | We provide a script for running BERT over text and writing the resulting 31 | attention maps to disk. 32 | The input data should be a [JSON](https://www.json.org/) file containing a 33 | list of dicts, each one corresponding to a single example to be passed in 34 | to BERT. Each dict must contain exactly one of the following fields: 35 | * `"text"`: A string. 36 | * `"words"`: A list of strings. Needed if you want word-level rather than 37 | token-level attention. 38 | * `"tokens"`: A list of strings corresponding to BERT wordpiece tokenization. 39 | 40 | If the present field is "tokens," the script expects [CLS]/[SEP] tokens 41 | to be already added; otherwise it adds these tokens to the 42 | beginning/end of the text automatically. 43 | Note that if an example is longer than `max_sequence_length` tokens 44 | after BERT wordpiece tokenization, attention maps will not be extracted for it. 45 | Attention extraction adds two additional fields to each dict: 46 | * `"attns"`: A numpy array of size [num_layers, heads_per_layer, sequence_length, 47 | sequence_length] containing attention weights. 48 | * `"tokens"`: If `"tokens"` was not already provided for the example, the 49 | BERT-wordpiece-tokenized text (list of strings). 50 | 51 | Other fields already in the feature dicts will be preserved. For example 52 | if each dict has a `tags` key containing POS tags, they will stay 53 | in the data after attention extraction so they can be used when 54 | analyzing the data. 55 | 56 | Attention extraction is run with 57 | ``` 58 | python extract_attention.py --preprocessed_data_file --bert_dir 59 | ``` 60 | The following optional arguments can also be added: 61 | * `--max_sequence_length`: Maximum input sequence length after tokenization (default is 128). 62 | * `--batch_size`: Batch size when running BERT over examples (default is 16). 63 | * `--debug`: Use a tiny BERT model for fast debugging. 64 | * `--cased`: Do not lowercase the input text. 65 | * `--word_level`: Compute word-level instead of token-level attention (see Section 4.1 of the paper). 66 | 67 | The feature dicts with added attention maps (numpy arrays with shape [n_layers, n_heads_per_layer, n_tokens, n_tokens]) are written to `_attn.pkl` 68 | 69 | 70 | ## Pre-processing Scripts 71 | We include two pre-processing scripts for going from a raw data file to 72 | JSON that can be supplied to ``attention_extractor.py``. 73 | 74 | `preprocess_unlabeled.py` does BERT-pre-training-style preprocessing for unlabeled text 75 | (i.e, taking two consecutive text spans, truncating them so they are at most 76 | `max_sequence_length` tokens, and adding [CLS]/[SEP] tokens). 77 | Each line of the input data file 78 | should be one sentence. Documents should be separated by empty lines. 79 | Example usage: 80 | ``` 81 | python preprocess_unlabeled.py --data-file $ATTN_DATA_DIR/unlabeled.txt --bert-dir $ATTN_DATA_DIR/uncased_L-12_H-768_A-12 82 | ``` 83 | will create the file `$ATTN_DATA_DIR/unlabeled.json` containing pre-processed data. 84 | After pre-processing, you can run `extract_attention.py` to get attention maps, e.g., 85 | ``` 86 | python extract_attention.py --preprocessed-data-file $ATTN_DATA_DIR/unlabeled.json --bert-dir $ATTN_DATA_DIR/uncased_L-12_H-768_A-12 87 | ``` 88 | 89 | 90 | `preprocess_depparse.py` pre-processes dependency parsing data. 91 | Dependency parsing data should consist of two files `train.txt` and `dev.txt` under a common directory. 92 | Each line in the files should contain a word followed by a space followed by - 93 | (e.g., 0-root). Examples should be separated by empty lines. Example usage: 94 | ``` 95 | python preprocess_depparse.py --data-dir $ATTN_DATA_DIR/depparse 96 | ``` 97 | 98 | After pre-processing, you can run `extract_attention.py` to get attention maps, e.g., 99 | ``` 100 | python extract_attention.py --preprocessed-data-file $ATTN_DATA_DIR/depparse/dev.json --bert-dir $ATTN_DATA_DIR/uncased_L-12_H-768_A-12 --word_level 101 | ``` 102 | ## Computing Distances Between Attention Heads 103 | `head_distances.py` computes the average Jenson-Shannon divergence between the attention weights of all pairs of attention heads and writes the results to disk as a numpy array of shape [n_heads, n_heads]. These distances can be used to cluster BERT's attention heads (see Section 6 and Figure 6 of the paper; code for doing this clustering is in `General_Analysis.ipynb`). Example usage (requires that attention maps have already been extracted): 104 | ``` 105 | python head_distances.py --attn-data-file $ATTN_DATA_DIR/unlabeled_attn.pkl --outfile $ATTN_DATA_DIR/head_distances.pkl 106 | ``` 107 | 108 | ## Citation 109 | If you find the code or data helpful, please cite the original paper: 110 | 111 | ``` 112 | @inproceedings{clark2019what, 113 | title = {What Does BERT Look At? An Analysis of BERT's Attention}, 114 | author = {Kevin Clark and Urvashi Khandelwal and Omer Levy and Christopher D. Manning}, 115 | booktitle = {BlackBoxNLP@ACL}, 116 | year = {2019} 117 | } 118 | ``` 119 | 120 | ## Contact 121 | [Kevin Clark](https://cs.stanford.edu/~kevclark/) (@clarkkev). 122 | -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/bert/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/bpe_utils.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | """Going from BERT's bpe tokenization to word-level tokenization.""" 4 | 5 | import utils 6 | from bert import tokenization 7 | 8 | import numpy as np 9 | 10 | 11 | def tokenize_and_align(tokenizer, words, cased): 12 | """Given already-tokenized text (as a list of strings), returns a list of 13 | lists where each sub-list contains BERT-tokenized tokens for the 14 | correponding word.""" 15 | 16 | words = ["[CLS]"] + words + ["[SEP]"] 17 | basic_tokenizer = tokenizer.basic_tokenizer 18 | tokenized_words = [] 19 | for word in words: 20 | word = tokenization.convert_to_unicode(word) 21 | word = basic_tokenizer._clean_text(word) 22 | if word == "[CLS]" or word == "[SEP]": 23 | word_toks = [word] 24 | else: 25 | if not cased: 26 | word = word.lower() 27 | word = basic_tokenizer._run_strip_accents(word) 28 | word_toks = basic_tokenizer._run_split_on_punc(word) 29 | 30 | tokenized_word = [] 31 | for word_tok in word_toks: 32 | tokenized_word += tokenizer.wordpiece_tokenizer.tokenize(word_tok) 33 | tokenized_words.append(tokenized_word) 34 | 35 | i = 0 36 | word_to_tokens = [] 37 | for word in tokenized_words: 38 | tokens = [] 39 | for _ in word: 40 | tokens.append(i) 41 | i += 1 42 | word_to_tokens.append(tokens) 43 | assert len(word_to_tokens) == len(words) 44 | 45 | return word_to_tokens 46 | 47 | 48 | def get_word_word_attention(token_token_attention, words_to_tokens, length, 49 | mode="mean"): 50 | """Convert token-token attention to word-word attention (when tokens are 51 | derived from words using something like byte-pair encodings).""" 52 | 53 | word_starts = set() 54 | for word in words_to_tokens: 55 | if len(word) > 0: 56 | word_starts.add(word[0]) 57 | not_word_starts = [i for i in range(length) if i not in word_starts] 58 | 59 | # find python tokens that are mapped to no cubert tokens 60 | not_mapped_tokens = [idx for idx, l in enumerate(words_to_tokens) if l ==[]] 61 | 62 | # sum up the attentions for all tokens in a word that has been split 63 | for word in words_to_tokens: 64 | if len(word) > 0: 65 | token_token_attention[:, word[0]] = token_token_attention[:, word].sum(axis=-1) 66 | token_token_attention = np.delete(token_token_attention, not_word_starts, -1) 67 | # do not delete python token that is not mapped to cubert tokens 68 | for idx in not_mapped_tokens: 69 | token_token_attention = np.insert(token_token_attention, idx, 0, axis=-1) 70 | 71 | # several options for combining attention maps for words that have been split 72 | # we use "mean" in the paper 73 | for word in words_to_tokens: 74 | if len(word) > 0: 75 | if mode == "first": 76 | pass 77 | elif mode == "mean": 78 | token_token_attention[word[0]] = np.mean(token_token_attention[word], axis=0) 79 | elif mode == "max": 80 | token_token_attention[word[0]] = np.max(token_token_attention[word], axis=0) 81 | token_token_attention[word[0]] /= token_token_attention[word[0]].sum() 82 | else: 83 | raise ValueError("Unknown aggregation mode", mode) 84 | token_token_attention = np.delete(token_token_attention, not_word_starts, 0) 85 | # do not delete python token that is not mapped to cubert tokens 86 | for idx in not_mapped_tokens: 87 | token_token_attention = np.insert(token_token_attention, idx, 0, axis=0) 88 | 89 | return token_token_attention 90 | 91 | 92 | def make_attn_word_level(data): 93 | for features in utils.logged_loop(data): 94 | words_to_tokens = features["alignment"] 95 | length = len(features["tokens"]) 96 | features["attns"] = np.stack([[ 97 | get_word_word_attention(attn_head, words_to_tokens, length) 98 | for attn_head in layer_attns] for layer_attns in features["attns"]]) 99 | 100 | 101 | 102 | def make_attn_block_level(data, ast_dataset, k=3, mode="mean"): 103 | to_remove = [] 104 | for features in utils.logged_loop(data): 105 | print(features["id"]) 106 | blocks_to_words = ast_dataset[features["id"]]["blocks"] 107 | if len(blocks_to_words) == 0: 108 | to_remove.append(features) 109 | else: 110 | features["attns"] = np.stack([[ 111 | get_block_block_attention(attn_head, blocks_to_words, k=k, mode=mode) 112 | for attn_head in layer_attns] for layer_attns in features["attns"]]) 113 | for item in to_remove: 114 | data.remove(item) 115 | 116 | # find top k args along axis 1 117 | # if not enough values exist in each row, then return original array 118 | def argtopk(A, k): 119 | if k >= A.shape[1]: 120 | return A 121 | else: 122 | top_k = np.argpartition(A, -k)[:, -k:] 123 | x = A.shape[0] 124 | return A[np.repeat(np.arange(x), k), top_k.ravel()].reshape(x, k) 125 | 126 | def get_block_block_attention(word_word_attention, blocks_to_words, k, 127 | mode="mean"): 128 | """Convert token-token attention to word-word attention (when tokens are 129 | derived from words using something like byte-pair encodings).""" 130 | 131 | if type(word_word_attention) != np.ndarray: 132 | word_word_attention = np.array(word_word_attention) 133 | block_block_attention = np.zeros([word_word_attention.shape[0], len(blocks_to_words)], dtype=np.float16) 134 | 135 | 136 | # average the attentions for all words in a block 137 | if mode == "topk": 138 | for i, block in enumerate(blocks_to_words): 139 | block_block_attention[:, i] = np.mean(argtopk(word_word_attention[:, block[0]:(block[1]+1)], k), axis=-1) 140 | elif mode == "mean": 141 | for i, block in enumerate(blocks_to_words): 142 | block_block_attention[:, i] = np.mean(word_word_attention[:, block[0]:(block[1]+1)], axis=-1) 143 | for i in range(word_word_attention.shape[0]): 144 | sum = block_block_attention[i].sum() 145 | if sum != 0: 146 | block_block_attention[i] /= block_block_attention[i].sum() 147 | 148 | # several options for combining attention maps for words that have been split 149 | # we use "mean" in the paper 150 | word_word_attention = block_block_attention 151 | block_block_attention = np.zeros([len(blocks_to_words), len(blocks_to_words)], dtype=np.float16) 152 | for i, block in enumerate(blocks_to_words): 153 | if mode == "mean" or mode == "topk": 154 | block_block_attention[i] = np.mean(word_word_attention[block[0]:(block[1]+1)], axis=0) 155 | sum = block_block_attention[i].sum() 156 | if sum != 0: 157 | block_block_attention[i] /= block_block_attention[i].sum() 158 | elif mode == "max": 159 | block_block_attention[i] = np.max(block_block_attention[block[0]:(block[1]+1)], axis=0) 160 | sum = block_block_attention[i].sum() 161 | if sum != 0: 162 | block_block_attention[i] /= block_block_attention[i].sum() 163 | else: 164 | raise ValueError("Unknown aggregation mode", mode) 165 | 166 | 167 | return block_block_attention -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/extract_attention.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | """Runs BERT over input data and writes out its attention maps to disk.""" 4 | 5 | import argparse 6 | import os 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from bert import modeling 11 | from bert import tokenization 12 | import bpe_utils 13 | import utils 14 | 15 | 16 | class Example(object): 17 | """Represents a single input sequence to be passed into BERT.""" 18 | 19 | def __init__(self, features, tokenizer, max_sequence_length,): 20 | self.features = features 21 | 22 | if "tokens" in features: 23 | self.tokens = features["tokens"] 24 | else: 25 | if "text" in features: 26 | text = features["text"] 27 | else: 28 | text = " ".join(features["words"]) 29 | self.tokens = ["[CLS]"] + tokenizer.tokenize(text) + ["[SEP]"] 30 | 31 | self.input_ids = tokenizer.convert_tokens_to_ids(self.tokens) 32 | self.segment_ids = [0] * len(self.tokens) 33 | self.input_mask = [1] * len(self.tokens) 34 | while len(self.input_ids) < max_sequence_length: 35 | self.input_ids.append(0) 36 | self.input_mask.append(0) 37 | self.segment_ids.append(0) 38 | 39 | 40 | def examples_in_batches(examples, batch_size): 41 | for i in utils.logged_loop(range(1 + ((len(examples) - 1) // batch_size))): 42 | yield examples[i * batch_size:(i + 1) * batch_size] 43 | 44 | 45 | class AttnMapExtractor(object): 46 | """Runs BERT over examples to get its attention maps.""" 47 | 48 | def __init__(self, bert_config_file, init_checkpoint, 49 | max_sequence_length=128, debug=False): 50 | make_placeholder = lambda name: tf.placeholder( 51 | tf.int32, shape=[None, max_sequence_length], name=name) 52 | self._input_ids = make_placeholder("input_ids") 53 | self._segment_ids = make_placeholder("segment_ids") 54 | self._input_mask = make_placeholder("input_mask") 55 | 56 | bert_config = modeling.BertConfig.from_json_file(bert_config_file) 57 | if debug: 58 | bert_config.num_hidden_layers = 3 59 | bert_config.hidden_size = 144 60 | self._attn_maps = modeling.BertModel( 61 | config=bert_config, 62 | is_training=False, 63 | input_ids=self._input_ids, 64 | input_mask=self._input_mask, 65 | token_type_ids=self._segment_ids, 66 | use_one_hot_embeddings=True).attn_maps 67 | 68 | if not debug: 69 | print("Loading BERT from checkpoint...") 70 | assignment_map, _ = modeling.get_assignment_map_from_checkpoint( 71 | tf.trainable_variables(), init_checkpoint) 72 | tf.train.init_from_checkpoint(init_checkpoint, assignment_map) 73 | 74 | def get_attn_maps(self, sess, examples): 75 | feed = { 76 | self._input_ids: np.vstack([e.input_ids for e in examples]), 77 | self._segment_ids: np.vstack([e.segment_ids for e in examples]), 78 | self._input_mask: np.vstack([e.input_mask for e in examples]) 79 | } 80 | return sess.run(self._attn_maps, feed_dict=feed) 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser(description=__doc__) 85 | parser.add_argument( 86 | "--preprocessed-data-file", required=True, 87 | help="Location of preprocessed data (JSON file); see the README for " 88 | "expected data format.") 89 | parser.add_argument("--bert-dir", required=True, 90 | help="Location of the pre-trained BERT model.") 91 | parser.add_argument("--cased", default=False, action='store_true', 92 | help="Don't lowercase the input.") 93 | parser.add_argument("--max_sequence_length", default=128, type=int, 94 | help="Maximum input sequence length after tokenization " 95 | "(default=128).") 96 | parser.add_argument("--batch_size", default=16, type=int, 97 | help="Batch size when running BERT (default=16).") 98 | parser.add_argument("--debug", default=False, action='store_true', 99 | help="Use tiny model for fast debugging.") 100 | parser.add_argument("--word_level", default=False, action='store_true', 101 | help="Get word-level rather than token-level attention.") 102 | args = parser.parse_args() 103 | 104 | print("Creating examples...") 105 | tokenizer = tokenization.FullTokenizer( 106 | vocab_file=os.path.join(args.bert_dir, "vocab.txt"), 107 | do_lower_case=not args.cased) 108 | examples = [] 109 | for features in utils.load_json(args.preprocessed_data_file): 110 | example = Example(features, tokenizer, args.max_sequence_length) 111 | if len(example.input_ids) <= args.max_sequence_length: 112 | examples.append(example) 113 | 114 | print("Building BERT model...") 115 | extractor = AttnMapExtractor( 116 | os.path.join(args.bert_dir, "bert_config.json"), 117 | os.path.join(args.bert_dir, "bert_model.ckpt"), 118 | args.max_sequence_length, args.debug 119 | ) 120 | 121 | print("Extracting attention maps...") 122 | feature_dicts_with_attn = [] 123 | with tf.Session() as sess: 124 | sess.run(tf.global_variables_initializer()) 125 | for batch_of_examples in examples_in_batches(examples, args.batch_size): 126 | attns = extractor.get_attn_maps(sess, batch_of_examples) 127 | for e, e_attn in zip(batch_of_examples, attns): 128 | seq_len = len(e.tokens) 129 | e.features["attns"] = e_attn[:, :, :seq_len, :seq_len].astype("float16") 130 | e.features["tokens"] = e.tokens 131 | feature_dicts_with_attn.append(e.features) 132 | 133 | if args.word_level: 134 | print("Converting to word-level attention...") 135 | bpe_utils.make_attn_word_level(feature_dicts_with_attn) 136 | 137 | outpath = args.preprocessed_data_file.replace(".json", "") 138 | outpath += "_attn.pkl" 139 | print("Writing attention maps to {:}...".format(outpath)) 140 | utils.write_pickle(feature_dicts_with_attn, outpath) 141 | print("Done!") 142 | 143 | 144 | if __name__ == "__main__": 145 | main() -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/preprocess_depparse.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | """Preprocesses dependency parsing data and writes the result as JSON.""" 4 | 5 | import argparse 6 | import os 7 | 8 | import utils 9 | 10 | 11 | def preprocess_depparse_data(raw_data_file): 12 | examples = [] 13 | with open(raw_data_file, encoding='utf-8') as f: 14 | current_example = {"words": [], "relns": [], "heads": []} 15 | for line in f: 16 | line = line.strip() 17 | if line: 18 | word, label = line.split() 19 | head, reln = label.split("-") 20 | head = int(head) 21 | current_example["words"].append(word) 22 | current_example["relns"].append(reln) 23 | current_example["heads"].append(head) 24 | else: 25 | examples.append(current_example) 26 | current_example = {"words": [], "relns": [], "heads": []} 27 | utils.write_json(examples, raw_data_file.replace(".txt", ".json")) 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser(description=__doc__) 32 | parser.add_argument( 33 | "--data-dir", required=True, 34 | help="The location of dependency parsing data. Should contain files " 35 | "train.txt and dev.txt. See the README for expected data format.") 36 | args = parser.parse_args() 37 | for split in ["train", "dev"]: 38 | print("Preprocessing {:} data...".format(split)) 39 | preprocess_depparse_data(os.path.join(args.data_dir, split + ".txt")) 40 | print("Done!") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /evaluating_models/PL/attention-analysis/utils.py: -------------------------------------------------------------------------------- 1 | # This file is obtained from https://github.com/clarkkev/attention-analysis 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import json 8 | import pickle 9 | import time 10 | 11 | import tensorflow as tf 12 | 13 | 14 | def load_json(path): 15 | with tf.gfile.GFile(path, 'r') as f: 16 | return json.load(f) 17 | 18 | 19 | def write_json(o, path): 20 | tf.gfile.MakeDirs(path.rsplit('/', 1)[0]) 21 | with tf.gfile.GFile(path, 'w') as f: 22 | json.dump(o, f) 23 | 24 | 25 | def load_pickle(path): 26 | with tf.gfile.GFile(path, 'rb') as f: 27 | return pickle.load(f) 28 | 29 | 30 | def write_pickle(o, path): 31 | if '/' in path: 32 | tf.gfile.MakeDirs(path.rsplit('/', 1)[0]) 33 | with tf.gfile.GFile(path, 'wb') as f: 34 | pickle.dump(o, f, -1) 35 | 36 | 37 | def logged_loop(iterable, n=None, **kwargs): 38 | if n is None: 39 | n = len(iterable) 40 | ll = LoopLogger(n, **kwargs) 41 | for i, elem in enumerate(iterable): 42 | ll.update(i + 1) 43 | yield elem 44 | 45 | 46 | class LoopLogger(object): 47 | """Class for printing out progress/ETA for a loop.""" 48 | 49 | def __init__(self, max_value=None, step_size=1, n_steps=25, print_time=True): 50 | self.max_value = max_value 51 | if n_steps is not None: 52 | self.step_size = max(1, max_value // n_steps) 53 | else: 54 | self.step_size = step_size 55 | self.print_time = print_time 56 | self.n = 0 57 | self.start_time = time.time() 58 | 59 | def step(self, values=None): 60 | self.update(self.n + 1, values) 61 | 62 | def update(self, i, values=None): 63 | self.n = i 64 | if self.n % self.step_size == 0 or self.n == self.max_value: 65 | if self.max_value is None: 66 | msg = 'On item ' + str(self.n) 67 | else: 68 | msg = '{:}/{:} = {:.1f}%'.format(self.n, self.max_value, 69 | 100.0 * self.n / self.max_value) 70 | if self.print_time: 71 | time_elapsed = time.time() - self.start_time 72 | time_per_step = time_elapsed / self.n 73 | msg += ', ELAPSED: {:.1f}s'.format(time_elapsed) 74 | msg += ', ETA: {:.1f}s'.format((self.max_value - self.n) 75 | * time_per_step) 76 | if values is not None: 77 | for k, v in values: 78 | msg += ' - ' + str(k) + ': ' + ('{:.4f}'.format(v) 79 | if isinstance(v, float) else str(v)) 80 | print(msg) 81 | -------------------------------------------------------------------------------- /evaluating_models/PL/data/CodeBERT_tokenized/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/PL/data/CodeBERT_tokenized/.gitignore -------------------------------------------------------------------------------- /evaluating_models/PL/data/CuBERT_tokenized/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/PL/data/CuBERT_tokenized/.gitignore -------------------------------------------------------------------------------- /evaluating_models/PL/data/attention/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/PL/data/attention/.gitignore -------------------------------------------------------------------------------- /evaluating_models/PL/data/cubert_model_java/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "type_vocab_size": 2, 13 | "vocab_size": 50032, 14 | "max_position_embeddings": 512 15 | } -------------------------------------------------------------------------------- /evaluating_models/PL/data/cubert_model_java/readme.txt: -------------------------------------------------------------------------------- 1 | Save pre-trained CuBERT model in this folder. -------------------------------------------------------------------------------- /evaluating_models/PL/data/cubert_model_python/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "type_vocab_size": 2, 13 | "vocab_size": 49988, 14 | "max_position_embeddings": 512 15 | } -------------------------------------------------------------------------------- /evaluating_models/PL/data/cubert_model_python/readme.txt: -------------------------------------------------------------------------------- 1 | Save pre-trained CuBERT model in this folder. -------------------------------------------------------------------------------- /evaluating_models/PL/data/scores/scores.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/evaluating_models/PL/data/scores/scores.zip -------------------------------------------------------------------------------- /evaluating_models/PL/preprocess_attn_java.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | 5 | 6 | # preprocess attention 7 | # find the token that is predicted to be head for each single attention head 8 | def load_pickle(fname): 9 | with open(fname, "rb") as f: 10 | return pickle.load(f) 11 | 12 | 13 | for i in range(0, 14): 14 | max_attn_data = [] 15 | print("processing java_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn.pkl") 16 | attn_data = load_pickle("data/CuBERT_tokenized/java_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn.pkl") 17 | for data in attn_data: 18 | # if data["id"] in common_ids: 19 | # cls and sep are already removed from word-level attention 20 | attn = data["attns"] 21 | attn[:, :, range(attn.shape[2]), range(attn.shape[2])] = 0 22 | max_attn = np.flip(np.argsort(attn, axis=3),axis=3).astype(np.int16)[:,:,:,:20] 23 | max_attn_data.append({"tokens": data["tokens"], "max_attn": max_attn, "id": data["id"]}) 24 | 25 | with open("data/attention/cubert_java_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn_sorted.pkl", "wb") as f: 26 | pickle.dump(max_attn_data,f) 27 | 28 | del max_attn_data 29 | del attn_data 30 | 31 | 32 | # merge into one file 33 | max_attn_data = [] 34 | for i in range(0, 14): 35 | print("loading java_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn_sorted.pkl") 36 | with open("data/attention/cubert_java_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn_sorted.pkl", "rb") as f: 37 | max_attn_data.extend(pickle.load(f)) 38 | 39 | with open("data/attention/cubert_java_full_attn_sorted_valid.pkl", "wb") as f: 40 | pickle.dump([c for c in max_attn_data if c["id"]>=11009], f) 41 | 42 | with open("data/attention/cubert_java_full_attn_sorted_test.pkl", "wb") as f: 43 | pickle.dump([c for c in max_attn_data if c["id"]<11009], f) -------------------------------------------------------------------------------- /evaluating_models/PL/preprocess_attn_python.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | 5 | 6 | # preprocess attention 7 | # find the token that is predicted to be head for each single attention head 8 | def load_pickle(fname): 9 | with open(fname, "rb") as f: 10 | return pickle.load(f) 11 | 12 | 13 | for i in range(0, 19): 14 | max_attn_data = [] 15 | print("processing python_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn.pkl") 16 | attn_data = load_pickle("data/CuBERT_tokenized/python_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn.pkl") 17 | for data in attn_data: 18 | # if data["id"] in common_ids: 19 | # cls and sep are already removed from word-level attention 20 | attn = data["attns"] 21 | attn[:, :, range(attn.shape[2]), range(attn.shape[2])] = 0 22 | max_attn = np.flip(np.argsort(attn, axis=3),axis=3).astype(np.int16)[:,:,:,:20] 23 | max_attn_data.append({"tokens": data["tokens"], "max_attn": max_attn, "id": data["id"]}) 24 | 25 | with open("data/attention/cubert_python_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn_sorted.pkl", "wb") as f: 26 | pickle.dump(max_attn_data,f) 27 | 28 | del max_attn_data 29 | del attn_data 30 | 31 | 32 | # merge into one file 33 | max_attn_data = [] 34 | for i in range(0, 19): 35 | print("loading python_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn_sorted.pkl") 36 | with open("data/attention/cubert_python_" + str(i*1000)+"_" + str((i+1)*1000) + "_attn_sorted.pkl", "rb") as f: 37 | max_attn_data.extend(pickle.load(f)) 38 | 39 | with open("data/attention/cubert_python_full_attn_sorted_valid.pkl", "wb") as f: 40 | pickle.dump([c for c in max_attn_data if c["id"]>=8680], f) 41 | 42 | with open("data/attention/cubert_python_full_attn_sorted_test.pkl", "wb") as f: 43 | pickle.dump([c for c in max_attn_data if c["id"]<8680], f) -------------------------------------------------------------------------------- /evaluating_models/PL/remove_uncommon_datapoints.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | language="java" 4 | 5 | 6 | print("loading") 7 | 8 | with open("data/attention/codebert_"+language+"_full_attn_sorted_test.pkl", "rb") as f: 9 | codebert_test = pickle.load(f) 10 | with open("data/attention/codebert_"+language+"_full_attn_sorted_valid.pkl", "rb") as f: 11 | codebert_valid = pickle.load(f) 12 | 13 | codebert_ids = set() 14 | cubert_ids = set() 15 | for dataset in [codebert_test, codebert_valid]: 16 | for data in dataset: 17 | codebert_ids.add(data["id"]) 18 | 19 | with open("data/attention/"+language+"_common_ids.pkl", "wb") as f: 20 | pickle.dump(codebert_ids, f) 21 | 22 | with open("data/attention/cubert_"+language+"_full_attn_sorted_test.pkl", "rb") as f: 23 | cubert_test = pickle.load(f) 24 | with open("data/attention/cubert_"+language+"_full_attn_sorted_valid.pkl", "rb") as f: 25 | cubert_valid = pickle.load(f) 26 | 27 | 28 | 29 | for dataset in [cubert_test, cubert_valid]: 30 | for data in dataset: 31 | cubert_ids.add(data["id"]) 32 | 33 | uncommon_ids = codebert_ids.symmetric_difference(cubert_ids) 34 | common_ids = codebert_ids.intersection(cubert_ids) 35 | print("we have", len(uncommon_ids), "uncommon datapoints") 36 | 37 | for filename, all_data in [("codebert_"+language+"_full_attn_sorted_test", codebert_test), 38 | ("codebert_"+language+"_full_attn_sorted_valid", codebert_valid), 39 | ("cubert_"+language+"_full_attn_sorted_test", cubert_test), 40 | ("cubert_"+language+"_full_attn_sorted_valid", cubert_valid)]: 41 | common_data = [] 42 | for data in all_data: 43 | if data["id"] in common_ids: 44 | common_data.append(data) 45 | print(filename, "has", len(common_data), "datapoints") 46 | with open("data/attention/"+filename+"_common.pkl", "wb") as f: 47 | pickle.dump(common_data, f) -------------------------------------------------------------------------------- /evaluating_models/PL/run_exp_codebert_java.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaTokenizerFast 6 | 7 | 8 | 9 | # extract attention and convert attention to word-level 10 | 11 | def get_word_word_attention(token_token_attention, words_to_tokens): 12 | """Convert token-token attention to word-word attention (when tokens are 13 | derived from words using something like byte-pair encodings).""" 14 | 15 | codebert_tokens_length = token_token_attention.shape[0] 16 | word_starts = set() 17 | not_mapped_tokens = [] # store the python tokens that no codebert token is mapped to 18 | 19 | # special case: sometimes two python tokens are combined as one in codebert tokens 20 | # e.g. "):" 21 | # when we convert to word-level attention, we need to make a copy of the row becase 22 | # ")" and ":" are two different python tokens 23 | # we store such words in conflicting_heads 24 | conflicting_heads = {} 25 | for i in range(len(words_to_tokens)): 26 | word = words_to_tokens[i] 27 | if len(word) > 0: 28 | word_starts.add(word[0]) 29 | if i < len(words_to_tokens)-1 and len(words_to_tokens[i+1]) > 0 and word[0] == words_to_tokens[i+1][0]: 30 | conflicting_heads[i] = None 31 | not_mapped_tokens.append(i) 32 | 33 | not_word_starts = [i for i in range(codebert_tokens_length) if i not in word_starts] 34 | 35 | # find python tokens that are mapped to no cubert tokens 36 | not_mapped_tokens.extend([idx for idx, l in enumerate(words_to_tokens) if l ==[]]) 37 | not_mapped_tokens = sorted(not_mapped_tokens) 38 | 39 | # sum up the attentions for all tokens in a word that has been split 40 | for i, word in enumerate(words_to_tokens): 41 | if len(word) > 0: 42 | if i in conflicting_heads: 43 | conflicting_heads[i] = token_token_attention[:, word].sum(axis=-1) 44 | else: 45 | token_token_attention[:, word[0]] = token_token_attention[:, word].sum(axis=-1) 46 | token_token_attention = np.delete(token_token_attention, not_word_starts, -1) 47 | # do not delete python token that is not mapped to cubert tokens 48 | for idx in not_mapped_tokens: 49 | token_token_attention = np.insert(token_token_attention, idx, 0, axis=-1) 50 | # resolve the special case that two python tokens are combined as one in codebert tokens 51 | for i in conflicting_heads.keys(): 52 | token_token_attention[:, i] = conflicting_heads[i] 53 | 54 | # combining attention maps for words that have been split 55 | for i, word in enumerate(words_to_tokens): 56 | if len(word) > 0: 57 | if i in conflicting_heads: 58 | conflicting_heads[i] = np.mean(token_token_attention[word], axis=0) 59 | else: 60 | # mean 61 | token_token_attention[word[0]] = np.mean(token_token_attention[word], axis=0) 62 | # # max 63 | # token_token_attention[word[0]] = np.max(token_token_attention[word], axis=0) 64 | # token_token_attention[word[0]] /= token_token_attention[word[0]].sum() 65 | 66 | token_token_attention = np.delete(token_token_attention, not_word_starts, 0) 67 | # do not delete python token that is not mapped to cubert tokens 68 | for idx in not_mapped_tokens: 69 | token_token_attention = np.insert(token_token_attention, idx, 0, axis=0) 70 | # resolve the special case that two python tokens are combined as one in codebert tokens 71 | for i in conflicting_heads.keys(): 72 | token_token_attention[i] = conflicting_heads[i] 73 | 74 | return token_token_attention 75 | 76 | def make_attn_word_level(alignment, attn): 77 | return np.stack([[ 78 | get_word_word_attention(attn_head, alignment) 79 | for attn_head in layer_attns] for layer_attns in attn]) 80 | 81 | 82 | 83 | # run CodeBERT to extract attention and convert attention to word-level 84 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 85 | tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") 86 | model = RobertaModel.from_pretrained("microsoft/codebert-base") 87 | model.to(device) 88 | model.eval() 89 | 90 | with open("data/CodeBERT_tokenized/java_full.json", 'r') as f: 91 | examples = json.load(f) 92 | 93 | 94 | max_attn_data = [] 95 | for example in examples: 96 | if example['id']%100 == 0: 97 | print(example['id']) 98 | # if example["id"] in common_ids: 99 | if len(example['input_ids']) < 512: 100 | outputs = model(torch.tensor(example['input_ids']).unsqueeze(0), output_attentions=True) 101 | attn = outputs.attentions # list of tensors of shape 1*12*num_of_tokens*num_of_tokens 102 | attn = np.vstack([layer.cpu().detach().numpy() for layer in attn]) # shape of 12*12*num_of_tokens*num_of_tokens 12 layers 12 heads 103 | attn = make_attn_word_level(example["alignment"], attn) 104 | # example['attns'] = attn 105 | 106 | # preprocess attention by sorting predictions based upon weights 107 | attn[:, :, range(attn.shape[2]), range(attn.shape[2])] = 0 108 | max_attn = np.flip(np.argsort(attn, axis=3),axis=3).astype(np.int16)[:,:,:,:20] 109 | max_attn_data.append({"java_tokens": example["java_tokens"], "tokens": example["tokens"], 110 | "max_attn": max_attn, "id": example["id"]}) 111 | 112 | 113 | 114 | 115 | with open("data/attention/codebert_java_full_attn_sorted.pkl", "wb") as f: 116 | pickle.dump(max_attn_data,f) 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /evaluating_models/PL/run_exp_codebert_python.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaTokenizerFast 6 | 7 | 8 | 9 | # extract attention and convert attention to word-level 10 | 11 | def get_word_word_attention(token_token_attention, words_to_tokens): 12 | """Convert token-token attention to word-word attention (when tokens are 13 | derived from words using something like byte-pair encodings).""" 14 | 15 | codebert_tokens_length = token_token_attention.shape[0] 16 | word_starts = set() 17 | not_mapped_tokens = [] # store the python tokens that no codebert token is mapped to 18 | 19 | # special case: sometimes two python tokens are combined as one in codebert tokens 20 | # e.g. "):" 21 | # when we convert to word-level attention, we need to make a copy of the row becase 22 | # ")" and ":" are two different python tokens 23 | # we store such words in conflicting_heads 24 | conflicting_heads = {} 25 | for i in range(len(words_to_tokens)): 26 | word = words_to_tokens[i] 27 | if len(word) > 0: 28 | word_starts.add(word[0]) 29 | if i < len(words_to_tokens)-1 and len(words_to_tokens[i+1]) > 0 and word[0] == words_to_tokens[i+1][0]: 30 | conflicting_heads[i] = None 31 | not_mapped_tokens.append(i) 32 | 33 | not_word_starts = [i for i in range(codebert_tokens_length) if i not in word_starts] 34 | 35 | # find python tokens that are mapped to no cubert tokens 36 | not_mapped_tokens.extend([idx for idx, l in enumerate(words_to_tokens) if l ==[]]) 37 | not_mapped_tokens = sorted(not_mapped_tokens) 38 | 39 | # sum up the attentions for all tokens in a word that has been split 40 | for i, word in enumerate(words_to_tokens): 41 | if len(word) > 0: 42 | if i in conflicting_heads: 43 | conflicting_heads[i] = token_token_attention[:, word].sum(axis=-1) 44 | else: 45 | token_token_attention[:, word[0]] = token_token_attention[:, word].sum(axis=-1) 46 | token_token_attention = np.delete(token_token_attention, not_word_starts, -1) 47 | # do not delete python token that is not mapped to cubert tokens 48 | for idx in not_mapped_tokens: 49 | token_token_attention = np.insert(token_token_attention, idx, 0, axis=-1) 50 | # resolve the special case that two python tokens are combined as one in codebert tokens 51 | for i in conflicting_heads.keys(): 52 | token_token_attention[:, i] = conflicting_heads[i] 53 | 54 | # combining attention maps for words that have been split 55 | for i, word in enumerate(words_to_tokens): 56 | if len(word) > 0: 57 | if i in conflicting_heads: 58 | conflicting_heads[i] = np.mean(token_token_attention[word], axis=0) 59 | else: 60 | # mean 61 | token_token_attention[word[0]] = np.mean(token_token_attention[word], axis=0) 62 | # # max 63 | # token_token_attention[word[0]] = np.max(token_token_attention[word], axis=0) 64 | # token_token_attention[word[0]] /= token_token_attention[word[0]].sum() 65 | 66 | token_token_attention = np.delete(token_token_attention, not_word_starts, 0) 67 | # do not delete python token that is not mapped to cubert tokens 68 | for idx in not_mapped_tokens: 69 | token_token_attention = np.insert(token_token_attention, idx, 0, axis=0) 70 | # resolve the special case that two python tokens are combined as one in codebert tokens 71 | for i in conflicting_heads.keys(): 72 | token_token_attention[i] = conflicting_heads[i] 73 | 74 | return token_token_attention 75 | 76 | def make_attn_word_level(alignment, attn): 77 | return np.stack([[ 78 | get_word_word_attention(attn_head, alignment) 79 | for attn_head in layer_attns] for layer_attns in attn]) 80 | 81 | 82 | 83 | # run CodeBERT to extract attention and convert attention to word-level 84 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 85 | tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") 86 | model = RobertaModel.from_pretrained("microsoft/codebert-base") 87 | model.to(device) 88 | model.eval() 89 | 90 | 91 | with open("data/CodeBERT_tokenized/python_full.json", 'r') as f: 92 | examples = json.load(f) 93 | 94 | max_attn_data = [] 95 | for example in examples: 96 | if example['id']%100 == 0: 97 | print(example['id']) 98 | if len(example['input_ids']) < 512: 99 | outputs = model(torch.tensor(example['input_ids']).unsqueeze(0), output_attentions=True) 100 | attn = outputs.attentions # list of tensors of shape 1*12*num_of_tokens*num_of_tokens 101 | attn = np.vstack([layer.cpu().detach().numpy() for layer in attn]) # shape of 12*12*num_of_tokens*num_of_tokens 12 layers 12 heads 102 | attn = make_attn_word_level(example["alignment"], attn) 103 | # head_view(attn, example["python_tokens"]) 104 | # example['attns'] = attn 105 | 106 | # preprocess attention by sorting predictions based upon weights 107 | attn[:, :, range(attn.shape[2]), range(attn.shape[2])] = 0 108 | max_attn = np.flip(np.argsort(attn, axis=3),axis=3).astype(np.int16)[:,:,:,:20] 109 | max_attn_data.append({"python_tokens": example["python_tokens"], "tokens": example["tokens"], 110 | "max_attn": max_attn, "id": example["id"]}) 111 | 112 | 113 | 114 | with open("data/attention/codebert_python_full_attn_sorted.pkl", "wb") as f: 115 | pickle.dump(max_attn_data,f) 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /evaluating_models/PL/run_exp_java.sh: -------------------------------------------------------------------------------- 1 | python ./attention-analysis/extract_attention.py --preprocessed-data-file data/CuBERT_tokenized/java_0_1000.json --bert-dir data/cubert_model_java --max_sequence_length 512 --batch_size 4 --word_level 2 | sleep 1 3 | 4 | max=13 5 | for (( i=1; i <= $max; ++i )) 6 | do 7 | echo "${i}" 8 | python ./attention-analysis/extract_attention.py --preprocessed-data-file data/CuBERT_tokenized/java_${i}000_$((${i}+1))000.json --bert-dir data/cubert_model_java --max_sequence_length 512 --batch_size 4 --word_level 9 | sleep 1 10 | done -------------------------------------------------------------------------------- /evaluating_models/PL/run_exp_python.sh: -------------------------------------------------------------------------------- 1 | python ./attention-analysis/extract_attention.py --preprocessed-data-file data/CuBERT_tokenized/python_0_1000.json --bert-dir data/cubert_model_python --max_sequence_length 512 --batch_size 4 --word_level 2 | sleep 1 3 | 4 | max=18 5 | for (( i=1; i <= $max; ++i )) 6 | do 7 | echo "${i}" 8 | python ./attention-analysis/extract_attention.py --preprocessed-data-file data/CuBERT_tokenized/python_${i}000_$((${i}+1))000.json --bert-dir data/cubert_model_python --max_sequence_length 512 --batch_size 4 --word_level 9 | sleep 1 10 | done -------------------------------------------------------------------------------- /evaluating_models/PL/tokenize_and_align_codebert.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizer, RobertaConfig, RobertaModel, RobertaTokenizerFast 2 | import re 3 | import io 4 | import pickle 5 | import json 6 | import tokenize, javalang 7 | 8 | def convert_line_offset_to_char_offset(line_offset, line_to_length): 9 | line, offset = line_offset 10 | if line > 1: 11 | offset += sum(line_to_length[:line-1]) + line-1 12 | return offset 13 | 14 | # run tokenization and generate subtoken-token alignment 15 | # Align cubert sub-tokens with python token 16 | # Rewrite Attention-analysis’ tokenize_and_align(). 17 | # returns a list of lists where each sub-list 18 | # contains cubert-tokenized tokens for the correponding word. 19 | # each cubert-subtoken is represented by int starting from 0 20 | # e.g. [[1], [2, 3]] where special token class is always at index 0 21 | 22 | # By default, javalang tokenizer does not produce '\n' tokens while python tokenizer does. 23 | # For Java, if we want to add '\n' into tokens, we need to produce different alignments too. 24 | 25 | def align_codebert_tokens(codebert_encodings, python_tokens, code, id=-1): 26 | # get start and end of python tokens 27 | line_to_length = [len(l) for l in code.split("\n")] # stores the number of chars in each line of source code 28 | python_token_range = [(convert_line_offset_to_char_offset(token.start, line_to_length), 29 | convert_line_offset_to_char_offset(token.end, line_to_length), 30 | token.string) for token in python_tokens] 31 | 32 | # produce alignment by using start and end 33 | result = [] 34 | for python_token in python_token_range: 35 | start, end, token = python_token 36 | codebert_token_index_first = codebert_encodings.char_to_token(start) # position in souce code -> token index 37 | codebert_token_index_last = codebert_encodings.char_to_token(end-1) 38 | if codebert_token_index_first != None and codebert_token_index_last!= None: 39 | result.append([*range(codebert_token_index_first, codebert_token_index_last+1)]) 40 | # print(repr(token), codebert_encodings.tokens()[codebert_token_index_first: codebert_token_index_last+1]) 41 | else: 42 | # not found 43 | if (not token.startswith('#') and not token.startswith('"""') and 44 | not token.startswith("'''") and token != "" and not token.isspace()): 45 | print("id =", id, ": python token not found: ", repr(token)) 46 | result.append([]) 47 | 48 | 49 | return result 50 | 51 | 52 | def convert_line_offset_to_char_offset_java(line_offset, line_to_length): 53 | line, offset = line_offset 54 | offset -= 1 55 | if line > 1: 56 | offset += sum(line_to_length[:line-1]) + line-1 57 | return offset 58 | 59 | def convert_char_offset_to_line_offset_java(char_offset, line_to_length): 60 | line = 0 61 | while char_offset > line_to_length[line]: 62 | char_offset -= line_to_length[line]+1 63 | line += 1 64 | return javalang.tokenizer.Position(line+1, char_offset+1) 65 | 66 | def tokenize_java_code(code, keep_string_only=False, add_new_lines=False): 67 | tokens = [token for token in javalang.tokenizer.tokenize(code)] 68 | 69 | if add_new_lines: 70 | line_to_length = [len(l) for l in code.split("\n")] # stores the number of chars in each line of source code 71 | token_range = [(convert_line_offset_to_char_offset_java(token.position, line_to_length), 72 | convert_line_offset_to_char_offset_java((token.position[0], token.position[1]+len(token.value)), line_to_length), 73 | token.value) for token in tokens] 74 | 75 | i = 0 76 | while i < len(tokens) -1: 77 | code_slice = code[token_range[i][1]:token_range[i+1][0]] 78 | matches = re.finditer("\n", code_slice) 79 | new_line_index=[match.start()+token_range[i][1] for match in matches] 80 | tokens = tokens[0:i+1] + [javalang.tokenizer.JavaToken('\n', 81 | convert_char_offset_to_line_offset_java(idx, line_to_length)) 82 | for idx in new_line_index] + tokens[i+1:] 83 | token_range = token_range[0:i+1] + [(i, i+1, '\n') for i in new_line_index] + token_range[i+1:] 84 | i+=1 85 | 86 | if keep_string_only: 87 | tokens = [token.value for token in tokens] 88 | return tokens 89 | 90 | # run tokenization and generate subtoken-token alignment 91 | # Align cubert sub-tokens with python token 92 | # Rewrite Attention-analysis’ tokenize_and_align(). 93 | # returns a list of lists where each sub-list 94 | # contains cubert-tokenized tokens for the correponding word. 95 | # each cubert-subtoken is represented by int starting from 0 96 | # e.g. [[1], [2, 3]] where special token class is always at index 0 97 | def align_codebert_tokens_java(codebert_encodings, python_tokens, code, id=-1): 98 | # get start and end of python tokens 99 | line_to_length = [len(l) for l in code.split("\n")] # stores the number of chars in each line of source code 100 | python_token_range = [(convert_line_offset_to_char_offset_java(token.position, line_to_length), 101 | convert_line_offset_to_char_offset_java((token.position[0], token.position[1]+len(token.value)), line_to_length), 102 | token.value) for token in python_tokens] 103 | 104 | # produce alignment by using start and end 105 | result = [] 106 | for python_token in python_token_range: 107 | start, end, token = python_token 108 | codebert_token_index_first = codebert_encodings.char_to_token(start) 109 | codebert_token_index_last = codebert_encodings.char_to_token(end-1) 110 | if codebert_token_index_first != None and codebert_token_index_last!= None: 111 | result.append([*range(codebert_token_index_first, codebert_token_index_last+1)]) 112 | # print(repr(token), codebert_encodings.tokens()[codebert_token_index_first: codebert_token_index_last+1]) 113 | else: 114 | # not found 115 | if (not token.isspace()): 116 | print("id =", id, ": java token not found: ", repr(token)) 117 | result.append([]) 118 | 119 | return result 120 | 121 | 122 | 123 | # cls code_tokens sep 124 | tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") 125 | 126 | # python 127 | with open('../../generating_CodeSyntax/deduplicated_python_code.pickle', 'rb') as f: 128 | sample = pickle.load(f) 129 | examples = [] 130 | for i in range(len(sample)): 131 | if i%1000 == 0: 132 | print("processing sample ", i) 133 | encodings = tokenizer(sample[i]) 134 | # class and sep special tokens are already added 135 | if len(encodings.tokens()) < 512: 136 | # generate alignment 137 | python_tokens = [] 138 | for token in tokenize.generate_tokens(io.StringIO(sample[i]).readline): 139 | python_tokens.append(token) 140 | alignment = align_codebert_tokens(encodings, python_tokens, sample[i], i) 141 | python_tokens = [token.string for token in python_tokens] 142 | examples.append({'input_ids': encodings['input_ids'], "tokens": encodings.tokens(), 143 | "id": i, "alignment": alignment, "python_tokens": python_tokens}) 144 | print(len(examples)) 145 | 146 | with open("data/CodeBERT_tokenized/python_full.json", 'w') as f: 147 | json.dump(examples, f, indent=2) 148 | 149 | 150 | # java 151 | # cls code_tokens sep 152 | add_new_lines = False 153 | output_file_name = "_with_new_lines" if add_new_lines else "" 154 | 155 | tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") 156 | 157 | with open('../../generating_CodeSyntax/deduplicated_java_code.pickle', 'rb') as f: 158 | sample = pickle.load(f) 159 | 160 | has_unicode = [] 161 | for i in range(len(sample)): 162 | if '\\u' in sample[i]: 163 | has_unicode.append(i) 164 | 165 | examples = [] 166 | for i in range(len(sample)): 167 | if i%1000 == 0: 168 | print("processing sample ", i) 169 | encodings = tokenizer(sample[i]) 170 | # class and sep special tokens are already added 171 | if i not in has_unicode and len(encodings.tokens()) < 512: 172 | # generate alignment 173 | python_tokens = tokenize_java_code(sample[i], keep_string_only=False, add_new_lines=add_new_lines) 174 | alignment = align_codebert_tokens_java(encodings, python_tokens, sample[i], i) 175 | python_tokens = [token.value for token in python_tokens] 176 | examples.append({'input_ids': encodings['input_ids'], "tokens": encodings.tokens(), 177 | "id": i, "alignment": alignment, "java_tokens": python_tokens}) 178 | print(len(examples)) 179 | 180 | with open("data/CodeBERT_tokenized/java_full"+output_file_name+".json", 'w') as f: 181 | json.dump(examples, f, indent=2) -------------------------------------------------------------------------------- /evaluating_models/PL/tokenize_and_align_cubert_java.py: -------------------------------------------------------------------------------- 1 | from cubert import python_tokenizer 2 | from cubert import java_tokenizer 3 | from cubert import code_to_subtokenized_sentences 4 | import re 5 | import io 6 | import javalang 7 | import tokenize 8 | from tensor2tensor.data_generators import text_encoder 9 | import pickle 10 | import json 11 | 12 | # Align cubert sub-tokens with java token 13 | # Rewrite Attention-analysis’ tokenize_and_align(). 14 | # returns a list of lists where each sub-list 15 | # contains cubert-tokenized tokens for the correponding word. 16 | # each cubert-subtoken is represented by int starting from 0 17 | # e.g. [[0, 1], [2]] 18 | 19 | # By default, javalang tokenizer does not produce '\n' tokens while python tokenizer does. 20 | # For Java, if we want to add '\n' into tokens, we need to produce different alignments too. 21 | 22 | token_map = {'0': '0_', '\n':"\\u\\u\\uNEWLINE\\u\\u\\u_", '\r\n':"\\u\\u\\uNEWLINE\\u\\u\\u_", 23 | '\n\t':"\\u\\u\\uNEWLINE\\u\\u\\u_"} 24 | ignore_set = set([]) 25 | new_line = set(['\n', '\r\n', '\n\t']) 26 | 27 | def convert_line_offset_to_char_offset_java(line_offset, line_to_length): 28 | line, offset = line_offset 29 | offset -= 1 30 | if line > 1: 31 | offset += sum(line_to_length[:line-1]) + line-1 32 | return offset 33 | 34 | def convert_char_offset_to_line_offset_java(char_offset, line_to_length): 35 | line = 0 36 | while char_offset > line_to_length[line]: 37 | char_offset -= line_to_length[line]+1 38 | line += 1 39 | return javalang.tokenizer.Position(line+1, char_offset+1) 40 | 41 | def tokenize_java_code(code, keep_string_only=False, add_new_lines=False): 42 | tokens = [token for token in javalang.tokenizer.tokenize(code)] 43 | 44 | if add_new_lines: 45 | line_to_length = [len(l) for l in code.split("\n")] # stores the number of chars in each line of source code 46 | token_range = [(convert_line_offset_to_char_offset_java(token.position, line_to_length), 47 | convert_line_offset_to_char_offset_java((token.position[0], token.position[1]+len(token.value)), line_to_length), 48 | token.value) for token in tokens] 49 | 50 | i = 0 51 | while i < len(tokens) -1: 52 | code_slice = code[token_range[i][1]:token_range[i+1][0]] 53 | matches = re.finditer("\n", code_slice) 54 | new_line_index=[match.start()+token_range[i][1] for match in matches] 55 | tokens = tokens[0:i+1] + [javalang.tokenizer.JavaToken('\n', 56 | convert_char_offset_to_line_offset_java(idx, line_to_length)) 57 | for idx in new_line_index] + tokens[i+1:] 58 | token_range = token_range[0:i+1] + [(i, i+1, '\n') for i in new_line_index] + token_range[i+1:] 59 | i+=1 60 | 61 | if keep_string_only: 62 | tokens = [token.value for token in tokens] 63 | return tokens 64 | 65 | def align_cubert_tokens_java(cubert_tokens, python_tokens, id=-1): 66 | result = [] 67 | current_idx = 0 68 | for python_token in python_tokens: 69 | # convert python token into cubert token 70 | if python_token in token_map.keys(): 71 | subtokenized_python_token = [[token_map[python_token], 'EOS']] 72 | else: 73 | subtokenized_python_token = ( 74 | code_to_subtokenized_sentences.code_to_cubert_sentences( 75 | code=python_token, 76 | initial_tokenizer=tokenizer, 77 | subword_tokenizer=subword_tokenizer)) 78 | # print(python_token, subtokenized_python_token) 79 | #ignore invlaid tokens and comments 80 | if len(subtokenized_python_token) >0 and subtokenized_python_token[0][0]!='#^_': 81 | subtokenized_python_token = subtokenized_python_token[0][0:-1] 82 | 83 | length = len(subtokenized_python_token) 84 | remaining_cubert_tokens = cubert_tokens[current_idx:] 85 | found = False 86 | 87 | if python_token in new_line: 88 | # print(repr(python_token), remaining_cubert_tokens[0:1]) 89 | if remaining_cubert_tokens[0] == "\\u\\u\\uNEWLINE\\u\\u\\u_" or remaining_cubert_tokens[0] == "\\u\\u\\uNL\\u\\u\\u_": 90 | result.append([*range(current_idx, current_idx+1)]) 91 | current_idx += 1 92 | found = True 93 | else: 94 | for i in range(len(remaining_cubert_tokens) - length +1): 95 | if remaining_cubert_tokens[i:i+length] == subtokenized_python_token and (i == 0 or not remaining_cubert_tokens[i-1].endswith("u^_")): 96 | # print(python_token, 'skipping', i,'tokens.', repr(python_token), remaining_cubert_tokens[i:i+length]) 97 | # print('remaining tokens', remaining_cubert_tokens[i+length:i+length+5]) 98 | result.append([*range(current_idx+i, current_idx+i+length)]) 99 | current_idx += length+i 100 | found = True 101 | break 102 | if not found: 103 | if python_token not in new_line: 104 | print("id =", id, ": java token not found: ", repr(python_token)) 105 | result.append([]) 106 | 107 | else: 108 | if python_token not in ignore_set and not python_token.isspace(): 109 | print("ignoring java token: ", repr(python_token)) 110 | result.append([]) 111 | 112 | return result 113 | 114 | 115 | tokenizer = java_tokenizer.JavaTokenizer() 116 | subword_tokenizer = text_encoder.SubwordTextEncoder("cubert/vocab_java.txt") 117 | 118 | 119 | 120 | # save subtoken alignment to json 121 | with open('../../generating_CodeSyntax/deduplicated_java_code.pickle', 'rb') as f: 122 | sample = pickle.load(f) 123 | examples = [] 124 | add_new_lines = False 125 | output_file_name = "_with_new_lines" if add_new_lines else "" 126 | 127 | total_count=len(sample) 128 | batch_count=int(total_count/1000)+1 129 | for b in range(0, batch_count): 130 | examples = [] 131 | for i in range(1000*b, min(1000*(b+1),total_count)): 132 | try: 133 | cubert_tokens = (code_to_subtokenized_sentences.code_to_cubert_sentences( 134 | code=sample[i], 135 | initial_tokenizer=tokenizer, 136 | subword_tokenizer=subword_tokenizer)) 137 | tokens = ["[CLS]_"] 138 | for sentence in cubert_tokens: 139 | tokens.extend(sentence) 140 | tokens.append("[SEP]_") 141 | if len(tokens) < 512: 142 | # generate alignment 143 | python_tokens = tokenize_java_code(sample[i], keep_string_only=True, add_new_lines=add_new_lines) 144 | if tokens == ['[CLS]_', '\\u\\u\\u', 'ERROR', '\\u\\u\\u_', '\\u\\u\\uNEWLINE\\u\\u\\u_', '[SEP]_']: 145 | print("skipping", i, "because of cubert tokenization error") 146 | continue 147 | alignment = align_cubert_tokens_java(tokens, python_tokens, id=i) 148 | examples.append({"tokens": tokens, "id": i, "alignment": alignment, "java_tokens": python_tokens}) 149 | except Exception as e: 150 | print("skipping", i, "because of", e) 151 | print("length of batch", b, ":", len(examples)) 152 | 153 | with open("data/CuBERT_tokenized/java_"+str(1000*b)+"_"+str(1000*(b+1))+output_file_name+".json", 'w') as f: 154 | json.dump(examples, f, indent=2) 155 | -------------------------------------------------------------------------------- /evaluating_models/PL/tokenize_and_align_cubert_python.py: -------------------------------------------------------------------------------- 1 | from cubert import python_tokenizer 2 | from cubert import java_tokenizer 3 | from cubert import code_to_subtokenized_sentences 4 | import re 5 | import io 6 | import javalang 7 | import tokenize 8 | from tensor2tensor.data_generators import text_encoder 9 | import pickle 10 | import json 11 | 12 | # Align cubert sub-tokens with python token 13 | # Rewrite Attention-analysis’ tokenize_and_align(). 14 | # returns a list of lists where each sub-list 15 | # contains cubert-tokenized tokens for the correponding word. 16 | # each cubert-subtoken is represented by int starting from 0 17 | # e.g. [[0, 1], [2]] 18 | 19 | # By default, javalang tokenizer does not produce '\n' tokens while python tokenizer does. 20 | # For Java, if we want to add '\n' into tokens, we need to produce different alignments too. 21 | 22 | token_map = {'(': '(_', '[': '[_', "{":"{_", "}": "}_", ')': ')_', ']': ']_', 23 | ' ':'\\u\\u\\uINDENT\\u\\u\\u ', ' ':'\\u\\u\\uINDENT\\u\\u\\u ', ' ':'\\u\\u\\uINDENT\\u\\u\\u ', 24 | ' ':'\\u\\u\\uINDENT\\u\\u\\u ', ' ':'\\u\\u\\uINDENT\\u\\u\\u ', 25 | ' ':'\\u\\u\\uINDENT\\u\\u\\u ', '\n':"\\u\\u\\uNEWLINE\\u\\u\\u_", '\r\n':"\\u\\u\\uNEWLINE\\u\\u\\u_", 26 | '\n\t':"\\u\\u\\uNEWLINE\\u\\u\\u_"} 27 | ignore_set = set(['']) 28 | new_line = set(['\n', '\r\n', '\n\t']) 29 | 30 | def align_cubert_tokens(cubert_tokens, python_tokens, id=-1): 31 | result = [] 32 | current_idx = 0 33 | for python_token in python_tokens: 34 | # convert python token into cubert token 35 | if python_token in token_map.keys(): 36 | subtokenized_python_token = [[token_map[python_token], 'EOS']] 37 | else: 38 | subtokenized_python_token = ( 39 | code_to_subtokenized_sentences.code_to_cubert_sentences( 40 | code=python_token, 41 | initial_tokenizer=tokenizer, 42 | subword_tokenizer=subword_tokenizer)) 43 | # print(python_token, subtokenized_python_token) 44 | #ignore invlaid tokens and comments 45 | if len(subtokenized_python_token) >0 and subtokenized_python_token[0][0]!='#^_': 46 | subtokenized_python_token = subtokenized_python_token[0][0:-1] 47 | 48 | length = len(subtokenized_python_token) 49 | remaining_cubert_tokens = cubert_tokens[current_idx:] 50 | found = False 51 | 52 | if python_token in new_line: 53 | # print(repr(python_token), remaining_cubert_tokens[0:1]) 54 | if remaining_cubert_tokens[0] == "\\u\\u\\uNEWLINE\\u\\u\\u_" or remaining_cubert_tokens[0] == "\\u\\u\\uNL\\u\\u\\u_": 55 | result.append([*range(current_idx, current_idx+1)]) 56 | current_idx += 1 57 | found = True 58 | else: 59 | for i in range(len(remaining_cubert_tokens) - length +1): 60 | if remaining_cubert_tokens[i:i+length] == subtokenized_python_token and (i == 0 or not remaining_cubert_tokens[i-1].endswith("u^_")): 61 | # print('skipping', i,'tokens.', repr(python_token), remaining_cubert_tokens[i:i+length]) 62 | # print('remaining tokens', remaining_cubert_tokens[i+length:i+length+5]) 63 | result.append([*range(current_idx+i, current_idx+i+length)]) 64 | current_idx += length+i 65 | found = True 66 | break 67 | if not found: 68 | if (not python_token.startswith('"""') and not python_token.startswith("'''") 69 | and not python_token.startswith("#") and python_token not in new_line): 70 | print("id =", id, ": python token not found: ", repr(python_token)) 71 | print('remaining tokens', remaining_cubert_tokens[0:10]) 72 | result.append([]) 73 | 74 | else: 75 | if python_token not in ignore_set and not python_token.startswith("#") and not python_token.isspace(): 76 | print("ignoring python token: ", repr(python_token)) 77 | result.append([]) 78 | 79 | return result 80 | 81 | 82 | tokenizer = python_tokenizer.PythonTokenizer() 83 | subword_tokenizer = text_encoder.SubwordTextEncoder("cubert/vocab.txt") 84 | 85 | 86 | 87 | # save subtoken alignment to json 88 | with open('../../generating_CodeSyntax/deduplicated_python_code.pickle', 'rb') as f: 89 | sample = pickle.load(f) 90 | examples = [] 91 | 92 | total_count=len(sample) 93 | batch_count=int(total_count/1000)+1 94 | for b in range(0, batch_count): 95 | examples = [] 96 | for i in range(1000*b, min(1000*(b+1),total_count)): 97 | try: 98 | cubert_tokens = (code_to_subtokenized_sentences.code_to_cubert_sentences( 99 | code=sample[i], 100 | initial_tokenizer=tokenizer, 101 | subword_tokenizer=subword_tokenizer)) 102 | tokens = ["[CLS]_"] 103 | for sentence in cubert_tokens: 104 | tokens.extend(sentence) 105 | tokens.append("[SEP]_") 106 | if len(tokens) < 512: 107 | # generate alignment 108 | python_tokens = [] 109 | for token in tokenize.generate_tokens(io.StringIO(sample[i]).readline): 110 | python_tokens.append(token.string) 111 | if tokens == ['[CLS]_', '\\u\\u\\u', 'ERROR', '\\u\\u\\u_', '\\u\\u\\uNEWLINE\\u\\u\\u_', '[SEP]_']: 112 | print("skipping", i, "because of cubert tokenization error") 113 | continue 114 | alignment = align_cubert_tokens(tokens, python_tokens, id=i) 115 | examples.append({"tokens": tokens, "id": i, "alignment": alignment, "python_tokens": python_tokens}) 116 | except Exception as e: 117 | print("skipping", i, "because of", e) 118 | print("length of batch", b, ":", len(examples)) 119 | 120 | with open("data/CuBERT_tokenized/python_"+str(1000*b)+"_"+str(1000*(b+1))+".json", 'w') as f: 121 | json.dump(examples, f, indent=2) 122 | -------------------------------------------------------------------------------- /evaluating_models/PL/topk_scores_codebert_attention_java.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | from collections import defaultdict 5 | import itertools 6 | 7 | # Code for evaluating individual attention maps and baselines 8 | # at word level 9 | # metric: "first", "any", "last" 10 | # Code for evaluating individual attention maps and baselines 11 | # at word level 12 | # metric: "first", "any", "last" 13 | def evaluate_predictor_topk(max_attn_data, dataset, reln, attn_layer, attn_head, max_k = 20, metric="first"): 14 | scores = np.zeros((max_k+1), dtype = 'float') 15 | n_correct = [0]*(max_k+1) 16 | n_incorrect = [0]*(max_k+1) 17 | for index, example in enumerate(max_attn_data): 18 | i = example['id'] 19 | if reln in dataset[i]["relns"]: 20 | for head_idx, dep_range_start_idx, dep_range_end_idx in dataset[i]["relns"][reln]: 21 | for k in range(1, max_k+1): 22 | k_th_prediction = example["max_attn"][attn_layer][attn_head][head_idx][k-1] 23 | if ((metric == "first" and k_th_prediction == dep_range_start_idx) or 24 | (metric == "last" and k_th_prediction == dep_range_end_idx) or 25 | (metric == "any" and k_th_prediction >= dep_range_start_idx and k_th_prediction <= dep_range_end_idx)): 26 | n_correct[k:] = [c+1 for c in n_correct[k:]] 27 | break 28 | else: 29 | n_incorrect[k] += 1 30 | for k in range(1, max_k+1): 31 | if (n_correct[k] + n_incorrect[k]) == 0: 32 | scores[k] = None 33 | else: 34 | scores[k] = n_correct[k] / float(n_correct[k] + n_incorrect[k]) 35 | return scores 36 | 37 | 38 | 39 | 40 | def get_relns(dataset): 41 | relns = set() 42 | for example in dataset: 43 | for reln in example["relns"].keys(): 44 | relns.add(reln) 45 | relns = list(relns) 46 | relns.sort() 47 | return relns 48 | 49 | # scores[reln][layer][head] 50 | def get_scores(max_attn_data, dataset, relns, max_k=20, metric="first"): 51 | scores = {} 52 | for reln in relns: 53 | print("processing relationship: ", reln) 54 | num_layer = max_attn_data[0]["max_attn"].shape[0] 55 | num_head = max_attn_data[0]["max_attn"].shape[1] 56 | scores[reln] = np.zeros((num_layer, num_head, max_k+1), dtype = 'float') 57 | for layer in range(num_layer): 58 | for head in range(num_head): 59 | # if head == 0: 60 | # print("layer: ", layer) 61 | scores[reln][layer][head] = evaluate_predictor_topk(max_attn_data, dataset, reln, layer, head, max_k, metric) 62 | return scores 63 | 64 | 65 | 66 | # average topk scores for each relationship and categories (word level) 67 | def get_avg(scores, relns, max_k=20): 68 | reln_avg = [None]*(max_k+1) 69 | cat_avg = {} 70 | for cat, cat_relns in categories.items(): 71 | cat_avg[cat] = [None]*(max_k+1) 72 | 73 | for k in range(1, (max_k+1)): 74 | sum, count = 0, 0 75 | for cat, cat_relns in categories.items(): 76 | cat_sum, cat_count = 0, 0 77 | for cat_reln in cat_relns: 78 | for reln in relns: 79 | if reln.startswith(cat_reln+":"): 80 | flatten_idx = np.argmax(scores[reln][:,:,k]) 81 | num_head = scores[reln].shape[1] 82 | row = int(flatten_idx/num_head) 83 | col = flatten_idx % num_head 84 | sum += scores[reln][row][col][k] 85 | count += 1 86 | cat_sum += scores[reln][row][col][k] 87 | cat_count += 1 88 | cat_avg[cat][k] = cat_sum/cat_count 89 | reln_avg[k] = sum/count 90 | return (reln_avg, cat_avg) 91 | 92 | 93 | def print_attn_table(k, relns, scores): 94 | print("relationship\t\t accuracy\tlayer\thead") 95 | sum, count = 0, 0 96 | table = "" 97 | table2 = "category\t\t average accuracy\n" 98 | for cat, cate_relns in categories.items(): 99 | table += "===================" + cat.ljust(20,"=") + "==========\n" 100 | cate_sum, cate_count = 0, 0 101 | for cate_reln in cate_relns: 102 | for reln in relns: 103 | if reln.startswith(cate_reln+":"): 104 | flatten_idx = np.argmax(scores[reln][:,:,k]) 105 | num_head = scores[reln].shape[1] 106 | row = int(flatten_idx/num_head) 107 | col = flatten_idx % num_head 108 | table += reln.ljust(30) + str(round(scores[reln][row][col][k],3)).ljust(5) + "\t" + str(row) + "\t" + str(col) + '\n' 109 | sum += scores[reln][row][col][k] 110 | count += 1 111 | cate_sum += scores[reln][row][col][k] 112 | cate_count += 1 113 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 114 | print(table) 115 | print(table2) 116 | print("average of",count,"relations:", sum/count) 117 | 118 | 119 | 120 | def print_baseline_table(k, relns, reln_scores_topk): 121 | print("relationship\t\t accuracy\toffset") 122 | sum, count = 0, 0 123 | table = "" 124 | table2 = "category\t\t average accuracy\n" 125 | for cat, cate_relns in categories.items(): 126 | table += "===================" + cat.ljust(20,"=") + "==========\n" 127 | cate_sum, cate_count = 0, 0 128 | for cate_reln in cate_relns: 129 | for reln in relns: 130 | if reln.startswith(cate_reln+":"): 131 | table += reln.ljust(30) + str(round(reln_scores_topk[reln][k][0],3)).ljust(5) + "\t" + str(reln_scores_topk[reln][k][1])[1:-1] + '\n' 132 | sum += reln_scores_topk[reln][k][0] 133 | count += 1 134 | cate_sum += reln_scores_topk[reln][k][0] 135 | cate_count += 1 136 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 137 | print(table) 138 | print(table2) 139 | print("average of",count,"relations:", sum/count) 140 | 141 | def print_attn_baseline_table(k, relns, attn_scores, baseline_reln_scores_topk): 142 | print("relationship\t\tattention\tbaseline\toffset") 143 | attn_sum, count, baseline_sum = 0, 0, 0 144 | table = "" 145 | table2 = "category\t\tattention\tbaseline\n" 146 | for cat, cate_relns in categories.items(): 147 | table += "=========================" + cat.ljust(20,"=") + "================\n" 148 | attn_cate_sum, cate_count, baseline_cate_sum = 0, 0, 0 149 | for cate_reln in cate_relns: 150 | for reln in relns: 151 | if reln.startswith(cate_reln+":"): 152 | flatten_idx = np.argmax(attn_scores[reln][:,:,k]) 153 | num_head = attn_scores[reln].shape[1] 154 | row = int(flatten_idx/num_head) 155 | col = flatten_idx % num_head 156 | table += reln.ljust(30) + str(round(attn_scores[reln][row][col][k],3)).ljust(5) + "\t" + str(round(baseline_reln_scores_topk[reln][k][0],3)).ljust(5) + "\t\t" + str(baseline_reln_scores_topk[reln][k][1])[1:-1] + '\n' 157 | attn_sum += attn_scores[reln][row][col][k] 158 | baseline_sum += baseline_reln_scores_topk[reln][k][0] 159 | count += 1 160 | attn_cate_sum += attn_scores[reln][row][col][k] 161 | baseline_cate_sum += baseline_reln_scores_topk[reln][k][0] 162 | cate_count += 1 163 | table2 += cat.ljust(20) + "\t\t"+str(round(attn_cate_sum/cate_count,3)) + "\t"+str(round(baseline_cate_sum/cate_count,3)) + "\n" 164 | print(table) 165 | print(table2) 166 | print("attention average of",count,"relations:", attn_sum/count) 167 | print("baseline average of",count,"relations:", baseline_sum/count) 168 | 169 | 170 | categories = {'Control Flow': ['If', 'For', 'While', 'Try', "Do", "Switch"], 171 | 'Expressions': [ 'InfixExpr', 'Call', 'IfExp', 'Attribute', "InstanceofExpr"], 172 | 'Expr-Subscripting': ['Subscript'], 173 | 'Statements': ['Assign', "LabeledStatement"], 174 | 'Vague': ['children'] 175 | } 176 | 177 | 178 | # max_attn_data = [] 179 | # with open("data/attention/codebert_java_full_attn_sorted.pkl", "rb") as f: 180 | # max_attn_data = pickle.load(f) 181 | # with open("data/attention/codebert_java_full_attn_sorted_valid.pkl", "wb") as f: 182 | # pickle.dump([c for c in max_attn_data if c["id"]>=11009], f) 183 | # with open("data/attention/codebert_java_full_attn_sorted_test.pkl", "wb") as f: 184 | # pickle.dump([c for c in max_attn_data if c["id"]<11009], f) 185 | 186 | 187 | Skip_semicolon = False 188 | add_new_lines = False 189 | assert not (add_new_lines == True and Skip_semicolon == True) 190 | 191 | 192 | if add_new_lines: 193 | output_file_name = "_with_new_lines" 194 | dataset_filename="CodeSyntax_java_with_new_lines.json" 195 | elif Skip_semicolon: 196 | output_file_name = "_skip_semicolon" 197 | dataset_filename="CodeSyntax_java_skip_semicolon.json" 198 | else: 199 | output_file_name = "" 200 | dataset_filename = "CodeSyntax_java.json" 201 | 202 | with open("../../CodeSyntax/"+dataset_filename, "r") as f: 203 | dataset = json.load(f) 204 | relns = get_relns(dataset) 205 | print("relations", relns) 206 | 207 | for metric in ["first", "any", "last"]: 208 | for partition in ["valid", "test"]: 209 | with open("data/attention/codebert_java_full_attn_sorted_"+partition+"_common.pkl", "rb") as f: 210 | max_attn_data = pickle.load(f) 211 | 212 | print(partition, metric) 213 | scores = get_scores(max_attn_data, dataset, relns, max_k=20, metric=metric) 214 | reln_avg, cat_avg = get_avg(scores, relns, max_k=20) 215 | 216 | print_attn_table(1, relns, scores) 217 | 218 | with open("data/scores/codebert_java_full_topk_scores_"+partition+"_"+metric+output_file_name+"_common.pkl", "wb") as f: 219 | pickle.dump(scores, f) -------------------------------------------------------------------------------- /evaluating_models/PL/topk_scores_codebert_attention_python.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | from collections import defaultdict 5 | import itertools 6 | 7 | # Code for evaluating individual attention maps and baselines 8 | # at word level 9 | # metric: "first", "any", "last" 10 | # Code for evaluating individual attention maps and baselines 11 | # at word level 12 | # metric: "first", "any", "last" 13 | def evaluate_predictor_topk(max_attn_data, dataset, reln, attn_layer, attn_head, max_k = 20, metric="first"): 14 | scores = np.zeros((max_k+1), dtype = 'float') 15 | n_correct = [0]*(max_k+1) 16 | n_incorrect = [0]*(max_k+1) 17 | for index, example in enumerate(max_attn_data): 18 | i = example['id'] 19 | if reln in dataset[i]["relns"]: 20 | for head_idx, dep_range_start_idx, dep_range_end_idx in dataset[i]["relns"][reln]: 21 | for k in range(1, max_k+1): 22 | k_th_prediction = example["max_attn"][attn_layer][attn_head][head_idx][k-1] 23 | if ((metric == "first" and k_th_prediction == dep_range_start_idx) or 24 | (metric == "last" and k_th_prediction == dep_range_end_idx) or 25 | (metric == "any" and k_th_prediction >= dep_range_start_idx and k_th_prediction <= dep_range_end_idx)): 26 | n_correct[k:] = [c+1 for c in n_correct[k:]] 27 | break 28 | else: 29 | n_incorrect[k] += 1 30 | for k in range(1, max_k+1): 31 | if (n_correct[k] + n_incorrect[k]) == 0: 32 | scores[k] = None 33 | else: 34 | scores[k] = n_correct[k] / float(n_correct[k] + n_incorrect[k]) 35 | return scores 36 | 37 | 38 | 39 | 40 | def get_relns(dataset): 41 | relns = set() 42 | for example in dataset: 43 | for reln in example["relns"].keys(): 44 | relns.add(reln) 45 | relns = list(relns) 46 | relns.sort() 47 | return relns 48 | 49 | # scores[reln][layer][head] 50 | def get_scores(max_attn_data, dataset, relns, max_k=20, metric="first"): 51 | scores = {} 52 | for reln in relns: 53 | print("processing relationship: ", reln) 54 | num_layer = max_attn_data[0]["max_attn"].shape[0] 55 | num_head = max_attn_data[0]["max_attn"].shape[1] 56 | scores[reln] = np.zeros((num_layer, num_head, max_k+1), dtype = 'float') 57 | for layer in range(num_layer): 58 | for head in range(num_head): 59 | # if head == 0: 60 | # print("layer: ", layer) 61 | scores[reln][layer][head] = evaluate_predictor_topk(max_attn_data, dataset, reln, layer, head, max_k, metric) 62 | return scores 63 | 64 | categories = {'Control Flow': ['If', 'For', 'While', 'Try'], 65 | 'Expressions': ['BinOp', 'BoolOp', 'Compare', 'Call', 'IfExp', 'Attribute'], 66 | 'Expr-Subscripting': ['Subscript'], 67 | 'Statements': ['Assign', 'AugAssign'], 68 | 'Vague': ['children'] 69 | } 70 | 71 | 72 | # average topk scores for each relationship and categories (word level) 73 | def get_avg(scores, relns, max_k=20): 74 | reln_avg = [None]*(max_k+1) 75 | cat_avg = {} 76 | for cat, cat_relns in categories.items(): 77 | cat_avg[cat] = [None]*(max_k+1) 78 | 79 | for k in range(1, (max_k+1)): 80 | sum, count = 0, 0 81 | for cat, cat_relns in categories.items(): 82 | cat_sum, cat_count = 0, 0 83 | for cat_reln in cat_relns: 84 | for reln in relns: 85 | if reln.startswith(cat_reln+":"): 86 | flatten_idx = np.argmax(scores[reln][:,:,k]) 87 | num_head = scores[reln].shape[1] 88 | row = int(flatten_idx/num_head) 89 | col = flatten_idx % num_head 90 | sum += scores[reln][row][col][k] 91 | count += 1 92 | cat_sum += scores[reln][row][col][k] 93 | cat_count += 1 94 | cat_avg[cat][k] = cat_sum/cat_count 95 | reln_avg[k] = sum/count 96 | return (reln_avg, cat_avg) 97 | 98 | 99 | def print_attn_table(k, relns, scores): 100 | print("relationship\t\t accuracy\tlayer\thead") 101 | sum, count = 0, 0 102 | table = "" 103 | table2 = "category\t\t average accuracy\n" 104 | for cat, cate_relns in categories.items(): 105 | table += "===================" + cat.ljust(20,"=") + "==========\n" 106 | cate_sum, cate_count = 0, 0 107 | for cate_reln in cate_relns: 108 | for reln in relns: 109 | if reln.startswith(cate_reln+":"): 110 | flatten_idx = np.argmax(scores[reln][:,:,k]) 111 | num_head = scores[reln].shape[1] 112 | row = int(flatten_idx/num_head) 113 | col = flatten_idx % num_head 114 | table += reln.ljust(30) + str(round(scores[reln][row][col][k],3)).ljust(5) + "\t" + str(row) + "\t" + str(col) + '\n' 115 | sum += scores[reln][row][col][k] 116 | count += 1 117 | cate_sum += scores[reln][row][col][k] 118 | cate_count += 1 119 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 120 | print(table) 121 | print(table2) 122 | print("average of",count,"relations:", sum/count) 123 | 124 | 125 | 126 | def print_baseline_table(k, relns, reln_scores_topk): 127 | print("relationship\t\t accuracy\toffset") 128 | sum, count = 0, 0 129 | table = "" 130 | table2 = "category\t\t average accuracy\n" 131 | for cat, cate_relns in categories.items(): 132 | table += "===================" + cat.ljust(20,"=") + "==========\n" 133 | cate_sum, cate_count = 0, 0 134 | for cate_reln in cate_relns: 135 | for reln in relns: 136 | if reln.startswith(cate_reln+":"): 137 | table += reln.ljust(30) + str(round(reln_scores_topk[reln][k][0],3)).ljust(5) + "\t" + str(reln_scores_topk[reln][k][1])[1:-1] + '\n' 138 | sum += reln_scores_topk[reln][k][0] 139 | count += 1 140 | cate_sum += reln_scores_topk[reln][k][0] 141 | cate_count += 1 142 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 143 | print(table) 144 | print(table2) 145 | print("average of",count,"relations:", sum/count) 146 | 147 | def print_attn_baseline_table(k, relns, attn_scores, baseline_reln_scores_topk): 148 | print("relationship\t\tattention\tbaseline\toffset") 149 | attn_sum, count, baseline_sum = 0, 0, 0 150 | table = "" 151 | table2 = "category\t\tattention\tbaseline\n" 152 | for cat, cate_relns in categories.items(): 153 | table += "=========================" + cat.ljust(20,"=") + "================\n" 154 | attn_cate_sum, cate_count, baseline_cate_sum = 0, 0, 0 155 | for cate_reln in cate_relns: 156 | for reln in relns: 157 | if reln.startswith(cate_reln+":"): 158 | flatten_idx = np.argmax(attn_scores[reln][:,:,k]) 159 | num_head = attn_scores[reln].shape[1] 160 | row = int(flatten_idx/num_head) 161 | col = flatten_idx % num_head 162 | table += reln.ljust(30) + str(round(attn_scores[reln][row][col][k],3)).ljust(5) + "\t" + str(round(baseline_reln_scores_topk[reln][k][0],3)).ljust(5) + "\t\t" + str(baseline_reln_scores_topk[reln][k][1])[1:-1] + '\n' 163 | attn_sum += attn_scores[reln][row][col][k] 164 | baseline_sum += baseline_reln_scores_topk[reln][k][0] 165 | count += 1 166 | attn_cate_sum += attn_scores[reln][row][col][k] 167 | baseline_cate_sum += baseline_reln_scores_topk[reln][k][0] 168 | cate_count += 1 169 | table2 += cat.ljust(20) + "\t\t"+str(round(attn_cate_sum/cate_count,3)) + "\t"+str(round(baseline_cate_sum/cate_count,3)) + "\n" 170 | print(table) 171 | print(table2) 172 | print("attention average of",count,"relations:", attn_sum/count) 173 | print("baseline average of",count,"relations:", baseline_sum/count) 174 | 175 | 176 | 177 | 178 | 179 | 180 | # max_attn_data = [] 181 | # with open("data/attention/codebert_python_full_attn_sorted.pkl", "rb") as f: 182 | # max_attn_data = pickle.load(f) 183 | # with open("data/attention/codebert_python_full_attn_sorted_valid.pkl", "wb") as f: 184 | # pickle.dump([c for c in max_attn_data if c["id"]>=8680], f) 185 | # with open("data/attention/codebert_python_full_attn_sorted_test.pkl", "wb") as f: 186 | # pickle.dump([c for c in max_attn_data if c["id"]<8680], f) 187 | 188 | 189 | 190 | add_new_lines=False 191 | dataset_file_name = "CodeSyntax_python" 192 | output_file_name = "" if not add_new_lines else "_with_new_lines" 193 | if add_new_lines: 194 | dataset_file_name += "_with_new_lines" 195 | with open("../../CodeSyntax/"+dataset_file_name+".json", 'r') as f: 196 | dataset = json.load(f) 197 | relns = get_relns(dataset) 198 | print("relations", relns) 199 | 200 | for metric in ["first", "any", "last"]: 201 | for partition in ["valid", "test"]: 202 | with open("data/attention/codebert_python_full_attn_sorted_"+partition+"_common.pkl", "rb") as f: 203 | max_attn_data = pickle.load(f) 204 | 205 | print(partition, metric) 206 | scores = get_scores(max_attn_data, dataset, relns, max_k=20, metric=metric) 207 | reln_avg, cat_avg = get_avg(scores, relns, max_k=20) 208 | 209 | print_attn_table(1, relns, scores) 210 | 211 | with open("data/scores/codebert_python_full_topk_scores_"+partition+"_"+metric+output_file_name+"_common.pkl", "wb") as f: 212 | pickle.dump(scores, f) -------------------------------------------------------------------------------- /evaluating_models/PL/topk_scores_cubert_attention_java.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | from collections import defaultdict 5 | import itertools 6 | 7 | # Code for evaluating individual attention maps and baselines 8 | # at word level 9 | # metric: "first", "any", "last" 10 | # Code for evaluating individual attention maps and baselines 11 | # at word level 12 | # metric: "first", "any", "last" 13 | def evaluate_predictor_topk(max_attn_data, dataset, reln, attn_layer, attn_head, max_k = 20, metric="first"): 14 | scores = np.zeros((max_k+1), dtype = 'float') 15 | n_correct = [0]*(max_k+1) 16 | n_incorrect = [0]*(max_k+1) 17 | for index, example in enumerate(max_attn_data): 18 | i = example['id'] 19 | if reln in dataset[i]["relns"]: 20 | for head_idx, dep_range_start_idx, dep_range_end_idx in dataset[i]["relns"][reln]: 21 | for k in range(1, max_k+1): 22 | k_th_prediction = example["max_attn"][attn_layer][attn_head][head_idx][k-1] 23 | if ((metric == "first" and k_th_prediction == dep_range_start_idx) or 24 | (metric == "last" and k_th_prediction == dep_range_end_idx) or 25 | (metric == "any" and k_th_prediction >= dep_range_start_idx and k_th_prediction <= dep_range_end_idx)): 26 | n_correct[k:] = [c+1 for c in n_correct[k:]] 27 | break 28 | else: 29 | n_incorrect[k] += 1 30 | for k in range(1, max_k+1): 31 | if (n_correct[k] + n_incorrect[k]) == 0: 32 | scores[k] = None 33 | else: 34 | scores[k] = n_correct[k] / float(n_correct[k] + n_incorrect[k]) 35 | return scores 36 | 37 | 38 | 39 | 40 | def get_relns(dataset): 41 | relns = set() 42 | for example in dataset: 43 | for reln in example["relns"].keys(): 44 | relns.add(reln) 45 | relns = list(relns) 46 | relns.sort() 47 | return relns 48 | 49 | # scores[reln][layer][head] 50 | def get_scores(max_attn_data, dataset, relns, max_k=20, metric="first"): 51 | scores = {} 52 | for reln in relns: 53 | print("processing relationship: ", reln) 54 | scores[reln] = np.zeros((24, 16, max_k+1), dtype = 'float') 55 | for layer in range(24): 56 | for head in range(16): 57 | # if head == 0: 58 | # print("layer: ", layer) 59 | scores[reln][layer][head] = evaluate_predictor_topk(max_attn_data, dataset, reln, layer, head, max_k, metric) 60 | return scores 61 | 62 | 63 | # average topk scores for each relationship and categories (word level) 64 | def get_avg(scores, relns, max_k=20): 65 | reln_avg = [None]*(max_k+1) 66 | cat_avg = {} 67 | for cat, cat_relns in categories.items(): 68 | cat_avg[cat] = [None]*(max_k+1) 69 | 70 | for k in range(1, (max_k+1)): 71 | sum, count = 0, 0 72 | for cat, cat_relns in categories.items(): 73 | cat_sum, cat_count = 0, 0 74 | for cat_reln in cat_relns: 75 | for reln in relns: 76 | if reln.startswith(cat_reln+":"): 77 | flatten_idx = np.argmax(scores[reln][:,:,k]) 78 | row = int(flatten_idx/16) 79 | col = flatten_idx % 16 80 | sum += scores[reln][row][col][k] 81 | count += 1 82 | cat_sum += scores[reln][row][col][k] 83 | cat_count += 1 84 | cat_avg[cat][k] = cat_sum/cat_count 85 | reln_avg[k] = sum/count 86 | return (reln_avg, cat_avg) 87 | 88 | def print_attn_table(k, relns, scores): 89 | print("relationship\t\t accuracy\tlayer\thead") 90 | sum, count = 0, 0 91 | table = "" 92 | table2 = "category\t\t average accuracy\n" 93 | for cat, cate_relns in categories.items(): 94 | table += "===================" + cat.ljust(20,"=") + "==========\n" 95 | cate_sum, cate_count = 0, 0 96 | for cate_reln in cate_relns: 97 | for reln in relns: 98 | if reln.startswith(cate_reln+":"): 99 | flatten_idx = np.argmax(scores[reln][:,:,k]) 100 | row = int(flatten_idx/16) 101 | col = flatten_idx % 16 102 | table += reln.ljust(30) + str(round(scores[reln][row][col][k],3)).ljust(5) + "\t" + str(row) + "\t" + str(col) + '\n' 103 | sum += scores[reln][row][col][k] 104 | count += 1 105 | cate_sum += scores[reln][row][col][k] 106 | cate_count += 1 107 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 108 | print(table) 109 | print(table2) 110 | print("average of",count,"relations:", sum/count) 111 | 112 | 113 | def print_baseline_table(k, relns, reln_scores_topk): 114 | print("relationship\t\t accuracy\toffset") 115 | sum, count = 0, 0 116 | table = "" 117 | table2 = "category\t\t average accuracy\n" 118 | for cat, cate_relns in categories.items(): 119 | table += "===================" + cat.ljust(20,"=") + "==========\n" 120 | cate_sum, cate_count = 0, 0 121 | for cate_reln in cate_relns: 122 | for reln in relns: 123 | if reln.startswith(cate_reln+":"): 124 | table += reln.ljust(30) + str(round(reln_scores_topk[reln][k][0],3)).ljust(5) + "\t" + str(reln_scores_topk[reln][k][1])[1:-1] + '\n' 125 | sum += reln_scores_topk[reln][k][0] 126 | count += 1 127 | cate_sum += reln_scores_topk[reln][k][0] 128 | cate_count += 1 129 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 130 | print(table) 131 | print(table2) 132 | print("average of",count,"relations:", sum/count) 133 | 134 | def print_attn_baseline_table(k, relns, attn_scores, baseline_reln_scores_topk): 135 | print("relationship\t\tattention\tbaseline\toffset") 136 | attn_sum, count, baseline_sum = 0, 0, 0 137 | table = "" 138 | table2 = "category\t\tattention\tbaseline\n" 139 | for cat, cate_relns in categories.items(): 140 | table += "=========================" + cat.ljust(20,"=") + "================\n" 141 | attn_cate_sum, cate_count, baseline_cate_sum = 0, 0, 0 142 | for cate_reln in cate_relns: 143 | for reln in relns: 144 | if reln.startswith(cate_reln+":"): 145 | flatten_idx = np.argmax(attn_scores[reln][:,:,k]) 146 | row = int(flatten_idx/16) 147 | col = flatten_idx % 16 148 | table += reln.ljust(30) + str(round(attn_scores[reln][row][col][k],3)).ljust(5) + "\t" + str(round(baseline_reln_scores_topk[reln][k][0],3)).ljust(5) + "\t\t" + str(baseline_reln_scores_topk[reln][k][1])[1:-1] + '\n' 149 | attn_sum += attn_scores[reln][row][col][k] 150 | baseline_sum += baseline_reln_scores_topk[reln][k][0] 151 | count += 1 152 | attn_cate_sum += attn_scores[reln][row][col][k] 153 | baseline_cate_sum += baseline_reln_scores_topk[reln][k][0] 154 | cate_count += 1 155 | table2 += cat.ljust(20) + "\t\t"+str(round(attn_cate_sum/cate_count,3)) + "\t"+str(round(baseline_cate_sum/cate_count,3)) + "\n" 156 | print(table) 157 | print(table2) 158 | print("attention average of",count,"relations:", attn_sum/count) 159 | print("baseline average of",count,"relations:", baseline_sum/count) 160 | 161 | 162 | categories = {'Control Flow': ['If', 'For', 'While', 'Try', "Do", "Switch"], 163 | 'Expressions': [ 'InfixExpr', 'Call', 'IfExp', 'Attribute', "InstanceofExpr"], 164 | 'Expr-Subscripting': ['Subscript'], 165 | 'Statements': ['Assign', "LabeledStatement"], 166 | 'Vague': ['children'] 167 | } 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | Skip_semicolon = False 178 | add_new_lines = False 179 | assert not (add_new_lines == True and Skip_semicolon == True) 180 | 181 | 182 | attn_file_name = "" 183 | if add_new_lines: 184 | attn_file_name = "_with_new_lines" 185 | output_file_name = "_with_new_lines" 186 | dataset_filename="CodeSyntax_java_with_new_lines.json" 187 | elif Skip_semicolon: 188 | output_file_name = "_skip_semicolon" 189 | dataset_filename="CodeSyntax_java_skip_semicolon.json" 190 | else: 191 | output_file_name = "" 192 | dataset_filename = "CodeSyntax_java.json" 193 | 194 | 195 | 196 | 197 | with open("../../CodeSyntax/"+dataset_filename, "r") as f: 198 | dataset = json.load(f) 199 | relns = get_relns(dataset) 200 | print("relations", relns) 201 | 202 | for metric in ["first", "any", "last"]: 203 | for partition in ["valid", "test"]: 204 | with open("data/attention/cubert_java_full_attn_sorted_"+partition+attn_file_name+"_common.pkl", "rb") as f: 205 | max_attn_data = pickle.load(f) 206 | 207 | print(partition, metric) 208 | scores = get_scores(max_attn_data, dataset, relns, max_k=20, metric=metric) 209 | reln_avg, cat_avg = get_avg(scores, relns, max_k=20) 210 | 211 | print_attn_table(1, relns, scores) 212 | 213 | with open("data/scores/cubert_java_full_topk_scores_"+partition+"_"+metric+output_file_name+"_common.pkl", "wb") as f: 214 | pickle.dump(scores, f) -------------------------------------------------------------------------------- /evaluating_models/PL/topk_scores_cubert_attention_python.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import numpy as np 4 | from collections import defaultdict 5 | import itertools 6 | 7 | # Code for evaluating individual attention maps and baselines 8 | # at word level 9 | # metric: "first", "any", "last" 10 | # Code for evaluating individual attention maps and baselines 11 | # at word level 12 | # metric: "first", "any", "last" 13 | def evaluate_predictor_topk(max_attn_data, dataset, reln, attn_layer, attn_head, max_k = 20, metric="first"): 14 | scores = np.zeros((max_k+1), dtype = 'float') 15 | n_correct = [0]*(max_k+1) 16 | n_incorrect = [0]*(max_k+1) 17 | for index, example in enumerate(max_attn_data): 18 | i = example['id'] 19 | if reln in dataset[i]["relns"]: 20 | for head_idx, dep_range_start_idx, dep_range_end_idx in dataset[i]["relns"][reln]: 21 | for k in range(1, max_k+1): 22 | k_th_prediction = example["max_attn"][attn_layer][attn_head][head_idx][k-1] 23 | if ((metric == "first" and k_th_prediction == dep_range_start_idx) or 24 | (metric == "last" and k_th_prediction == dep_range_end_idx) or 25 | (metric == "any" and k_th_prediction >= dep_range_start_idx and k_th_prediction <= dep_range_end_idx)): 26 | n_correct[k:] = [c+1 for c in n_correct[k:]] 27 | break 28 | else: 29 | n_incorrect[k] += 1 30 | for k in range(1, max_k+1): 31 | if (n_correct[k] + n_incorrect[k]) == 0: 32 | scores[k] = None 33 | else: 34 | scores[k] = n_correct[k] / float(n_correct[k] + n_incorrect[k]) 35 | return scores 36 | 37 | 38 | 39 | 40 | def get_relns(dataset): 41 | relns = set() 42 | for example in dataset: 43 | for reln in example["relns"].keys(): 44 | relns.add(reln) 45 | relns = list(relns) 46 | relns.sort() 47 | return relns 48 | 49 | # scores[reln][layer][head] 50 | def get_scores(max_attn_data, dataset, relns, max_k=20, metric="first"): 51 | scores = {} 52 | for reln in relns: 53 | print("processing relationship: ", reln) 54 | scores[reln] = np.zeros((24, 16, max_k+1), dtype = 'float') 55 | for layer in range(24): 56 | for head in range(16): 57 | # if head == 0: 58 | # print("layer: ", layer) 59 | scores[reln][layer][head] = evaluate_predictor_topk(max_attn_data, dataset, reln, layer, head, max_k, metric) 60 | return scores 61 | 62 | categories = {'Control Flow': ['If', 'For', 'While', 'Try'], 63 | 'Expressions': ['BinOp', 'BoolOp', 'Compare', 'Call', 'IfExp', 'Attribute'], 64 | 'Expr-Subscripting': ['Subscript'], 65 | 'Statements': ['Assign', 'AugAssign'], 66 | 'Vague': ['children'] 67 | } 68 | 69 | 70 | # average topk scores for each relationship and categories (word level) 71 | def get_avg(scores, relns, max_k=20): 72 | reln_avg = [None]*(max_k+1) 73 | cat_avg = {} 74 | for cat, cat_relns in categories.items(): 75 | cat_avg[cat] = [None]*(max_k+1) 76 | 77 | for k in range(1, (max_k+1)): 78 | sum, count = 0, 0 79 | for cat, cat_relns in categories.items(): 80 | cat_sum, cat_count = 0, 0 81 | for cat_reln in cat_relns: 82 | for reln in relns: 83 | if reln.startswith(cat_reln+":"): 84 | flatten_idx = np.argmax(scores[reln][:,:,k]) 85 | row = int(flatten_idx/16) 86 | col = flatten_idx % 16 87 | sum += scores[reln][row][col][k] 88 | count += 1 89 | cat_sum += scores[reln][row][col][k] 90 | cat_count += 1 91 | cat_avg[cat][k] = cat_sum/cat_count 92 | reln_avg[k] = sum/count 93 | return (reln_avg, cat_avg) 94 | 95 | def print_attn_table(k, relns, scores): 96 | print("relationship\t\t accuracy\tlayer\thead") 97 | sum, count = 0, 0 98 | table = "" 99 | table2 = "category\t\t average accuracy\n" 100 | for cat, cate_relns in categories.items(): 101 | table += "===================" + cat.ljust(20,"=") + "==========\n" 102 | cate_sum, cate_count = 0, 0 103 | for cate_reln in cate_relns: 104 | for reln in relns: 105 | if reln.startswith(cate_reln+":"): 106 | flatten_idx = np.argmax(scores[reln][:,:,k]) 107 | row = int(flatten_idx/16) 108 | col = flatten_idx % 16 109 | table += reln.ljust(30) + str(round(scores[reln][row][col][k],3)).ljust(5) + "\t" + str(row) + "\t" + str(col) + '\n' 110 | sum += scores[reln][row][col][k] 111 | count += 1 112 | cate_sum += scores[reln][row][col][k] 113 | cate_count += 1 114 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 115 | print(table) 116 | print(table2) 117 | print("average of",count,"relations:", sum/count) 118 | 119 | 120 | def print_baseline_table(k, relns, reln_scores_topk): 121 | print("relationship\t\t accuracy\toffset") 122 | sum, count = 0, 0 123 | table = "" 124 | table2 = "category\t\t average accuracy\n" 125 | for cat, cate_relns in categories.items(): 126 | table += "===================" + cat.ljust(20,"=") + "==========\n" 127 | cate_sum, cate_count = 0, 0 128 | for cate_reln in cate_relns: 129 | for reln in relns: 130 | if reln.startswith(cate_reln+":"): 131 | table += reln.ljust(30) + str(round(reln_scores_topk[reln][k][0],3)).ljust(5) + "\t" + str(reln_scores_topk[reln][k][1])[1:-1] + '\n' 132 | sum += reln_scores_topk[reln][k][0] 133 | count += 1 134 | cate_sum += reln_scores_topk[reln][k][0] 135 | cate_count += 1 136 | table2 += cat.ljust(20) + "\t\t"+str(round(cate_sum/cate_count,3)) + "\n" 137 | print(table) 138 | print(table2) 139 | print("average of",count,"relations:", sum/count) 140 | 141 | def print_attn_baseline_table(k, relns, attn_scores, baseline_reln_scores_topk): 142 | print("relationship\t\tattention\tbaseline\toffset") 143 | attn_sum, count, baseline_sum = 0, 0, 0 144 | table = "" 145 | table2 = "category\t\tattention\tbaseline\n" 146 | for cat, cate_relns in categories.items(): 147 | table += "=========================" + cat.ljust(20,"=") + "================\n" 148 | attn_cate_sum, cate_count, baseline_cate_sum = 0, 0, 0 149 | for cate_reln in cate_relns: 150 | for reln in relns: 151 | if reln.startswith(cate_reln+":"): 152 | flatten_idx = np.argmax(attn_scores[reln][:,:,k]) 153 | row = int(flatten_idx/16) 154 | col = flatten_idx % 16 155 | table += reln.ljust(30) + str(round(attn_scores[reln][row][col][k],3)).ljust(5) + "\t" + str(round(baseline_reln_scores_topk[reln][k][0],3)).ljust(5) + "\t\t" + str(baseline_reln_scores_topk[reln][k][1])[1:-1] + '\n' 156 | attn_sum += attn_scores[reln][row][col][k] 157 | baseline_sum += baseline_reln_scores_topk[reln][k][0] 158 | count += 1 159 | attn_cate_sum += attn_scores[reln][row][col][k] 160 | baseline_cate_sum += baseline_reln_scores_topk[reln][k][0] 161 | cate_count += 1 162 | table2 += cat.ljust(20) + "\t\t"+str(round(attn_cate_sum/cate_count,3)) + "\t"+str(round(baseline_cate_sum/cate_count,3)) + "\n" 163 | print(table) 164 | print(table2) 165 | print("attention average of",count,"relations:", attn_sum/count) 166 | print("baseline average of",count,"relations:", baseline_sum/count) 167 | 168 | 169 | 170 | 171 | 172 | add_new_lines = False 173 | 174 | 175 | 176 | dataset_file_name = "CodeSyntax_python" 177 | output_file_name = "" if not add_new_lines else "_with_new_lines" 178 | if add_new_lines: 179 | dataset_file_name += "_with_new_lines" 180 | with open("../../CodeSyntax/"+dataset_file_name+".json", 'r') as f: 181 | dataset = json.load(f) 182 | relns = get_relns(dataset) 183 | print("relations", relns) 184 | 185 | for metric in ["first", "any", "last"]: 186 | for partition in ["valid", "test"]: 187 | with open("data/attention/cubert_python_full_attn_sorted_"+partition+"_common.pkl", "rb") as f: 188 | max_attn_data = pickle.load(f) 189 | 190 | print(partition, metric) 191 | scores = get_scores(max_attn_data, dataset, relns, max_k=20, metric=metric) 192 | reln_avg, cat_avg = get_avg(scores, relns, max_k=20) 193 | 194 | print_attn_table(1, relns, scores) 195 | 196 | with open("data/scores/cubert_python_full_topk_scores_"+partition+"_"+metric+output_file_name+"_common.pkl", "wb") as f: 197 | pickle.dump(scores, f) -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/.classpath: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/.project: -------------------------------------------------------------------------------- 1 | 2 | 3 | generate_labels_java_word_level_any_token 4 | 5 | 6 | 7 | 8 | 9 | org.eclipse.jdt.core.javabuilder 10 | 11 | 12 | 13 | 14 | 15 | org.eclipse.jdt.core.javanature 16 | 17 | 18 | -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/.settings/org.eclipse.jdt.core.prefs: -------------------------------------------------------------------------------- 1 | eclipse.preferences.version=1 2 | org.eclipse.jdt.core.compiler.codegen.inlineJsrBytecode=enabled 3 | org.eclipse.jdt.core.compiler.codegen.targetPlatform=16 4 | org.eclipse.jdt.core.compiler.codegen.unusedLocal=preserve 5 | org.eclipse.jdt.core.compiler.compliance=16 6 | org.eclipse.jdt.core.compiler.debug.lineNumber=generate 7 | org.eclipse.jdt.core.compiler.debug.localVariable=generate 8 | org.eclipse.jdt.core.compiler.debug.sourceFile=generate 9 | org.eclipse.jdt.core.compiler.problem.assertIdentifier=error 10 | org.eclipse.jdt.core.compiler.problem.enablePreviewFeatures=disabled 11 | org.eclipse.jdt.core.compiler.problem.enumIdentifier=error 12 | org.eclipse.jdt.core.compiler.problem.reportPreviewFeatures=warning 13 | org.eclipse.jdt.core.compiler.release=enabled 14 | org.eclipse.jdt.core.compiler.source=16 15 | -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/bin/.gitignore: -------------------------------------------------------------------------------- 1 | /generate_labels_java/ 2 | -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.core.contenttype_3.7.1000.v20210409-1722.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.core.contenttype_3.7.1000.v20210409-1722.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.core.jobs_3.11.0.v20210420-1453.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.core.jobs_3.11.0.v20210420-1453.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.core.resources_3.15.0.v20210521-0722.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.core.resources_3.15.0.v20210521-0722.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.core.runtime_3.22.0.v20210506-1025.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.core.runtime_3.22.0.v20210506-1025.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.equinox.common_3.15.0.v20210518-0604.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.equinox.common_3.15.0.v20210518-0604.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.equinox.preferences_3.8.200.v20210212-1143.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.equinox.preferences_3.8.200.v20210212-1143.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.jdt.core_3.26.0.v20210609-0549.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.jdt.core_3.26.0.v20210609-0549.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.osgi_3.16.300.v20210525-1715.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.osgi_3.16.300.v20210525-1715.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/Java AST Parser/org.eclipse.text_3.12.0.v20210512-1644.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/Java AST Parser/org.eclipse.text_3.12.0.v20210512-1644.jar -------------------------------------------------------------------------------- /generating_CodeSyntax/deduplicated_java_code.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/deduplicated_java_code.pickle -------------------------------------------------------------------------------- /generating_CodeSyntax/deduplicated_python_code.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dashends/CodeSyntax/2d64fad934c3d5dbb859a8ef6764869b68036f18/generating_CodeSyntax/deduplicated_python_code.pickle -------------------------------------------------------------------------------- /generating_CodeSyntax/generate_labels_java.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | from collections import defaultdict 4 | import javalang 5 | import re 6 | 7 | def convert_line_offset_to_char_offset_java(line_offset, line_to_length): 8 | line, offset = line_offset 9 | offset -= 1 10 | if line > 1: 11 | offset += sum(line_to_length[:line-1]) + line-1 12 | return offset 13 | 14 | def convert_char_offset_to_line_offset_java(char_offset, line_to_length): 15 | line = 0 16 | while char_offset > line_to_length[line]: 17 | char_offset -= line_to_length[line]+1 18 | line += 1 19 | return javalang.tokenizer.Position(line+1, char_offset+1) 20 | 21 | def tokenize_java_code(code, keep_string_only=False, add_new_lines=False): 22 | tokens = [token for token in javalang.tokenizer.tokenize(code)] 23 | 24 | if add_new_lines: 25 | line_to_length = [len(l) for l in code.split("\n")] # stores the number of chars in each line of source code 26 | token_range = [(convert_line_offset_to_char_offset_java(token.position, line_to_length), 27 | convert_line_offset_to_char_offset_java((token.position[0], token.position[1]+len(token.value)), line_to_length), 28 | token.value) for token in tokens] 29 | 30 | i = 0 31 | while i < len(tokens) -1: 32 | code_slice = code[token_range[i][1]:token_range[i+1][0]] 33 | matches = re.finditer("\n", code_slice) 34 | new_line_index=[match.start()+token_range[i][1] for match in matches] 35 | tokens = tokens[0:i+1] + [javalang.tokenizer.JavaToken('\n', 36 | convert_char_offset_to_line_offset_java(idx, line_to_length)) 37 | for idx in new_line_index] + tokens[i+1:] 38 | token_range = token_range[0:i+1] + [(i, i+1, '\n') for i in new_line_index] + token_range[i+1:] 39 | i+=1 40 | 41 | if keep_string_only: 42 | tokens = [token.value for token in tokens] 43 | return tokens 44 | 45 | class Visit(): 46 | def __init__(self): 47 | self.result = defaultdict(lambda : []) # relation -> list of tuples [(head_idx, dep_idx), ...] 48 | # self.op_dict = {ast.And: "and", ast.Or: "or", ast.Not: "not", ast.Invert:"~", ast.UAdd: "+", ast.USub: "-"} 49 | # self.all_relns = set() 50 | 51 | def set_code_and_tokens(self, i): 52 | self.code = examples[i]["code"] 53 | self.tokens = examples[i]["tokens"] 54 | 55 | def clear_result(self): 56 | self.result = defaultdict(lambda : []) 57 | 58 | def parse_positions(self, label_and_positions): 59 | for label_and_position in label_and_positions: 60 | reln, head_start, head_end, dep_start, dep_end = label_and_position.split(" ") 61 | self.add_idx_tuple(reln, int(head_start), int(head_end), int(dep_start), int(dep_end)) 62 | 63 | def add_idx_tuple(self, reln, head_start, head_end, dep_start, dep_end, head_first_token=True, dep_first_token=True): 64 | # create the tuple (head_idx, dep_idx) 65 | head_idx = self.get_idx(head_start, head_end, True) 66 | dep_start_idx = self.get_idx(dep_start, dep_end, True) 67 | dep_end_idx = self.get_idx(dep_start, dep_end, False) 68 | if (head_idx != None and dep_start_idx!= None and dep_end_idx != None and 69 | head_idx != dep_start_idx and head_idx > 0 and dep_start_idx > 0 70 | and head_idx < dep_start_idx and dep_start_idx<= dep_end_idx): 71 | self.result[reln].append((head_idx, dep_start_idx, dep_end_idx)) 72 | 73 | 74 | def get_matching_tokens_count(self, tokens, code): 75 | # return the number of tokens that starts in the beginning and are the same 76 | tokens2 = tokenize_java_code(code, keep_string_only=True, add_new_lines=add_new_lines) 77 | 78 | for i in range(min(len(tokens), len(tokens2))): 79 | if tokens[i] != tokens2[i]: 80 | # the previous i tokens are matched 81 | return i 82 | return min(len(tokens), len(tokens2)) 83 | 84 | def get_idx(self, start_position, end_position, first_token=True): 85 | # get the index of the first/last token within the range between start_position and end_position 86 | # if first_token == true, get the index for the first token 87 | 88 | # get token count in previous segments 89 | code_prev = self.code[0:start_position] + "_EndOfLineSymbol_" 90 | 91 | 92 | # only account for macthing tokens and ignore extra ending strings 93 | count_prev = self.get_matching_tokens_count(self.tokens, code_prev) 94 | 95 | # get token count in current segment 96 | if first_token: 97 | count_curr = 1 98 | # exclude '(' and ')' from results 99 | while self.tokens[count_prev+count_curr-1] in skip_set[0]: 100 | count_curr += 1 101 | # if we reach the end of the block, don't add this label 102 | if self.tokens[count_prev+count_curr-1] in skip_set[1]: 103 | return None 104 | else: 105 | code_curr = self.code[start_position:end_position] 106 | count_curr = self.get_matching_tokens_count(self.tokens[count_prev:], code_curr) 107 | # add new lines 108 | if add_new_lines and count_prev+count_curr < len(self.tokens) and self.tokens[count_prev+count_curr] == '\n': 109 | count_curr += 1 110 | # exclude '(' and ')' from results 111 | while self.tokens[count_prev+count_curr-1] in skip_set[1]: 112 | count_curr -= 1 113 | if self.tokens[count_prev+count_curr-1] in skip_set[0]: 114 | return None 115 | return count_prev+count_curr-1 116 | 117 | 118 | def get_label(i, visitor, positions): 119 | visitor.set_code_and_tokens(i) 120 | visitor.parse_positions(positions) 121 | result = dict(visitor.result) 122 | visitor.clear_result() 123 | return result 124 | 125 | 126 | 127 | 128 | 129 | 130 | Skip_semicolon = False 131 | add_new_lines = True 132 | assert not (add_new_lines == True and Skip_semicolon == True) 133 | 134 | 135 | 136 | with open('deduplicated_java_code.pickle', 'rb') as f: 137 | sample = pickle.load(f) 138 | 139 | examples = [] 140 | print("tokenizing java code") 141 | for i in range(0, len(sample)): 142 | java_tokens = tokenize_java_code(sample[i], keep_string_only=True, add_new_lines=add_new_lines) 143 | examples.append({"tokens": java_tokens, "id": i, "code": sample[i], 'relns': {}}) 144 | 145 | 146 | 147 | print("generating labels") 148 | error_count = 0 149 | total_count = 0 150 | v = Visit() 151 | 152 | # get start and end positions for each label and node 153 | file = open('Java AST Parser/java_node_start_end_position.txt',mode='r') 154 | positions = file.read().split("\n") 155 | file.close() 156 | id_to_positions = {} 157 | current_id = 0 158 | cache = [] 159 | for line in positions: 160 | if line.isdigit(): 161 | id_to_positions[current_id] = cache 162 | cache = [] 163 | current_id = int(line) 164 | elif line != "": 165 | cache.append(line) 166 | id_to_positions[current_id] = cache 167 | 168 | 169 | # generate labels 170 | skip_set = (set(["(","[","{","\n"]), set([")","]","}", ";"])) if Skip_semicolon else (set(["(","[","{","\n"]), set([")","]","}"])) 171 | time_prev = None 172 | for i in range(len(examples)): 173 | if i%100 == 0: 174 | print("processing", i) 175 | total_count += 1 176 | examples[i]['relns'] = get_label(i, v, id_to_positions[i]) # list of tuple [(head_idx, dep_idx), ...] 177 | 178 | 179 | 180 | 181 | print("error_count", error_count) 182 | print("total_count", total_count) 183 | 184 | output_file_name = "CodeSyntax_java" 185 | if add_new_lines: 186 | output_file_name += "_with_new_lines" 187 | elif Skip_semicolon: 188 | output_file_name += "_skip_semicolon" 189 | with open("../CodeSyntax/"+output_file_name+".json", 'w') as f: 190 | json.dump(examples, f, indent=2) --------------------------------------------------------------------------------