├── .gitignore ├── LICENSE ├── README.md ├── collector ├── .gitignore ├── pom.xml └── src │ └── main │ └── java │ └── org │ └── tseval │ ├── ExtractToken.java │ ├── MaskMethodNameVisitor.java │ ├── MethodDataCollector.java │ ├── MethodDataCollectorVisitor.java │ ├── data │ ├── MethodData.java │ ├── ProjectData.java │ └── RevisionIds.java │ └── util │ ├── AbstractConfig.java │ ├── BashUtils.java │ ├── NLPUtils.java │ └── Option.java └── python ├── .gitignore ├── exps ├── cg.yaml └── mn.yaml ├── prepare_conda_env.sh ├── requirements.txt ├── run.sh ├── switch-cuda.sh └── tseval ├── Environment.py ├── Macros.py ├── Plot.py ├── Table.py ├── Utils.py ├── __init__.py ├── collector ├── DataCollector.py ├── ExperimentsAnalyzer.py ├── MetricsCollector.py └── __init__.py ├── comgen ├── __init__.py ├── eval │ ├── CGEvalHelper.py │ ├── CGModelLoader.py │ ├── StandardSetup.py │ └── __init__.py └── model │ ├── CGModelBase.py │ ├── DeepComHybridESE19.py │ ├── TransformerACL20.py │ └── __init__.py ├── data ├── MethodData.py ├── RevisionIds.py └── __init__.py ├── eval ├── EvalHelper.py ├── EvalMetrics.py ├── EvalSetupBase.py └── __init__.py ├── main.py ├── metnam ├── __init__.py ├── eval │ ├── MNEvalHelper.py │ ├── MNModelLoader.py │ ├── StandardSetup.py │ └── __init__.py └── model │ ├── Code2SeqICLR19.py │ ├── Code2VecPOPL19.py │ ├── MNModelBase.py │ └── __init__.py └── util ├── ModelUtils.py ├── TrainConfig.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | *~ 3 | \#*\# 4 | .DS_Store 5 | 6 | # Temp directories 7 | 8 | /_downloads/ 9 | /_results/ 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 EngineeringSoftware 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Impact of Evaluation Methodologies on Code Summarization 2 | 3 | This repo hosts the code and data for the following ACL 2022 paper: 4 | 5 | Title: [Impact of Evaluation Methodologies on Code Summarization][paper-utcs] 6 | 7 | Authors: [Pengyu Nie](https://pengyunie.github.io/), [Jiyang Zhang](https://jiyangzhang.github.io/), [Junyi Jessy Li](https://jessyli.com/), [Raymond J. Mooney](https://www.cs.utexas.edu/~mooney/), [Milos Gligoric](https://users.ece.utexas.edu/~gligoric/) 8 | 9 | ```bibtex 10 | @inproceedings{NieETAL22EvalMethodologies, 11 | title = {Impact of Evaluation Methodologies on Code Summarization}, 12 | author = {Pengyu Nie and Jiyang Zhang and Junyi Jessy Li and Raymond J. Mooney and Milos Gligoric}, 13 | pages = {to appear}, 14 | booktitle = {Annual Meeting of the Association for Computational Linguistics}, 15 | year = {2022}, 16 | } 17 | ``` 18 | 19 | ## Introduction 20 | 21 | This repo contains the code and data for producing the experiments in 22 | [Impact of Evaluation Methodologies on Code 23 | Summarization][paper-utcs]. In this work, we study the impact of 24 | evaluation methodologies, i.e., the way people split datasets into 25 | training, validation, and test sets, in the field of code 26 | summarization. We introduce the time-segmented evaluation methodology, 27 | which is novel to the code summarization research community, and 28 | compare it with the mixed-project and cross-project methodologies that 29 | have been commonly used. 30 | 31 | The code includes: 32 | * a data collection tool for collecting (method, comment) pairs with 33 | timestamps. 34 | * a data processing pipeline for splitting a dataset following the 35 | three evaluation methodologies. 36 | * scripts for running four recent machine learning models for code 37 | summarization and comparing their results across methodologies. 38 | 39 | **How to...** 40 | * **reproduce the training and evaluation of ML models on our 41 | collected dataset**: [install dependency][sec-dependency], [download 42 | all data][sec-downloads], and follow the instructions 43 | [here][sec-traineval]. 44 | * **reproduce our full study from scratch**: [install 45 | dependency][sec-dependency], [download `_work/src`][sec-downloads] 46 | (the source code for the ML models used in our study), and follow 47 | the instructions to [collect data][sec-collect], [process 48 | data][sec-process], and [train and evaluate models][sec-traineval]. 49 | 50 | 51 | ## Table of Contents 52 | 53 | 1. [Dependency][sec-dependency] 54 | 2. [Data Downloads][sec-downloads] 55 | 3. [Code for Collecting Data][sec-collect] 56 | 4. [Code for Processing Data][sec-process] 57 | 5. [Code for Training and Evaluating Models][sec-traineval] 58 | 59 | ## Dependency 60 | [sec-dependency]: #dependency 61 | 62 | Our code require the following hardware and software environments. 63 | 64 | * Operating system: Linux (tested on Ubuntu 20.04) 65 | * Minimum disk space: 4 GB 66 | * Python: 3.8 67 | * Java: 8 68 | * Maven: 3.6.3 69 | * Anaconda/Miniconda: appropriate versions for Python 3.8 or higher 70 | 71 | Additional requirements for training and evaluating ML models: 72 | 73 | * GPU: NVIDIA GTX 1080 or better 74 | * CUDA: 10.0 ~ 11.0 75 | * Disk space: 2 GB per trained model 76 | 77 | [Anaconda](https://www.anaconda.com/products/individual#Downloads) or 78 | [Miniconda](https://docs.conda.io/en/latest/miniconda.html) is 79 | required for installing the other Python library dependencies. Once 80 | Anaconda/Miniconda is installed, you can use the following command to 81 | setup a virtual environment, named `tseval`, with the Python library 82 | dependencies installed: 83 | 84 | ``` 85 | cd python/ 86 | ./prepare_conda_env.sh 87 | ``` 88 | 89 | And then use `conda activate tseval` to activate the created virtual 90 | environment. 91 | 92 | The Java code `collector` will automatically be compiled as needed in 93 | our Python code. The Java library dependencies are automatically 94 | downloaded, by the Maven build system, during this process. 95 | 96 | 97 | ## Data Downloads 98 | [sec-downloads]: #data-downloads 99 | 100 | All our data is hosted on UTBox via [a shared folder](https://utexas.box.com/s/32qq85ttp9js0qqnv68ebhvr19ajo7v9). 101 | 102 | Data should be downloaded to this directory with the same directory 103 | structure (e.g., `_work/src` from the shared folder should be 104 | downloaded as `_work/src` under current directory). 105 | 106 | 107 | ## Code for Collecting Data 108 | [sec-collect]: #code-for-collecting-data 109 | 110 | ### Collect the list of popular Java projects on GitHub 111 | 112 | ``` 113 | python -m tseval.main collect_repos 114 | python -m tseval.main filter_repos 115 | ``` 116 | 117 | Results are generated to `results/repos/`: 118 | 119 | * `github-java-repos.json` is the full list of projects returned by 120 | the GitHub API. 121 | 122 | * `filtered-repos.json` is the list of projects filtered according to 123 | the conditions in our paper. 124 | 125 | * `*-logs.json` documents the time, configurations, and metrics of the 126 | collection/filtering of the list. 127 | 128 | Note that the list of projects may already differ from the list of 129 | projects we used, because old projects may be removed, and the 130 | ordering of projects may change. 131 | 132 | ### Collect raw dataset 133 | 134 | Requires the list of projects at `results/repos/filtered-repos.json` 135 | 136 | ``` 137 | python -m tseval.main collect_raw_data 138 | ``` 139 | 140 | Results are generated to `_raw_data/`. Each project's raw data 141 | is in a directory named `$user_$repo` (e.g., `apache_commons-codec`): 142 | 143 | * `method-data.json` is the list of method samples (includes code, API 144 | comments, etc.) extracted from the project at the selected revisions 145 | (at Jan 1st of 2018, 2019, 2020, 2021). 146 | 147 | * `revision-ids.json` is the mapping from revision to the method 148 | samples that are available at that revision. 149 | 150 | * `filtered-counters.json` is the count of samples discarded during 151 | collection according to our paper. 152 | 153 | * `log.txt` is the log of the collection. 154 | 155 | ## Code for Processing Data 156 | [sec-process]: #code-for-processing-data 157 | 158 | 159 | ### Process raw data to use our data structure (tseval.data.MethodData) 160 | 161 | Requires the raw data at `_raw_data/`. 162 | 163 | ``` 164 | python -m tseval.main process_raw_data 165 | ``` 166 | 167 | Results are generated to `_work/shared/`: 168 | 169 | * `*.jsonl` files are the dataset, where each file stores one field of 170 | all samples, and each line stores the field for one sample. 171 | 172 | * `filtered-counters.json` is the combined count of samples discarded 173 | during collection. 174 | 175 | 176 | ### Apply methodologies (non task-specific part) 177 | 178 | Requires the dataset at `_work/shared/`. 179 | 180 | ``` 181 | python -m tseval.main get_splits --seed=7 --split=Full 182 | ``` 183 | 184 | Results are generated to `_work/split/Full/`: 185 | 186 | * `$X-$Y.json`, where X in {MP, CP, T} and Y in {train, val, test_standard}; and 187 | `$X1-$X2-test_common.json`, where X1, X2 in {MP, CP, T}. 188 | 189 | - each file contains a list of ids. 190 | 191 | - MP = mixed-project; CP = cross-project; T = temporally. 192 | 193 | - train = training; val = validation; 194 | test_standard = standard test; test_common = common test. 195 | 196 | ### Apply methodologies (task-specific part) 197 | 198 | Requires the dataset at `_work/shared/` and the splits at 199 | `_work/split/Full/`. 200 | 201 | From this point on, we define two variables to use in our commands: 202 | 203 | * `$task`: to indicate the targeted code summarization task. 204 | - CG: comment generation. 205 | - MN: method naming. 206 | 207 | * `$method`: to indicate the methodology used. 208 | - MP: mixed-project. 209 | - CP: cross-project. 210 | - T: temporally. 211 | 212 | ``` 213 | python -m tseval.main exp_prepare \ 214 | --task=$task \ 215 | --setup=StandardSetup \ 216 | --setup_name=$method \ 217 | --split_name=Full \ 218 | --split_type=$method 219 | # Example: python -m tseval.main exp_prepare \ 220 | # --task=CG \ 221 | # --setup=StandardSetup \ 222 | # --setup_name=T \ 223 | # --split_name=Full \ 224 | # --split_type=T 225 | ``` 226 | 227 | Results are generated to `_work/$task/setup/$method/`: 228 | 229 | * `data/` contains the dataset (jsonl files) and splits (ids in 230 | Train/Val/TestT/TestC sets). 231 | 232 | * `setup_config.json` documents the configurations of this 233 | methodology. 234 | 235 | ## Code for Training and Evaluating Models 236 | [sec-traineval]: #code-for-training-and-evaluating-models 237 | 238 | ### Prepare the Python environments for ML models 239 | 240 | Requires Anaconda/Miniconda, and the models' source code at `_work/src/`. 241 | 242 | ``` 243 | python -m tseval.main prepare_envs --which=$model_cls 244 | # Example: python -m tseval.main prepare_envs --which=TransformerACL20 245 | ``` 246 | 247 | Where the `$model_cls` for each model can be looked up in this table 248 | (Transformer and Seq2Seq are using the same model class and 249 | environment): 250 | 251 | | $task | $model_cls | Model | 252 | |:------|:-------------------|:--------------| 253 | | CG | DeepComHybridESE19 | DeepComHybrid | 254 | | CG | TransformerACL20 | Transformer | 255 | | CG | TransformerACL20 | Seq2Seq | 256 | | MN | Code2VecPOPL19 | Code2Vec | 257 | | MN | Code2SeqICLR19 | Code2Seq | 258 | 259 | The name of the conda environment created is `tseval-$task-$model_cls`. 260 | 261 | ### Train ML models under a methodology 262 | 263 | Requires the dataset at `_work/$task/setup/$method/`, and 264 | activating the right conda environment 265 | (`conda activate tseval-$task-$model_cls`). 266 | 267 | ``` 268 | python -m tseval.main exp_train \ 269 | --task=$task \ 270 | --setup_name=$method \ 271 | --model_name=$model_cls \ 272 | --exp_name=$exp_name \ 273 | --seed=$seed \ 274 | $model_args 275 | # Example: python -m tseval.main exp_train \ 276 | # --task=CG \ 277 | # --setup_name=T \ 278 | # --model_name=TransformerACL20 \ 279 | # --exp_name=Transformer \ 280 | # --seed=4182 281 | ``` 282 | 283 | Where `$exp_name` is the name of the output directory; `$seed` is the 284 | random seed (integer) to control the random process in the experiments 285 | (the `--seed=$seed` argument can be omitted for a random run using the 286 | current timestamp as seed); `$model_args` is potential additional 287 | arguments for the model and can be looked up in the following table: 288 | 289 | | $task | Model | $model_args | 290 | |:------|:--------------|:---------------| 291 | | CG | DeepComHybrid | (empty) | 292 | | CG | Transformer | (empty) | 293 | | CG | Seq2Seq | --use_rnn=True | 294 | | MN | Code2Vec | (empty) | 295 | | MN | Code2Seq | (empty) | 296 | 297 | Results are generated to `_work/$task/exp/$method/$exp_name/`: 298 | 299 | * `model/` the trained model. 300 | 301 | * Other files documents the configurations for initializing and 302 | training the model. 303 | 304 | ### Evaluate ML models 305 | 306 | Requires the dataset at `_work/$task/setup/$method/`, the trained 307 | model at `_work/$task/exp/$method/$exp_name/`, and activating the 308 | right conda environment (`conda activate tseval-$task-$model_cls`). 309 | 310 | ``` 311 | for $action in val test_standard test_common; do 312 | python -m tseval.main exp_eval \ 313 | --task=$task \ 314 | --setup_name=$method \ 315 | --exp_name=$exp_name \ 316 | --action=$action 317 | done 318 | # Example: for $action in val test_standard test_common; do 319 | # python -m tseval.main exp_eval \ 320 | # --task=CG \ 321 | # --setup_name=T \ 322 | # --exp_name=Transformer \ 323 | # --action=$action 324 | #done 325 | ``` 326 | 327 | Results are generated to `_work/$task/result/$method/$exp_name/`: 328 | 329 | * `$X_predictions.jsonl`: the predictions. 330 | 331 | * `$X_golds.jsonl`: the golds (ground truths). 332 | 333 | * `$X_eval_time.jsonl`: the time taken for the evaluation. 334 | 335 | * Where $X in {val, test_standard, test_common-$method-$method1 (where $method1 != $method)}. 336 | 337 | 338 | ### Compute automatic metrics 339 | 340 | Requires the evaluation results at 341 | `_work/$task/result/$method/$exp_name/`, and the use of `tseval` 342 | environment (`conda activate tseval`). 343 | 344 | ``` 345 | for $action in val test_standard test_common; do 346 | python -m tseval.main exp_compute_metrics \ 347 | --task=$task \ 348 | --setup_name=$method \ 349 | --exp_name=$exp_name \ 350 | --action=$action 351 | done 352 | # Example: for $action in val test_standard test_common; do 353 | # python -m tseval.main exp_compute_metrics \ 354 | # --task=CG \ 355 | # --setup_name=T \ 356 | # --exp_name=Transformer \ 357 | # --action=$action 358 | #done 359 | ``` 360 | 361 | Results are generated to `_work/$task/metric/$method/$exp_name/`: 362 | 363 | * `$X_metrics.json` and `$X_metrics.json`: the average of automatic metrics. 364 | 365 | * `$X_metrics_list.pkl`: the (compressed) list of automatic metrics per sample. 366 | 367 | * Where $X in {val, test_standard, test_common-$method-$method1 (where $method1 != $method)}. 368 | 369 | 370 | 371 | [paper-arxiv]: https://arxiv.org/abs/2108.09619 372 | [paper-utcs]: https://www.cs.utexas.edu/users/ai-lab/downloadPublication.php?filename=http://www.cs.utexas.edu/users/ml/papers/nie.acl2022.pdf&pubid=127948 373 | -------------------------------------------------------------------------------- /collector/.gitignore: -------------------------------------------------------------------------------- 1 | # Additional IntelliJ 2 | *.iml 3 | .idea/ 4 | 5 | 6 | ### IntelliJ 7 | 8 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 9 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 10 | 11 | # User-specific stuff 12 | .idea/**/workspace.xml 13 | .idea/**/tasks.xml 14 | .idea/**/usage.statistics.xml 15 | .idea/**/dictionaries 16 | .idea/**/shelf 17 | 18 | # Generated files 19 | .idea/**/contentModel.xml 20 | 21 | # Sensitive or high-churn files 22 | .idea/**/dataSources/ 23 | .idea/**/dataSources.ids 24 | .idea/**/dataSources.local.xml 25 | .idea/**/sqlDataSources.xml 26 | .idea/**/dynamic.xml 27 | .idea/**/uiDesigner.xml 28 | .idea/**/dbnavigator.xml 29 | 30 | # Gradle 31 | .idea/**/gradle.xml 32 | .idea/**/libraries 33 | 34 | # Gradle and Maven with auto-import 35 | # When using Gradle or Maven with auto-import, you should exclude module files, 36 | # since they will be recreated, and may cause churn. Uncomment if using 37 | # auto-import. 38 | # .idea/modules.xml 39 | # .idea/*.iml 40 | # .idea/modules 41 | # *.iml 42 | # *.ipr 43 | 44 | # CMake 45 | cmake-build-*/ 46 | 47 | # Mongo Explorer plugin 48 | .idea/**/mongoSettings.xml 49 | 50 | # File-based project format 51 | *.iws 52 | 53 | # IntelliJ 54 | out/ 55 | 56 | # mpeltonen/sbt-idea plugin 57 | .idea_modules/ 58 | 59 | # JIRA plugin 60 | atlassian-ide-plugin.xml 61 | 62 | # Cursive Clojure plugin 63 | .idea/replstate.xml 64 | 65 | # Crashlytics plugin (for Android Studio and IntelliJ) 66 | com_crashlytics_export_strings.xml 67 | crashlytics.properties 68 | crashlytics-build.properties 69 | fabric.properties 70 | 71 | # Editor-based Rest Client 72 | .idea/httpRequests 73 | 74 | # Android studio 3.1+ serialized cache file 75 | .idea/caches/build_file_checksums.ser 76 | 77 | 78 | ### Java 79 | 80 | # Compiled class file 81 | *.class 82 | 83 | # Log file 84 | *.log 85 | 86 | # BlueJ files 87 | *.ctxt 88 | 89 | # Mobile Tools for Java (J2ME) 90 | .mtj.tmp/ 91 | 92 | # Package Files # 93 | *.jar 94 | *.war 95 | *.nar 96 | *.ear 97 | *.zip 98 | *.tar.gz 99 | *.rar 100 | 101 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 102 | hs_err_pid* 103 | 104 | 105 | ### Maven 106 | 107 | target/ 108 | pom.xml.tag 109 | pom.xml.releaseBackup 110 | pom.xml.versionsBackup 111 | pom.xml.next 112 | release.properties 113 | dependency-reduced-pom.xml 114 | buildNumber.properties 115 | .mvn/timing.properties 116 | # https://github.com/takari/maven-wrapper#usage-without-binary-jar 117 | .mvn/wrapper/maven-wrapper.jar 118 | -------------------------------------------------------------------------------- /collector/pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | org.tseval 5 | collector 6 | jar 7 | 0.1-dev 8 | collector 9 | http://maven.apache.org 10 | 11 | 12 | 1.8 13 | 1.8 14 | 15 | 16 | 17 | 18 | 19 | org.apache.commons 20 | commons-text 21 | 1.8 22 | 23 | 24 | 25 | com.google.code.gson 26 | gson 27 | 2.9.0 28 | 29 | 30 | 31 | com.github.javaparser 32 | javaparser-symbol-solver-core 33 | 3.15.12 34 | 35 | 36 | 37 | edu.stanford.nlp 38 | stanford-corenlp 39 | 4.4.0 40 | 41 | 42 | 43 | junit 44 | junit 45 | 4.12 46 | test 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | org.apache.maven.plugins 56 | maven-shade-plugin 57 | 3.2.1 58 | 59 | 60 | package 61 | 62 | shade 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/ExtractToken.java: -------------------------------------------------------------------------------- 1 | package org.tseval; 2 | 3 | import com.github.javaparser.ParseProblemException; 4 | import com.github.javaparser.StaticJavaParser; 5 | import com.github.javaparser.ast.Node; 6 | import com.github.javaparser.ast.body.MethodDeclaration; 7 | import com.github.javaparser.ast.type.ClassOrInterfaceType; 8 | import com.github.javaparser.printer.PrettyPrinterConfiguration; 9 | import com.google.gson.Gson; 10 | import com.google.gson.GsonBuilder; 11 | import com.google.gson.JsonArray; 12 | import com.google.gson.JsonDeserializer; 13 | import com.google.gson.JsonObject; 14 | import com.google.gson.JsonParseException; 15 | import com.google.gson.JsonSerializer; 16 | import com.google.gson.reflect.TypeToken; 17 | 18 | import java.io.BufferedReader; 19 | import java.io.BufferedWriter; 20 | import java.io.IOException; 21 | import java.lang.reflect.InvocationTargetException; 22 | import java.lang.reflect.Method; 23 | import java.nio.file.Files; 24 | import java.nio.file.Path; 25 | import java.nio.file.Paths; 26 | import java.util.Arrays; 27 | import java.util.LinkedList; 28 | import java.util.List; 29 | 30 | /** 31 | * Utility class for extracting tokens from methods. 32 | * Used for tokenization and calculating metrics. 33 | */ 34 | public class ExtractToken { 35 | 36 | /** 37 | * Input data format class for {@link ExtractToken}. 38 | */ 39 | public static class InputIndexAndCode { 40 | public int index; 41 | public String code; 42 | 43 | public static final JsonDeserializer sDeserializer = getDeserializer(); 44 | public static JsonDeserializer getDeserializer() { 45 | return (json, type, context) -> { 46 | try { 47 | InputIndexAndCode obj = new InputIndexAndCode(); 48 | JsonObject jObj = json.getAsJsonObject(); 49 | obj.index = jObj.get("index").getAsInt(); 50 | obj.code = jObj.get("code").getAsString(); 51 | return obj; 52 | } catch (IllegalStateException e) { 53 | throw new JsonParseException(e); 54 | } 55 | }; 56 | } 57 | } 58 | 59 | public static class OutputData { 60 | public int index; 61 | public List tokens; 62 | 63 | public static final JsonSerializer sSerializer = getSerializer(); 64 | public static JsonSerializer getSerializer() { 65 | return (d, type, jsonSerializationContext) -> { 66 | JsonObject jObj = new JsonObject(); 67 | jObj.addProperty("index", d.index); 68 | JsonArray jTokens = new JsonArray(); 69 | for (String token: d.tokens) { 70 | jTokens.add(token); 71 | } 72 | jObj.add("tokens", jTokens); 73 | return jObj; 74 | }; 75 | } 76 | } 77 | 78 | // Gson: For json (de)serialization 79 | private static final Gson GSON = new GsonBuilder() 80 | .disableHtmlEscaping() 81 | .registerTypeAdapter(InputIndexAndCode.class, InputIndexAndCode.sDeserializer) 82 | .registerTypeAdapter(OutputData.class, OutputData.sSerializer) 83 | .create(); 84 | 85 | // For JavaParser pretty-printing AST to code without comments 86 | private static final PrettyPrinterConfiguration METHOD_PPRINT_CONFIG = new PrettyPrinterConfiguration(); 87 | static { 88 | METHOD_PPRINT_CONFIG 89 | .setPrintJavadoc(false) 90 | .setPrintComments(false); 91 | } 92 | 93 | private static Node parseWhatever(String code) { 94 | // First try to parse as Compilation Unit 95 | try { 96 | return StaticJavaParser.parse(code); 97 | } catch (ParseProblemException ignored) { 98 | } 99 | 100 | // Then, try several types 101 | List> possibleTypes = Arrays.asList( 102 | MethodDeclaration.class, 103 | ClassOrInterfaceType.class 104 | ); 105 | for (Class t : possibleTypes) { 106 | try { 107 | Method parseMethod = StaticJavaParser.class.getMethod("parse" + t.getSimpleName(), String.class); 108 | Node n = (Node) parseMethod.invoke(null, code); 109 | return n; 110 | } catch (NoSuchMethodException | IllegalAccessException e) { 111 | throw new RuntimeException(e); 112 | } catch (InvocationTargetException ignored) { 113 | } 114 | } 115 | 116 | // If all fails, return null 117 | return null; 118 | } 119 | 120 | 121 | /** 122 | * Main entry point. 123 | * 124 | * @param args expect exactly two arguments: the input file path, the output file path. 125 | */ 126 | public static void main(String... args) { 127 | // Load arguments 128 | if (args.length != 2) { 129 | throw new RuntimeException("Args: inputPath outputPath"); 130 | } 131 | Path inputPath = Paths.get(args[0]); 132 | Path outputPath = Paths.get(args[1]); 133 | 134 | // Load inputs 135 | List inputList; 136 | try (BufferedReader r = Files.newBufferedReader(inputPath)) { 137 | inputList = GSON.fromJson( 138 | r, 139 | TypeToken.getParameterized( 140 | TypeToken.get(List.class).getType(), 141 | TypeToken.get(InputIndexAndCode.class).getType() 142 | ).getType() 143 | ); 144 | } catch (IOException e) { 145 | throw new RuntimeException(e); 146 | } 147 | 148 | // Get tokens 149 | List outputList = new LinkedList<>(); 150 | for (InputIndexAndCode input: inputList) { 151 | OutputData output = new OutputData(); 152 | Node parsed = parseWhatever(input.code); 153 | 154 | if (parsed == null) { 155 | // Return empty list, representing the code is non-parsable 156 | output.index = input.index; 157 | output.tokens = new LinkedList<>(); 158 | } else { 159 | List tokens = new LinkedList<>(); 160 | parsed.getTokenRange().get() 161 | .forEach(t -> { 162 | if (!t.getCategory().isWhitespaceOrComment()) { 163 | tokens.add(t.getText()); 164 | } 165 | }); 166 | output.index = input.index; 167 | output.tokens = tokens; 168 | } 169 | 170 | outputList.add(output); 171 | } 172 | 173 | // Save outputs 174 | try (BufferedWriter w = Files.newBufferedWriter(outputPath)) { 175 | w.write(GSON.toJson( 176 | outputList, 177 | TypeToken.getParameterized( 178 | TypeToken.get(List.class).getType(), 179 | TypeToken.get(OutputData.class).getType() 180 | ).getType() 181 | )); 182 | } catch (IOException e) { 183 | throw new RuntimeException(e); 184 | } 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/MaskMethodNameVisitor.java: -------------------------------------------------------------------------------- 1 | package org.tseval; 2 | 3 | import com.github.javaparser.ast.body.MethodDeclaration; 4 | import com.github.javaparser.ast.expr.MethodCallExpr; 5 | import com.github.javaparser.ast.expr.SimpleName; 6 | import com.github.javaparser.ast.visitor.ModifierVisitor; 7 | import com.github.javaparser.ast.visitor.Visitable; 8 | 9 | public class MaskMethodNameVisitor extends ModifierVisitor { 10 | 11 | public static final MaskMethodNameVisitor sVisitor = new MaskMethodNameVisitor(); 12 | 13 | private static final SimpleName MASK = new SimpleName(""); 14 | 15 | public static class Context { 16 | public String name; 17 | public Context(String name) { 18 | this.name = name; 19 | } 20 | } 21 | 22 | @Override 23 | public Visitable visit(MethodDeclaration n, Context arg) { 24 | MethodDeclaration ret = (MethodDeclaration) super.visit(n, arg); 25 | 26 | if (n.getNameAsString().equals(arg.name)) { 27 | ret.setName(MASK.clone()); 28 | } 29 | 30 | return ret; 31 | } 32 | 33 | @Override 34 | public Visitable visit(MethodCallExpr n, Context arg) { 35 | MethodCallExpr ret = (MethodCallExpr) super.visit(n, arg); 36 | 37 | if (n.getNameAsString().equals(arg.name)) { 38 | ret.setName(MASK.clone()); 39 | } 40 | 41 | return ret; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/MethodDataCollector.java: -------------------------------------------------------------------------------- 1 | package org.tseval; 2 | 3 | import com.github.javaparser.ParseProblemException; 4 | import com.github.javaparser.StaticJavaParser; 5 | import com.github.javaparser.ast.CompilationUnit; 6 | import com.google.gson.Gson; 7 | import com.google.gson.GsonBuilder; 8 | import com.google.gson.stream.JsonWriter; 9 | import org.tseval.MethodDataCollectorVisitor.FilteredReason; 10 | import org.tseval.data.MethodData; 11 | import org.tseval.data.RevisionIds; 12 | import org.tseval.util.BashUtils; 13 | import org.tseval.util.AbstractConfig; 14 | import org.tseval.util.Option; 15 | 16 | import java.io.BufferedWriter; 17 | import java.io.IOException; 18 | import java.nio.file.Files; 19 | import java.nio.file.Path; 20 | import java.nio.file.Paths; 21 | import java.nio.file.StandardOpenOption; 22 | import java.util.Arrays; 23 | import java.util.Comparator; 24 | import java.util.HashMap; 25 | import java.util.LinkedList; 26 | import java.util.List; 27 | import java.util.Map; 28 | import java.util.Objects; 29 | import java.util.stream.Collectors; 30 | 31 | public class MethodDataCollector { 32 | 33 | public static class Config extends AbstractConfig { 34 | 35 | // A list of shas, separated by space 36 | @Option public String shas; 37 | @Option public Path projectDir; 38 | @Option public Path outputDir; 39 | @Option public Path logFile; 40 | 41 | public boolean repOk() { 42 | if (projectDir == null || outputDir == null || logFile == null) { 43 | return false; 44 | } 45 | 46 | if (shas == null || shas.length() == 0) { 47 | return false; 48 | } 49 | 50 | return true; 51 | } 52 | } 53 | public static Config sConfig; 54 | 55 | public static final Gson GSON; 56 | public static final Gson GSON_NO_PPRINT; 57 | static { 58 | GsonBuilder gsonBuilder = new GsonBuilder() 59 | .disableHtmlEscaping() 60 | .serializeNulls() 61 | .registerTypeAdapter(MethodData.class, MethodData.sSerDeser) 62 | .registerTypeAdapter(RevisionIds.class, RevisionIds.sSerializer); 63 | GSON_NO_PPRINT = gsonBuilder.create(); 64 | gsonBuilder.setPrettyPrinting(); 65 | GSON = gsonBuilder.create(); 66 | } 67 | 68 | private static Map sMethodDataIdHashMap = new HashMap<>(); 69 | private static int sCurrentMethodDataId = 0; 70 | private static Map> sFileCache = new HashMap<>(); 71 | private static Map sFilteredCounters = MethodDataCollectorVisitor.initFilteredCounters(); 72 | 73 | private static JsonWriter sMethodDataWriter; 74 | private static JsonWriter sMethodProjectRevisionWriter; 75 | 76 | public static void main(String... args) { 77 | if (args.length != 1) { 78 | System.err.println("Exactly one argument, the path to the json config, is required"); 79 | System.exit(-1); 80 | } 81 | 82 | sConfig = AbstractConfig.load(Paths.get(args[0]), Config.class); 83 | collect(); 84 | } 85 | 86 | public static void collect() { 87 | try { 88 | // Init the writers for saving 89 | sMethodDataWriter = GSON.newJsonWriter(Files.newBufferedWriter(sConfig.outputDir.resolve("method-data.json"))); 90 | sMethodDataWriter.beginArray(); 91 | sMethodProjectRevisionWriter = GSON_NO_PPRINT.newJsonWriter(Files.newBufferedWriter(sConfig.outputDir.resolve("revision-ids.json"))); 92 | sMethodProjectRevisionWriter.beginArray(); 93 | 94 | // Collect for each sha (chronological order) 95 | for (String sha: sConfig.shas.split(" ")) { 96 | log("Sha " + sha); 97 | collectSha(sha); 98 | } 99 | 100 | // Save filtered counters 101 | JsonWriter filteredCountersWriter = GSON.newJsonWriter(Files.newBufferedWriter(sConfig.outputDir.resolve("filtered-counters.json"))); 102 | filteredCountersWriter.beginObject(); 103 | for (FilteredReason fr : FilteredReason.values()) { 104 | filteredCountersWriter.name(fr.getKey()); 105 | filteredCountersWriter.value(sFilteredCounters.get(fr)); 106 | } 107 | filteredCountersWriter.endObject(); 108 | filteredCountersWriter.close(); 109 | 110 | // Close writers 111 | sMethodDataWriter.endArray(); 112 | sMethodDataWriter.close(); 113 | sMethodProjectRevisionWriter.endArray(); 114 | sMethodProjectRevisionWriter.close(); 115 | } catch (IOException e) { 116 | throw new RuntimeException(e); 117 | } 118 | } 119 | 120 | public static void log(String msg) { 121 | if (sConfig.logFile != null) { 122 | try (BufferedWriter fw = Files.newBufferedWriter(sConfig.logFile, StandardOpenOption.APPEND, StandardOpenOption.CREATE)) { 123 | fw.write("[" + Thread.currentThread().getId() + "]" + msg + "\n"); 124 | // (new Throwable()).printStackTrace(fos); 125 | } catch (IOException e) { 126 | System.err.println("Couldn't log to " + sConfig.logFile); 127 | System.exit(-1); 128 | } 129 | } 130 | } 131 | 132 | private static void collectSha(String sha) throws IOException { 133 | // Check out the sha 134 | BashUtils.run("cd " + sConfig.projectDir + " && git checkout -f " + sha, 0); 135 | 136 | // Find all java files 137 | List javaFiles = Files.walk(sConfig.projectDir) 138 | .filter(Files::isRegularFile) 139 | .filter(p -> p.toString().endsWith(".java")) 140 | .sorted(Comparator.comparing(Object::toString)) 141 | .collect(Collectors.toList()); 142 | log("In revision " + sha +", got " + javaFiles.size() + " files to parse"); 143 | 144 | // For each java file, parse and get methods 145 | MethodDataCollectorVisitor visitor = new MethodDataCollectorVisitor(); 146 | List idsRevision = new LinkedList<>(); 147 | int parseErrorCount = 0; 148 | int filteredCount = 0; 149 | int reuseFileCount = 0; 150 | int parseFileCount = 0; 151 | for (Path javaFile : javaFiles) { 152 | // Skip parsing identical files, just add the ids 153 | int fileHash = getFileHash(javaFile); 154 | List idsFile = sFileCache.get(fileHash); 155 | 156 | if (idsFile == null) { 157 | // Actually parse this file and collect ids 158 | idsFile = new LinkedList<>(); 159 | String path = sConfig.projectDir.relativize(javaFile).toString(); 160 | 161 | MethodDataCollectorVisitor.Context context = new MethodDataCollectorVisitor.Context(); 162 | try { 163 | CompilationUnit cu = StaticJavaParser.parse(javaFile); 164 | cu.accept(visitor, context); 165 | } catch (ParseProblemException e) { 166 | ++parseErrorCount; 167 | continue; 168 | } 169 | 170 | for (FilteredReason fr : FilteredReason.values()) { 171 | sFilteredCounters.compute(fr, (k, v) -> v + context.filteredCounters.get(k)); 172 | filteredCount += context.filteredCounters.get(fr); 173 | } 174 | 175 | for (MethodData methodData : context.methodDataList) { 176 | // Reuse (for duplicate data) or allocate the data id 177 | methodData.path = path; 178 | int methodId = addMethodData(methodData); 179 | idsFile.add(methodId); 180 | } 181 | 182 | // Update file cache 183 | sFileCache.put(fileHash, idsFile); 184 | ++parseFileCount; 185 | } else { 186 | ++reuseFileCount; 187 | } 188 | 189 | idsRevision.addAll(idsFile); 190 | } 191 | 192 | // Create and save MethodProjectRevision 193 | RevisionIds revisionIds = new RevisionIds(); 194 | revisionIds.revision = sha; 195 | revisionIds.methodIds = idsRevision; 196 | addRevisionIds(revisionIds); 197 | 198 | log("Parsed " + parseFileCount + " files. " + 199 | "Reused " + reuseFileCount + " files. " + 200 | "Parsing error for " + parseErrorCount + " files. " + 201 | "Ignored " + filteredCount + " methods. " + 202 | "Total collected " + sMethodDataIdHashMap.size() + " methods."); 203 | } 204 | 205 | 206 | private static int getFileHash(Path javaFile) throws IOException { 207 | // Hash both the path and the content 208 | return Objects.hash(javaFile.toString(), Arrays.hashCode(Files.readAllBytes(javaFile))); 209 | } 210 | 211 | private static int addMethodData(MethodData methodData) { 212 | // Don't duplicate previous appeared methods (keys: path, code, comment) 213 | int hash = Objects.hash(methodData.path, methodData.code, methodData.comment); 214 | Integer prevMethodDataId = sMethodDataIdHashMap.get(hash); 215 | if (prevMethodDataId != null) { 216 | // If this method org.csevo.data already existed before, retrieve its id 217 | return prevMethodDataId; 218 | } else { 219 | // Allocate a new id and save this org.csevo.data to the hash map 220 | methodData.id = sCurrentMethodDataId; 221 | ++sCurrentMethodDataId; 222 | sMethodDataIdHashMap.put(hash, methodData.id); 223 | 224 | // Save the method org.csevo.data 225 | GSON.toJson(methodData, MethodData.class, sMethodDataWriter); 226 | return methodData.id; 227 | } 228 | } 229 | 230 | private static void addRevisionIds(RevisionIds revisionIds) { 231 | // Directly write to file 232 | GSON_NO_PPRINT.toJson(revisionIds, RevisionIds.class, sMethodProjectRevisionWriter); 233 | } 234 | } 235 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/MethodDataCollectorVisitor.java: -------------------------------------------------------------------------------- 1 | package org.tseval; 2 | 3 | import com.github.javaparser.ast.PackageDeclaration; 4 | import com.github.javaparser.ast.body.AnnotationDeclaration; 5 | import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration; 6 | import com.github.javaparser.ast.body.EnumDeclaration; 7 | import com.github.javaparser.ast.body.MethodDeclaration; 8 | import com.github.javaparser.ast.body.Parameter; 9 | import com.github.javaparser.ast.body.TypeDeclaration; 10 | import com.github.javaparser.ast.visitor.VoidVisitorAdapter; 11 | import com.github.javaparser.javadoc.Javadoc; 12 | import com.github.javaparser.javadoc.JavadocBlockTag; 13 | import com.github.javaparser.javadoc.description.JavadocDescription; 14 | import com.github.javaparser.javadoc.description.JavadocDescriptionElement; 15 | import com.github.javaparser.javadoc.description.JavadocInlineTag; 16 | import com.github.javaparser.printer.PrettyPrinterConfiguration; 17 | import org.apache.commons.lang3.tuple.Pair; 18 | import org.tseval.data.MethodData; 19 | import org.tseval.util.NLPUtils; 20 | 21 | import java.util.HashMap; 22 | import java.util.LinkedList; 23 | import java.util.List; 24 | import java.util.Map; 25 | 26 | public class MethodDataCollectorVisitor extends VoidVisitorAdapter { 27 | 28 | private static final int METHOD_LENGTH_MAX = 10_000; 29 | 30 | public static class Context { 31 | String packageName = ""; 32 | String className = ""; 33 | List methodDataList = new LinkedList<>(); 34 | Map filteredCounters = initFilteredCounters(); 35 | } 36 | 37 | public enum FilteredReason { 38 | CODE_TOO_LONG("code_too_long"), 39 | CODE_NON_ENGLISH("code_non_english"), 40 | COMMENT_NON_ENGLISH("comment_non_english"), 41 | EMPTY_BODY("empty_body"), 42 | EMPTY_COMMENT_SUMMARY("empty_comment_summary"), 43 | EMPTY_COMMENT("empty_comment"); 44 | 45 | private final String key; 46 | FilteredReason(String key) { 47 | this.key = key; 48 | } 49 | 50 | public String getKey() { 51 | return key; 52 | } 53 | } 54 | 55 | public static Map initFilteredCounters() { 56 | Map filteredCounters = new HashMap<>(); 57 | for (FilteredReason fr : FilteredReason.values()) { 58 | filteredCounters.put(fr, 0); 59 | } 60 | return filteredCounters; 61 | } 62 | 63 | private static PrettyPrinterConfiguration METHOD_PPRINT_CONFIG = new PrettyPrinterConfiguration(); 64 | static { 65 | METHOD_PPRINT_CONFIG 66 | .setPrintJavadoc(false) 67 | .setPrintComments(false); 68 | } 69 | 70 | @Override 71 | public void visit(ClassOrInterfaceDeclaration n, Context context) { 72 | commonVisitTypeDeclaration(n, context); 73 | super.visit(n, context); 74 | } 75 | 76 | @Override 77 | public void visit(AnnotationDeclaration n, Context context) { 78 | commonVisitTypeDeclaration(n, context); 79 | super.visit(n, context); 80 | } 81 | 82 | @Override 83 | public void visit(EnumDeclaration n, Context context) { 84 | commonVisitTypeDeclaration(n, context); 85 | super.visit(n, context); 86 | } 87 | 88 | public void commonVisitTypeDeclaration(TypeDeclaration n, Context context) { 89 | // Update context class name 90 | context.className = n.getNameAsString(); 91 | } 92 | 93 | @Override 94 | public void visit(PackageDeclaration n, Context context) { 95 | // Update context package name 96 | context.packageName = n.getNameAsString(); 97 | super.visit(n, context); 98 | } 99 | 100 | private String dollaryClassName(TypeDeclaration n) { 101 | if (n.isNestedType()) { 102 | return dollaryClassName((TypeDeclaration) n.getParentNode().get()) + "$" + n.getNameAsString(); 103 | } 104 | return n.getNameAsString(); 105 | } 106 | 107 | @Override 108 | public void visit(MethodDeclaration n, Context context) { 109 | MethodData methodData = new MethodData(); 110 | 111 | if (!n.getBody().isPresent() || n.getBody().get().isEmpty()) { 112 | // Ignore if no/empty method body 113 | context.filteredCounters.compute(FilteredReason.EMPTY_BODY, (k, v) -> v+1); 114 | return; 115 | } 116 | 117 | methodData.name = n.getNameAsString(); 118 | 119 | if (n.getJavadoc().isPresent()) { 120 | Javadoc javadoc = n.getJavadoc().get(); 121 | methodData.comment = javadoc.toText(); 122 | 123 | if (!NLPUtils.isValidISOLatin(methodData.comment)) { 124 | // Ignore if comment is not English 125 | context.filteredCounters.compute(FilteredReason.COMMENT_NON_ENGLISH, (k, v) -> v+1); 126 | return; 127 | } 128 | 129 | methodData.commentSummary = NLPUtils.getFirstSentence(javadocDescToTextNoInlineTags(javadoc.getDescription())).orElse(null); 130 | 131 | if (methodData.commentSummary == null) { 132 | // Ignore if the comment summary is empty 133 | context.filteredCounters.compute(FilteredReason.EMPTY_COMMENT_SUMMARY, (k, v) -> v+1); 134 | return; 135 | } 136 | } else { 137 | // Ignore if comment is empty 138 | context.filteredCounters.compute(FilteredReason.EMPTY_COMMENT, (k, v) -> v+1); 139 | return; 140 | } 141 | 142 | methodData.code = n.toString(METHOD_PPRINT_CONFIG); 143 | 144 | // Ignore if method is too long 145 | if (methodData.code.length() > METHOD_LENGTH_MAX) { 146 | context.filteredCounters.compute(FilteredReason.CODE_TOO_LONG, (k, v) -> v+1); 147 | return; 148 | } 149 | 150 | // Ignore if method is not English 151 | if (!NLPUtils.isValidISOLatin(methodData.code)) { 152 | context.filteredCounters.compute(FilteredReason.CODE_NON_ENGLISH, (k, v) -> v+1); 153 | return; 154 | } 155 | 156 | MethodDeclaration maskedMD = n.clone(); 157 | maskedMD.accept(MaskMethodNameVisitor.sVisitor, new MaskMethodNameVisitor.Context(methodData.name)); 158 | methodData.codeMasked = maskedMD.toString(METHOD_PPRINT_CONFIG); 159 | 160 | try { 161 | methodData.cname = dollaryClassName((TypeDeclaration) n.getParentNode().get()); 162 | } 163 | catch (Exception e) { 164 | methodData.cname = context.className; 165 | } 166 | 167 | if (!context.packageName.isEmpty()) { 168 | methodData.qcname = context.packageName + "." + methodData.cname; 169 | } else { 170 | methodData.qcname = methodData.cname; 171 | } 172 | 173 | methodData.ret = n.getType().asString(); 174 | for (Parameter param : n.getParameters()) { 175 | methodData.params.add(Pair.of(param.getType().asString(), param.getNameAsString())); 176 | } 177 | 178 | context.methodDataList.add(methodData); 179 | 180 | // Note: no recursive visiting the body of this method 181 | } 182 | 183 | static String javadocDescToTextNoInlineTags(JavadocDescription desc) { 184 | StringBuilder sb = new StringBuilder(); 185 | for (JavadocDescriptionElement e : desc.getElements()) { 186 | if (e instanceof JavadocInlineTag) { 187 | sb.append(((JavadocInlineTag) e).getContent()); 188 | } else { 189 | sb.append(e.toText()); 190 | } 191 | } 192 | return sb.toString(); 193 | } 194 | } 195 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/data/MethodData.java: -------------------------------------------------------------------------------- 1 | package org.tseval.data; 2 | 3 | import com.google.gson.JsonArray; 4 | import com.google.gson.JsonDeserializationContext; 5 | import com.google.gson.JsonDeserializer; 6 | import com.google.gson.JsonElement; 7 | import com.google.gson.JsonObject; 8 | import com.google.gson.JsonParseException; 9 | import com.google.gson.JsonSerializationContext; 10 | import com.google.gson.JsonSerializer; 11 | import org.apache.commons.lang3.tuple.Pair; 12 | 13 | import java.lang.reflect.Type; 14 | import java.util.ArrayList; 15 | import java.util.List; 16 | 17 | public class MethodData { 18 | 19 | public int id; 20 | public String prj; 21 | 22 | public String name; 23 | public String code; 24 | public String codeMasked; 25 | public String comment; 26 | public String commentSummary; 27 | public String cname; 28 | public String qcname; 29 | public String path; 30 | public String ret; 31 | public List> params = new ArrayList<>(); 32 | 33 | // Serialization 34 | public static JsonSerializer sSerializer = getSerializer(); 35 | 36 | public static JsonSerializer getSerializer() { 37 | return (obj, type, jsonSerializationContext) -> { 38 | JsonObject jObj = new JsonObject(); 39 | 40 | jObj.addProperty("id", obj.id); 41 | jObj.addProperty("prj", obj.prj); 42 | jObj.addProperty("name", obj.name); 43 | jObj.addProperty("code", obj.code); 44 | jObj.addProperty("code_masked", obj.codeMasked); 45 | jObj.addProperty("comment", obj.comment); 46 | jObj.addProperty("comment_summary", obj.commentSummary); 47 | jObj.addProperty("cname", obj.cname); 48 | jObj.addProperty("qcname", obj.qcname); 49 | jObj.addProperty("path", obj.path); 50 | jObj.addProperty("ret", obj.ret); 51 | JsonArray aParams = new JsonArray(); 52 | for (Pair param : obj.params) { 53 | JsonArray aParam = new JsonArray(); 54 | aParam.add(param.getLeft()); 55 | aParam.add(param.getRight()); 56 | aParams.add(aParam); 57 | } 58 | jObj.add("params", aParams); 59 | 60 | return jObj; 61 | }; 62 | } 63 | 64 | // Deserialization 65 | public static final JsonDeserializer sDeserializer = getDeserializer(); 66 | 67 | public static JsonDeserializer getDeserializer() { 68 | return (json, type, context) -> { 69 | try { 70 | MethodData obj = new MethodData(); 71 | 72 | JsonObject jObj = json.getAsJsonObject(); 73 | obj.id = jObj.get("id").getAsInt(); 74 | obj.prj = jObj.get("prj").getAsString(); 75 | obj.name = jObj.get("name").getAsString(); 76 | obj.code = jObj.get("code").getAsString(); 77 | obj.codeMasked = jObj.get("code_masked").getAsString(); 78 | obj.comment = jObj.get("comment").getAsString(); 79 | obj.commentSummary = jObj.get("comment_summary").getAsString(); 80 | obj.cname = jObj.get("cname").getAsString(); 81 | obj.qcname = jObj.get("qcname").getAsString(); 82 | obj.path = jObj.get("path").getAsString(); 83 | obj.ret = jObj.get("ret").getAsString(); 84 | JsonArray aParams = jObj.getAsJsonArray("params"); 85 | for (JsonElement aParamElem : aParams) { 86 | JsonArray aParam = aParamElem.getAsJsonArray(); 87 | obj.params.add(Pair.of(aParam.get(0).getAsString(), aParam.get(1).getAsString())); 88 | } 89 | 90 | return obj; 91 | } catch (IllegalStateException e) { 92 | throw new JsonParseException(e); 93 | } 94 | }; 95 | } 96 | 97 | // Ser+Deser 98 | public static SerDeser sSerDeser = new SerDeser(); 99 | private static class SerDeser implements JsonSerializer, JsonDeserializer { 100 | @Override 101 | public MethodData deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException { 102 | return sDeserializer.deserialize(json, typeOfT, context); 103 | } 104 | 105 | @Override 106 | public JsonElement serialize(MethodData src, Type typeOfSrc, JsonSerializationContext context) { 107 | return sSerializer.serialize(src, typeOfSrc, context); 108 | } 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/data/ProjectData.java: -------------------------------------------------------------------------------- 1 | package org.tseval.data; 2 | 3 | import com.google.gson.JsonDeserializer; 4 | import com.google.gson.JsonElement; 5 | import com.google.gson.JsonObject; 6 | import com.google.gson.JsonParseException; 7 | 8 | import java.util.HashMap; 9 | import java.util.LinkedList; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | public class ProjectData { 14 | 15 | public String name; 16 | public String url; 17 | public List revisions = new LinkedList<>(); 18 | public Map> parentRevisions = new HashMap<>(); 19 | public Map yearRevisions = new HashMap<>(); 20 | 21 | // Deserialization 22 | public static final JsonDeserializer sDeserializer = getDeserializer(); 23 | 24 | public static JsonDeserializer getDeserializer() { 25 | return (json, type, context) -> { 26 | try { 27 | ProjectData obj = new ProjectData(); 28 | 29 | JsonObject jObj = json.getAsJsonObject(); 30 | obj.name = jObj.get("name").getAsString(); 31 | obj.url = jObj.get("url").getAsString(); 32 | 33 | // revisions 34 | for (JsonElement eRevision : jObj.get("revisions").getAsJsonArray()) { 35 | obj.revisions.add(eRevision.getAsString()); 36 | } 37 | 38 | // parent revisions 39 | JsonObject jObjParentRevisions = jObj.get("parent_revisions").getAsJsonObject(); 40 | for (Map.Entry entry : jObjParentRevisions.entrySet()) { 41 | List parentRevisions = new LinkedList<>(); 42 | for (JsonElement eParent : entry.getValue().getAsJsonArray()) { 43 | parentRevisions.add(eParent.getAsString()); 44 | } 45 | obj.parentRevisions.put(entry.getKey(), parentRevisions); 46 | } 47 | 48 | // year revisions 49 | JsonObject jObjYearRevisions = jObj.get("year_revisions").getAsJsonObject(); 50 | for (Map.Entry entry : jObjYearRevisions.entrySet()) { 51 | obj.yearRevisions.put(entry.getKey(), entry.getValue().getAsString()); 52 | } 53 | 54 | 55 | return obj; 56 | } catch (IllegalStateException e) { 57 | throw new JsonParseException(e); 58 | } 59 | }; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/data/RevisionIds.java: -------------------------------------------------------------------------------- 1 | package org.tseval.data; 2 | 3 | import com.google.gson.JsonArray; 4 | import com.google.gson.JsonObject; 5 | import com.google.gson.JsonSerializer; 6 | 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | 10 | public class RevisionIds { 11 | 12 | public String revision; 13 | public List methodIds = new LinkedList<>(); 14 | 15 | // Serialization 16 | public static JsonSerializer sSerializer = getSerializer(); 17 | 18 | public static JsonSerializer getSerializer() { 19 | return (obj, type, jsonSerializationContext) -> { 20 | JsonObject jObj = new JsonObject(); 21 | 22 | jObj.addProperty("revision", obj.revision); 23 | 24 | JsonArray aMethodIds = new JsonArray(); 25 | for (int methodId : obj.methodIds) { 26 | aMethodIds.add(methodId); 27 | } 28 | jObj.add("method_ids", aMethodIds); 29 | 30 | return jObj; 31 | }; 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/util/AbstractConfig.java: -------------------------------------------------------------------------------- 1 | package org.tseval.util; 2 | 3 | import com.google.gson.JsonElement; 4 | import com.google.gson.JsonParser; 5 | import org.apache.commons.lang3.tuple.Pair; 6 | 7 | import java.io.FileNotFoundException; 8 | import java.io.FileReader; 9 | import java.io.Reader; 10 | import java.lang.reflect.Field; 11 | import java.nio.file.Path; 12 | import java.nio.file.Paths; 13 | import java.util.Collections; 14 | import java.util.HashMap; 15 | import java.util.Map; 16 | import java.util.function.Function; 17 | 18 | public abstract class AbstractConfig { 19 | 20 | public boolean repOk() { 21 | return true; 22 | } 23 | 24 | public void autoInfer() {} 25 | 26 | private static final Map, Function> SUPPORTED_CLASSES_2_READERS; 27 | static { 28 | Map, Function> supportedClasses2Readers = new HashMap<>(); 29 | supportedClasses2Readers.put(Boolean.TYPE, JsonElement::getAsBoolean); 30 | supportedClasses2Readers.put(Boolean.class, JsonElement::getAsBoolean); 31 | supportedClasses2Readers.put(Integer.TYPE, JsonElement::getAsInt); 32 | supportedClasses2Readers.put(Integer.class, JsonElement::getAsInt); 33 | supportedClasses2Readers.put(Double.TYPE, JsonElement::getAsDouble); 34 | supportedClasses2Readers.put(Double.class, JsonElement::getAsDouble); 35 | supportedClasses2Readers.put(String.class, JsonElement::getAsString); 36 | supportedClasses2Readers.put(Path.class, e -> Paths.get(e.getAsString())); 37 | SUPPORTED_CLASSES_2_READERS = Collections.unmodifiableMap(supportedClasses2Readers); 38 | } 39 | 40 | public static T load(Path configPath, Class clz) { 41 | try { 42 | return load(new FileReader(configPath.toFile()), clz); 43 | } catch (FileNotFoundException e) { 44 | throw new RuntimeException("Could not load config " + configPath, e); 45 | } 46 | } 47 | 48 | /** 49 | * Loads the config of type {@code T} from the reader, which should provide a json dict. 50 | * @param reader the reader that provides the config 51 | * @param clz T.class 52 | * @param the type of the config 53 | * @return the loaded config 54 | */ 55 | public static T load(Reader reader, Class clz) { 56 | try { 57 | T config = clz.newInstance(); 58 | 59 | Map, Field>> options = new HashMap<>(); 60 | for (Field f : clz.getFields()) { 61 | if (f.getAnnotation(Option.class) != null) { 62 | Class fType = f.getType(); 63 | if (!SUPPORTED_CLASSES_2_READERS.containsKey(fType)) { 64 | throw new RuntimeException("Unsupported option type " + fType + ", for field " + f.getName() + " in class " + clz); 65 | } 66 | options.put(f.getName(), Pair.of(fType, f)); 67 | } 68 | } 69 | 70 | JsonElement configJson = JsonParser.parseReader(reader); 71 | if (configJson.isJsonObject()) { 72 | for (Map.Entry entry : configJson.getAsJsonObject().entrySet()) { 73 | Pair, Field> p = options.get(entry.getKey()); 74 | if (p != null) { 75 | p.getRight().set(config, SUPPORTED_CLASSES_2_READERS.get(p.getLeft()).apply(entry.getValue())); 76 | } else { 77 | throw new RuntimeException("Unknown config key " + entry.getKey()); 78 | } 79 | } 80 | } else { 81 | throw new RuntimeException("Input should be a json dict"); 82 | } 83 | 84 | config.autoInfer(); 85 | if (!config.repOk()) { 86 | throw new RuntimeException("Config invalid"); 87 | } 88 | 89 | return config; 90 | } catch (RuntimeException e) { 91 | throw new RuntimeException("Malformed config", e); 92 | } catch (IllegalAccessException | InstantiationException e) { 93 | throw new RuntimeException("The config class " + clz + " is not correctly set up", e); 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/util/BashUtils.java: -------------------------------------------------------------------------------- 1 | package org.tseval.util; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.IOException; 5 | import java.io.InputStreamReader; 6 | import java.nio.file.Path; 7 | import java.nio.file.Paths; 8 | import java.util.stream.Collectors; 9 | 10 | public class BashUtils { 11 | 12 | public static class RunResult { 13 | public final int exitCode; 14 | public final String stdout; 15 | public final String stderr; 16 | 17 | public RunResult(int exitCode, String stdout, String stderr) { 18 | this.exitCode = exitCode; 19 | this.stdout = stdout; 20 | this.stderr = stderr; 21 | } 22 | } 23 | 24 | public static RunResult run(String cmd) { 25 | return run(cmd, null); 26 | } 27 | 28 | public static RunResult run(String cmd, Integer expectedReturnCode) { 29 | try { 30 | Runtime rt = Runtime.getRuntime(); 31 | String[] commands = {"/bin/bash", "-c", cmd}; 32 | Process proc = rt.exec(commands); 33 | proc.waitFor(); 34 | 35 | String stdout = new BufferedReader(new InputStreamReader(proc.getInputStream())).lines().collect(Collectors.joining("\n")); 36 | String stderr = new BufferedReader(new InputStreamReader(proc.getErrorStream())).lines().collect(Collectors.joining("\n")); 37 | int exitCode = proc.exitValue(); 38 | 39 | // Check expected return code 40 | if (expectedReturnCode != null && exitCode != expectedReturnCode) { 41 | // TODO: implement print limit for stdout & stderr, dump to file if they ex 42 | throw new RuntimeException("Expected " + expectedReturnCode + " but returned " + exitCode + " while executing bash command '" + cmd + "'.\n" + 43 | "stdout: " + stdout + "\n" + 44 | "stderr: " + stderr); 45 | } 46 | return new RunResult(exitCode, stdout, stderr); 47 | } catch (IOException | InterruptedException e) { 48 | throw new RuntimeException(e); 49 | } 50 | } 51 | 52 | public static Path getTempDir() { 53 | return Paths.get(run("mktemp -d", 0).stdout.trim()); 54 | } 55 | 56 | public static Path getTempFile() { 57 | return Paths.get(run("mktemp", 0).stdout.trim()); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/util/NLPUtils.java: -------------------------------------------------------------------------------- 1 | package org.tseval.util; 2 | 3 | import edu.stanford.nlp.simple.Document; 4 | 5 | import java.nio.charset.Charset; 6 | import java.util.Optional; 7 | 8 | public class NLPUtils { 9 | 10 | public static Optional getFirstSentence(String str) { 11 | if (str.trim().isEmpty()) { 12 | return Optional.empty(); 13 | } 14 | 15 | try { 16 | Document document = new Document(str); 17 | return Optional.of(document.sentence(0).toString()); 18 | } catch(Exception e) { 19 | System.err.println("Cannot get first sentence of: " + str); 20 | return Optional.empty(); 21 | } 22 | } 23 | 24 | public static boolean isValidISOLatin(String s) { 25 | return Charset.forName("US-ASCII").newEncoder().canEncode(s); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /collector/src/main/java/org/tseval/util/Option.java: -------------------------------------------------------------------------------- 1 | package org.tseval.util; 2 | 3 | import java.lang.annotation.ElementType; 4 | import java.lang.annotation.Retention; 5 | import java.lang.annotation.RetentionPolicy; 6 | import java.lang.annotation.Target; 7 | 8 | @Retention(RetentionPolicy.RUNTIME) 9 | @Target(ElementType.FIELD) 10 | public @interface Option { 11 | } 12 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | # Log & debug files 2 | /ml-logs/ 3 | /debug/ 4 | 5 | 6 | # Additional Pycharm 7 | *.iml 8 | .idea/ 9 | 10 | 11 | ### Pycharm 12 | 13 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 14 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 15 | 16 | # User-specific stuff 17 | .idea/**/workspace.xml 18 | .idea/**/tasks.xml 19 | .idea/**/usage.statistics.xml 20 | .idea/**/dictionaries 21 | .idea/**/shelf 22 | 23 | # Generated files 24 | .idea/**/contentModel.xml 25 | 26 | # Sensitive or high-churn files 27 | .idea/**/dataSources/ 28 | .idea/**/dataSources.ids 29 | .idea/**/dataSources.local.xml 30 | .idea/**/sqlDataSources.xml 31 | .idea/**/dynamic.xml 32 | .idea/**/uiDesigner.xml 33 | .idea/**/dbnavigator.xml 34 | 35 | # Gradle 36 | .idea/**/gradle.xml 37 | .idea/**/libraries 38 | 39 | # Gradle and Maven with auto-import 40 | # When using Gradle or Maven with auto-import, you should exclude module files, 41 | # since they will be recreated, and may cause churn. Uncomment if using 42 | # auto-import. 43 | # .idea/modules.xml 44 | # .idea/*.iml 45 | # .idea/modules 46 | # *.iml 47 | # *.ipr 48 | 49 | # CMake 50 | cmake-build-*/ 51 | 52 | # Mongo Explorer plugin 53 | .idea/**/mongoSettings.xml 54 | 55 | # File-based project format 56 | *.iws 57 | 58 | # IntelliJ 59 | out/ 60 | 61 | # mpeltonen/sbt-idea plugin 62 | .idea_modules/ 63 | 64 | # JIRA plugin 65 | atlassian-ide-plugin.xml 66 | 67 | # Cursive Clojure plugin 68 | .idea/replstate.xml 69 | 70 | # Crashlytics plugin (for Android Studio and IntelliJ) 71 | com_crashlytics_export_strings.xml 72 | crashlytics.properties 73 | crashlytics-build.properties 74 | fabric.properties 75 | 76 | # Editor-based Rest Client 77 | .idea/httpRequests 78 | 79 | # Android studio 3.1+ serialized cache file 80 | .idea/caches/build_file_checksums.ser 81 | 82 | 83 | ### Python 84 | 85 | # Byte-compiled / optimized / DLL files 86 | __pycache__/ 87 | *.py[cod] 88 | *$py.class 89 | 90 | # C extensions 91 | *.so 92 | 93 | # Distribution / packaging 94 | .Python 95 | build/ 96 | develop-eggs/ 97 | dist/ 98 | downloads/ 99 | eggs/ 100 | .eggs/ 101 | lib/ 102 | lib64/ 103 | parts/ 104 | sdist/ 105 | var/ 106 | wheels/ 107 | pip-wheel-metadata/ 108 | share/python-wheels/ 109 | *.egg-info/ 110 | .installed.cfg 111 | *.egg 112 | MANIFEST 113 | 114 | # PyInstaller 115 | # Usually these files are written by a python script from a template 116 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 117 | *.manifest 118 | *.spec 119 | 120 | # Installer logs 121 | pip-log.txt 122 | pip-delete-this-directory.txt 123 | 124 | # Unit test / coverage reports 125 | htmlcov/ 126 | .tox/ 127 | .nox/ 128 | .coverage 129 | .coverage.* 130 | .cache 131 | nosetests.xml 132 | coverage.xml 133 | *.cover 134 | *.py,cover 135 | .hypothesis/ 136 | .pytest_cache/ 137 | 138 | # Translations 139 | *.mo 140 | *.pot 141 | 142 | # Django stuff: 143 | *.log 144 | local_settings.py 145 | db.sqlite3 146 | db.sqlite3-journal 147 | 148 | # Flask stuff: 149 | instance/ 150 | .webassets-cache 151 | 152 | # Scrapy stuff: 153 | .scrapy 154 | 155 | # Sphinx documentation 156 | docs/_build/ 157 | 158 | # PyBuilder 159 | target/ 160 | 161 | # Jupyter Notebook 162 | .ipynb_checkpoints 163 | 164 | # IPython 165 | profile_default/ 166 | ipython_config.py 167 | 168 | # pyenv 169 | .python-version 170 | 171 | # pipenv 172 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 173 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 174 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 175 | # install all needed dependencies. 176 | #Pipfile.lock 177 | 178 | # celery beat schedule file 179 | celerybeat-schedule 180 | 181 | # SageMath parsed files 182 | *.sage.py 183 | 184 | # Environments 185 | .env 186 | .venv 187 | env/ 188 | venv/ 189 | ENV/ 190 | env.bak/ 191 | venv.bak/ 192 | 193 | # Spyder project settings 194 | .spyderproject 195 | .spyproject 196 | 197 | # Rope project settings 198 | .ropeproject 199 | 200 | # mkdocs documentation 201 | /site 202 | 203 | # mypy 204 | .mypy_cache/ 205 | .dmypy.json 206 | dmypy.json 207 | 208 | # Pyre type checker 209 | .pyre/ 210 | -------------------------------------------------------------------------------- /python/exps/cg.yaml: -------------------------------------------------------------------------------- 1 | task: CG 2 | setups: 3 | - MP 4 | - CP 5 | - T 6 | num_trials: 3 7 | metrics: 8 | - exact_match 9 | - bleu 10 | - meteor 11 | - rouge_l_f 12 | - rouge_l_p 13 | - rouge_l_r 14 | models: 15 | - name: DeepComHybridESE19 16 | exps: 17 | - DeepComHybridESE19-1 18 | - DeepComHybridESE19-2 19 | - DeepComHybridESE19-3 20 | - name: RNNBaseline 21 | exps: 22 | - RNNBaseline-1 23 | - RNNBaseline-2 24 | - RNNBaseline-3 25 | - name: TransformerACL20 26 | exps: 27 | - TransformerACL20-1 28 | - TransformerACL20-2 29 | - TransformerACL20-3 30 | table_args: 31 | metrics: 32 | - bleu 33 | - meteor 34 | - rouge_l_f 35 | - exact_match 36 | plot_args: 37 | metrics: 38 | bleu: BLEU 39 | meteor: METEOR 40 | rouge_l_f: ROUGE-L 41 | exact_match: EM 42 | metrics_percent: 43 | bleu: True 44 | meteor: True 45 | rouge_l_f: True 46 | exact_match: True 47 | models: 48 | DeepComHybridESE19: DeepComHybrid 49 | RNNBaseline: Seq2Seq 50 | TransformerACL20: Transformer 51 | -------------------------------------------------------------------------------- /python/exps/mn.yaml: -------------------------------------------------------------------------------- 1 | task: MN 2 | setups: 3 | - MP 4 | - CP 5 | - T 6 | num_trials: 3 7 | metrics: 8 | - exact_match 9 | - set_match_f 10 | - set_match_p 11 | - set_match_r 12 | models: 13 | - name: Code2VecPOPL19 14 | exps: 15 | - Code2VecPOPL19-1 16 | - Code2VecPOPL19-2 17 | - Code2VecPOPL19-3 18 | - name: Code2SeqICLR19 19 | exps: 20 | - Code2SeqICLR19-1 21 | - Code2SeqICLR19-2 22 | - Code2SeqICLR19-3 23 | table_args: 24 | metrics: 25 | - set_match_p 26 | - set_match_r 27 | - set_match_f 28 | - exact_match 29 | plot_args: 30 | metrics: 31 | set_match_p: Precision 32 | set_match_r: Recall 33 | set_match_f: F1 34 | exact_match: EM 35 | metrics_percent: 36 | set_match_p: True 37 | set_match_r: True 38 | set_match_f: True 39 | exact_match: True 40 | models: 41 | Code2VecPOPL19: Code2Vec 42 | Code2SeqICLR19: Code2Seq 43 | -------------------------------------------------------------------------------- /python/prepare_conda_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | _DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) 4 | 5 | # ========== Prepare Conda Environments 6 | 7 | DEFAULT_CONDA_PATH="$HOME/opt/anaconda3/etc/profile.d/conda.sh" 8 | PYTORCH_VERSION=1.7.1 9 | TORCHVISION_VERSION=0.8.2 10 | 11 | function get_conda_path() { 12 | local conda_exe=$(which conda) 13 | if [[ -z ${conda_exe} ]]; then 14 | echo "Fail to detect conda! Have you installed Anaconda/Miniconda?" 1>&2 15 | exit 1 16 | fi 17 | 18 | echo "$(dirname ${conda_exe})/../etc/profile.d/conda.sh" 19 | } 20 | 21 | function get_cuda_version() { 22 | local nvidia_smi_exe=$(which nvidia-smi) 23 | if [[ -z ${nvidia_smi_exe} ]]; then 24 | echo "cpu" 25 | else 26 | local cuda_version_number="$(nvcc -V | grep "release" | sed -E "s/.*release ([^,]+),.*/\1/")" 27 | case $cuda_version_number in 28 | 10.0*) 29 | echo "cu100";; 30 | 10.1*) 31 | echo "cu101";; 32 | 10.2*) 33 | echo "cu102";; 34 | 11.0*) 35 | echo "cu110";; 36 | *) 37 | echo "Unsupported cuda version $cuda_version_number!" 1>&2 38 | exit 1 39 | esac 40 | fi 41 | } 42 | 43 | 44 | function prepare_conda_env() { 45 | ### Preparing the base environment "tseval" 46 | local env_name=${1:-tseval}; shift 47 | local conda_path=${1:-$(get_conda_path)}; shift 48 | local cuda_version=${1:-$(get_cuda_version)}; shift 49 | 50 | echo ">>> Preparing conda environment \"${env_name}\", for cuda version: ${cuda_version}; conda at ${conda_path}" 51 | 52 | # Preparation 53 | set -e 54 | set -x 55 | source ${conda_path} 56 | conda env remove --name $env_name 57 | conda create --name $env_name python=3.8 pip -y 58 | conda activate $env_name 59 | 60 | # PyTorch 61 | local cuda_toolkit=""; 62 | case $cuda_version in 63 | cpu) 64 | cuda_toolkit=cpuonly;; 65 | cu100) 66 | cuda_toolkit="cudatoolkit=10.0";; 67 | cu101) 68 | cuda_toolkit="cudatoolkit=10.1";; 69 | cu102) 70 | cuda_toolkit="cudatoolkit=10.2";; 71 | cu110) 72 | cuda_toolkit="cudatoolkit=11.0";; 73 | *) 74 | echo "Unexpected cuda version $cuda_version!" 1>&2 75 | exit 1 76 | esac 77 | 78 | conda install -y pytorch=${PYTORCH_VERSION} torchvision=${TORCHVISION_VERSION} ${cuda_toolkit} -c pytorch 79 | 80 | # Other libraries 81 | pip install -r requirements.txt 82 | } 83 | 84 | 85 | prepare_conda_env "$@" 86 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | javalang~=0.13.0 2 | keras~=2.3.1 3 | nltk~=3.5 4 | recordclass==0.13.2 5 | rouge==1.0.0 6 | scikit-learn~=0.22.1 7 | seaborn==0.11.1 8 | seutil==0.5.7 9 | # tensorflow~=2.1.0 10 | torch==1.7.1 11 | torchtext==0.8.1 12 | tqdm~=4.54.1 13 | -------------------------------------------------------------------------------- /python/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script documents the exact procedures we use to get the 4 | # dataset, get the models running, and collect the results. 5 | 6 | # Each function is a group of commands and a later function usually 7 | # requires the execution of all the proceeding functions. 8 | 9 | # The commands within each function should always be executed one 10 | # after one sequentially unless only partial functionality is wanted. 11 | 12 | 13 | _DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) 14 | 15 | 16 | # ========== Metrics, Tables, and Plots 17 | 18 | function metric_dataset() { 19 | python -m tseval.main collect_metrics --which=raw-data-filtered 20 | python -m tseval.main collect_metrics --which=split-dataset --split=Full 21 | 22 | for task in CG MN; do 23 | for setup in CP MP T; do 24 | python -m tseval.main collect_metrics --which=setup-dataset --task=$task --setup=$setup 25 | done 26 | done 27 | 28 | for task in CG MN; do 29 | for setup in CP MP T; do 30 | python -m tseval.main collect_metrics --which=setup-dataset-leak --task=$task --setup=$setup 31 | done 32 | done 33 | } 34 | 35 | 36 | function tables_dataset() { 37 | python -m tseval.main make_tables --which=dataset-metrics 38 | } 39 | 40 | 41 | function plots_dataset() { 42 | for task in CG MN; do 43 | python -m tseval.main make_plots --which=dataset-metrics-dist --task=$task 44 | done 45 | } 46 | 47 | 48 | function analyze_exps() { 49 | for exps in cg mn; do 50 | python -m tseval.main analyze_extract_metrics --exps=./exps/${exps}.yaml 51 | python -m tseval.main analyze_sign_test --exps=./exps/${exps}.yaml 52 | done 53 | } 54 | 55 | 56 | function tables_exps() { 57 | for exps in cg mn; do 58 | python -m tseval.main analyze_make_tables --exps=./exps/${exps}.yaml 59 | done 60 | } 61 | 62 | 63 | function plots_exps() { 64 | for exps in cg mn; do 65 | python -m tseval.main analyze_make_plots --exps=./exps/${exps}.yaml 66 | done 67 | } 68 | 69 | 70 | function extract_data_sim() { 71 | for exps in cg mn; do 72 | python -m tseval.main analyze_extract_data_similarities --exps=./exps/${exps}.yaml 73 | done 74 | } 75 | 76 | 77 | function analyze_near_duplicates() { 78 | for exps in cg mn; do 79 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \ 80 | --config=same_code \ 81 | --code_sim=1 --nl_sim=1.1 82 | 83 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \ 84 | --config=same_nl \ 85 | --code_sim=1.1 --nl_sim=1 86 | 87 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \ 88 | --config=sim_90 \ 89 | --code_sim=0.9 --nl_sim=0.9 90 | done 91 | } 92 | 93 | 94 | function analyze_near_duplicates_only_tables_plots() { 95 | for exps in cg mn; do 96 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \ 97 | --config=same_code \ 98 | --code_sim=1 --nl_sim=1.1 --only_tables_plots 99 | 100 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \ 101 | --config=same_nl \ 102 | --code_sim=1.1 --nl_sim=1 --only_tables_plots 103 | 104 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \ 105 | --config=sim_90 \ 106 | --code_sim=0.9 --nl_sim=0.9 --only_tables_plots 107 | done 108 | } 109 | 110 | 111 | # ========== Data collection 112 | 113 | function collect_repos() { 114 | python -m tseval.main collect_repos 115 | python -m tseval.main filter_repos 116 | } 117 | 118 | 119 | function collect_raw_data() { 120 | python -m tseval.main collect_raw_data 121 | python -m tseval.main process_raw_data 122 | } 123 | 124 | 125 | # ========== Eval preparation 126 | 127 | function prepare_envs() { 128 | # Require tseval conda env first, prepared using prepare_conda_env 129 | python -m tseval.main prepare_envs 130 | } 131 | 132 | 133 | function prepare_splits() { 134 | python -m tseval.main get_splits --seed=7 --split=Debug --debug 135 | python -m tseval.main get_splits --seed=7 --split=Full 136 | } 137 | 138 | 139 | function cg_debug_workflow() { 140 | python -m tseval.main exp_prepare --task=CG --setup=StandardSetup --setup_name=Debug --split_name=Debug --split_type=MP 141 | python -m tseval.main exp_train --task=CG --setup_name=Debug --model_name=TransformerACL20 --exp_name=TransformerACL20 142 | for action in test_common val test_standard; do 143 | python -m tseval.main exp_eval --task=CG --setup_name=Debug --exp_name=TransformerACL20 --action=$action 144 | done 145 | for action in test_common val test_standard; do 146 | python -m tseval.main exp_compute_metrics --task=CG --setup_name=Debug --exp_name=TransformerACL20 --action=$action 147 | done 148 | } 149 | 150 | 151 | function cg_prepare_setups() { 152 | for split_type in MP CP T; do 153 | python -m tseval.main exp_prepare --task=CG --setup=StandardSetup --setup_name=$split_type --split_name=Full --split_type=$split_type 154 | done 155 | } 156 | 157 | 158 | function cg_debug_model() { 159 | local model=$1; shift 160 | local args="$@"; shift 161 | echo "Arguments to model $model: $args" 162 | 163 | set -e 164 | set -x 165 | python -m tseval.main exp_train --task=CG --setup_name=Debug --exp_name=$model\ 166 | --model_name=$model $args 167 | for action in test_common val test_standard; do 168 | python -m tseval.main exp_eval --task=CG --setup_name=Debug --exp_name=$model --action=$action 169 | python -m tseval.main exp_compute_metrics --task=CG --setup_name=Debug --exp_name=$model --action=$action 170 | done 171 | } 172 | 173 | 174 | function cg_run_model() { 175 | local model=$1; shift 176 | local setup=$1; shift 177 | local args="$@"; shift 178 | echo "Arguments to model $model: $args" 179 | 180 | set -e 181 | set -x 182 | python -m tseval.main exp_train --task=CG --setup_name=$setup --exp_name=$model\ 183 | --model_name=$model $args 184 | for action in test_common val test_standard; do 185 | python -m tseval.main exp_eval --task=CG --setup_name=$setup --exp_name=$model --action=$action 186 | python -m tseval.main exp_compute_metrics --task=CG --setup_name=$setup --exp_name=$model --action=$action 187 | done 188 | } 189 | 190 | 191 | function mn_debug_workflow() { 192 | python -m tseval.main exp_prepare --task=MN --setup=StandardSetup --setup_name=Debug --split_name=Debug --split_type=MP 193 | python -m tseval.main exp_train --task=MN --setup_name=Debug --model_name=Code2SeqICLR19 --exp_name=Code2SeqICLR19 194 | for action in test_common val test_standard; do 195 | python -m tseval.main exp_eval --task=MN --setup_name=Debug --exp_name=Code2SeqICLR19 --action=$action 196 | done 197 | for action in test_common val test_standard; do 198 | python -m tseval.main exp_compute_metrics --task=MN --setup_name=Debug --exp_name=Code2SeqICLR19 --action=$action 199 | done 200 | } 201 | 202 | 203 | function mn_prepare_setups() { 204 | for split_type in MP CP T; do 205 | python -m tseval.main exp_prepare --task=MN --setup=StandardSetup --setup_name=$split_type --split_name=Full --split_type=$split_type 206 | done 207 | } 208 | 209 | 210 | function mn_debug_model() { 211 | local model=$1; shift 212 | local args="$@"; shift 213 | echo "Arguments to model $model: $args" 214 | 215 | set -e 216 | set -x 217 | python -m tseval.main exp_train --task=MN --setup_name=Debug --exp_name=$model\ 218 | --model_name=$model $args 219 | for action in test_common val test_standard; do 220 | python -m tseval.main exp_eval --task=MN --setup_name=Debug --exp_name=$model --action=$action 221 | python -m tseval.main exp_compute_metrics --task=MN --setup_name=Debug --exp_name=$model --action=$action 222 | done 223 | } 224 | 225 | 226 | function mn_run_model() { 227 | local model=$1; shift 228 | local setup=$1; shift 229 | local args="$@"; shift 230 | echo "Arguments to model $model: $args" 231 | 232 | set -e 233 | set -x 234 | python -m tseval.main exp_train --task=MN --setup_name=$setup --exp_name=$model\ 235 | --model_name=$model $args 236 | for action in test_common val test_standard; do 237 | python -m tseval.main exp_eval --task=MN --setup_name=$setup --exp_name=$model --action=$action 238 | python -m tseval.main exp_compute_metrics --task=MN --setup_name=$setup --exp_name=$model --action=$action 239 | done 240 | } 241 | 242 | 243 | 244 | 245 | # ========== 246 | # Main function -- program entry point 247 | # This script can be executed as ./run.sh the_function_to_run 248 | 249 | function main() { 250 | local action=${1:?Need Argument}; shift 251 | 252 | ( cd ${_DIR} 253 | $action "$@" 254 | ) 255 | } 256 | 257 | main "$@" 258 | -------------------------------------------------------------------------------- /python/switch-cuda.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) 2018 Patrick Hohenecker 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | # author: Patrick Hohenecker 24 | # version: 2018.1 25 | # date: May 15, 2018 26 | 27 | 28 | set -e 29 | 30 | 31 | # ensure that the script has been sourced rather than just executed 32 | if [[ "${BASH_SOURCE[0]}" = "${0}" ]]; then 33 | echo "Please use 'source' to execute switch-cuda.sh!" 34 | exit 1 35 | fi 36 | 37 | INSTALL_FOLDER="/usr/local" # the location to look for CUDA installations at 38 | TARGET_VERSION=${1} # the target CUDA version to switch to (if provided) 39 | 40 | # if no version to switch to has been provided, then just print all available CUDA installations 41 | if [[ -z ${TARGET_VERSION} ]]; then 42 | echo "The following CUDA installations have been found (in '${INSTALL_FOLDER}'):" 43 | ls -l "${INSTALL_FOLDER}" | egrep -o "cuda-[0-9]+\\.[0-9]+$" | while read -r line; do 44 | echo "* ${line}" 45 | done 46 | set +e 47 | return 48 | # otherwise, check whether there is an installation of the requested CUDA version 49 | elif [[ ! -d "${INSTALL_FOLDER}/cuda-${TARGET_VERSION}" ]]; then 50 | echo "No installation of CUDA ${TARGET_VERSION} has been found!" 51 | set +e 52 | return 53 | fi 54 | 55 | # the path of the installation to use 56 | cuda_path="${INSTALL_FOLDER}/cuda-${TARGET_VERSION}" 57 | 58 | # filter out those CUDA entries from the PATH that are not needed anymore 59 | path_elements=(${PATH//:/ }) 60 | new_path="${cuda_path}/bin" 61 | for p in "${path_elements[@]}"; do 62 | if [[ ! ${p} =~ ^${INSTALL_FOLDER}/cuda ]]; then 63 | new_path="${new_path}:${p}" 64 | fi 65 | done 66 | 67 | # filter out those CUDA entries from the LD_LIBRARY_PATH that are not needed anymore 68 | ld_path_elements=(${LD_LIBRARY_PATH//:/ }) 69 | new_ld_path="${cuda_path}/lib64:${cuda_path}/extras/CUPTI/lib64" 70 | for p in "${ld_path_elements[@]}"; do 71 | if [[ ! ${p} =~ ^${INSTALL_FOLDER}/cuda ]]; then 72 | new_ld_path="${new_ld_path}:${p}" 73 | fi 74 | done 75 | 76 | # update environment variables 77 | export CUDA_HOME="${cuda_path}" 78 | export CUDA_ROOT="${cuda_path}" 79 | export LD_LIBRARY_PATH="${new_ld_path}" 80 | export PATH="${new_path}" 81 | 82 | echo "Switched to CUDA ${TARGET_VERSION}." 83 | 84 | set +e 85 | return 86 | -------------------------------------------------------------------------------- /python/tseval/Environment.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from seutil import BashUtils, IOUtils, LoggingUtils, MiscUtils 4 | 5 | from tseval.Macros import Macros 6 | 7 | 8 | class Environment: 9 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) 10 | 11 | # ---------- 12 | # Environment variables 13 | # ---------- 14 | is_debug: bool = False 15 | random_seed: int = 1 16 | is_parallel: bool = False 17 | 18 | # ---------- 19 | # Conda & CUDA 20 | # ---------- 21 | conda_env = "tseval" 22 | 23 | @classmethod 24 | def get_conda_init_path(cls) -> str: 25 | which_conda = BashUtils.run("which conda").stdout.strip() 26 | if len(which_conda) == 0: 27 | raise RuntimeError(f"Cannot detect conda environment!") 28 | return str(Path(which_conda).parent.parent/"etc"/"profile.d"/"conda.sh") 29 | 30 | conda_init_path_cached = None 31 | 32 | @MiscUtils.classproperty 33 | def conda_init_path(cls): 34 | if cls.conda_init_path_cached is None: 35 | cls.conda_init_path_cached = cls.get_conda_init_path() 36 | return cls.conda_init_path_cached 37 | 38 | @classmethod 39 | def get_cuda_version(cls) -> str: 40 | which_nvidia_smi = BashUtils.run("which nvidia-smi").stdout.strip() 41 | if len(which_nvidia_smi) == 0: 42 | return "cpu" 43 | else: 44 | cuda_version_number = BashUtils.run(r'nvcc -V | grep "release" | sed -E "s/.*release ([^,]+),.*/\1/"').stdout.strip() 45 | if cuda_version_number.startswith("10.0"): 46 | return "cu100" 47 | elif cuda_version_number.startswith("10.1"): 48 | return "cu101" 49 | elif cuda_version_number.startswith("10.2"): 50 | return "cu102" 51 | elif cuda_version_number.startswith("11.0"): 52 | return "cu110" 53 | else: 54 | raise RuntimeError(f"Unsupported cuda version {cuda_version_number}!") 55 | 56 | cuda_version_cached = None 57 | 58 | @MiscUtils.classproperty 59 | def cuda_version(cls): 60 | if cls.cuda_version_cached is None: 61 | cls.cuda_version_cached = cls.get_cuda_version() 62 | return cls.cuda_version_cached 63 | 64 | @classmethod 65 | def get_cuda_toolkit_spec(cls): 66 | cuda_version = cls.cuda_version 67 | if cuda_version == "cpu": 68 | return "cpuonly" 69 | elif cuda_version == "cu100": 70 | return "cudatoolkit=10.1" 71 | elif cuda_version == "cu101": 72 | return "cudatoolkit=10.1" 73 | elif cuda_version == "cu102": 74 | return "cudatoolkit=10.2" 75 | elif cuda_version == "cu110": 76 | return "cudatoolkit=11.0" 77 | else: 78 | raise RuntimeError(f"Unexpected cuda version {cuda_version}!") 79 | 80 | # ---------- 81 | # Tools 82 | # ---------- 83 | 84 | collector_installed = False 85 | collector_jar = str(Macros.collector_dir / "target" / f"collector-{Macros.collector_version}.jar") 86 | 87 | @classmethod 88 | def require_collector(cls): 89 | if cls.is_parallel: 90 | return 91 | if not cls.collector_installed: 92 | cls.logger.info("Require collector, installing ...") 93 | with IOUtils.cd(Macros.collector_dir): 94 | BashUtils.run(f"mvn clean install -DskipTests", expected_return_code=0) 95 | cls.collector_installed = True 96 | else: 97 | cls.logger.debug("Require collector, and already installed") 98 | return 99 | -------------------------------------------------------------------------------- /python/tseval/Macros.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import os 4 | from pathlib import Path 5 | 6 | 7 | class Macros: 8 | this_dir: Path = Path(os.path.dirname(os.path.realpath(__file__))) 9 | python_dir: Path = this_dir.parent 10 | project_dir: Path = python_dir.parent 11 | paper_dir: Path = project_dir / "papers" / "acl22" 12 | 13 | collector_dir: Path = project_dir / "collector" 14 | collector_version = "0.1-dev" 15 | 16 | results_dir: Path = project_dir / "results" 17 | raw_data_dir: Path = project_dir / "_raw_data" 18 | work_dir: Path = project_dir / "_work" 19 | repos_downloads_dir: Path = project_dir / "_downloads" 20 | 21 | train = "train" 22 | val = "val" 23 | test = "test" 24 | test_common = "test_common" 25 | test_standard = "test_standard" 26 | 27 | mixed_prj = "MP" 28 | cross_prj = "CP" 29 | temporally = "T" 30 | split_types = [mixed_prj, cross_prj, temporally] 31 | 32 | @classmethod 33 | def get_pairwise_split_types_with(cls, split: str) -> List[Tuple[str, str]]: 34 | pairs = [] 35 | before = True 36 | for s in cls.split_types: 37 | if s == split: 38 | before = False 39 | else: 40 | if before: 41 | pairs.append((s, split)) 42 | else: 43 | pairs.append((split, s)) 44 | return pairs 45 | 46 | com_gen = "CG" 47 | met_nam = "MN" 48 | 49 | tasks = ["CG", "MN"] 50 | -------------------------------------------------------------------------------- /python/tseval/Plot.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | from typing import * 4 | 5 | import matplotlib as mpl 6 | import pandas as pd 7 | import seaborn as sns 8 | from matplotlib import pyplot as plt 9 | from seutil import IOUtils, latex, LoggingUtils 10 | 11 | from tseval.Environment import Environment 12 | from tseval.Macros import Macros 13 | 14 | 15 | class Plot: 16 | 17 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO) 18 | 19 | @classmethod 20 | def init_plot_libs(cls): 21 | # Initialize seaborn 22 | sns.set() 23 | sns.set_palette("Dark2") 24 | sns.set_context("paper") 25 | # Set matplotlib fonts 26 | mpl.rcParams["axes.titlesize"] = 24 27 | mpl.rcParams["axes.labelsize"] = 24 28 | mpl.rcParams["font.size"] = 18 29 | mpl.rcParams["xtick.labelsize"] = 24 30 | mpl.rcParams["xtick.major.size"] = 14 31 | mpl.rcParams["xtick.minor.size"] = 14 32 | mpl.rcParams["ytick.labelsize"] = 24 33 | mpl.rcParams["ytick.major.size"] = 14 34 | mpl.rcParams["ytick.minor.size"] = 14 35 | mpl.rcParams["legend.fontsize"] = 18 36 | mpl.rcParams["legend.title_fontsize"] = 18 37 | # print(mpl.rcParams) 38 | 39 | def __init__(self): 40 | self.plots_dir: Path = Macros.paper_dir / "figs" 41 | IOUtils.mk_dir(self.plots_dir) 42 | self.init_plot_libs() 43 | 44 | def make_plots(self, options: dict): 45 | which = options.pop("which") 46 | 47 | if which == "dataset-metrics-dist": 48 | self.dataset_metrics_dist( 49 | task=options["task"], 50 | ) 51 | else: 52 | self.logger.warning(f"Unknown plot {which}") 53 | 54 | def dataset_metrics_dist(self, task: str): 55 | plots_sub_dir = self.plots_dir / f"dataset-{task}" 56 | plots_sub_dir_rel = str(plots_sub_dir.relative_to(Macros.paper_dir)) 57 | IOUtils.rm_dir(plots_sub_dir) 58 | plots_sub_dir.mkdir(parents=True) 59 | 60 | # Load metrics list into a DataFrame 61 | label_x = "code" 62 | max_x = 200 63 | if task == "CG": 64 | label_y = "comment" 65 | max_y = 60 66 | else: 67 | label_y = "name" 68 | max_y = 8 69 | lod: List[dict] = [] 70 | seen_split_combination = set() 71 | for setup in Macros.split_types: 72 | metrics_list = IOUtils.load(Macros.results_dir / "metrics" / f"setup-dataset-metrics-list_{task}_{setup}.pkl", IOUtils.Format.pkl) 73 | for sn in [Macros.train, Macros.val, Macros.test_standard]: 74 | for i, (x, y) in enumerate(zip( 75 | metrics_list[f"{sn}_len-{label_x}"], metrics_list[f"{sn}_len-{label_y}"], 76 | )): 77 | lod.append({ 78 | "i": i, 79 | "set_name": f"{sn}-{setup}", 80 | label_x: x, 81 | label_y: y, 82 | }) 83 | 84 | for s1, s2 in Macros.get_pairwise_split_types_with(setup): 85 | if (s1, s2) in seen_split_combination: 86 | continue 87 | for i, (x, y) in enumerate(zip( 88 | metrics_list[f"{Macros.test_common}-{s1}-{s2}_len-{label_x}"], 89 | metrics_list[f"{Macros.test_common}-{s1}-{s2}_len-{label_y}"], 90 | )): 91 | lod.append({ 92 | "i": i, 93 | "set_name": f"{Macros.test_common}-{s1}-{s2}", 94 | label_x: x, 95 | label_y: y, 96 | }) 97 | seen_split_combination.add((s1, s2)) 98 | 99 | df = pd.DataFrame(lod) 100 | 101 | # Make plots 102 | for sn, df_sn in df.groupby("set_name", as_index=False): 103 | sn: str 104 | if sn in [f"{x}-{Macros.temporally}" for x in [Macros.train, Macros.val, Macros.test_standard]] + [f"{Macros.test_common}-{Macros.cross_prj}-{Macros.temporally}"]: 105 | display_xlabel = "len(" + label_x + ")" 106 | else: 107 | display_xlabel = None 108 | if sn.startswith(Macros.train): 109 | display_ylabel = "len(" + label_y + ")" 110 | else: 111 | display_ylabel = None 112 | 113 | fig = sns.jointplot( 114 | data=df_sn, 115 | x=label_x, y=label_y, 116 | kind="hist", 117 | xlim=(0, max_x), 118 | ylim=(0, max_y), 119 | height=6, 120 | ratio=3, 121 | space=.01, 122 | joint_kws=dict( 123 | bins=(12, min(12, max_y-1)), 124 | binrange=((0, max_x), (0, max_y)), 125 | pmax=.5, 126 | ), 127 | color="royalblue", 128 | ) 129 | fig.set_axis_labels(display_xlabel, display_ylabel) 130 | plt.tight_layout() 131 | fig.savefig(plots_sub_dir / f"{sn}.pdf") 132 | 133 | # Generate a tex file that organizes the plots 134 | f = latex.File(plots_sub_dir / f"plot.tex") 135 | f.append(r"\begin{center}") 136 | f.append(r"\begin{footnotesize}") 137 | f.append(r"\begin{tabular}{|l|c|c|c || l|c|}") 138 | f.append(r"\hline") 139 | f.append(r"& \textbf{ATrain} & \textbf{\AVal} & \textbf{\ATestS} & & \textbf{\ATestC} \\") 140 | 141 | for sn_l, (s1_r, s2_r) in zip(Macros.split_types, itertools.combinations(Macros.split_types, 2)): 142 | f.append(r"\hline") 143 | f.append(latex.Macro(f"TH-ds-{sn_l}").use()) 144 | for sn in [Macros.train, Macros.val, Macros.test_standard]: 145 | f.append(r" & \begin{minipage}{.18\textwidth}\includegraphics[width=\textwidth]{" 146 | + f"{plots_sub_dir_rel}/{sn}-{sn_l}" 147 | + r"}\end{minipage}") 148 | f.append(" & " + latex.Macro(f"TH-ds-{s1_r}-{s2_r}").use()) 149 | f.append(r" & \begin{minipage}{.18\textwidth}\includegraphics[width=\textwidth]{" 150 | + f"{plots_sub_dir_rel}/{Macros.test_common}-{s1_r}-{s2_r}" 151 | + r"}\end{minipage}") 152 | f.append(r"\\") 153 | f.append(r"\hline") 154 | f.append(r"\end{tabular}") 155 | f.append(r"\end{footnotesize}") 156 | f.append(r"\end{center}") 157 | f.save() 158 | -------------------------------------------------------------------------------- /python/tseval/Table.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | from pathlib import Path 4 | 5 | from seutil import IOUtils, latex, LoggingUtils 6 | from seutil.latex.Macro import Macro 7 | 8 | from tseval.Environment import Environment 9 | from tseval.Macros import Macros 10 | 11 | 12 | class Table: 13 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO) 14 | 15 | COLSEP = "COLSEP" 16 | ROWSEP = "ROWSEP" 17 | 18 | def __init__(self): 19 | self.tables_dir: Path = Macros.paper_dir / "tables" 20 | IOUtils.mk_dir(self.tables_dir) 21 | 22 | self.metrics_dir: Path = Macros.results_dir / "metrics" 23 | return 24 | 25 | def make_tables(self, options): 26 | which = options.pop("which") 27 | if which == "dataset-metrics": 28 | self.make_numbers_dataset_metrics() 29 | self.make_table_dataset_metrics_small() 30 | for task in [Macros.com_gen, Macros.met_nam]: 31 | self.make_table_dataset_metrics(task) 32 | else: 33 | self.logger.warning(f"Unknown table name {which}") 34 | 35 | def make_numbers_dataset_metrics(self): 36 | file = latex.File(self.tables_dir / f"numbers-dataset-metrics.tex") 37 | 38 | dataset_filtered_metrics = IOUtils.load(Macros.results_dir / "metrics" / f"raw-data-filtered.json", IOUtils.Format.json) 39 | for k, v in dataset_filtered_metrics.items(): 40 | file.append_macro(latex.Macro(f"ds-filter-{k}", f"{v:,d}")) 41 | 42 | dataset_metrics = IOUtils.load(Macros.results_dir / "metrics" / f"split-dataset-metrics_Full.json", IOUtils.Format.json) 43 | for k, v in dataset_metrics.items(): 44 | fmt = f",d" if type(v) == int else f",.2f" 45 | file.append_macro(latex.Macro(f"ds-{k}", f"{v:{fmt}}")) 46 | 47 | for task in [Macros.com_gen, Macros.met_nam]: 48 | for split in Macros.split_types: 49 | setup_metrics = IOUtils.load(Macros.results_dir / "metrics" / f"setup-dataset-metrics_{task}_{split}.json", IOUtils.Format.json) 50 | for k, v in setup_metrics.items(): 51 | fmt = f",d" if type(v) == int else f",.2f" 52 | if k.startswith("all_"): 53 | # Skip all_ metrics 54 | continue 55 | else: 56 | # Replace train/val/test_standard to train-split/val-split/test_standard-split 57 | for x in [Macros.train, Macros.val, Macros.test_standard]: 58 | if k.startswith(f"{x}_"): 59 | k = f"{x}-{split}_" + k[len(f"{x}_"):] 60 | break 61 | if f"ds-{task}-{k}" in file.macros_indexed: 62 | continue 63 | file.append_macro(latex.Macro(f"ds-{task}-{k}", f"{v:{fmt}}")) 64 | 65 | file.save() 66 | 67 | def make_table_dataset_metrics(self, task: str): 68 | f = latex.File(self.tables_dir / f"table-dataset-metrics-{task}.tex") 69 | 70 | metric_2_th = collections.OrderedDict() 71 | # metric_2_th["num-proj"] = r"\multicolumn{2}{c|}{\UseMacro{TH-ds-num-project}}" 72 | metric_2_th["num-data"] = r"\multicolumn{2}{c|}{\UseMacro{TH-ds-num-data}}" 73 | metric_2_th["sep-data"] = self.ROWSEP 74 | metric_2_th["len-code-AVG"] = r"& \UseMacro{TH-ds-len-code-avg}" 75 | # metric_2_th["len-code-MODE"] = r"& \UseMacro{TH-ds-len-code-mode}" 76 | # metric_2_th["len-code-MEDIAN"] = r"& \UseMacro{TH-ds-len-code-median}" 77 | metric_2_th["len-code-le-100"] = r"& \UseMacro{TH-ds-len-code-le100}" 78 | metric_2_th["len-code-le-150"] = r"& \UseMacro{TH-ds-len-code-le150}" 79 | metric_2_th["len-code-le-200"] = r"\multirow{-4}{*}{\UseMacro{TH-ds-len-code}} & \UseMacro{TH-ds-len-code-le200}" 80 | if task == Macros.com_gen: 81 | metric_2_th["sep-cg"] = self.ROWSEP 82 | metric_2_th["len-comment-AVG"] = r"& \UseMacro{TH-ds-len-comment-avg}" 83 | # metric_2_th["len-comment-MODE"] = r"& \UseMacro{TH-ds-len-comment-mode}" 84 | # metric_2_th["len-comment-MEDIAN"] = r"& \UseMacro{TH-ds-len-comment-median}" 85 | metric_2_th["len-comment-le-20"] = r"& \UseMacro{TH-ds-len-comment-le20}" 86 | metric_2_th["len-comment-le-30"] = r"& \UseMacro{TH-ds-len-comment-le30}" 87 | metric_2_th["len-comment-le-50"] = r"\multirow{-4}{*}{\UseMacro{TH-ds-len-comment}} & \UseMacro{TH-ds-len-comment-le50}" 88 | if task == Macros.met_nam: 89 | metric_2_th["sep-mn"] = self.ROWSEP 90 | metric_2_th["len-name-AVG"] = r"& \UseMacro{TH-ds-len-name-avg}" 91 | # metric_2_th["len-name-MODE"] = r"& \UseMacro{TH-ds-len-name-mode}" 92 | # metric_2_th["len-name-MEDIAN"] = r"& \UseMacro{TH-ds-len-name-median}" 93 | metric_2_th["len-name-le-2"] = r"& \UseMacro{TH-ds-len-name-le2}" 94 | metric_2_th["len-name-le-3"] = r"& \UseMacro{TH-ds-len-name-le3}" 95 | metric_2_th["len-name-le-6"] = r"\multirow{-4}{*}{\UseMacro{TH-ds-len-name}} & \UseMacro{TH-ds-len-name-le6}" 96 | 97 | cols = sum( 98 | [ 99 | [f"{s1}-{s2}" for s1 in [Macros.train, Macros.val, Macros.test_standard]] + [self.COLSEP] 100 | for s2 in [Macros.mixed_prj, Macros.cross_prj, Macros.temporally] 101 | ] 102 | , [], 103 | ) + [f"{Macros.test_common}-{x}-{y}" 104 | for x, y in itertools.combinations([Macros.mixed_prj, Macros.cross_prj, Macros.temporally], 2)] 105 | 106 | # Header 107 | f.append(r"\begin{table*}[t]") 108 | f.append(r"\begin{footnotesize}") 109 | f.append(r"\begin{center}") 110 | table_name = f"dataset-metrics-{task}" 111 | 112 | f.append( 113 | r"\begin{tabular}{ l@{\hspace{2pt}}c@{\hspace{2pt}} | " 114 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r @{\hspace{3pt}}c@{\hspace{3pt}} " 115 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r @{\hspace{3pt}}c@{\hspace{3pt}} " 116 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r @{\hspace{3pt}}c@{\hspace{3pt}} " 117 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r }" 118 | ) 119 | 120 | f.append(r"\toprule") 121 | 122 | # Line 1 123 | f.append( 124 | r"\multicolumn{2}{c|}{}" 125 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-MP}} &" 126 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-CP}} &" 127 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-T}} &" 128 | r" & \UseMacro{TH-ds-MP-CP} & \UseMacro{TH-ds-MP-T} & \UseMacro{TH-ds-CP-T}" 129 | r" \\\cline{3-5}\cline{7-9}\cline{11-13}\cline{15-17}" 130 | ) 131 | 132 | # Line 2 133 | f.append( 134 | r"\multicolumn{2}{c|}{\multirow{-2}{*}{\THDSStat}}" 135 | r" & \UseMacro{TH-ds-train} & \UseMacro{TH-ds-val} & \UseMacro{TH-ds-test_standard} &" 136 | r" & \UseMacro{TH-ds-train} & \UseMacro{TH-ds-val} & \UseMacro{TH-ds-test_standard} &" 137 | r" & \UseMacro{TH-ds-train} & \UseMacro{TH-ds-val} & \UseMacro{TH-ds-test_standard} &" 138 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-test_common}} \\" 139 | ) 140 | 141 | f.append(r"\midrule") 142 | 143 | for metric, row_th in metric_2_th.items(): 144 | if row_th == self.ROWSEP: 145 | f.append(r"\midrule") 146 | continue 147 | 148 | f.append(row_th) 149 | 150 | for col in cols: 151 | if col == self.COLSEP: 152 | f.append(" & ") 153 | continue 154 | f.append(" & " + latex.Macro(f"ds-{task}-{col}_{metric}").use()) 155 | 156 | f.append(r"\\") 157 | 158 | # Footer 159 | f.append(r"\bottomrule") 160 | f.append(r"\end{tabular}") 161 | f.append(r"\end{center}") 162 | f.append(r"\end{footnotesize}") 163 | f.append(r"\vspace{" + latex.Macro(f"TV-{table_name}").use() + "}") 164 | f.append(r"\caption{" + latex.Macro(f"TC-{table_name}").use() + r"}") 165 | f.append(r"\end{table*}") 166 | 167 | f.save() 168 | 169 | def make_table_dataset_metrics_small(self): 170 | f = latex.File(self.tables_dir / f"table-dataset-metrics-small.tex") 171 | 172 | # Header 173 | f.append(r"\begin{table}[t]") 174 | f.append(r"\begin{footnotesize}") 175 | f.append(r"\begin{center}") 176 | table_name = f"dataset-metrics-small" 177 | 178 | f.append(r"\begin{tabular}{ @{} l | c @{\hspace{5pt}} r @{\hspace{5pt}} r @{\hspace{5pt}} r @{\hspace{3pt}}c@{\hspace{3pt}} c @{\hspace{5pt}} r@{} }") 179 | f.append(r"\toprule") 180 | f.append(r"\textbf{Task} & & \textbf{\ATrain} & \textbf{\AVal} & \textbf{\ATestS} & & & \textbf{\ATestC} \\") 181 | 182 | for task in [Macros.com_gen, Macros.met_nam]: 183 | 184 | f.append(r"\midrule") 185 | 186 | for i, (m, p) in enumerate(zip( 187 | [Macros.mixed_prj, Macros.cross_prj, Macros.temporally], 188 | [f"{x}-{y}" for x, y in itertools.combinations([Macros.mixed_prj, Macros.cross_prj, Macros.temporally], 2)], 189 | )): 190 | if i == 2: 191 | f.append(r"\multirow{-3}{*}{\rotatebox[origin=c]{90}{" + latex.Macro(f"TaskM_{task}").use() + r"}}") 192 | 193 | f.append(" & " + latex.Macro(f"TH-ds-{m}").use()) 194 | for sn in [Macros.train, Macros.val, Macros.test_standard]: 195 | f.append(" & " + latex.Macro(f"ds-{task}-{sn}-{m}_num-data").use()) 196 | f.append(r" & \tikz[remember picture, baseline] \node[inner sep=2pt, outer sep=0, yshift=1ex] (" + task + m + r"-base) {\phantom{XX}};") 197 | f.append(" & " + latex.Macro(f"TH-ds-{p}").use()) 198 | f.append(" & " + latex.Macro(f"ds-{task}-{Macros.test_common}-{p}_num-data").use()) 199 | f.append(r"\\") 200 | 201 | f.append(r"\bottomrule") 202 | f.append(r"\end{tabular}") 203 | 204 | f.append(r"\begin{tikzpicture}[remember picture, overlay, thick]") 205 | for task in [Macros.com_gen, Macros.met_nam]: 206 | for r, (l1, l2) in zip( 207 | [Macros.mixed_prj, Macros.cross_prj, Macros.temporally], 208 | [ 209 | (Macros.mixed_prj, Macros.cross_prj), 210 | (Macros.mixed_prj, Macros.temporally), 211 | (Macros.cross_prj, Macros.temporally), 212 | ] 213 | ): 214 | f.append(r"\draw[->] (" + task + l1 + r"-base.west) .. controls ($(" + task + r + r"-base.east) - (1em,0)$) .. (" + task + r + r"-base.east);") 215 | f.append(r"\draw (" + task + l2 + r"-base.west) .. controls ($(" + task + r + r"-base.east) - (1em,0)$) .. (" + task + r + r"-base.east);") 216 | 217 | f.append(r"\end{tikzpicture}") 218 | 219 | f.append(r"\end{center}") 220 | f.append(r"\end{footnotesize}") 221 | f.append(r"\vspace{" + latex.Macro(f"TV-{table_name}").use() + "}") 222 | f.append(r"\caption{" + latex.Macro(f"TC-{table_name}").use() + "}") 223 | f.append(r"\end{table}") 224 | 225 | f.save() 226 | -------------------------------------------------------------------------------- /python/tseval/Utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from inspect import Parameter 3 | from pathlib import Path 4 | from typing import * 5 | 6 | import copy 7 | import numpy as np 8 | import typing_inspect 9 | from scipy import stats 10 | 11 | 12 | class Utils: 13 | 14 | @classmethod 15 | def get_option_as_boolean(cls, options, opt, default=False, pop=False) -> bool: 16 | if opt not in options: 17 | return default 18 | else: 19 | # Due to limitations of CliUtils... 20 | value = options.get(opt, "false") 21 | if pop: 22 | del options[opt] 23 | return str(value).lower() != "false" 24 | 25 | @classmethod 26 | def get_option_as_list(cls, options, opt, default=None, pop=False) -> list: 27 | if opt not in options: 28 | return copy.deepcopy(default) 29 | else: 30 | lst = options[opt] 31 | if pop: 32 | del options[opt] 33 | if not isinstance(lst, list): 34 | lst = [lst] 35 | return lst 36 | 37 | @classmethod 38 | def get_option_and_pop(cls, options, opt, default=None) -> any: 39 | if opt in options: 40 | return options.pop(opt) 41 | else: 42 | return copy.deepcopy(default) 43 | 44 | # Summaries 45 | SUMMARIES_FUNCS: Dict[str, Callable[[Union[list, np.ndarray]], Union[int, float]]] = { 46 | "AVG": lambda l: np.mean(l) if len(l) > 0 else np.NaN, 47 | "SUM": lambda l: sum(l) if len(l) > 0 else np.NaN, 48 | "MAX": lambda l: max(l) if len(l) > 0 else np.NaN, 49 | "MIN": lambda l: min(l) if len(l) > 0 else np.NaN, 50 | "MEDIAN": lambda l: np.median(l) if len(l) > 0 and np.NaN not in l else np.NaN, 51 | "STDEV": lambda l: np.std(l) if len(l) > 0 else np.NaN, 52 | "MODE": lambda l: stats.mode(l).mode[0].item() if len(l) > 0 else np.NaN, 53 | "CNT": lambda l: len(l), 54 | } 55 | 56 | SUMMARIES_PRESERVE_INT: Dict[str, bool] = { 57 | "AVG": False, 58 | "SUM": True, 59 | "MAX": True, 60 | "MIN": True, 61 | "MEDIAN": False, 62 | "STDEV": False, 63 | "MODE": True, 64 | "CNT": True, 65 | } 66 | 67 | @classmethod 68 | def parse_cmd_options_for_type( 69 | cls, 70 | options: dict, 71 | typ: type, 72 | excluding_params: List[str] = None, 73 | ) -> Tuple[dict, dict, list]: 74 | """ 75 | Parses the commandline options (got from seutil.CliUtils) based on the parameters 76 | and their types specified in typ.__init__. 77 | 78 | :param options: the commandline options got from seutil.CliUtils. 79 | :param typ: the type to initialize. 80 | :param excluding_params: the list of parameters that are not expected to be 81 | passed from commandline, by default ["self"]. 82 | :return: two dictionaries and a list: 83 | a dictionary with options that can be sent to typ.__init__; 84 | a dictionary that contains the remaining options; 85 | a list of any missing options required by the typ.__init__. 86 | """ 87 | if excluding_params is None: 88 | excluding_params = ["self"] 89 | 90 | accepted_options = {} 91 | unk_options = options 92 | missing_options = [] 93 | 94 | signature = inspect.signature(typ.__init__) 95 | for param in signature.parameters.values(): 96 | if param.name in excluding_params: 97 | continue 98 | 99 | if param.kind == Parameter.POSITIONAL_ONLY \ 100 | or param.kind == Parameter.VAR_KEYWORD \ 101 | or param.kind == Parameter.VAR_POSITIONAL: 102 | raise AssertionError(f"Class {typ.__name__} should not have '/', '**' '*'" 103 | f" parameters in order to be configured from commandline") 104 | 105 | if param.name not in unk_options: 106 | if param.default == Parameter.empty: 107 | missing_options.append(param.name) 108 | continue 109 | else: 110 | # No need to insert anything to model_options 111 | continue 112 | 113 | if param.annotation == bool: 114 | accepted_options[param.name] = Utils.get_option_as_boolean(unk_options, param.name, pop=True) 115 | elif typing_inspect.get_origin(param.annotation) == list: 116 | accepted_options[param.name] = Utils.get_option_as_list(unk_options, param.name, pop=True) 117 | elif typing_inspect.get_origin(param.annotation) == set: 118 | accepted_options[param.name] = set(Utils.get_option_as_list(unk_options, param.name, pop=True)) 119 | elif typing_inspect.get_origin(param.annotation) == tuple: 120 | accepted_options[param.name] = tuple(Utils.get_option_as_list(unk_options, param.name, pop=True)) 121 | else: 122 | accepted_options[param.name] = unk_options.pop(param.name) 123 | 124 | return accepted_options, unk_options, missing_options 125 | 126 | @classmethod 127 | def expect_dir_or_suggest_dvc_pull(cls, path: Path): 128 | if not path.is_dir(): 129 | dvc_file = path.parent / (path.name+".dvc") 130 | if dvc_file.exists(): 131 | print(f"{path} does not exist, but {dvc_file} exists. You probably want to dvc pull that file first?") 132 | print(f"# DVC command to run:") 133 | print(f" dvc pull {dvc_file}") 134 | raise AssertionError(f"{path} does not exist but can be pulled from dvc.\n dvc pull {dvc_file}") 135 | else: 136 | raise AssertionError(f"{path} does not exist.") 137 | 138 | @classmethod 139 | def suggest_dvc_add(cls, *paths: Path) -> str: 140 | s = f"# DVC commands:\n" 141 | s += f" dvc add " 142 | for path in paths: 143 | s += str(path) 144 | s += "\n" 145 | return s 146 | -------------------------------------------------------------------------------- /python/tseval/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | module_root = os.path.dirname(os.path.realpath(__file__)) + "/.." 5 | if module_root not in sys.path: 6 | sys.path.insert(0, module_root) 7 | 8 | # Remove temporary names 9 | del os 10 | del sys 11 | del module_root 12 | -------------------------------------------------------------------------------- /python/tseval/collector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/collector/__init__.py -------------------------------------------------------------------------------- /python/tseval/comgen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/comgen/__init__.py -------------------------------------------------------------------------------- /python/tseval/comgen/eval/CGEvalHelper.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from seutil import IOUtils, LoggingUtils 5 | 6 | from tseval.comgen.eval import get_setup_cls 7 | from tseval.eval.EvalSetupBase import EvalSetupBase 8 | from tseval.Macros import Macros 9 | from tseval.Utils import Utils 10 | 11 | logger = LoggingUtils.get_logger(__name__) 12 | 13 | 14 | class CGEvalHelper: 15 | 16 | def __init__(self): 17 | self.work_subdir: Path = Macros.work_dir / "CG" 18 | 19 | def exp_prepare(self, setup: str, setup_name: str, **cmd_options): 20 | # Clean the setup dir 21 | setup_dir: Path = self.work_subdir / "setup" / setup_name 22 | IOUtils.rm_dir(setup_dir) 23 | setup_dir.mkdir(parents=True) 24 | 25 | # Initialize setup 26 | setup_cls = get_setup_cls(setup) 27 | setup_options, unk_options, missing_options = Utils.parse_cmd_options_for_type( 28 | cmd_options, 29 | setup_cls, 30 | ["self", "work_dir", "work_subdir", "setup_name"], 31 | ) 32 | if len(missing_options) > 0: 33 | raise KeyError(f"Missing options: {missing_options}") 34 | if len(unk_options) > 0: 35 | logger.warning(f"Unrecognized options: {unk_options}") 36 | setup_obj: EvalSetupBase = setup_cls( 37 | work_dir=Macros.work_dir, 38 | work_subdir=self.work_subdir, 39 | setup_name=setup_name, 40 | **setup_options, 41 | ) 42 | 43 | # Save setup configs 44 | setup_options["setup"] = setup 45 | IOUtils.dump(setup_dir / "setup_config.json", setup_options, IOUtils.Format.jsonNoSort) 46 | 47 | # Prepare data 48 | setup_obj.prepare() 49 | 50 | # Print dvc commands 51 | print(Utils.suggest_dvc_add(setup_obj.setup_dir)) 52 | 53 | def load_setup(self, setup_dir: Path, setup_name: str) -> EvalSetupBase: 54 | """ 55 | Loads the setup from its save directory, with updating setup_name. 56 | """ 57 | config = IOUtils.load(setup_dir / "setup_config.json", IOUtils.Format.json) 58 | setup_cls = get_setup_cls(config.pop("setup")) 59 | setup_obj = setup_cls(work_dir=Macros.work_dir, work_subdir=self.work_subdir, setup_name=setup_name, **config) 60 | return setup_obj 61 | 62 | def exp_train( 63 | self, 64 | setup_name: str, 65 | exp_name: str, 66 | model_name: str, 67 | cont_train: bool, 68 | no_save: bool, 69 | **cmd_options, 70 | ): 71 | # Load saved setup 72 | setup_dir = self.work_subdir / "setup" / setup_name 73 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir) 74 | setup = self.load_setup(setup_dir, setup_name) 75 | 76 | if not cont_train: 77 | # Delete existing trained model 78 | IOUtils.rm_dir(setup.get_exp_dir(exp_name)) 79 | 80 | # Invoke training 81 | setup.train(exp_name, model_name, cont_train, no_save, **cmd_options) 82 | 83 | # Print dvc commands 84 | print(Utils.suggest_dvc_add(setup.get_exp_dir(exp_name))) 85 | 86 | def exp_eval( 87 | self, 88 | setup_name: str, 89 | exp_name: str, 90 | action: Optional[str], 91 | gpu_id: int = 0, 92 | ): 93 | # Load saved setup 94 | setup_dir = self.work_subdir / "setup" / setup_name 95 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir) 96 | setup = self.load_setup(setup_dir, setup_name) 97 | 98 | # Invoke eval 99 | setup.eval(exp_name, action, gpu_id=gpu_id) 100 | 101 | # Print dvc commands 102 | print(Utils.suggest_dvc_add(setup.get_result_dir(exp_name))) 103 | 104 | def exp_compute_metrics( 105 | self, 106 | setup_name: str, 107 | exp_name: str, 108 | action: Optional[str] = None, 109 | ): 110 | # Load saved setup 111 | setup_dir = self.work_subdir / "setup" / setup_name 112 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir) 113 | setup = self.load_setup(setup_dir, setup_name) 114 | 115 | # Invoke eval 116 | setup.compute_metrics(exp_name, action) 117 | 118 | # Print dvc commands 119 | print(Utils.suggest_dvc_add(setup.get_metric_dir(exp_name))) 120 | -------------------------------------------------------------------------------- /python/tseval/comgen/eval/CGModelLoader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from seutil import IOUtils, LoggingUtils 4 | 5 | from tseval.comgen.model import get_model_cls 6 | from tseval.comgen.model.CGModelBase import CGModelBase 7 | from tseval.Utils import Utils 8 | 9 | logger = LoggingUtils.get_logger(__name__) 10 | 11 | 12 | class CGModelLoader: 13 | 14 | @classmethod 15 | def init_or_load_model( 16 | cls, 17 | model_name: str, 18 | exp_dir: Path, 19 | cont_train: bool, 20 | no_save: bool, 21 | cmd_options: dict, 22 | ) -> CGModelBase: 23 | model_cls = get_model_cls(model_name) 24 | model_work_dir = exp_dir / "model" 25 | 26 | if cont_train and model_work_dir.is_dir() and not no_save: 27 | # Restore model name 28 | loaded_model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt) 29 | if model_name != loaded_model_name: 30 | raise ValueError(f"Contradicting model name (saved: {model_name}; new {loaded_model_name})") 31 | 32 | # Warning about any additional command line arguments 33 | if len(cmd_options) > 0: 34 | logger.warning(f"These options will not be used in cont_train mode: {cmd_options}") 35 | 36 | # Load existing model 37 | model: CGModelBase = model_cls.load(model_work_dir) 38 | else: 39 | 40 | if not no_save: 41 | exp_dir.mkdir(parents=True, exist_ok=True) 42 | 43 | # Save model name 44 | IOUtils.dump(exp_dir / "model_name.txt", model_name, IOUtils.Format.txt) 45 | 46 | # Prepare directory for model 47 | IOUtils.rm(model_work_dir) 48 | model_work_dir.mkdir(parents=True) 49 | 50 | # Initialize the model, using command line arguments 51 | model_options, unk_options, missing_options = Utils.parse_cmd_options_for_type( 52 | cmd_options, 53 | model_cls, 54 | ["self", "model_work_dir"], 55 | ) 56 | if len(missing_options) > 0: 57 | raise KeyError(f"Missing options: {missing_options}") 58 | if len(unk_options) > 0: 59 | logger.warning(f"Unrecognized options: {unk_options}") 60 | 61 | model: CGModelBase = model_cls(model_work_dir=model_work_dir, no_save=no_save, **model_options) 62 | 63 | if not no_save: 64 | # Save model configs 65 | IOUtils.dump(exp_dir / "model_config.json", model_options, IOUtils.Format.jsonNoSort) 66 | return model 67 | 68 | @classmethod 69 | def load_model(cls, exp_dir: Path) -> CGModelBase: 70 | """ 71 | Loads a trained model from exp_dir. Gets the model name from train_config.json. 72 | """ 73 | Utils.expect_dir_or_suggest_dvc_pull(exp_dir) 74 | model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt) 75 | model_cls = get_model_cls(model_name) 76 | model_dir = exp_dir / "model" 77 | return model_cls.load(model_dir) 78 | -------------------------------------------------------------------------------- /python/tseval/comgen/eval/StandardSetup.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import time 4 | from pathlib import Path 5 | from typing import Dict, List 6 | 7 | import nltk 8 | import numpy as np 9 | from seutil import IOUtils, LoggingUtils 10 | from tqdm import tqdm 11 | 12 | from tseval.comgen.eval.CGModelLoader import CGModelLoader 13 | from tseval.comgen.model.CGModelBase import CGModelBase 14 | from tseval.data.MethodData import MethodData 15 | from tseval.eval.EvalMetrics import EvalMetrics 16 | from tseval.eval.EvalSetupBase import EvalSetupBase 17 | from tseval.Macros import Macros 18 | from tseval.util.ModelUtils import ModelUtils 19 | from tseval.util.TrainConfig import TrainConfig 20 | from tseval.Utils import Utils 21 | 22 | logger = LoggingUtils.get_logger(__name__) 23 | 24 | 25 | class StandardSetup(EvalSetupBase): 26 | 27 | # Validation set on self's split type 28 | EVAL_VAL = Macros.val 29 | # Test_standard set on self's split type 30 | EVAL_TESTS = Macros.test_standard 31 | # Test_common sets, pairwisely between self's split type and other split types 32 | EVAL_TESTC = Macros.test_common 33 | 34 | EVAL_ACTIONS = [EVAL_VAL, EVAL_TESTS, EVAL_TESTC] 35 | DEFAULT_EVAL_ACTION = EVAL_TESTC 36 | 37 | def __init__( 38 | self, 39 | work_dir: Path, 40 | work_subdir: Path, 41 | setup_name: str, 42 | split_name: str, 43 | split_type: str, 44 | ): 45 | super().__init__(work_dir, work_subdir, setup_name) 46 | self.split_name = split_name 47 | self.split_type = split_type 48 | 49 | def prepare(self) -> None: 50 | # Check and prepare directories 51 | split_dir = self.get_split_dir(self.split_name) 52 | Utils.expect_dir_or_suggest_dvc_pull(self.shared_data_dir) 53 | Utils.expect_dir_or_suggest_dvc_pull(split_dir) 54 | IOUtils.rm_dir(self.data_dir) 55 | self.data_dir.mkdir(parents=True) 56 | 57 | # Copy split indexes 58 | all_indexes = [] 59 | for sn in [Macros.train, Macros.val, Macros.test_standard]: 60 | ids = IOUtils.load(split_dir / f"{self.split_type}-{sn}.json", IOUtils.Format.json) 61 | all_indexes += ids 62 | IOUtils.dump(self.data_dir / f"split_{sn}.json", ids, IOUtils.Format.json) 63 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type): 64 | ids = IOUtils.load(split_dir / f"{s1}-{s2}-{Macros.test_common}.json", IOUtils.Format.json) 65 | all_indexes += ids 66 | IOUtils.dump( 67 | self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json", 68 | ids, 69 | IOUtils.Format.json, 70 | ) 71 | all_indexes = list(sorted(set(all_indexes))) 72 | 73 | # Load raw data 74 | tbar = tqdm() 75 | dataset: List[MethodData] = MethodData.load_dataset( 76 | self.shared_data_dir, 77 | expected_ids=all_indexes, 78 | only=["code", "comment_summary", "misc"], 79 | tbar=tbar, 80 | ) 81 | tbar.close() 82 | 83 | # Subtokenize code and comments 84 | tbar = tqdm() 85 | 86 | tbar.set_description("Subtokenizing") 87 | tbar.reset(len(dataset)) 88 | orig_code_list = [] 89 | tokenized_comment_list = [] 90 | for d in dataset: 91 | d.fill_none() 92 | d.misc["orig_code"] = d.code 93 | d.misc["orig_comment_summary"] = d.comment_summary 94 | orig_code_list.append(d.code) 95 | tokenized_comment_list.append(nltk.word_tokenize(d.comment_summary)) 96 | tbar.update(1) 97 | 98 | tokenized_code_list = ModelUtils.tokenize_javaparser_batch(orig_code_list, dup_share=False, tbar=tbar) 99 | 100 | tbar.set_description("Subtokenizing") 101 | tbar.reset(len(dataset)) 102 | for d, tokenized_code, tokenized_comment in zip(dataset, tokenized_code_list, tokenized_comment_list): 103 | d.code, d.misc["code_src_ids"] = ModelUtils.subtokenize_batch(tokenized_code) 104 | d.comment_summary, d.misc["comment_summary_src_ids"] = ModelUtils.subtokenize_batch(tokenized_comment) 105 | 106 | # convert comment to lower case 107 | d.comment_summary = [t.lower() for t in d.comment_summary] 108 | tbar.update(1) 109 | tbar.close() 110 | 111 | # Clean eval ids 112 | indexed_dataset = {d.id: d for d in dataset} 113 | for sn in [Macros.val, Macros.test_standard] + [f"{Macros.test_common}-{x}-{y}" for x, y in Macros.get_pairwise_split_types_with(self.split_type)]: 114 | eval_ids = IOUtils.load(self.data_dir / f"split_{sn}.json", IOUtils.Format.json) 115 | IOUtils.dump(self.data_dir / f"split_{sn}.json", self.clean_eval_set(indexed_dataset, eval_ids), IOUtils.Format.json) 116 | 117 | # Save dataset 118 | MethodData.save_dataset(dataset, self.data_dir) 119 | 120 | def clean_eval_set(self, indexed_dataset: Dict[int, MethodData], eval_ids: List[int]) -> List[int]: 121 | """ 122 | Keeps the eval set absolutely clean by: 123 | - Remove duplicate (comment, code) pairs; 124 | - Remove pseudo-empty comment; 125 | indexed_dataset should already been subtokenized. 126 | """ 127 | seen_data = set() 128 | clean_eval_ids = [] 129 | for i in eval_ids: 130 | data = indexed_dataset[i] 131 | 132 | # Remove comment like "." 133 | if len(data.comment_summary) == 1 and data.comment_summary[0] == ".": 134 | continue 135 | 136 | # Remove duplicate (comment, code) pairs 137 | data_key = (tuple(data.code), tuple(data.comment_summary)) 138 | if data_key in seen_data: 139 | continue 140 | else: 141 | seen_data.add(data_key) 142 | 143 | clean_eval_ids.append(i) 144 | return clean_eval_ids 145 | 146 | def train(self, exp_name: str, model_name: str, cont_train: bool, no_save: bool, **options) -> None: 147 | # Init or load model 148 | exp_dir = self.get_exp_dir(exp_name) 149 | train_config = TrainConfig.get_train_config_from_cmd_options(options) 150 | model = CGModelLoader.init_or_load_model(model_name, exp_dir, cont_train, no_save, options) 151 | if not no_save: 152 | IOUtils.dump(exp_dir / "train_config.jsonl", [IOUtils.jsonfy(train_config)], IOUtils.Format.jsonList, append=True) 153 | 154 | # Load data 155 | tbar = tqdm(desc="Loading data") 156 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar) 157 | indexed_dataset = {d.id: d for d in dataset} 158 | 159 | tbar.set_description("Loading data | take indexes") 160 | tbar.reset(2) 161 | 162 | train_ids = IOUtils.load(self.data_dir / f"split_{Macros.train}.json", IOUtils.Format.json) 163 | train_dataset = [indexed_dataset[i] for i in train_ids] 164 | tbar.update(1) 165 | 166 | val_ids = IOUtils.load(self.data_dir / f"split_{Macros.val}.json", IOUtils.Format.json) 167 | val_dataset = [indexed_dataset[i] for i in val_ids] 168 | tbar.update(1) 169 | 170 | tbar.close() 171 | 172 | # Train model 173 | start = time.time() 174 | model.train(train_dataset, val_dataset, resources_path=self.data_dir, train_config=train_config) 175 | end = time.time() 176 | 177 | if not no_save: 178 | model.save() 179 | IOUtils.dump(exp_dir / "train_time.json", end - start, IOUtils.Format.json) 180 | 181 | def eval_one(self, exp_name: str, eval_ids: List[int], prefix: str, indexed_dataset: Dict[int, MethodData], model: CGModelBase, gpu_id: int = 0): 182 | # Prepare output directory 183 | result_dir = self.get_result_dir(exp_name) 184 | result_dir.mkdir(parents=True, exist_ok=True) 185 | 186 | # Prepare eval data (remove target) 187 | eval_dataset = [indexed_dataset[i] for i in eval_ids] 188 | golds = [] 189 | for d in eval_dataset: 190 | golds.append(d.comment_summary) 191 | d.comment_summary = ["dummy"] 192 | d.misc["orig_comment_summary"] = "dummy" 193 | d.misc["comment_summary_src_ids"] = [0] 194 | 195 | # Perform batched queries 196 | tbar = tqdm(desc=f"Predicting | {prefix}") 197 | eval_start = time.time() 198 | predictions = model.batch_predict(eval_dataset, tbar=tbar, gpu_id=gpu_id) 199 | eval_end = time.time() 200 | tbar.close() 201 | 202 | eval_time = eval_end - eval_start 203 | 204 | # Save predictions & golds 205 | IOUtils.dump(result_dir / f"{prefix}_predictions.jsonl", predictions, IOUtils.Format.jsonList) 206 | IOUtils.dump(result_dir / f"{prefix}_golds.jsonl", golds, IOUtils.Format.jsonList) 207 | IOUtils.dump(result_dir / f"{prefix}_eval_time.json", eval_time, IOUtils.Format.json) 208 | 209 | def eval(self, exp_name: str, action: str = None, gpu_id: int = 0) -> None: 210 | if action is None: 211 | action = self.DEFAULT_EVAL_ACTION 212 | if action not in self.EVAL_ACTIONS: 213 | raise RuntimeError(f"Unknown eval action {action}") 214 | 215 | # Load eval data 216 | tbar = tqdm(desc="Loading data") 217 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar) 218 | indexed_dataset = {d.id: d for d in dataset} 219 | tbar.close() 220 | 221 | # Load model 222 | exp_dir = self.get_exp_dir(exp_name) 223 | model: CGModelBase = CGModelLoader.load_model(exp_dir) 224 | if not model.is_train_finished(): 225 | logger.warning(f"Model not finished training, at {exp_dir}") 226 | 227 | # Invoke eval_one with specific data ids 228 | if action in [self.EVAL_VAL, self.EVAL_TESTS]: 229 | self.eval_one( 230 | exp_name, 231 | IOUtils.load(self.data_dir / f"split_{action}.json", IOUtils.Format.json), 232 | action, 233 | indexed_dataset, 234 | model, 235 | gpu_id=gpu_id, 236 | ) 237 | elif action == self.EVAL_TESTC: 238 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type): 239 | self.eval_one( 240 | exp_name, 241 | IOUtils.load(self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json", IOUtils.Format.json), 242 | f"{Macros.test_common}-{s1}-{s2}", 243 | copy.deepcopy(indexed_dataset), 244 | model, 245 | gpu_id=gpu_id, 246 | ) 247 | else: 248 | raise RuntimeError(f"Unknown action {action}") 249 | 250 | def compute_metrics_one(self, exp_name: str, prefix: str): 251 | # Prepare output directory 252 | metric_dir = self.get_metric_dir(exp_name) 253 | metric_dir.mkdir(parents=True, exist_ok=True) 254 | 255 | # Load golds and predictions 256 | result_dir = self.get_result_dir(exp_name) 257 | Utils.expect_dir_or_suggest_dvc_pull(result_dir) 258 | golds = IOUtils.load(result_dir / f"{prefix}_golds.jsonl", IOUtils.Format.jsonList) 259 | predictions = IOUtils.load(result_dir / f"{prefix}_predictions.jsonl", IOUtils.Format.jsonList) 260 | 261 | metrics_list: Dict[str, List] = collections.defaultdict(list) 262 | metrics_list["exact_match"] = EvalMetrics.batch_exact_match(golds, predictions) 263 | metrics_list["token_acc"] = EvalMetrics.batch_token_acc(golds, predictions) 264 | metrics_list["bleu"] = EvalMetrics.batch_bleu(golds, predictions) 265 | rouge_l_res = EvalMetrics.batch_rouge_l(golds, predictions) 266 | metrics_list["rouge_l_f"] = [x["f"] for x in rouge_l_res] 267 | metrics_list["rouge_l_p"] = [x["p"] for x in rouge_l_res] 268 | metrics_list["rouge_l_r"] = [x["r"] for x in rouge_l_res] 269 | metrics_list["meteor"] = EvalMetrics.batch_meteor(golds, predictions) 270 | set_match_res = EvalMetrics.batch_set_match(golds, predictions) 271 | metrics_list["set_match_f"] = [x["f"] for x in set_match_res] 272 | metrics_list["set_match_p"] = [x["p"] for x in set_match_res] 273 | metrics_list["set_match_r"] = [x["r"] for x in set_match_res] 274 | 275 | # Take average 276 | metrics = {} 277 | for k, l in metrics_list.items(): 278 | metrics[k] = np.mean(l).item() 279 | 280 | # Save metrics 281 | IOUtils.dump(metric_dir / f"{prefix}_metrics.json", metrics, IOUtils.Format.jsonNoSort) 282 | IOUtils.dump(metric_dir / f"{prefix}_metrics.txt", [f"{k}: {v}" for k, v in metrics.items()], IOUtils.Format.txtList) 283 | IOUtils.dump(metric_dir / f"{prefix}_metrics_list.pkl", metrics_list, IOUtils.Format.pkl) 284 | 285 | def compute_metrics(self, exp_name: str, action: str = None) -> None: 286 | if action is None: 287 | action = self.DEFAULT_EVAL_ACTION 288 | if action not in self.EVAL_ACTIONS: 289 | raise RuntimeError(f"Unknown eval action {action}") 290 | 291 | if action in [self.EVAL_VAL, self.EVAL_TESTS]: 292 | self.compute_metrics_one( 293 | exp_name, 294 | action, 295 | ) 296 | elif action == self.EVAL_TESTC: 297 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type): 298 | self.compute_metrics_one( 299 | exp_name, 300 | f"{Macros.test_common}-{s1}-{s2}", 301 | ) 302 | else: 303 | raise RuntimeError(f"Unknown action {action}") 304 | -------------------------------------------------------------------------------- /python/tseval/comgen/eval/__init__.py: -------------------------------------------------------------------------------- 1 | def get_setup_cls(name: str) -> type: 2 | if name == "StandardSetup": 3 | from tseval.comgen.eval.StandardSetup import StandardSetup 4 | return StandardSetup 5 | else: 6 | raise ValueError(f"No setup with name {name}") 7 | -------------------------------------------------------------------------------- /python/tseval/comgen/model/CGModelBase.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from pathlib import Path 3 | from typing import List, Optional, Tuple 4 | 5 | from seutil import IOUtils 6 | from tqdm import tqdm 7 | 8 | from tseval.data.MethodData import MethodData 9 | from tseval.util.TrainConfig import TrainConfig 10 | 11 | 12 | class CGModelBase: 13 | 14 | def __init__(self, model_work_dir: Path, no_save: bool = False): 15 | self.model_work_dir = model_work_dir 16 | self.no_save = no_save 17 | 18 | @abc.abstractmethod 19 | def train( 20 | self, 21 | train_dataset: List[MethodData], 22 | val_dataset: List[MethodData], 23 | resources_path: Optional[Path] = None, 24 | train_config: Optional[TrainConfig] = None, 25 | ): 26 | """ 27 | Trains the model. 28 | 29 | :param train_dataset: training set. 30 | :param val_dataset: validation set. 31 | :param resources_path: path to resources that could be shared by multiple model's training process, 32 | e.g., pre-trained embeddings. 33 | """ 34 | raise NotImplementedError 35 | 36 | @abc.abstractmethod 37 | def is_train_finished(self) -> bool: 38 | raise NotImplementedError 39 | 40 | @abc.abstractmethod 41 | def predict( 42 | self, 43 | data: MethodData, 44 | gpu_id: int = 0, 45 | ) -> List[str]: 46 | """ 47 | Predicts the comment summary given the context in data. The model should output 48 | results with a confidence score in [0, 1]. 49 | :param data: the data, with its statements partially filled. 50 | :return: a list of predicted comment summary tokens. 51 | """ 52 | raise NotImplementedError 53 | 54 | def batch_predict( 55 | self, 56 | dataset: List[MethodData], 57 | tbar: Optional[tqdm] = None, 58 | gpu_id: int = 0, 59 | ) -> List[List[str]]: 60 | """ 61 | Performs batched predictions using given dataset as inputs. 62 | 63 | The default implementation invokes #predict multiple times. Subclass can override 64 | this method to speed up the prediction by using batching. 65 | 66 | :param dataset: a list of inputs. 67 | :param tbar: an optional tqdm progress bar to show prediction progress. 68 | :return: a list of the return value of #predict. 69 | """ 70 | if tbar is not None: 71 | tbar.reset(len(dataset)) 72 | 73 | results = [] 74 | for data in dataset: 75 | results.append(self.predict(data, gpu_id=gpu_id)) 76 | if tbar is not None: 77 | tbar.update(1) 78 | 79 | return results 80 | 81 | def save(self) -> None: 82 | """ 83 | Saves the current model at the work_dir. 84 | Default behavior is to serialize the entire object in model.pkl. 85 | """ 86 | if not self.no_save: 87 | IOUtils.dump(self.model_work_dir / "model.pkl", self, IOUtils.Format.pkl) 88 | 89 | @classmethod 90 | def load(cls, work_dir) -> "CGModelBase": 91 | """ 92 | Loads a model from the work_dir. 93 | Default behavior is to deserialize the object from model.pkl, with resetting its work_dir. 94 | """ 95 | obj = IOUtils.load(work_dir / "model.pkl", IOUtils.Format.pkl) 96 | obj.model_work_dir = work_dir 97 | return obj 98 | -------------------------------------------------------------------------------- /python/tseval/comgen/model/TransformerACL20.py: -------------------------------------------------------------------------------- 1 | import stat 2 | import tempfile 3 | from pathlib import Path 4 | from subprocess import TimeoutExpired 5 | from typing import List, Optional 6 | 7 | from recordclass import RecordClass 8 | from seutil import BashUtils, IOUtils, LoggingUtils 9 | from tqdm import tqdm 10 | 11 | from tseval.comgen.model.CGModelBase import CGModelBase 12 | from tseval.data.MethodData import MethodData 13 | from tseval.Environment import Environment 14 | from tseval.Macros import Macros 15 | from tseval.util.ModelUtils import ModelUtils 16 | from tseval.util.TrainConfig import TrainConfig 17 | from tseval.Utils import Utils 18 | 19 | logger = LoggingUtils.get_logger(__name__) 20 | 21 | 22 | class TransformerACL20Config(RecordClass): 23 | max_src_len: int = 150 24 | max_tgt_len: int = 50 25 | use_rnn: bool = False 26 | seed: int = None 27 | 28 | 29 | class TransformerACL20(CGModelBase): 30 | ENV_NAME = "tseval-CG-TransformerACL20" 31 | SRC_DIR = Macros.work_dir / "src" / "CG-TransformerACL20" 32 | 33 | @classmethod 34 | def prepare_env(cls): 35 | Utils.expect_dir_or_suggest_dvc_pull(cls.SRC_DIR) 36 | s = "#!/bin/bash\n" 37 | s += "set -e\n" 38 | s += f"source {Environment.conda_init_path}\n" 39 | s += f"conda env remove --name {cls.ENV_NAME}\n" 40 | s += f"conda create --name {cls.ENV_NAME} python=3.7 pip -y\n" 41 | s += f"conda activate {cls.ENV_NAME}\n" 42 | s += "set -x\n" 43 | s += f"cd {cls.SRC_DIR}\n" 44 | # Pytorch 1.4.0 -> max cuda version 10.1 45 | cuda_toolkit_spec = Environment.get_cuda_toolkit_spec() 46 | if cuda_toolkit_spec in ["cudatoolkit=11.0", "cudatoolkit=10.2"]: 47 | cuda_toolkit_spec = "cudatoolkit=10.1" 48 | s += f"conda install -y pytorch==1.4.0 torchvision==0.5.0 {cuda_toolkit_spec} -c pytorch\n" 49 | s += f"pip install -r requirements.txt\n" 50 | t = Path(tempfile.mktemp(prefix="tseval")) 51 | IOUtils.dump(t, s, IOUtils.Format.txt) 52 | t.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) 53 | print(f"Preparing env for {__name__}...") 54 | rr = BashUtils.run(t) 55 | if rr.return_code == 0: 56 | print("Success!") 57 | else: 58 | print("Failed!") 59 | print(f"STDOUT:\n{rr.stdout}") 60 | print(f"STDERR:\n{rr.stderr}") 61 | print(f"^^^ Preparing env for {__name__} failed!") 62 | 63 | def __init__( 64 | self, 65 | model_work_dir: Path, 66 | no_save: bool = False, 67 | max_src_len: int = 150, 68 | max_tgt_len: int = 50, 69 | use_rnn: bool = False, 70 | seed: int = ModelUtils.get_random_seed(), 71 | ): 72 | super().__init__(model_work_dir, no_save) 73 | if not self.SRC_DIR.is_dir(): 74 | raise RuntimeError(f"Environment missing (expected at {self.SRC_DIR})") 75 | 76 | self.config = TransformerACL20Config( 77 | max_src_len=max_src_len, 78 | max_tgt_len=max_tgt_len, 79 | use_rnn=use_rnn, 80 | seed=seed, 81 | ) 82 | self.train_finished = False 83 | 84 | def prepare_data(self, dataset: List[MethodData], data_dir: Path, sn: str): 85 | sn_dir = data_dir / "java" / sn 86 | IOUtils.rm_dir(sn_dir) 87 | sn_dir.mkdir(parents=True) 88 | 89 | # subtokenized code 90 | IOUtils.dump( 91 | sn_dir / "code.original_subtoken", 92 | [" ".join(d.code[:self.config.max_src_len]) for d in dataset], 93 | IOUtils.Format.txtList, 94 | ) 95 | 96 | # subtokenized comment 97 | IOUtils.dump( 98 | sn_dir / "javadoc.original", 99 | [" ".join(d.comment_summary[:self.config.max_tgt_len]) for d in dataset], 100 | IOUtils.Format.txtList, 101 | ) 102 | 103 | # tokenized code 104 | with open(sn_dir / "code.original", "w") as f: 105 | for d in dataset: 106 | tokens = ModelUtils.regroup_subtokens(d.code, d.misc["code_src_ids"]) 107 | f.write(" ".join(tokens) + "\n") 108 | 109 | def train( 110 | self, 111 | train_dataset: List[MethodData], 112 | val_dataset: List[MethodData], 113 | resources_path: Optional[Path] = None, 114 | train_config: TrainConfig = None, 115 | ): 116 | if train_config is None: 117 | train_config = TrainConfig() 118 | 119 | # Prepare data 120 | data_dir = self.model_work_dir / "data" 121 | if not data_dir.is_dir(): 122 | data_dir.mkdir(parents=True) 123 | self.prepare_data(train_dataset, data_dir, "train") 124 | self.prepare_data(val_dataset, data_dir, "dev") 125 | 126 | # Prepare script 127 | model_dir = self.model_work_dir / "model" 128 | s = "#!/bin/bash\n" 129 | s += "set -e\n" 130 | s += f"source {Environment.conda_init_path}\n" 131 | s += f"conda activate {self.ENV_NAME}\n" 132 | s += "set -x\n" 133 | s += f"cd {self.SRC_DIR}\n" 134 | s += f"""MKL_THREADING_LAYER=GNU CUDA_VISIBLE_DEVICES={train_config.gpu_id} timeout {train_config.train_session_time} python -W ignore '{self.SRC_DIR}/main/train.py' \\ 135 | --data_workers 5 \\ 136 | --dataset_name java \\ 137 | --data_dir '{data_dir}/' \\ 138 | --model_dir '{model_dir}' \\ 139 | --model_name model \\ 140 | --train_src train/code.original_subtoken \\ 141 | --train_tgt train/javadoc.original \\ 142 | --dev_src dev/code.original_subtoken \\ 143 | --dev_tgt dev/javadoc.original \\ 144 | --uncase True \\ 145 | --use_src_word True \\ 146 | --use_src_char False \\ 147 | --use_tgt_word True \\ 148 | --use_tgt_char False \\ 149 | --max_src_len {self.config.max_src_len} \\ 150 | --max_tgt_len {self.config.max_tgt_len} \\ 151 | --emsize 512 \\ 152 | --fix_embeddings False \\ 153 | --src_vocab_size 50000 \\ 154 | --tgt_vocab_size 30000 \\ 155 | --share_decoder_embeddings True \\ 156 | --max_examples -1 \\ 157 | --batch_size 32 \\ 158 | --test_batch_size 64 \\ 159 | --num_epochs 200 \\ 160 | --dropout_emb 0.2 \\ 161 | --dropout 0.2 \\ 162 | --copy_attn True \\ 163 | --early_stop 20 \\ 164 | --optimizer adam \\ 165 | --lr_decay 0.99 \\ 166 | --valid_metric bleu \\ 167 | --checkpoint True \\ 168 | --random_seed {self.config.seed} \\ 169 | """ 170 | 171 | if not self.config.use_rnn: 172 | s += """--model_type transformer \\ 173 | --num_head 8 \\ 174 | --d_k 64 \\ 175 | --d_v 64 \\ 176 | --d_ff 2048 \\ 177 | --src_pos_emb False \\ 178 | --tgt_pos_emb True \\ 179 | --max_relative_pos 32 \\ 180 | --use_neg_dist True \\ 181 | --nlayers 6 \\ 182 | --trans_drop 0.2 \\ 183 | --warmup_steps 2000 \\ 184 | --learning_rate 0.0001 185 | """ 186 | else: 187 | s += """--model_type rnn \\ 188 | --conditional_decoding False \\ 189 | --nhid 512 \\ 190 | --nlayers 2 \\ 191 | --use_all_enc_layers False \\ 192 | --dropout_rnn 0.2 \\ 193 | --reuse_copy_attn True \\ 194 | --learning_rate 0.002 \\ 195 | --grad_clipping 5.0 196 | """ 197 | 198 | script_path = self.model_work_dir / "train.sh" 199 | stdout_path = self.model_work_dir / "train.stdout" 200 | stderr_path = self.model_work_dir / "train.stderr" 201 | IOUtils.dump(script_path, s, IOUtils.Format.txt) 202 | script_path.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) 203 | 204 | with IOUtils.cd(self.model_work_dir): 205 | try: 206 | logger.info(f"=====Starting train\nScript: {script_path}\nSTDOUT: {stdout_path}\nSTDERR: {stderr_path}\n=====") 207 | rr = BashUtils.run(f"{script_path} 1>{stdout_path} 2>{stderr_path}", timeout=train_config.train_session_time) 208 | if rr.return_code != 0: 209 | raise RuntimeError(f"Train returned {rr.return_code}; check STDERR at {stderr_path}") 210 | except (TimeoutExpired, KeyboardInterrupt): 211 | logger.warning("Training not finished") 212 | self.train_finished = False 213 | return 214 | except: 215 | logger.warning(f"Error during training!") 216 | raise 217 | 218 | # If we can reach here, training should be finished 219 | self.train_finished = True 220 | # Remove the big checkpoint file 221 | IOUtils.rm(model_dir / "model.mdl.checkpoint") 222 | return 223 | 224 | def is_train_finished(self) -> bool: 225 | return self.train_finished 226 | 227 | def predict(self, data: MethodData, gpu_id: int = 0) -> List[str]: 228 | return self.batch_predict([data])[0] 229 | 230 | def batch_predict( 231 | self, 232 | dataset: List[MethodData], 233 | tbar: Optional[tqdm] = None, 234 | gpu_id: int = 0, 235 | ) -> List[List[str]]: 236 | # Prepare data 237 | data_dir = Path(tempfile.mkdtemp(prefix="tseval")) 238 | 239 | # Use the dummy comment_summary field to carry id information, so that we know what ids are deleted 240 | for i, d in enumerate(dataset): 241 | d.comment_summary = [str(i)] 242 | self.prepare_data(dataset, data_dir, "test") 243 | 244 | # Prepare script 245 | model_dir = self.model_work_dir / "model" 246 | s = "#!/bin/bash\n" 247 | s += "set -e\n" 248 | s += f"source {Environment.conda_init_path}\n" 249 | s += f"conda activate {self.ENV_NAME}\n" 250 | s += "set -x\n" 251 | s += f"cd {self.SRC_DIR}\n" 252 | # Reducing test_batch_size to 4, otherwise it will delete some test data due to some bug 253 | s += f"""MKL_THREADING_LAYER=GNU CUDA_VISIBLE_DEVICES={gpu_id} python -W ignore '{self.SRC_DIR}/main/test.py' \\ 254 | --data_workers 5 \\ 255 | --dataset_name java \\ 256 | --data_dir '{data_dir}/' \\ 257 | --model_dir '{model_dir}' \\ 258 | --model_name model \\ 259 | --dev_src test/code.original_subtoken \\ 260 | --dev_tgt test/javadoc.original \\ 261 | --uncase True \\ 262 | --max_examples -1 \\ 263 | --max_src_len {self.config.max_src_len} \\ 264 | --max_tgt_len {self.config.max_tgt_len} \\ 265 | --test_batch_size 4 \\ 266 | --beam_size 4 \\ 267 | --n_best 1 \\ 268 | --block_ngram_repeat 3 \\ 269 | --stepwise_penalty False \\ 270 | --coverage_penalty none \\ 271 | --length_penalty none \\ 272 | --beta 0 \\ 273 | --gamma 0 \\ 274 | --replace_unk 275 | """ 276 | 277 | script_path = Path(tempfile.mktemp(prefix="tseval.test.sh-")) 278 | stdout_path = Path(tempfile.mktemp(prefix="tseval.test.stdout-")) 279 | stderr_path = Path(tempfile.mktemp(prefix="tseval.test.stderr-")) 280 | IOUtils.dump(script_path, s, IOUtils.Format.txt) 281 | script_path.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR) 282 | 283 | # Run predictions 284 | with IOUtils.cd(self.model_work_dir): 285 | logger.info(f"=====Starting eval\nScript: {script_path}\nSTDOUT: {stdout_path}\nSTDERR: {stderr_path}\n=====") 286 | rr = BashUtils.run(f"{script_path} 1>{stdout_path} 2>{stderr_path}") 287 | if rr.return_code != 0: 288 | raise RuntimeError(f"Eval returned {rr.return_code}; check STDERR at {stderr_path}") 289 | 290 | # Load predictions 291 | beam_res = IOUtils.load(model_dir / "model_beam.json", IOUtils.Format.jsonList) 292 | predictions = [[] for _ in range(len(dataset))] 293 | for x in beam_res: 294 | predictions[int(x["references"][0])] = x["predictions"][0].split(" ") 295 | 296 | # Delete temp files 297 | IOUtils.rm_dir(data_dir) 298 | IOUtils.rm(script_path) 299 | IOUtils.rm(stdout_path) 300 | IOUtils.rm(stderr_path) 301 | 302 | return predictions 303 | 304 | def save(self) -> None: 305 | # Save config and training status 306 | IOUtils.dump(self.model_work_dir / "config.json", IOUtils.jsonfy(self.config), IOUtils.Format.jsonPretty) 307 | IOUtils.dump(self.model_work_dir / "train_finished.json", self.train_finished) 308 | 309 | # Model should already be saved/checkpointed to the correct path 310 | return 311 | 312 | @classmethod 313 | def load(cls, model_work_dir) -> "CGModelBase": 314 | obj = TransformerACL20(model_work_dir) 315 | obj.config = IOUtils.dejsonfy(IOUtils.load(model_work_dir / "config.json"), TransformerACL20Config) 316 | obj.train_finished = IOUtils.load(model_work_dir / "train_finished.json") 317 | return obj 318 | -------------------------------------------------------------------------------- /python/tseval/comgen/model/__init__.py: -------------------------------------------------------------------------------- 1 | def get_model_cls(name: str) -> type: 2 | if name == "TransformerACL20": 3 | from tseval.comgen.model.TransformerACL20 import TransformerACL20 4 | return TransformerACL20 5 | elif name == "DeepComHybridESE19": 6 | from tseval.comgen.model.DeepComHybridESE19 import DeepComHybridESE19 7 | return DeepComHybridESE19 8 | else: 9 | raise ValueError(f"No model with name {name}") 10 | -------------------------------------------------------------------------------- /python/tseval/data/MethodData.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import * 4 | 5 | from recordclass import RecordClass 6 | from seutil import IOUtils 7 | from tqdm import tqdm 8 | 9 | 10 | class MethodData(RecordClass): 11 | id: int = -1 12 | 13 | # Project name 14 | prj: str = None 15 | # Years that has this data 16 | years: List[int] = None 17 | 18 | # Method name 19 | name: str = None 20 | # Code (subtokenized version after preprocessing) 21 | code: Union[str, List[str]] = None 22 | # Code with masking its name 23 | code_masked: str = None 24 | # Comment (full, including tags) 25 | comment: str = None 26 | # Comment (summary first sentence) (subtokenized version after preprocessing) 27 | comment_summary: Union[str, List[str]] = None 28 | # Class name 29 | cname: str = None 30 | # Qualified class name 31 | qcname: str = None 32 | # File relative path 33 | path: str = None 34 | # Return type 35 | ret: str = None 36 | # Parameter types 37 | params: List[Tuple[str, str]] = None 38 | 39 | misc: dict = None 40 | 41 | def init(self): 42 | self.years = [] 43 | self.params = [] 44 | self.misc = {} 45 | return 46 | 47 | def fill_none(self): 48 | if self.years is None: 49 | self.years = [] 50 | if self.params is None: 51 | self.params = [] 52 | if self.misc is None: 53 | self.misc = {} 54 | 55 | @classmethod 56 | def save_dataset( 57 | cls, 58 | dataset: List["MethodData"], 59 | save_dir: Path, 60 | exist_ok: bool = True, 61 | append: bool = False, 62 | only: Optional[Iterable[str]] = None, 63 | ): 64 | """ 65 | Saves dataset to save_dir. Different fields are saved in different files in the 66 | directory. Call graphs are shared for data from one project. 67 | :param dataset: the list of data to save. 68 | :param save_dir: the path to save. 69 | :param exist_ok: if False, requires that save_dir doesn't exist; otherwise, 70 | existing files in save_dir will be modified. 71 | :param append: if True, append to current saved data (requires exist_ok=True); 72 | otherwise, wipes out existing data at save_dir. 73 | :param only: only save certain fields; the files corresponding to the other fields 74 | are not touched; id are always saved. 75 | """ 76 | save_dir.mkdir(parents=True, exist_ok=exist_ok) 77 | 78 | IOUtils.dump(save_dir / "id.jsonl", [d.id for d in dataset], IOUtils.Format.jsonList, append=append) 79 | 80 | if only is None or "prj" in only: 81 | IOUtils.dump(save_dir / "prj.jsonl", [d.prj for d in dataset], IOUtils.Format.jsonList, append=append) 82 | 83 | if only is None or "years" in only: 84 | IOUtils.dump(save_dir / "years.jsonl", [d.years for d in dataset], IOUtils.Format.jsonList, append=append) 85 | 86 | if only is None or "name" in only: 87 | IOUtils.dump(save_dir / "name.jsonl", [d.name for d in dataset], IOUtils.Format.jsonList, append=append) 88 | 89 | if only is None or "code" in only: 90 | IOUtils.dump(save_dir / "code.jsonl", [d.code for d in dataset], IOUtils.Format.jsonList, append=append) 91 | 92 | if only is None or "code_masked" in only: 93 | IOUtils.dump(save_dir / "code_masked.jsonl", [d.code_masked for d in dataset], IOUtils.Format.jsonList, append=append) 94 | 95 | if only is None or "comment" in only: 96 | IOUtils.dump(save_dir / "comment.jsonl", [d.comment for d in dataset], IOUtils.Format.jsonList, append=append) 97 | 98 | if only is None or "comment_summary" in only: 99 | IOUtils.dump(save_dir / "comment_summary.jsonl", [d.comment_summary for d in dataset], IOUtils.Format.jsonList, append=append) 100 | 101 | if only is None or "cname" in only: 102 | IOUtils.dump(save_dir / "cname.jsonl", [d.cname for d in dataset], IOUtils.Format.jsonList, append=append) 103 | 104 | if only is None or "qcname" in only: 105 | IOUtils.dump(save_dir / "qcname.jsonl", [d.qcname for d in dataset], IOUtils.Format.jsonList, append=append) 106 | 107 | if only is None or "path" in only: 108 | IOUtils.dump(save_dir / "path.jsonl", [d.path for d in dataset], IOUtils.Format.jsonList, append=append) 109 | 110 | if only is None or "ret" in only: 111 | IOUtils.dump(save_dir / "ret.jsonl", [d.ret for d in dataset], IOUtils.Format.jsonList, append=append) 112 | 113 | if only is None or "params" in only: 114 | IOUtils.dump(save_dir / "params.jsonl", [d.params for d in dataset], IOUtils.Format.jsonList, append=append) 115 | 116 | if only is None or "misc" in only: 117 | IOUtils.dump(save_dir / "misc.jsonl", [d.misc for d in dataset], IOUtils.Format.jsonList, append=append) 118 | 119 | @classmethod 120 | def iter_load_dataset( 121 | cls, 122 | save_dir: Path, 123 | only: Optional[Iterable[str]] = None, 124 | ) -> Generator["MethodData", None, None]: 125 | """ 126 | Iteratively loads dataset from the save directory. 127 | :param save_dir: the directory to load data from. 128 | :param only: only load certain fields; the other fields are not filled in the 129 | loaded data; id is always loaded. 130 | :return: a generator iteratively loading the dataset. 131 | """ 132 | if not save_dir.is_dir(): 133 | raise FileNotFoundError(f"Not found saved data at {save_dir}") 134 | 135 | # First, load all ids 136 | ids = IOUtils.load(save_dir / "id.jsonl", IOUtils.Format.jsonList) 137 | 138 | # The types of some line-by-line loaded fields 139 | f2type = {} 140 | f2file = {} 141 | for f in ["prj", "years", "name", "code", "code_masked", "comment", "comment_summary", "cname", "qcname", "path", "ret", "params", "misc"]: 142 | if only is None or f in only: 143 | f2file[f] = open(save_dir / f"{f}.jsonl", "r") 144 | 145 | for i in ids: 146 | d = MethodData(id=i) 147 | 148 | # Load line-by-line fields 149 | for f in f2file.keys(): 150 | o = json.loads(f2file[f].readline()) 151 | if f in f2type: 152 | o = IOUtils.dejsonfy(o, f2type[f]) 153 | setattr(d, f, o) 154 | 155 | yield d 156 | 157 | # Close all files 158 | for file in f2file.values(): 159 | file.close() 160 | 161 | @classmethod 162 | def load_dataset( 163 | cls, 164 | save_dir: Path, 165 | only: Optional[List[str]] = None, 166 | expected_ids: List[int] = None, 167 | tbar: Optional[tqdm] = None, 168 | ) -> List["MethodData"]: 169 | """ 170 | Loads the dataset from save_dir. 171 | 172 | :param expected_ids: if provided, the list of data ids to load; the returned dataset 173 | will only contain these data. 174 | :param tbar: an optional progress bar. 175 | Other parameters are the same as #iter_load_dataset. 176 | """ 177 | dataset = [] 178 | 179 | # Load all data by default 180 | if expected_ids is None: 181 | expected_ids = IOUtils.load(save_dir / "id.jsonl", IOUtils.Format.jsonList) 182 | 183 | # Convert to set to speed up checking "has" relation 184 | expected_ids = set(expected_ids) 185 | 186 | if tbar is not None: 187 | tbar.set_description("Loading dataset") 188 | tbar.reset(len(expected_ids)) 189 | 190 | for d in cls.iter_load_dataset(save_dir=save_dir, only=only): 191 | if d.id in expected_ids: 192 | dataset.append(d) 193 | if tbar is not None: 194 | tbar.update(1) 195 | 196 | # Early stop loading if all data have been loaded 197 | if len(dataset) == len(expected_ids): 198 | break 199 | 200 | return dataset 201 | -------------------------------------------------------------------------------- /python/tseval/data/RevisionIds.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from recordclass import RecordClass 4 | 5 | 6 | class RevisionIds(RecordClass): 7 | 8 | revision: str = None 9 | method_ids: List[int] = None 10 | 11 | def init(self): 12 | self.method_ids = [] 13 | -------------------------------------------------------------------------------- /python/tseval/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/data/__init__.py -------------------------------------------------------------------------------- /python/tseval/eval/EvalHelper.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import math 4 | import random 5 | from typing import Dict, List, Tuple 6 | 7 | from seutil import IOUtils 8 | 9 | from tseval.data.MethodData import MethodData 10 | from tseval.Macros import Macros 11 | from tseval.Utils import Utils 12 | 13 | 14 | class EvalHelper: 15 | 16 | def get_splits( 17 | self, 18 | split_name: str, 19 | seed: int, 20 | prj_val_ratio: float = 0.1, 21 | prj_test_ratio: float = 0.2, 22 | inprj_val_ratio: float = 0.1, 23 | inprj_test_ratio: float = 0.2, 24 | train_year: int = 2019, 25 | val_year: int = 2020, 26 | test_year: int = 2021, 27 | debug: bool = False, 28 | ): 29 | """ 30 | Gets {mixed-project, cross-project, temporally} splits for the given seed and configurations. 31 | 32 | :param debug: take maximum of 500/100/100 data in the result train/val/test sets. 33 | """ 34 | split_dir = Macros.work_dir / "split" / split_name 35 | IOUtils.rm_dir(split_dir) 36 | split_dir.mkdir(parents=True) 37 | 38 | # Save configs 39 | IOUtils.dump( 40 | split_dir / "config.json", 41 | { 42 | "seed": seed, 43 | "prj_val_ratio": prj_val_ratio, 44 | "prj_test_ratio": prj_test_ratio, 45 | "inprj_val_ratio": inprj_val_ratio, 46 | "inprj_test_ratio": inprj_test_ratio, 47 | "train_year": train_year, 48 | "val_year": val_year, 49 | "test_year": test_year, 50 | "debug": debug, 51 | }, 52 | IOUtils.Format.jsonNoSort, 53 | ) 54 | stats = {} 55 | 56 | # Load shared data 57 | dataset: List[MethodData] = MethodData.load_dataset(Macros.work_dir / "shared") 58 | 59 | # Initialize random state 60 | random.seed(seed) 61 | 62 | all_prjs = list(sorted(set([d.prj for d in dataset]))) 63 | prj2ids = collections.defaultdict(list) 64 | for d in dataset: 65 | prj2ids[d.prj].append(d.id) 66 | 67 | # Get project split 68 | prj_split_names: Dict[str, List[str]] = {sn: l for sn, l in zip( 69 | [Macros.train, Macros.val, Macros.test], 70 | self.split(all_prjs, prj_val_ratio, prj_test_ratio), 71 | )} 72 | prj_split: Dict[str, List[int]] = {sn: sum([prj2ids[n] for n in names], []) 73 | for sn, names in prj_split_names.items()} 74 | for sn in [Macros.train, Macros.val, Macros.test]: 75 | IOUtils.dump(split_dir / f"prj-{sn}.jsonl", prj_split_names[sn], IOUtils.Format.jsonList) 76 | IOUtils.dump(split_dir / f"prj-split-{sn}.jsonl", prj_split[sn], IOUtils.Format.jsonList) 77 | stats[f"num_prj_{sn}"] = len(prj_split_names[sn]) 78 | stats[f"num_prj_split_{sn}"] = len(prj_split[sn]) 79 | 80 | # Get in-project splits 81 | prj2inprj_splits: Dict[str, Dict[str, List[str]]] = {} 82 | 83 | for prj in all_prjs: 84 | prj2inprj_splits[prj] = {sn: l for sn, l in zip( 85 | [Macros.train, Macros.val, Macros.test], 86 | self.split(prj2ids[prj], inprj_val_ratio, inprj_test_ratio), 87 | )} 88 | inprj_split: Dict[str, List[int]] = {sn: sum([prj2inprj_splits[prj][sn] for prj in all_prjs], []) 89 | for sn in [Macros.train, Macros.val, Macros.test]} 90 | for sn in [Macros.train, Macros.val, Macros.test]: 91 | IOUtils.dump(split_dir / f"inprj-split-{sn}.jsonl", inprj_split[sn], IOUtils.Format.jsonList) 92 | stats[f"num_inprj_split_{sn}"] = len(inprj_split[sn]) 93 | 94 | # Get year splits 95 | year_split: Dict[str, List[int]] = {sn: [] for sn in [Macros.train, Macros.val, Macros.test]} 96 | for d in dataset: 97 | min_year = min(d.years) 98 | if min_year <= train_year: 99 | year_split[Macros.train].append(d.id) 100 | elif min_year <= val_year: 101 | year_split[Macros.val].append(d.id) 102 | elif min_year <= test_year: 103 | year_split[Macros.test].append(d.id) 104 | for sn in [Macros.train, Macros.val, Macros.test]: 105 | IOUtils.dump(split_dir / f"year-split-{sn}.jsonl", year_split[sn], IOUtils.Format.jsonList) 106 | stats[f"num_year_split_{sn}"] = len(year_split[sn]) 107 | 108 | # Get actual mixed-prj/cross-prj/temporally splits 109 | train_size = min(len(inprj_split[Macros.train]), len(prj_split[Macros.train]), len(year_split[Macros.train])) 110 | split_sn2split_ids: Dict[Tuple[str, str], List[int]] = {} 111 | for sn in [Macros.train, Macros.val, Macros.test]: 112 | split_sn2split_ids[(Macros.mixed_prj, sn)] = list(sorted(inprj_split[sn])) 113 | split_sn2split_ids[(Macros.cross_prj, sn)] = list(sorted(prj_split[sn])) 114 | split_sn2split_ids[(Macros.temporally, sn)] = list(sorted(year_split[sn])) 115 | 116 | for s1_i, s1 in enumerate(Macros.split_types): 117 | # train/eval/test_standard 118 | for sn in [Macros.train, Macros.val, Macros.test]: 119 | split_ids = split_sn2split_ids[(s1, sn)] 120 | 121 | # Downsample train set 122 | if sn == Macros.train: 123 | IOUtils.dump(split_dir / f"{s1}-{sn}_full.json", split_ids, IOUtils.Format.json) 124 | random.shuffle(split_ids) 125 | split_ids = list(sorted(split_ids[:train_size])) 126 | 127 | # Debugging 128 | if debug: 129 | if sn == Macros.train: 130 | split_ids = split_ids[:500] 131 | else: 132 | split_ids = split_ids[:100] 133 | 134 | sn_oname = sn if sn != Macros.test else Macros.test_standard 135 | IOUtils.dump(split_dir / f"{s1}-{sn_oname}.json", split_ids, IOUtils.Format.json) 136 | stats[f"num_{s1}_{sn_oname}"] = len(split_ids) 137 | 138 | # test_common set 139 | for s2_i in range(s1_i+1, len(Macros.split_types)): 140 | s2 = Macros.split_types[s2_i] 141 | split_ids = self.intersect( 142 | split_sn2split_ids[(s1, Macros.test)], 143 | split_sn2split_ids[(s2, Macros.test)], 144 | ) 145 | 146 | # Debugging 147 | if debug: 148 | split_ids = split_ids[:100] 149 | 150 | IOUtils.dump(split_dir / f"{s1}-{s2}-{Macros.test_common}.json", split_ids, IOUtils.Format.json) 151 | stats[f"num_{s1}-{s2}_{Macros.test_common}"] = len(split_ids) 152 | 153 | # Save stats 154 | IOUtils.dump(split_dir / "stats.json", stats, IOUtils.Format.jsonNoSort) 155 | 156 | # Suggest dvc command 157 | print(Utils.suggest_dvc_add(split_dir)) 158 | return 159 | 160 | @classmethod 161 | def split( 162 | cls, 163 | l: List, 164 | val_ratio: float, 165 | test_ratio: float, 166 | ) -> Tuple[List, List, List]: 167 | assert val_ratio > 0 and test_ratio > 0 and val_ratio + test_ratio < 1 168 | lcopy = copy.copy(l) 169 | random.shuffle(lcopy) 170 | test_val_split = int(math.ceil(len(lcopy) * test_ratio)) 171 | val_train_split = int(math.ceil(len(lcopy) * (test_ratio + val_ratio))) 172 | return ( 173 | lcopy[val_train_split:], 174 | lcopy[test_val_split:val_train_split], 175 | lcopy[:test_val_split], 176 | ) 177 | 178 | @classmethod 179 | def intersect(cls, *lists: List[int])-> List[int]: 180 | return list(sorted(set.intersection(*[set(l) for l in lists]))) 181 | 182 | @classmethod 183 | def get_task_specific_eval_helper(cls, task): 184 | if task == Macros.com_gen: 185 | from tseval.comgen.eval.CGEvalHelper import CGEvalHelper 186 | return CGEvalHelper() 187 | elif task == Macros.met_nam: 188 | from tseval.metnam.eval.MNEvalHelper import MNEvalHelper 189 | return MNEvalHelper() 190 | else: 191 | raise KeyError(f"Invalid task {task}") 192 | 193 | def exp_prepare(self, **options): 194 | task = options.pop("task") 195 | eh = self.get_task_specific_eval_helper(task) 196 | setup_cls_name = options.pop("setup") 197 | setup_name = options.pop("setup_name") 198 | eh.exp_prepare(setup_cls_name, setup_name, **options) 199 | 200 | def exp_train(self, **options): 201 | task = options.pop("task") 202 | eh = self.get_task_specific_eval_helper(task) 203 | setup_name = options.pop("setup_name") 204 | exp_name = options.pop("exp_name") 205 | model_name = options.pop("model_name") 206 | cont_train = Utils.get_option_as_boolean(options, "cont_train", pop=True) 207 | no_save = Utils.get_option_as_boolean(options, "no_save", pop=True) 208 | eh.exp_train(setup_name, exp_name, model_name, cont_train, no_save, **options) 209 | 210 | def exp_eval(self, **options): 211 | task = options.pop("task") 212 | eh = self.get_task_specific_eval_helper(task) 213 | setup_name = options.pop("setup_name") 214 | exp_name = options.pop("exp_name") 215 | action = options.pop("action") 216 | gpu_id = Utils.get_option_and_pop(options, "gpu_id", 0) 217 | eh.exp_eval(setup_name, exp_name, action, gpu_id=gpu_id) 218 | 219 | def exp_compute_metrics(self, **options): 220 | task = options.pop("task") 221 | eh = self.get_task_specific_eval_helper(task) 222 | setup_name = options.pop("setup_name") 223 | exp_name = options.pop("exp_name") 224 | action = options.pop("action") 225 | eh.exp_compute_metrics(setup_name, exp_name, action) 226 | -------------------------------------------------------------------------------- /python/tseval/eval/EvalMetrics.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import math 3 | from typing import * 4 | 5 | import nltk 6 | from rouge import Rouge 7 | 8 | 9 | @functools.lru_cache(maxsize=128_000) 10 | def bleu_cached(gold: Tuple[str], pred: Tuple[str]) -> float: 11 | if len(pred) == 0: 12 | return 0 13 | return nltk.translate.bleu_score.sentence_bleu( 14 | [list(gold)], 15 | list(pred), 16 | smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2, 17 | auto_reweigh=True, 18 | ) 19 | 20 | 21 | @functools.lru_cache(maxsize=128_000) 22 | def token_acc_cached(gold: Tuple[str], pred: Tuple[str]) -> float: 23 | matches = len([ 24 | i for i in range(min(len(gold), len(pred))) 25 | if gold[i] == pred[i] 26 | ]) 27 | return matches / max(len(gold), len(pred)) 28 | 29 | 30 | @functools.lru_cache(maxsize=128_000) 31 | def rouge_l_cached(gold: Tuple[str], pred: Tuple[str]) -> Dict[str, float]: 32 | # Replace the "." characters (e.g., in identifier names), otherwise they'll always be considered as sentence boundaries 33 | hyp = " ".join(pred).replace(".", "") 34 | ref = " ".join(gold).replace(".", "") 35 | 36 | if len(hyp) == 0 or len(ref) == 0: 37 | return {'r': 0, 'p': 0, 'f': 0} 38 | 39 | rouge = Rouge() 40 | scores = rouge.get_scores(hyps=hyp, refs=ref, avg=True) 41 | return scores['rouge-l'] 42 | 43 | 44 | @functools.lru_cache(maxsize=128_000) 45 | def set_match_cached(gold: Tuple[str], pred: Tuple[str]) -> Dict[str, float]: 46 | if len(gold) == 0 or len(pred) == 0: 47 | return {"r": 0, "p": 0, "f": 0} 48 | 49 | gold_unique_tokens = set(gold) 50 | pred_unique_tokens = set(pred) 51 | match_tokens = gold_unique_tokens & pred_unique_tokens 52 | precision = min(len(match_tokens) / len(pred_unique_tokens), 1) 53 | recall = min(len(match_tokens) / len(gold_unique_tokens), 1) 54 | if precision == 0 or recall == 0: 55 | f1 = 0 56 | else: 57 | f1 = 2 / (1 / precision + 1 / recall) 58 | return {"r": recall, "p": precision, "f": f1} 59 | 60 | 61 | @functools.lru_cache(maxsize=128_000) 62 | def meteor_cached(gold: Tuple[str], pred: Tuple[str]) -> float: 63 | if len(gold) == 0 or len(pred) == 0: 64 | return 0 65 | 66 | return nltk.translate.meteor_score.single_meteor_score( 67 | " ".join(gold), 68 | " ".join(pred) 69 | ) 70 | 71 | 72 | @functools.lru_cache(maxsize=128_000) 73 | def near_duplicate_similarity_cached( 74 | gold: Tuple[str], 75 | pred: Tuple[str], 76 | threshold: float = 0.1, 77 | ) -> float: 78 | """ 79 | Computes the approximate token-level accuracy between gold and pred. 80 | 81 | Returns: 82 | token-level accuracy - if not exact match and mismatching tokens <= threshold; 83 | 0 - otherwise. 84 | """ 85 | mismatch_allowed = int(math.ceil(threshold * min(len(gold), len(pred)))) 86 | 87 | # Check length difference 88 | if abs(len(gold) - len(pred)) >= mismatch_allowed: 89 | return 0 90 | 91 | # Count number of mismatches 92 | mismatch_count = 0 93 | max_len = max(len(gold), len(pred)) 94 | for i in range(max_len): 95 | if i >= len(gold): 96 | mismatch_count += len(pred) - i 97 | break 98 | if i >= len(pred): 99 | mismatch_count += len(gold) - i 100 | break 101 | if gold[i] != pred[i]: 102 | mismatch_count += 1 103 | if mismatch_count >= mismatch_allowed: 104 | return 0 105 | if mismatch_count >= mismatch_allowed: 106 | return 0 107 | else: 108 | return 1 - mismatch_count / max_len 109 | 110 | 111 | class EvalMetrics: 112 | 113 | @classmethod 114 | def batch_exact_match(cls, golds: List[Any], preds: List[Any]) -> List[float]: 115 | """ 116 | return[i] = (golds[i] == preds[i]) ? 1 : 0 117 | """ 118 | assert len(golds) == len(preds) 119 | 120 | results = [] 121 | for gold, pred in zip(golds, preds): 122 | if gold == pred: 123 | results.append(1) 124 | else: 125 | results.append(0) 126 | return results 127 | 128 | @classmethod 129 | def token_acc(cls, gold: List[str], pred: List[str]) -> float: 130 | return token_acc_cached(tuple(gold), tuple(pred)) 131 | 132 | @classmethod 133 | def batch_token_acc(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]: 134 | assert len(golds) == len(preds) 135 | return [ 136 | token_acc_cached(tuple(gold), tuple(pred)) 137 | for gold, pred in zip(golds, preds) 138 | ] 139 | 140 | @classmethod 141 | def bleu(cls, gold: List[str], pred: List[str]) -> float: 142 | """ 143 | return = BLEU([gold], pred) 144 | """ 145 | return bleu_cached(tuple(gold), tuple(pred)) 146 | 147 | @classmethod 148 | def batch_bleu(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]: 149 | """ 150 | return[i] = #bleu(golds[i], preds[i]) 151 | """ 152 | assert len(golds) == len(preds) 153 | return [ 154 | bleu_cached(tuple(gold), tuple(pred)) 155 | for gold, pred in zip(golds, preds) 156 | ] 157 | 158 | @classmethod 159 | def rouge_l(cls, gold: List[str], pred: List[str]) -> Dict[str, float]: 160 | """ 161 | return = rouge l metric computed for given sequences 162 | """ 163 | return rouge_l_cached(tuple(gold), tuple(pred)) 164 | 165 | @classmethod 166 | def batch_rouge_l(cls, golds: List[List[str]], preds: List[List[str]]) -> List[Dict[str, float]]: 167 | """ 168 | return[i] = #rouge_l(golds[i], preds[i]) 169 | """ 170 | assert len(golds) == len(preds) 171 | return [ 172 | rouge_l_cached(tuple(gold), tuple(pred)) 173 | for gold, pred in zip(golds, preds) 174 | ] 175 | 176 | @classmethod 177 | def set_match(cls, gold: List[str], pred: List[str]) -> Dict[str, float]: 178 | return set_match_cached(tuple(gold), tuple(pred)) 179 | 180 | @classmethod 181 | def batch_set_match(cls, golds: List[List[str]], preds: List[List[str]]) -> List[Dict[str, float]]: 182 | assert len(golds) == len(preds) 183 | return [ 184 | set_match_cached(tuple(gold), tuple(pred)) 185 | for gold, pred in zip(golds, preds) 186 | ] 187 | 188 | @classmethod 189 | def batch_set_match_f1(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]: 190 | assert len(golds) == len(preds) 191 | return [ 192 | set_match_cached(tuple(gold), tuple(pred))["f"] 193 | for gold, pred in zip(golds, preds) 194 | ] 195 | 196 | @classmethod 197 | def meteor(cls, gold: List[str], pred: List[str]) -> float: 198 | return meteor_cached(tuple(gold), tuple(pred)) 199 | 200 | @classmethod 201 | def batch_meteor(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]: 202 | assert len(golds) == len(preds) 203 | return [ 204 | meteor_cached(tuple(gold), tuple(pred)) 205 | for gold, pred in zip(golds, preds) 206 | ] 207 | 208 | @classmethod 209 | def near_duplicate_similarity(cls, gold: List[str], pred: List[str]) -> float: 210 | return near_duplicate_similarity_cached(tuple(gold), tuple(pred)) 211 | 212 | @classmethod 213 | def batch_near_duplicate_similarity(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]: 214 | assert len(golds) == len(preds) 215 | return [ 216 | near_duplicate_similarity_cached(tuple(gold), tuple(pred)) 217 | for gold, pred in zip(golds, preds) 218 | ] 219 | -------------------------------------------------------------------------------- /python/tseval/eval/EvalSetupBase.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from pathlib import Path 3 | from typing import List 4 | 5 | 6 | class EvalSetupBase: 7 | 8 | def __init__(self, work_dir: Path, work_subdir: Path, setup_name: str): 9 | """ 10 | All setups require the work_dir parameter, which is configured by EvalHelper. 11 | Any other parameters for the implemented setup should be passed in via constructor. 12 | 13 | data_dir is an unit directory managed by dvc. 14 | """ 15 | self.work_dir = work_dir 16 | self.work_subdir = work_subdir 17 | self.setup_name = setup_name 18 | return 19 | 20 | @property 21 | def setup_dir(self): 22 | return self.work_subdir / "setup" / self.setup_name 23 | 24 | @property 25 | def shared_data_dir(self): 26 | return self.work_dir / "shared" 27 | 28 | def get_split_dir(self, split_name: str): 29 | return self.work_dir / "split" / split_name 30 | 31 | @property 32 | def data_dir(self): 33 | return self.setup_dir / "data" 34 | 35 | def get_exp_dir(self, exp_name: str): 36 | return self.work_subdir / "exp" / self.setup_name / exp_name 37 | 38 | def get_result_dir(self, exp_name: str): 39 | return self.work_subdir / "result" / self.setup_name / exp_name 40 | 41 | def get_metric_dir(self, exp_name: str): 42 | return self.work_subdir / "metric" / self.setup_name / exp_name 43 | 44 | @abc.abstractmethod 45 | def prepare(self) -> None: 46 | """ 47 | Prepares this eval setup, primarily obtains the training/validation/testing sets, 48 | reads data from shared_data_dir, saves any processed data to data_dir. 49 | 50 | setup_dir (which includes data_dir) is managed by dvc. 51 | """ 52 | raise NotImplementedError 53 | 54 | @abc.abstractmethod 55 | def train(self, exp_name: str, model_name: str, cont_train: bool, no_save: bool, **options) -> None: 56 | """ 57 | Trains the model, loads data from data_dir, saves the model to 58 | get_exp_dir(exp_name) (intermediate files, e.g., logs, can also be saved there). 59 | 60 | get_exp_dir(exp_name) is managed by dvc. 61 | 62 | :param exp_name: name given to this experiment. 63 | :param model_name: the model's name. 64 | :param cont_train: if True and if there is already a partially trained model in 65 | the save directory, load that model and continue training; otherwise, ignore 66 | any possible partially trained models. 67 | :param no_save: if True, avoids saving anything during training. 68 | :param options: options for initializing the models. 69 | """ 70 | raise NotImplementedError 71 | 72 | @abc.abstractmethod 73 | def eval(self, exp_name: str, actions: List[str] = None, gpu_id: int = 0) -> None: 74 | """ 75 | Evaluates the model (usually, on both validation and testing set), 76 | loads data from data_dir, 77 | loads trained model from get_exp_dir(exp_name), 78 | saves results to get_result_dir(exp_name). 79 | 80 | get_result_dir(exp_name) is managed by dvc. 81 | 82 | :param exp_name: name of the experiment. 83 | :param actions: a list of eval actions requested. 84 | """ 85 | raise NotImplementedError 86 | 87 | @abc.abstractmethod 88 | def compute_metrics(self, exp_name: str, actions: List[str] = None) -> None: 89 | """ 90 | Computes metrics on the prediction results, 91 | loads data from data_dir, 92 | loads results from get_result_dir(exp_name), 93 | saves metrics to get_metric_dir(exp_name). 94 | 95 | get_metric_dir(exp_name) is managed by git. 96 | 97 | :param exp_name: name of the experiment. 98 | :param actions: a list of eval actions requested (to compute metrics for). 99 | """ 100 | raise NotImplementedError 101 | -------------------------------------------------------------------------------- /python/tseval/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/eval/__init__.py -------------------------------------------------------------------------------- /python/tseval/main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | import time 4 | from pathlib import Path 5 | 6 | import pkg_resources 7 | from seutil import CliUtils, IOUtils, LoggingUtils 8 | 9 | from tseval.Environment import Environment 10 | from tseval.Macros import Macros 11 | from tseval.Utils import Utils 12 | 13 | # Check seutil version 14 | EXPECTED_SEUTIL_VERSION = "0.5.6" 15 | if pkg_resources.get_distribution("seutil").version < EXPECTED_SEUTIL_VERSION: 16 | print(f"seutil version does not meet expectation! Expected version: {EXPECTED_SEUTIL_VERSION}, current installed version: {pkg_resources.get_distribution('seutil').version}", file=sys.stderr) 17 | print(f"Hint: either upgrade seutil, or modify the expected version (after confirmation that the version will work)", file=sys.stderr) 18 | sys.exit(-1) 19 | 20 | 21 | logging_file = Macros.python_dir / "experiment.log" 22 | LoggingUtils.setup(filename=str(logging_file)) 23 | 24 | logger = LoggingUtils.get_logger(__name__) 25 | 26 | 27 | # ========== 28 | # Data collection, sample 29 | 30 | def collect_repos(**options): 31 | from tseval.collector.DataCollector import DataCollector 32 | DataCollector().search_github_java_repos() 33 | 34 | 35 | def filter_repos(**options): 36 | from tseval.collector.DataCollector import DataCollector 37 | DataCollector().filter_repos( 38 | year_end=options.get("year_end", 2021), 39 | year_cnt=options.get("year_cnt", 3), 40 | loc_min=options.get("loc_min", 1e6), 41 | loc_max=options.get("loc_max", 2e6), 42 | star_min=options.get("star_min", 20), 43 | ) 44 | 45 | 46 | def collect_raw_data(**options): 47 | from tseval.collector.DataCollector import DataCollector 48 | DataCollector().collect_raw_data_projects( 49 | year_end=options.get("year_end", 2021), 50 | year_cnt=options.get("year_cnt", 3), 51 | skip_collected=Utils.get_option_as_boolean(options, "skip_collected"), 52 | project_names=Utils.get_option_as_list(options, "projects"), 53 | ) 54 | 55 | 56 | def process_raw_data(**options): 57 | from tseval.collector.DataCollector import DataCollector 58 | DataCollector().process_raw_data( 59 | year_end=options.get("year_end", 2021), 60 | year_cnt=options.get("year_cnt", 3), 61 | ) 62 | 63 | 64 | def get_splits(**options): 65 | from tseval.eval.EvalHelper import EvalHelper 66 | EvalHelper().get_splits( 67 | split_name=options["split"], 68 | seed=options.get("seed", 7), 69 | prj_val_ratio=options.get("prj_val_ratio", 0.1), 70 | prj_test_ratio=options.get("prj_test_ratio", 0.2), 71 | inprj_val_ratio=options.get("inprj_val_ratio", 0.1), 72 | inprj_test_ratio=options.get("inprj_test_ratio", 0.2), 73 | train_year=options.get("train_year", 2019), 74 | val_year=options.get("val_year", 2020), 75 | test_year=options.get("test_year", 2021), 76 | debug=Utils.get_option_as_boolean(options, "debug"), 77 | ) 78 | 79 | 80 | # ========== 81 | # Machine learning 82 | 83 | def prepare_envs(**options): 84 | which = Utils.get_option_as_list(options, "which") 85 | if which is None or "TransformerACL20" in which: 86 | from tseval.comgen.model.TransformerACL20 import TransformerACL20 87 | TransformerACL20.prepare_env() 88 | if which is None or "DeepComHybridESE19" in which: 89 | from tseval.comgen.model.DeepComHybridESE19 import DeepComHybridESE19 90 | DeepComHybridESE19.prepare_env() 91 | if which is None or "Code2SeqICLR19" in which: 92 | from tseval.metnam.model.Code2SeqICLR19 import Code2SeqICLR19 93 | Code2SeqICLR19.prepare_env() 94 | if which is None or "Code2VecPOPL19" in which: 95 | from tseval.metnam.model.Code2VecPOPL19 import Code2VecPOPL19 96 | Code2VecPOPL19.prepare_env() 97 | 98 | 99 | def exp_prepare(**options): 100 | from tseval.eval.EvalHelper import EvalHelper 101 | EvalHelper().exp_prepare(**options) 102 | 103 | 104 | def exp_train(**options): 105 | from tseval.eval.EvalHelper import EvalHelper 106 | EvalHelper().exp_train(**options) 107 | 108 | 109 | def exp_eval(**options): 110 | from tseval.eval.EvalHelper import EvalHelper 111 | EvalHelper().exp_eval(**options) 112 | 113 | 114 | def exp_compute_metrics(**options): 115 | from tseval.eval.EvalHelper import EvalHelper 116 | EvalHelper().exp_compute_metrics(**options) 117 | 118 | 119 | # ========== 120 | # Table & Plot 121 | 122 | def make_tables(**options): 123 | from tseval.Table import Table 124 | Table().make_tables(options) 125 | 126 | 127 | def make_plots(**options): 128 | from tseval.Plot import Plot 129 | Plot().make_plots(options) 130 | 131 | 132 | # ========== 133 | # Metrics collection 134 | 135 | def collect_metrics(**options): 136 | from tseval.collector.MetricsCollector import MetricsCollector 137 | 138 | mc = MetricsCollector() 139 | mc.collect_metrics(**options) 140 | return 141 | 142 | 143 | # ========== 144 | # Collect and analyze results 145 | 146 | def analyze_check_files(**options): 147 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 148 | ExperimentsAnalyzer( 149 | exps_spec_path=Path(options["exps"]), 150 | output_prefix=options.get("output"), 151 | ).check_files() 152 | 153 | 154 | def analyze_recompute_metrics(**options): 155 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 156 | ExperimentsAnalyzer( 157 | exps_spec_path=Path(options["exps"]), 158 | output_prefix=options.get("output"), 159 | ).recompute_metrics() 160 | 161 | 162 | def analyze_extract_metrics(**options): 163 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 164 | ExperimentsAnalyzer( 165 | exps_spec_path=Path(options["exps"]), 166 | output_prefix=options.get("output"), 167 | ).extract_metrics() 168 | 169 | 170 | def analyze_sign_test(**options): 171 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 172 | ExperimentsAnalyzer( 173 | exps_spec_path=Path(options["exps"]), 174 | output_prefix=options.get("output"), 175 | ).sign_test_default() 176 | 177 | 178 | def analyze_make_tables(**options): 179 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 180 | ExperimentsAnalyzer( 181 | exps_spec_path=Path(options["exps"]), 182 | output_prefix=options.get("output"), 183 | ).make_tables_default() 184 | 185 | 186 | def analyze_make_plots(**options): 187 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 188 | ExperimentsAnalyzer( 189 | exps_spec_path=Path(options["exps"]), 190 | output_prefix=options.get("output"), 191 | ).make_plots_default() 192 | 193 | 194 | def analyze_sample_results(**options): 195 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 196 | ExperimentsAnalyzer( 197 | exps_spec_path=Path(options["exps"]), 198 | output_prefix=options.get("output"), 199 | ).sample_results( 200 | seed=options.get("seed", 7), 201 | count=options.get("count", 100), 202 | ) 203 | 204 | 205 | def analyze_extract_data_similarities(**options): 206 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 207 | ExperimentsAnalyzer( 208 | exps_spec_path=Path(options["exps"]), 209 | output_prefix=options.get("output"), 210 | ).extract_data_similarities() 211 | 212 | 213 | def analyze_near_duplicates(**options): 214 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer 215 | ExperimentsAnalyzer( 216 | exps_spec_path=Path(options["exps"]), 217 | output_prefix=options.get("output"), 218 | ).filter_near_duplicates_and_analyze( 219 | code_sim_threshold=options["code_sim"], 220 | nl_sim_threshold=options["nl_sim"], 221 | config_name=options["config"], 222 | only_tables_plots=Utils.get_option_as_boolean(options, "only_tables_plots", default=False), 223 | ) 224 | 225 | 226 | # ========== 227 | # Main 228 | 229 | def normalize_options(opts: dict) -> dict: 230 | # Set a different log file 231 | if "log_path" in opts: 232 | logger.info(f"Switching to log file {opts['log_path']}") 233 | LoggingUtils.setup(filename=opts['log_path']) 234 | 235 | # Set debug mode 236 | if "debug" in opts and str(opts["debug"]).lower() != "false": 237 | Environment.is_debug = True 238 | logger.debug("Debug mode on") 239 | logger.debug(f"Command line options: {opts}") 240 | 241 | # Set parallel mode - all automatic installations are disabled 242 | if "parallel" in opts and str(opts["parallel"]).lower() != "false": 243 | Environment.is_parallel = True 244 | logger.warning(f"Parallel mode on") 245 | 246 | # Set/report random seed 247 | if "random_seed" in opts: 248 | Environment.random_seed = int(opts["random_seed"]) 249 | else: 250 | Environment.random_seed = time.time_ns() 251 | 252 | random.seed(Environment.random_seed) 253 | logger.info(f"Random seed is {Environment.random_seed}") 254 | return opts 255 | 256 | 257 | if __name__ == "__main__": 258 | CliUtils.main(sys.argv[1:], globals(), normalize_options) 259 | -------------------------------------------------------------------------------- /python/tseval/metnam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/metnam/__init__.py -------------------------------------------------------------------------------- /python/tseval/metnam/eval/MNEvalHelper.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional 3 | 4 | from seutil import IOUtils, LoggingUtils 5 | 6 | from tseval.eval.EvalSetupBase import EvalSetupBase 7 | from tseval.Macros import Macros 8 | from tseval.metnam.eval import get_setup_cls 9 | from tseval.Utils import Utils 10 | 11 | logger = LoggingUtils.get_logger(__name__) 12 | 13 | 14 | class MNEvalHelper: 15 | 16 | def __init__(self): 17 | self.work_subdir: Path = Macros.work_dir / "MN" 18 | 19 | def exp_prepare(self, setup: str, setup_name: str, **cmd_options): 20 | # Clean the setup dir 21 | setup_dir: Path = self.work_subdir / "setup" / setup_name 22 | IOUtils.rm_dir(setup_dir) 23 | setup_dir.mkdir(parents=True) 24 | 25 | # Initialize setup 26 | setup_cls = get_setup_cls(setup) 27 | setup_options, unk_options, missing_options = Utils.parse_cmd_options_for_type( 28 | cmd_options, 29 | setup_cls, 30 | ["self", "work_dir", "work_subdir", "setup_name"], 31 | ) 32 | if len(missing_options) > 0: 33 | raise KeyError(f"Missing options: {missing_options}") 34 | if len(unk_options) > 0: 35 | logger.warning(f"Unrecognized options: {unk_options}") 36 | setup_obj: EvalSetupBase = setup_cls( 37 | work_dir=Macros.work_dir, 38 | work_subdir=self.work_subdir, 39 | setup_name=setup_name, 40 | **setup_options, 41 | ) 42 | 43 | # Save setup configs 44 | setup_options["setup"] = setup 45 | IOUtils.dump(setup_dir / "setup_config.json", setup_options, IOUtils.Format.jsonNoSort) 46 | 47 | # Prepare data 48 | setup_obj.prepare() 49 | 50 | # Print dvc commands 51 | print(Utils.suggest_dvc_add(setup_obj.setup_dir)) 52 | 53 | def load_setup(self, setup_dir: Path, setup_name: str) -> EvalSetupBase: 54 | """ 55 | Loads the setup from its save directory, with updating setup_name. 56 | """ 57 | config = IOUtils.load(setup_dir / "setup_config.json", IOUtils.Format.json) 58 | setup_cls = get_setup_cls(config.pop("setup")) 59 | setup_obj = setup_cls(work_dir=Macros.work_dir, work_subdir=self.work_subdir, setup_name=setup_name, **config) 60 | return setup_obj 61 | 62 | def exp_train( 63 | self, 64 | setup_name: str, 65 | exp_name: str, 66 | model_name: str, 67 | cont_train: bool, 68 | no_save: bool, 69 | **cmd_options, 70 | ): 71 | # Load saved setup 72 | setup_dir = self.work_subdir / "setup" / setup_name 73 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir) 74 | setup = self.load_setup(setup_dir, setup_name) 75 | 76 | if not cont_train: 77 | # Delete existing trained model 78 | IOUtils.rm_dir(setup.get_exp_dir(exp_name)) 79 | 80 | # Invoke training 81 | setup.train(exp_name, model_name, cont_train, no_save, **cmd_options) 82 | 83 | # Print dvc commands 84 | print(Utils.suggest_dvc_add(setup.get_exp_dir(exp_name))) 85 | 86 | def exp_eval( 87 | self, 88 | setup_name: str, 89 | exp_name: str, 90 | action: Optional[str], 91 | gpu_id: int = 0, 92 | ): 93 | # Load saved setup 94 | setup_dir = self.work_subdir / "setup" / setup_name 95 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir) 96 | setup = self.load_setup(setup_dir, setup_name) 97 | 98 | # Invoke eval 99 | setup.eval(exp_name, action, gpu_id=gpu_id) 100 | 101 | # Print dvc commands 102 | print(Utils.suggest_dvc_add(setup.get_result_dir(exp_name))) 103 | 104 | def exp_compute_metrics( 105 | self, 106 | setup_name: str, 107 | exp_name: str, 108 | action: Optional[str] = None, 109 | ): 110 | # Load saved setup 111 | setup_dir = self.work_subdir / "setup" / setup_name 112 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir) 113 | setup = self.load_setup(setup_dir, setup_name) 114 | 115 | # Invoke eval 116 | setup.compute_metrics(exp_name, action) 117 | 118 | # Print dvc commands 119 | print(Utils.suggest_dvc_add(setup.get_metric_dir(exp_name))) 120 | -------------------------------------------------------------------------------- /python/tseval/metnam/eval/MNModelLoader.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from seutil import IOUtils, LoggingUtils 4 | 5 | from tseval.metnam.model import get_model_cls 6 | from tseval.metnam.model.MNModelBase import MNModelBase 7 | from tseval.Utils import Utils 8 | 9 | logger = LoggingUtils.get_logger(__name__) 10 | 11 | 12 | class MNModelLoader: 13 | 14 | @classmethod 15 | def init_or_load_model( 16 | cls, 17 | model_name: str, 18 | exp_dir: Path, 19 | cont_train: bool, 20 | no_save: bool, 21 | cmd_options: dict, 22 | ) -> MNModelBase: 23 | model_cls = get_model_cls(model_name) 24 | model_work_dir = exp_dir / "model" 25 | 26 | if cont_train and model_work_dir.is_dir() and not no_save: 27 | # Restore model name 28 | loaded_model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt) 29 | if model_name != loaded_model_name: 30 | raise ValueError(f"Contradicting model name (saved: {model_name}; new {loaded_model_name})") 31 | 32 | # Warning about any additional command line arguments 33 | if len(cmd_options) > 0: 34 | logger.warning(f"These options will not be used in cont_train mode: {cmd_options}") 35 | 36 | # Load existing model 37 | model: MNModelBase = model_cls.load(model_work_dir) 38 | else: 39 | 40 | if not no_save: 41 | exp_dir.mkdir(parents=True, exist_ok=True) 42 | 43 | # Save model name 44 | IOUtils.dump(exp_dir / "model_name.txt", model_name, IOUtils.Format.txt) 45 | 46 | # Prepare directory for model 47 | IOUtils.rm(model_work_dir) 48 | model_work_dir.mkdir(parents=True) 49 | 50 | # Initialize the model, using command line arguments 51 | model_options, unk_options, missing_options = Utils.parse_cmd_options_for_type( 52 | cmd_options, 53 | model_cls, 54 | ["self", "model_work_dir"], 55 | ) 56 | if len(missing_options) > 0: 57 | raise KeyError(f"Missing options: {missing_options}") 58 | if len(unk_options) > 0: 59 | logger.warning(f"Unrecognized options: {unk_options}") 60 | 61 | model: MNModelBase = model_cls(model_work_dir=model_work_dir, no_save=no_save, **model_options) 62 | 63 | if not no_save: 64 | # Save model configs 65 | IOUtils.dump(exp_dir / "model_config.json", model_options, IOUtils.Format.jsonNoSort) 66 | return model 67 | 68 | @classmethod 69 | def load_model(cls, exp_dir: Path) -> MNModelBase: 70 | """ 71 | Loads a trained model from exp_dir. Gets the model name from train_config.json. 72 | """ 73 | Utils.expect_dir_or_suggest_dvc_pull(exp_dir) 74 | model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt) 75 | model_cls = get_model_cls(model_name) 76 | model_dir = exp_dir / "model" 77 | return model_cls.load(model_dir) 78 | -------------------------------------------------------------------------------- /python/tseval/metnam/eval/StandardSetup.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import time 4 | from pathlib import Path 5 | from typing import Dict, List 6 | 7 | import numpy as np 8 | from seutil import IOUtils, LoggingUtils 9 | from tqdm import tqdm 10 | 11 | from tseval.data.MethodData import MethodData 12 | from tseval.eval.EvalMetrics import EvalMetrics 13 | from tseval.eval.EvalSetupBase import EvalSetupBase 14 | from tseval.Macros import Macros 15 | from tseval.metnam.eval.MNModelLoader import MNModelLoader 16 | from tseval.metnam.model.MNModelBase import MNModelBase 17 | from tseval.util.ModelUtils import ModelUtils 18 | from tseval.util.TrainConfig import TrainConfig 19 | from tseval.Utils import Utils 20 | 21 | logger = LoggingUtils.get_logger(__name__) 22 | 23 | 24 | class StandardSetup(EvalSetupBase): 25 | 26 | # Validation set on self's split type 27 | EVAL_VAL = Macros.val 28 | # Test_standard set on self's split type 29 | EVAL_TESTS = Macros.test_standard 30 | # Test_common sets, pairwisely between self's split type and other split types 31 | EVAL_TESTC = Macros.test_common 32 | 33 | EVAL_ACTIONS = [EVAL_VAL, EVAL_TESTS, EVAL_TESTC] 34 | DEFAULT_EVAL_ACTION = EVAL_TESTC 35 | 36 | def __init__( 37 | self, 38 | work_dir: Path, 39 | work_subdir: Path, 40 | setup_name: str, 41 | split_name: str, 42 | split_type: str, 43 | ): 44 | super().__init__(work_dir, work_subdir, setup_name) 45 | self.split_name = split_name 46 | self.split_type = split_type 47 | 48 | def prepare(self) -> None: 49 | # Check and prepare directories 50 | split_dir = self.get_split_dir(self.split_name) 51 | Utils.expect_dir_or_suggest_dvc_pull(self.shared_data_dir) 52 | Utils.expect_dir_or_suggest_dvc_pull(split_dir) 53 | IOUtils.rm_dir(self.data_dir) 54 | self.data_dir.mkdir(parents=True) 55 | 56 | # Copy split indexes 57 | all_indexes = [] 58 | for sn in [Macros.train, Macros.val, Macros.test_standard]: 59 | ids = IOUtils.load(split_dir / f"{self.split_type}-{sn}.json", IOUtils.Format.json) 60 | all_indexes += ids 61 | IOUtils.dump(self.data_dir / f"split_{sn}.json", ids, IOUtils.Format.json) 62 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type): 63 | ids = IOUtils.load(split_dir / f"{s1}-{s2}-{Macros.test_common}.json", IOUtils.Format.json) 64 | all_indexes += ids 65 | IOUtils.dump( 66 | self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json", 67 | ids, 68 | IOUtils.Format.json, 69 | ) 70 | all_indexes = list(sorted(set(all_indexes))) 71 | 72 | # Load raw data 73 | tbar = tqdm() 74 | dataset: List[MethodData] = MethodData.load_dataset( 75 | self.shared_data_dir, 76 | expected_ids=all_indexes, 77 | only=["code", "code_masked", "name", "misc"], 78 | tbar=tbar, 79 | ) 80 | tbar.close() 81 | 82 | # Subtokenize code and comments 83 | tbar = tqdm() 84 | 85 | tbar.set_description("Subtokenizing") 86 | tbar.reset(len(dataset)) 87 | orig_code_list = [] 88 | name_list = [] 89 | for d in dataset: 90 | # Replace the mask token such that the code remains parsable 91 | code_masked = d.code_masked.replace("", "METHODNAMEMASK") 92 | d.fill_none() 93 | d.misc["orig_code"] = d.code 94 | d.misc["orig_code_masked"] = code_masked 95 | d.misc["orig_name"] = d.name 96 | orig_code_list.append(code_masked) 97 | name_list.append(d.name) 98 | tbar.update(1) 99 | 100 | tokenized_code_list = ModelUtils.tokenize_javaparser_batch(orig_code_list, dup_share=False, tbar=tbar) 101 | 102 | tbar.set_description("Subtokenizing") 103 | tbar.reset(len(dataset)) 104 | for d, tokenized_code, name in zip(dataset, tokenized_code_list, name_list): 105 | d.code, d.misc["code_src_ids"] = ModelUtils.subtokenize_batch(tokenized_code) 106 | d.code_masked = None 107 | d.name = ModelUtils.subtokenize(name) 108 | 109 | # convert name to lower case 110 | d.name = [t.lower() for t in d.name] 111 | tbar.update(1) 112 | tbar.close() 113 | 114 | # Clean eval ids 115 | indexed_dataset = {d.id: d for d in dataset} 116 | for sn in [Macros.val, Macros.test_standard] + [f"{Macros.test_common}-{x}-{y}" for x, y in Macros.get_pairwise_split_types_with(self.split_type)]: 117 | eval_ids = IOUtils.load(self.data_dir / f"split_{sn}.json", IOUtils.Format.json) 118 | IOUtils.dump(self.data_dir / f"split_{sn}.json", self.clean_eval_set(indexed_dataset, eval_ids), IOUtils.Format.json) 119 | 120 | # Save dataset 121 | MethodData.save_dataset(dataset, self.data_dir) 122 | 123 | def clean_eval_set(self, indexed_dataset: Dict[int, MethodData], eval_ids: List[int]) -> List[int]: 124 | """ 125 | Keeps the eval set absolutely clean by: 126 | - Remove duplicate (name, code) pairs; 127 | indexed_dataset should already been subtokenized. 128 | """ 129 | seen_data = set() 130 | clean_eval_ids = [] 131 | for i in eval_ids: 132 | data = indexed_dataset[i] 133 | 134 | # Remove duplicate (name, code) pairs 135 | data_key = (tuple(data.code), tuple(data.name)) 136 | if data_key in seen_data: 137 | continue 138 | else: 139 | seen_data.add(data_key) 140 | 141 | clean_eval_ids.append(i) 142 | return clean_eval_ids 143 | 144 | def train(self, exp_name: str, model_name: str, cont_train: bool, no_save: bool, **options) -> None: 145 | # Init or load model 146 | exp_dir = self.get_exp_dir(exp_name) 147 | train_config = TrainConfig.get_train_config_from_cmd_options(options) 148 | model = MNModelLoader.init_or_load_model(model_name, exp_dir, cont_train, no_save, options) 149 | if not no_save: 150 | IOUtils.dump(exp_dir / "train_config.jsonl", [IOUtils.jsonfy(train_config)], IOUtils.Format.jsonList, append=True) 151 | 152 | # Load data 153 | tbar = tqdm(desc="Loading data") 154 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar) 155 | indexed_dataset = {d.id: d for d in dataset} 156 | 157 | tbar.set_description("Loading data | take indexes") 158 | tbar.reset(2) 159 | 160 | train_ids = IOUtils.load(self.data_dir / f"split_{Macros.train}.json", IOUtils.Format.json) 161 | train_dataset = [indexed_dataset[i] for i in train_ids] 162 | tbar.update(1) 163 | 164 | val_ids = IOUtils.load(self.data_dir / f"split_{Macros.val}.json", IOUtils.Format.json) 165 | val_dataset = [indexed_dataset[i] for i in val_ids] 166 | tbar.update(1) 167 | 168 | tbar.close() 169 | 170 | # Train model 171 | start = time.time() 172 | model.train(train_dataset, val_dataset, resources_path=self.data_dir, train_config=train_config) 173 | end = time.time() 174 | 175 | if not no_save: 176 | model.save() 177 | IOUtils.dump(exp_dir / "train_time.json", end - start, IOUtils.Format.json) 178 | 179 | def eval_one(self, exp_name: str, eval_ids: List[int], prefix: str, indexed_dataset: Dict[int, MethodData], model: MNModelBase, gpu_id: int = 0): 180 | # Prepare output directory 181 | result_dir = self.get_result_dir(exp_name) 182 | result_dir.mkdir(parents=True, exist_ok=True) 183 | 184 | # Prepare eval data (remove target) 185 | eval_dataset = [indexed_dataset[i] for i in eval_ids] 186 | golds = [] 187 | for d in eval_dataset: 188 | golds.append(d.name) 189 | d.name = ["dummy"] 190 | d.misc["orig_name"] = "dummy" 191 | 192 | # Perform batched queries 193 | tbar = tqdm(desc=f"Predicting | {prefix}") 194 | eval_start = time.time() 195 | predictions = model.batch_predict(eval_dataset, tbar=tbar, gpu_id=gpu_id) 196 | eval_end = time.time() 197 | tbar.close() 198 | 199 | eval_time = eval_end - eval_start 200 | 201 | # Save predictions & golds 202 | IOUtils.dump(result_dir / f"{prefix}_predictions.jsonl", predictions, IOUtils.Format.jsonList) 203 | IOUtils.dump(result_dir / f"{prefix}_golds.jsonl", golds, IOUtils.Format.jsonList) 204 | IOUtils.dump(result_dir / f"{prefix}_eval_time.json", eval_time, IOUtils.Format.json) 205 | 206 | def eval(self, exp_name: str, action: str = None, gpu_id: int = 0) -> None: 207 | if action is None: 208 | action = self.DEFAULT_EVAL_ACTION 209 | if action not in self.EVAL_ACTIONS: 210 | raise RuntimeError(f"Unknown eval action {action}") 211 | 212 | # Load eval data 213 | tbar = tqdm(desc="Loading data") 214 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar) 215 | indexed_dataset = {d.id: d for d in dataset} 216 | tbar.close() 217 | 218 | # Load model 219 | exp_dir = self.get_exp_dir(exp_name) 220 | model: MNModelBase = MNModelLoader.load_model(exp_dir) 221 | if not model.is_train_finished(): 222 | logger.warning(f"Model not finished training, at {exp_dir}") 223 | 224 | # Invoke eval_one with specific data ids 225 | if action in [self.EVAL_VAL, self.EVAL_TESTS]: 226 | self.eval_one( 227 | exp_name, 228 | IOUtils.load(self.data_dir / f"split_{action}.json", IOUtils.Format.json), 229 | action, 230 | indexed_dataset, 231 | model, 232 | gpu_id=gpu_id, 233 | ) 234 | elif action == self.EVAL_TESTC: 235 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type): 236 | self.eval_one( 237 | exp_name, 238 | IOUtils.load(self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json", IOUtils.Format.json), 239 | f"{Macros.test_common}-{s1}-{s2}", 240 | copy.deepcopy(indexed_dataset), 241 | model, 242 | gpu_id=gpu_id, 243 | ) 244 | else: 245 | raise RuntimeError(f"Unknown action {action}") 246 | 247 | def compute_metrics_one(self, exp_name: str, prefix: str): 248 | # Prepare output directory 249 | metric_dir = self.get_metric_dir(exp_name) 250 | metric_dir.mkdir(parents=True, exist_ok=True) 251 | 252 | # Load golds and predictions 253 | result_dir = self.get_result_dir(exp_name) 254 | Utils.expect_dir_or_suggest_dvc_pull(result_dir) 255 | golds = IOUtils.load(result_dir / f"{prefix}_golds.jsonl", IOUtils.Format.jsonList) 256 | predictions = IOUtils.load(result_dir / f"{prefix}_predictions.jsonl", IOUtils.Format.jsonList) 257 | 258 | metrics_list: Dict[str, List] = collections.defaultdict(list) 259 | metrics_list["exact_match"] = EvalMetrics.batch_exact_match(golds, predictions) 260 | metrics_list["token_acc"] = EvalMetrics.batch_token_acc(golds, predictions) 261 | metrics_list["bleu"] = EvalMetrics.batch_bleu(golds, predictions) 262 | rouge_l_res = EvalMetrics.batch_rouge_l(golds, predictions) 263 | metrics_list["rouge_l_f"] = [x["f"] for x in rouge_l_res] 264 | metrics_list["rouge_l_p"] = [x["p"] for x in rouge_l_res] 265 | metrics_list["rouge_l_r"] = [x["r"] for x in rouge_l_res] 266 | metrics_list["meteor"] = EvalMetrics.batch_meteor(golds, predictions) 267 | set_match_res = EvalMetrics.batch_set_match(golds, predictions) 268 | metrics_list["set_match_f"] = [x["f"] for x in set_match_res] 269 | metrics_list["set_match_p"] = [x["p"] for x in set_match_res] 270 | metrics_list["set_match_r"] = [x["r"] for x in set_match_res] 271 | 272 | # Take average 273 | metrics = {} 274 | for k, l in metrics_list.items(): 275 | metrics[k] = np.mean(l).item() 276 | 277 | # Save metrics 278 | IOUtils.dump(metric_dir / f"{prefix}_metrics.json", metrics, IOUtils.Format.jsonNoSort) 279 | IOUtils.dump(metric_dir / f"{prefix}_metrics.txt", [f"{k}: {v}" for k, v in metrics.items()], IOUtils.Format.txtList) 280 | IOUtils.dump(metric_dir / f"{prefix}_metrics_list.pkl", metrics_list, IOUtils.Format.pkl) 281 | 282 | def compute_metrics(self, exp_name: str, action: str = None) -> None: 283 | if action is None: 284 | action = self.DEFAULT_EVAL_ACTION 285 | if action not in self.EVAL_ACTIONS: 286 | raise RuntimeError(f"Unknown eval action {action}") 287 | 288 | if action in [self.EVAL_VAL, self.EVAL_TESTS]: 289 | self.compute_metrics_one( 290 | exp_name, 291 | action, 292 | ) 293 | elif action == self.EVAL_TESTC: 294 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type): 295 | self.compute_metrics_one( 296 | exp_name, 297 | f"{Macros.test_common}-{s1}-{s2}", 298 | ) 299 | else: 300 | raise RuntimeError(f"Unknown action {action}") 301 | -------------------------------------------------------------------------------- /python/tseval/metnam/eval/__init__.py: -------------------------------------------------------------------------------- 1 | def get_setup_cls(name: str) -> type: 2 | if name == "StandardSetup": 3 | from tseval.metnam.eval.StandardSetup import StandardSetup 4 | return StandardSetup 5 | else: 6 | raise ValueError(f"No setup with name {name}") 7 | -------------------------------------------------------------------------------- /python/tseval/metnam/model/MNModelBase.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from pathlib import Path 3 | from typing import List, Optional 4 | 5 | from seutil import IOUtils 6 | from tqdm import tqdm 7 | 8 | from tseval.data.MethodData import MethodData 9 | from tseval.util.TrainConfig import TrainConfig 10 | 11 | 12 | class MNModelBase: 13 | 14 | def __init__(self, model_work_dir: Path, no_save: bool = False): 15 | self.model_work_dir = model_work_dir 16 | self.no_save = no_save 17 | 18 | @abc.abstractmethod 19 | def train( 20 | self, 21 | train_dataset: List[MethodData], 22 | val_dataset: List[MethodData], 23 | resources_path: Optional[Path] = None, 24 | train_config: Optional[TrainConfig] = None, 25 | ): 26 | """ 27 | Trains the model. 28 | 29 | :param train_dataset: training set. 30 | :param val_dataset: validation set. 31 | :param resources_path: path to resources that could be shared by multiple model's training process, 32 | e.g., pre-trained embeddings. 33 | """ 34 | raise NotImplementedError 35 | 36 | @abc.abstractmethod 37 | def is_train_finished(self) -> bool: 38 | raise NotImplementedError 39 | 40 | @abc.abstractmethod 41 | def predict( 42 | self, 43 | data: MethodData, 44 | gpu_id: int = 0, 45 | ) -> List[str]: 46 | """ 47 | Predicts the comment summary given the context in data. The model should output 48 | results with a confidence score in [0, 1]. 49 | :param data: the data, with its statements partially filled. 50 | :return: a list of predicted comment summary tokens. 51 | """ 52 | raise NotImplementedError 53 | 54 | def batch_predict( 55 | self, 56 | dataset: List[MethodData], 57 | tbar: Optional[tqdm] = None, 58 | gpu_id: int = 0, 59 | ) -> List[List[str]]: 60 | """ 61 | Performs batched predictions using given dataset as inputs. 62 | 63 | The default implementation invokes #predict multiple times. Subclass can override 64 | this method to speed up the prediction by using batching. 65 | 66 | :param dataset: a list of inputs. 67 | :param tbar: an optional tqdm progress bar to show prediction progress. 68 | :return: a list of the return value of #predict. 69 | """ 70 | if tbar is not None: 71 | tbar.reset(len(dataset)) 72 | 73 | results = [] 74 | for data in dataset: 75 | results.append(self.predict(data, gpu_id=gpu_id)) 76 | if tbar is not None: 77 | tbar.update(1) 78 | 79 | return results 80 | 81 | def save(self) -> None: 82 | """ 83 | Saves the current model at the work_dir. 84 | Default behavior is to serialize the entire object in model.pkl. 85 | """ 86 | if not self.no_save: 87 | IOUtils.dump(self.model_work_dir / "model.pkl", self, IOUtils.Format.pkl) 88 | 89 | @classmethod 90 | def load(cls, work_dir) -> "MNModelBase": 91 | """ 92 | Loads a model from the work_dir. 93 | Default behavior is to deserialize the object from model.pkl, with resetting its work_dir. 94 | """ 95 | obj = IOUtils.load(work_dir / "model.pkl", IOUtils.Format.pkl) 96 | obj.model_work_dir = work_dir 97 | return obj 98 | -------------------------------------------------------------------------------- /python/tseval/metnam/model/__init__.py: -------------------------------------------------------------------------------- 1 | def get_model_cls(name: str) -> type: 2 | if name == "Code2SeqICLR19": 3 | from tseval.metnam.model.Code2SeqICLR19 import Code2SeqICLR19 4 | return Code2SeqICLR19 5 | elif name == "Code2VecPOPL19": 6 | from tseval.metnam.model.Code2VecPOPL19 import Code2VecPOPL19 7 | return Code2VecPOPL19 8 | else: 9 | raise ValueError(f"No model with name {name}") 10 | -------------------------------------------------------------------------------- /python/tseval/util/ModelUtils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import re 4 | import tempfile 5 | import time 6 | from pathlib import Path 7 | from typing import Dict, List, Optional, OrderedDict, Tuple 8 | 9 | from seutil import BashUtils, IOUtils, LoggingUtils 10 | from tqdm import tqdm 11 | 12 | from tseval.Environment import Environment 13 | 14 | logger = LoggingUtils.get_logger(__name__) 15 | 16 | 17 | class ModelUtils: 18 | 19 | TOKENIZE_JAVAPARSER_BATCH_SIZE = 10000 20 | 21 | @classmethod 22 | def get_random_seed(cls) -> int: 23 | """ 24 | Generates a random int as seed, within the range of [0, 2**32) 25 | The seed is generated based on current time 26 | """ 27 | return time.time_ns() % (2**32) 28 | 29 | @classmethod 30 | def tokenize_javaparser(cls, code: str) -> List[str]: 31 | return cls.tokenize_javaparser_batch([code])[0] 32 | 33 | @classmethod 34 | def tokenize_javaparser_batch( 35 | cls, 36 | code_list: List[str], 37 | dup_share: bool = True, 38 | tbar: Optional[tqdm] = None, 39 | ) -> List[List[str]]: 40 | """ 41 | Tokenizes a list of code using JavaParser. 42 | :param code_list: a list of code to be tokenized. 43 | :param dup_share: if True (default), the returned lists of tokens will be shared across duplicate code 44 | (thus modifying one of them will affect others). 45 | :param tbar: optional tqdm progress bar. 46 | :returns a list of tokenized code. 47 | """ 48 | # get an unique list of code to tokenize, maintain a backward mapping 49 | code_2_id: OrderedDict[str, int] = collections.OrderedDict() 50 | ids: List[int] = [] 51 | for c in code_list: 52 | ids.append(code_2_id.setdefault(c, len(code_2_id))) 53 | 54 | unique_code_list = list(code_2_id.keys()) 55 | if tbar is not None: 56 | tbar.set_description(f"JavaParser Tokenize ({len(unique_code_list)}U/{len(code_list)})") 57 | tbar.reset(len(unique_code_list)) 58 | 59 | # Tokenize (with batching) 60 | unique_tokens_list = [] 61 | for beg in range(0, len(unique_code_list), cls.TOKENIZE_JAVAPARSER_BATCH_SIZE): 62 | unique_tokens_list += cls.tokenize_javaparser_batch_( 63 | unique_code_list[beg:beg + cls.TOKENIZE_JAVAPARSER_BATCH_SIZE], tbar=tbar, 64 | ) 65 | 66 | if dup_share: 67 | return [unique_tokens_list[i] for i in ids] 68 | else: 69 | return [copy.copy(unique_tokens_list[i]) for i in ids] 70 | 71 | @classmethod 72 | def tokenize_javaparser_batch_( 73 | cls, 74 | code_list: List[str], 75 | tbar: Optional[tqdm] = None, 76 | ): 77 | # Use JavaParser to tokenize 78 | Environment.require_collector() 79 | 80 | tokenizer_inputs = [] 81 | for code in code_list: 82 | tokenizer_inputs.append({ 83 | "index": len(tokenizer_inputs), 84 | "code": code, 85 | }) 86 | 87 | inputs_file = Path(tempfile.mktemp()) 88 | IOUtils.dump(inputs_file, tokenizer_inputs, IOUtils.Format.json) 89 | outputs_file = Path(tempfile.mktemp()) 90 | 91 | BashUtils.run( 92 | f"java -cp {Environment.collector_jar} org.tseval.ExtractToken '{inputs_file}' '{outputs_file}'", 93 | expected_return_code=0, 94 | ) 95 | 96 | tokenizer_outputs = IOUtils.load(outputs_file, IOUtils.Format.json) 97 | IOUtils.rm(inputs_file) 98 | IOUtils.rm(outputs_file) 99 | 100 | # Check for tokenizer failures 101 | for code, output in zip(code_list, tokenizer_outputs): 102 | if len(code.strip()) == 0: 103 | logger.warning(f"Empty code: {code}") 104 | continue 105 | if len(output["tokens"]) == 0: 106 | logger.warning(f"Tokenizer failed: {code}") 107 | 108 | if tbar is not None: 109 | tbar.update(len(code_list)) 110 | 111 | return [d["tokens"] for d in tokenizer_outputs] 112 | 113 | RE_SUBTOKENIZE = re.compile(r"(?<=[_$])(?!$)|(?" 115 | 116 | @classmethod 117 | def is_identifier(cls, token: str) -> bool: 118 | return len(token) > 0 and \ 119 | (token[0].isalpha() or token[0] in "_$") and \ 120 | all([c.isalnum() or c in "_$" for c in token]) 121 | 122 | @classmethod 123 | def subtokenize(cls, token: str) -> List[str]: 124 | """ 125 | Subtokenizes an identifier name into subtokens, by CamelCase and snake_case. 126 | """ 127 | # Only subtokenize identifier words (starts with letter _$, contains only alnum and _$) 128 | if cls.is_identifier(token): 129 | return cls.RE_SUBTOKENIZE.split(token) 130 | else: 131 | return [token] 132 | 133 | @classmethod 134 | def subtokenize_batch(cls, tokens: List[str]) -> Tuple[List[str], List[int]]: 135 | """ 136 | Subtokenizes list of tokens. 137 | :return a list of subtokens, and a list of pointers to the original token indices. 138 | """ 139 | sub_tokens = [] 140 | src_indices = [] 141 | for i, token in enumerate(tokens): 142 | new_sub_tokens = cls.subtokenize(token) 143 | sub_tokens += new_sub_tokens 144 | src_indices += [i] * len(new_sub_tokens) 145 | return sub_tokens, src_indices 146 | 147 | @classmethod 148 | def subtokenize_space_batch(cls, tokens: List[str]) -> Tuple[List[str], List[int]]: 149 | """ 150 | Subtokenizes list of tokens, and inserts special token when necessary 151 | (between two identifiers). 152 | :return a list of subtokens, and a list of pointers to the original token indices. 153 | """ 154 | sub_tokens = [] 155 | src_indices = [] 156 | last_is_identifier = False 157 | for i, token in enumerate(tokens): 158 | is_identifier = cls.is_identifier(token) 159 | if last_is_identifier and is_identifier: 160 | sub_tokens.append(cls.SPACE_TOKEN) 161 | src_indices.append(-1) 162 | new_sub_tokens = cls.subtokenize(token) 163 | sub_tokens += new_sub_tokens 164 | src_indices += [i] * len(new_sub_tokens) 165 | last_is_identifier = is_identifier 166 | return sub_tokens, src_indices 167 | 168 | @classmethod 169 | def regroup_subtokens(cls, subtokens: List[str], src_indices: List[int]) -> List[str]: 170 | """ 171 | Given a list of subtokens and the original token indices, groups them back to tokens. 172 | :param subtokens: a list of subtokens. 173 | :param src_indices: the i-th indice should point to the original token that 174 | sub_tokens[i] belongs to; -1 means it is a special sub_token. 175 | :return: a list of tokens after regrouping. 176 | """ 177 | id2tokens: Dict[int, str] = collections.defaultdict(str) 178 | for subtoken, i in zip(subtokens, src_indices): 179 | if i >= 0: 180 | id2tokens[i] += subtoken 181 | return [id2tokens[i] for i in sorted(id2tokens.keys())] 182 | -------------------------------------------------------------------------------- /python/tseval/util/TrainConfig.py: -------------------------------------------------------------------------------- 1 | from typing import get_type_hints 2 | 3 | from recordclass import RecordClass 4 | 5 | 6 | class TrainConfig(RecordClass): 7 | train_session_time: int = 20 * 3600 8 | gpu_id: int = 0 9 | 10 | @classmethod 11 | def get_train_config_from_cmd_options(cls, options: dict) -> "TrainConfig": 12 | """ 13 | Gets a TrainConfig from the command line options (the options will be modified 14 | in place to remove the parsed fields). 15 | """ 16 | field_values = {} 17 | for f, t in get_type_hints(cls).items(): 18 | if f in options: 19 | field_values[f] = t(options.pop(f)) 20 | return cls(**field_values) 21 | -------------------------------------------------------------------------------- /python/tseval/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/util/__init__.py --------------------------------------------------------------------------------