├── .gitignore
├── LICENSE
├── README.md
├── collector
├── .gitignore
├── pom.xml
└── src
│ └── main
│ └── java
│ └── org
│ └── tseval
│ ├── ExtractToken.java
│ ├── MaskMethodNameVisitor.java
│ ├── MethodDataCollector.java
│ ├── MethodDataCollectorVisitor.java
│ ├── data
│ ├── MethodData.java
│ ├── ProjectData.java
│ └── RevisionIds.java
│ └── util
│ ├── AbstractConfig.java
│ ├── BashUtils.java
│ ├── NLPUtils.java
│ └── Option.java
└── python
├── .gitignore
├── exps
├── cg.yaml
└── mn.yaml
├── prepare_conda_env.sh
├── requirements.txt
├── run.sh
├── switch-cuda.sh
└── tseval
├── Environment.py
├── Macros.py
├── Plot.py
├── Table.py
├── Utils.py
├── __init__.py
├── collector
├── DataCollector.py
├── ExperimentsAnalyzer.py
├── MetricsCollector.py
└── __init__.py
├── comgen
├── __init__.py
├── eval
│ ├── CGEvalHelper.py
│ ├── CGModelLoader.py
│ ├── StandardSetup.py
│ └── __init__.py
└── model
│ ├── CGModelBase.py
│ ├── DeepComHybridESE19.py
│ ├── TransformerACL20.py
│ └── __init__.py
├── data
├── MethodData.py
├── RevisionIds.py
└── __init__.py
├── eval
├── EvalHelper.py
├── EvalMetrics.py
├── EvalSetupBase.py
└── __init__.py
├── main.py
├── metnam
├── __init__.py
├── eval
│ ├── MNEvalHelper.py
│ ├── MNModelLoader.py
│ ├── StandardSetup.py
│ └── __init__.py
└── model
│ ├── Code2SeqICLR19.py
│ ├── Code2VecPOPL19.py
│ ├── MNModelBase.py
│ └── __init__.py
└── util
├── ModelUtils.py
├── TrainConfig.py
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Temp files
2 | *~
3 | \#*\#
4 | .DS_Store
5 |
6 | # Temp directories
7 |
8 | /_downloads/
9 | /_results/
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 EngineeringSoftware
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Impact of Evaluation Methodologies on Code Summarization
2 |
3 | This repo hosts the code and data for the following ACL 2022 paper:
4 |
5 | Title: [Impact of Evaluation Methodologies on Code Summarization][paper-utcs]
6 |
7 | Authors: [Pengyu Nie](https://pengyunie.github.io/), [Jiyang Zhang](https://jiyangzhang.github.io/), [Junyi Jessy Li](https://jessyli.com/), [Raymond J. Mooney](https://www.cs.utexas.edu/~mooney/), [Milos Gligoric](https://users.ece.utexas.edu/~gligoric/)
8 |
9 | ```bibtex
10 | @inproceedings{NieETAL22EvalMethodologies,
11 | title = {Impact of Evaluation Methodologies on Code Summarization},
12 | author = {Pengyu Nie and Jiyang Zhang and Junyi Jessy Li and Raymond J. Mooney and Milos Gligoric},
13 | pages = {to appear},
14 | booktitle = {Annual Meeting of the Association for Computational Linguistics},
15 | year = {2022},
16 | }
17 | ```
18 |
19 | ## Introduction
20 |
21 | This repo contains the code and data for producing the experiments in
22 | [Impact of Evaluation Methodologies on Code
23 | Summarization][paper-utcs]. In this work, we study the impact of
24 | evaluation methodologies, i.e., the way people split datasets into
25 | training, validation, and test sets, in the field of code
26 | summarization. We introduce the time-segmented evaluation methodology,
27 | which is novel to the code summarization research community, and
28 | compare it with the mixed-project and cross-project methodologies that
29 | have been commonly used.
30 |
31 | The code includes:
32 | * a data collection tool for collecting (method, comment) pairs with
33 | timestamps.
34 | * a data processing pipeline for splitting a dataset following the
35 | three evaluation methodologies.
36 | * scripts for running four recent machine learning models for code
37 | summarization and comparing their results across methodologies.
38 |
39 | **How to...**
40 | * **reproduce the training and evaluation of ML models on our
41 | collected dataset**: [install dependency][sec-dependency], [download
42 | all data][sec-downloads], and follow the instructions
43 | [here][sec-traineval].
44 | * **reproduce our full study from scratch**: [install
45 | dependency][sec-dependency], [download `_work/src`][sec-downloads]
46 | (the source code for the ML models used in our study), and follow
47 | the instructions to [collect data][sec-collect], [process
48 | data][sec-process], and [train and evaluate models][sec-traineval].
49 |
50 |
51 | ## Table of Contents
52 |
53 | 1. [Dependency][sec-dependency]
54 | 2. [Data Downloads][sec-downloads]
55 | 3. [Code for Collecting Data][sec-collect]
56 | 4. [Code for Processing Data][sec-process]
57 | 5. [Code for Training and Evaluating Models][sec-traineval]
58 |
59 | ## Dependency
60 | [sec-dependency]: #dependency
61 |
62 | Our code require the following hardware and software environments.
63 |
64 | * Operating system: Linux (tested on Ubuntu 20.04)
65 | * Minimum disk space: 4 GB
66 | * Python: 3.8
67 | * Java: 8
68 | * Maven: 3.6.3
69 | * Anaconda/Miniconda: appropriate versions for Python 3.8 or higher
70 |
71 | Additional requirements for training and evaluating ML models:
72 |
73 | * GPU: NVIDIA GTX 1080 or better
74 | * CUDA: 10.0 ~ 11.0
75 | * Disk space: 2 GB per trained model
76 |
77 | [Anaconda](https://www.anaconda.com/products/individual#Downloads) or
78 | [Miniconda](https://docs.conda.io/en/latest/miniconda.html) is
79 | required for installing the other Python library dependencies. Once
80 | Anaconda/Miniconda is installed, you can use the following command to
81 | setup a virtual environment, named `tseval`, with the Python library
82 | dependencies installed:
83 |
84 | ```
85 | cd python/
86 | ./prepare_conda_env.sh
87 | ```
88 |
89 | And then use `conda activate tseval` to activate the created virtual
90 | environment.
91 |
92 | The Java code `collector` will automatically be compiled as needed in
93 | our Python code. The Java library dependencies are automatically
94 | downloaded, by the Maven build system, during this process.
95 |
96 |
97 | ## Data Downloads
98 | [sec-downloads]: #data-downloads
99 |
100 | All our data is hosted on UTBox via [a shared folder](https://utexas.box.com/s/32qq85ttp9js0qqnv68ebhvr19ajo7v9).
101 |
102 | Data should be downloaded to this directory with the same directory
103 | structure (e.g., `_work/src` from the shared folder should be
104 | downloaded as `_work/src` under current directory).
105 |
106 |
107 | ## Code for Collecting Data
108 | [sec-collect]: #code-for-collecting-data
109 |
110 | ### Collect the list of popular Java projects on GitHub
111 |
112 | ```
113 | python -m tseval.main collect_repos
114 | python -m tseval.main filter_repos
115 | ```
116 |
117 | Results are generated to `results/repos/`:
118 |
119 | * `github-java-repos.json` is the full list of projects returned by
120 | the GitHub API.
121 |
122 | * `filtered-repos.json` is the list of projects filtered according to
123 | the conditions in our paper.
124 |
125 | * `*-logs.json` documents the time, configurations, and metrics of the
126 | collection/filtering of the list.
127 |
128 | Note that the list of projects may already differ from the list of
129 | projects we used, because old projects may be removed, and the
130 | ordering of projects may change.
131 |
132 | ### Collect raw dataset
133 |
134 | Requires the list of projects at `results/repos/filtered-repos.json`
135 |
136 | ```
137 | python -m tseval.main collect_raw_data
138 | ```
139 |
140 | Results are generated to `_raw_data/`. Each project's raw data
141 | is in a directory named `$user_$repo` (e.g., `apache_commons-codec`):
142 |
143 | * `method-data.json` is the list of method samples (includes code, API
144 | comments, etc.) extracted from the project at the selected revisions
145 | (at Jan 1st of 2018, 2019, 2020, 2021).
146 |
147 | * `revision-ids.json` is the mapping from revision to the method
148 | samples that are available at that revision.
149 |
150 | * `filtered-counters.json` is the count of samples discarded during
151 | collection according to our paper.
152 |
153 | * `log.txt` is the log of the collection.
154 |
155 | ## Code for Processing Data
156 | [sec-process]: #code-for-processing-data
157 |
158 |
159 | ### Process raw data to use our data structure (tseval.data.MethodData)
160 |
161 | Requires the raw data at `_raw_data/`.
162 |
163 | ```
164 | python -m tseval.main process_raw_data
165 | ```
166 |
167 | Results are generated to `_work/shared/`:
168 |
169 | * `*.jsonl` files are the dataset, where each file stores one field of
170 | all samples, and each line stores the field for one sample.
171 |
172 | * `filtered-counters.json` is the combined count of samples discarded
173 | during collection.
174 |
175 |
176 | ### Apply methodologies (non task-specific part)
177 |
178 | Requires the dataset at `_work/shared/`.
179 |
180 | ```
181 | python -m tseval.main get_splits --seed=7 --split=Full
182 | ```
183 |
184 | Results are generated to `_work/split/Full/`:
185 |
186 | * `$X-$Y.json`, where X in {MP, CP, T} and Y in {train, val, test_standard}; and
187 | `$X1-$X2-test_common.json`, where X1, X2 in {MP, CP, T}.
188 |
189 | - each file contains a list of ids.
190 |
191 | - MP = mixed-project; CP = cross-project; T = temporally.
192 |
193 | - train = training; val = validation;
194 | test_standard = standard test; test_common = common test.
195 |
196 | ### Apply methodologies (task-specific part)
197 |
198 | Requires the dataset at `_work/shared/` and the splits at
199 | `_work/split/Full/`.
200 |
201 | From this point on, we define two variables to use in our commands:
202 |
203 | * `$task`: to indicate the targeted code summarization task.
204 | - CG: comment generation.
205 | - MN: method naming.
206 |
207 | * `$method`: to indicate the methodology used.
208 | - MP: mixed-project.
209 | - CP: cross-project.
210 | - T: temporally.
211 |
212 | ```
213 | python -m tseval.main exp_prepare \
214 | --task=$task \
215 | --setup=StandardSetup \
216 | --setup_name=$method \
217 | --split_name=Full \
218 | --split_type=$method
219 | # Example: python -m tseval.main exp_prepare \
220 | # --task=CG \
221 | # --setup=StandardSetup \
222 | # --setup_name=T \
223 | # --split_name=Full \
224 | # --split_type=T
225 | ```
226 |
227 | Results are generated to `_work/$task/setup/$method/`:
228 |
229 | * `data/` contains the dataset (jsonl files) and splits (ids in
230 | Train/Val/TestT/TestC sets).
231 |
232 | * `setup_config.json` documents the configurations of this
233 | methodology.
234 |
235 | ## Code for Training and Evaluating Models
236 | [sec-traineval]: #code-for-training-and-evaluating-models
237 |
238 | ### Prepare the Python environments for ML models
239 |
240 | Requires Anaconda/Miniconda, and the models' source code at `_work/src/`.
241 |
242 | ```
243 | python -m tseval.main prepare_envs --which=$model_cls
244 | # Example: python -m tseval.main prepare_envs --which=TransformerACL20
245 | ```
246 |
247 | Where the `$model_cls` for each model can be looked up in this table
248 | (Transformer and Seq2Seq are using the same model class and
249 | environment):
250 |
251 | | $task | $model_cls | Model |
252 | |:------|:-------------------|:--------------|
253 | | CG | DeepComHybridESE19 | DeepComHybrid |
254 | | CG | TransformerACL20 | Transformer |
255 | | CG | TransformerACL20 | Seq2Seq |
256 | | MN | Code2VecPOPL19 | Code2Vec |
257 | | MN | Code2SeqICLR19 | Code2Seq |
258 |
259 | The name of the conda environment created is `tseval-$task-$model_cls`.
260 |
261 | ### Train ML models under a methodology
262 |
263 | Requires the dataset at `_work/$task/setup/$method/`, and
264 | activating the right conda environment
265 | (`conda activate tseval-$task-$model_cls`).
266 |
267 | ```
268 | python -m tseval.main exp_train \
269 | --task=$task \
270 | --setup_name=$method \
271 | --model_name=$model_cls \
272 | --exp_name=$exp_name \
273 | --seed=$seed \
274 | $model_args
275 | # Example: python -m tseval.main exp_train \
276 | # --task=CG \
277 | # --setup_name=T \
278 | # --model_name=TransformerACL20 \
279 | # --exp_name=Transformer \
280 | # --seed=4182
281 | ```
282 |
283 | Where `$exp_name` is the name of the output directory; `$seed` is the
284 | random seed (integer) to control the random process in the experiments
285 | (the `--seed=$seed` argument can be omitted for a random run using the
286 | current timestamp as seed); `$model_args` is potential additional
287 | arguments for the model and can be looked up in the following table:
288 |
289 | | $task | Model | $model_args |
290 | |:------|:--------------|:---------------|
291 | | CG | DeepComHybrid | (empty) |
292 | | CG | Transformer | (empty) |
293 | | CG | Seq2Seq | --use_rnn=True |
294 | | MN | Code2Vec | (empty) |
295 | | MN | Code2Seq | (empty) |
296 |
297 | Results are generated to `_work/$task/exp/$method/$exp_name/`:
298 |
299 | * `model/` the trained model.
300 |
301 | * Other files documents the configurations for initializing and
302 | training the model.
303 |
304 | ### Evaluate ML models
305 |
306 | Requires the dataset at `_work/$task/setup/$method/`, the trained
307 | model at `_work/$task/exp/$method/$exp_name/`, and activating the
308 | right conda environment (`conda activate tseval-$task-$model_cls`).
309 |
310 | ```
311 | for $action in val test_standard test_common; do
312 | python -m tseval.main exp_eval \
313 | --task=$task \
314 | --setup_name=$method \
315 | --exp_name=$exp_name \
316 | --action=$action
317 | done
318 | # Example: for $action in val test_standard test_common; do
319 | # python -m tseval.main exp_eval \
320 | # --task=CG \
321 | # --setup_name=T \
322 | # --exp_name=Transformer \
323 | # --action=$action
324 | #done
325 | ```
326 |
327 | Results are generated to `_work/$task/result/$method/$exp_name/`:
328 |
329 | * `$X_predictions.jsonl`: the predictions.
330 |
331 | * `$X_golds.jsonl`: the golds (ground truths).
332 |
333 | * `$X_eval_time.jsonl`: the time taken for the evaluation.
334 |
335 | * Where $X in {val, test_standard, test_common-$method-$method1 (where $method1 != $method)}.
336 |
337 |
338 | ### Compute automatic metrics
339 |
340 | Requires the evaluation results at
341 | `_work/$task/result/$method/$exp_name/`, and the use of `tseval`
342 | environment (`conda activate tseval`).
343 |
344 | ```
345 | for $action in val test_standard test_common; do
346 | python -m tseval.main exp_compute_metrics \
347 | --task=$task \
348 | --setup_name=$method \
349 | --exp_name=$exp_name \
350 | --action=$action
351 | done
352 | # Example: for $action in val test_standard test_common; do
353 | # python -m tseval.main exp_compute_metrics \
354 | # --task=CG \
355 | # --setup_name=T \
356 | # --exp_name=Transformer \
357 | # --action=$action
358 | #done
359 | ```
360 |
361 | Results are generated to `_work/$task/metric/$method/$exp_name/`:
362 |
363 | * `$X_metrics.json` and `$X_metrics.json`: the average of automatic metrics.
364 |
365 | * `$X_metrics_list.pkl`: the (compressed) list of automatic metrics per sample.
366 |
367 | * Where $X in {val, test_standard, test_common-$method-$method1 (where $method1 != $method)}.
368 |
369 |
370 |
371 | [paper-arxiv]: https://arxiv.org/abs/2108.09619
372 | [paper-utcs]: https://www.cs.utexas.edu/users/ai-lab/downloadPublication.php?filename=http://www.cs.utexas.edu/users/ml/papers/nie.acl2022.pdf&pubid=127948
373 |
--------------------------------------------------------------------------------
/collector/.gitignore:
--------------------------------------------------------------------------------
1 | # Additional IntelliJ
2 | *.iml
3 | .idea/
4 |
5 |
6 | ### IntelliJ
7 |
8 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
9 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
10 |
11 | # User-specific stuff
12 | .idea/**/workspace.xml
13 | .idea/**/tasks.xml
14 | .idea/**/usage.statistics.xml
15 | .idea/**/dictionaries
16 | .idea/**/shelf
17 |
18 | # Generated files
19 | .idea/**/contentModel.xml
20 |
21 | # Sensitive or high-churn files
22 | .idea/**/dataSources/
23 | .idea/**/dataSources.ids
24 | .idea/**/dataSources.local.xml
25 | .idea/**/sqlDataSources.xml
26 | .idea/**/dynamic.xml
27 | .idea/**/uiDesigner.xml
28 | .idea/**/dbnavigator.xml
29 |
30 | # Gradle
31 | .idea/**/gradle.xml
32 | .idea/**/libraries
33 |
34 | # Gradle and Maven with auto-import
35 | # When using Gradle or Maven with auto-import, you should exclude module files,
36 | # since they will be recreated, and may cause churn. Uncomment if using
37 | # auto-import.
38 | # .idea/modules.xml
39 | # .idea/*.iml
40 | # .idea/modules
41 | # *.iml
42 | # *.ipr
43 |
44 | # CMake
45 | cmake-build-*/
46 |
47 | # Mongo Explorer plugin
48 | .idea/**/mongoSettings.xml
49 |
50 | # File-based project format
51 | *.iws
52 |
53 | # IntelliJ
54 | out/
55 |
56 | # mpeltonen/sbt-idea plugin
57 | .idea_modules/
58 |
59 | # JIRA plugin
60 | atlassian-ide-plugin.xml
61 |
62 | # Cursive Clojure plugin
63 | .idea/replstate.xml
64 |
65 | # Crashlytics plugin (for Android Studio and IntelliJ)
66 | com_crashlytics_export_strings.xml
67 | crashlytics.properties
68 | crashlytics-build.properties
69 | fabric.properties
70 |
71 | # Editor-based Rest Client
72 | .idea/httpRequests
73 |
74 | # Android studio 3.1+ serialized cache file
75 | .idea/caches/build_file_checksums.ser
76 |
77 |
78 | ### Java
79 |
80 | # Compiled class file
81 | *.class
82 |
83 | # Log file
84 | *.log
85 |
86 | # BlueJ files
87 | *.ctxt
88 |
89 | # Mobile Tools for Java (J2ME)
90 | .mtj.tmp/
91 |
92 | # Package Files #
93 | *.jar
94 | *.war
95 | *.nar
96 | *.ear
97 | *.zip
98 | *.tar.gz
99 | *.rar
100 |
101 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
102 | hs_err_pid*
103 |
104 |
105 | ### Maven
106 |
107 | target/
108 | pom.xml.tag
109 | pom.xml.releaseBackup
110 | pom.xml.versionsBackup
111 | pom.xml.next
112 | release.properties
113 | dependency-reduced-pom.xml
114 | buildNumber.properties
115 | .mvn/timing.properties
116 | # https://github.com/takari/maven-wrapper#usage-without-binary-jar
117 | .mvn/wrapper/maven-wrapper.jar
118 |
--------------------------------------------------------------------------------
/collector/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 | org.tseval
5 | collector
6 | jar
7 | 0.1-dev
8 | collector
9 | http://maven.apache.org
10 |
11 |
12 | 1.8
13 | 1.8
14 |
15 |
16 |
17 |
18 |
19 | org.apache.commons
20 | commons-text
21 | 1.8
22 |
23 |
24 |
25 | com.google.code.gson
26 | gson
27 | 2.9.0
28 |
29 |
30 |
31 | com.github.javaparser
32 | javaparser-symbol-solver-core
33 | 3.15.12
34 |
35 |
36 |
37 | edu.stanford.nlp
38 | stanford-corenlp
39 | 4.4.0
40 |
41 |
42 |
43 | junit
44 | junit
45 | 4.12
46 | test
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | org.apache.maven.plugins
56 | maven-shade-plugin
57 | 3.2.1
58 |
59 |
60 | package
61 |
62 | shade
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/ExtractToken.java:
--------------------------------------------------------------------------------
1 | package org.tseval;
2 |
3 | import com.github.javaparser.ParseProblemException;
4 | import com.github.javaparser.StaticJavaParser;
5 | import com.github.javaparser.ast.Node;
6 | import com.github.javaparser.ast.body.MethodDeclaration;
7 | import com.github.javaparser.ast.type.ClassOrInterfaceType;
8 | import com.github.javaparser.printer.PrettyPrinterConfiguration;
9 | import com.google.gson.Gson;
10 | import com.google.gson.GsonBuilder;
11 | import com.google.gson.JsonArray;
12 | import com.google.gson.JsonDeserializer;
13 | import com.google.gson.JsonObject;
14 | import com.google.gson.JsonParseException;
15 | import com.google.gson.JsonSerializer;
16 | import com.google.gson.reflect.TypeToken;
17 |
18 | import java.io.BufferedReader;
19 | import java.io.BufferedWriter;
20 | import java.io.IOException;
21 | import java.lang.reflect.InvocationTargetException;
22 | import java.lang.reflect.Method;
23 | import java.nio.file.Files;
24 | import java.nio.file.Path;
25 | import java.nio.file.Paths;
26 | import java.util.Arrays;
27 | import java.util.LinkedList;
28 | import java.util.List;
29 |
30 | /**
31 | * Utility class for extracting tokens from methods.
32 | * Used for tokenization and calculating metrics.
33 | */
34 | public class ExtractToken {
35 |
36 | /**
37 | * Input data format class for {@link ExtractToken}.
38 | */
39 | public static class InputIndexAndCode {
40 | public int index;
41 | public String code;
42 |
43 | public static final JsonDeserializer sDeserializer = getDeserializer();
44 | public static JsonDeserializer getDeserializer() {
45 | return (json, type, context) -> {
46 | try {
47 | InputIndexAndCode obj = new InputIndexAndCode();
48 | JsonObject jObj = json.getAsJsonObject();
49 | obj.index = jObj.get("index").getAsInt();
50 | obj.code = jObj.get("code").getAsString();
51 | return obj;
52 | } catch (IllegalStateException e) {
53 | throw new JsonParseException(e);
54 | }
55 | };
56 | }
57 | }
58 |
59 | public static class OutputData {
60 | public int index;
61 | public List tokens;
62 |
63 | public static final JsonSerializer sSerializer = getSerializer();
64 | public static JsonSerializer getSerializer() {
65 | return (d, type, jsonSerializationContext) -> {
66 | JsonObject jObj = new JsonObject();
67 | jObj.addProperty("index", d.index);
68 | JsonArray jTokens = new JsonArray();
69 | for (String token: d.tokens) {
70 | jTokens.add(token);
71 | }
72 | jObj.add("tokens", jTokens);
73 | return jObj;
74 | };
75 | }
76 | }
77 |
78 | // Gson: For json (de)serialization
79 | private static final Gson GSON = new GsonBuilder()
80 | .disableHtmlEscaping()
81 | .registerTypeAdapter(InputIndexAndCode.class, InputIndexAndCode.sDeserializer)
82 | .registerTypeAdapter(OutputData.class, OutputData.sSerializer)
83 | .create();
84 |
85 | // For JavaParser pretty-printing AST to code without comments
86 | private static final PrettyPrinterConfiguration METHOD_PPRINT_CONFIG = new PrettyPrinterConfiguration();
87 | static {
88 | METHOD_PPRINT_CONFIG
89 | .setPrintJavadoc(false)
90 | .setPrintComments(false);
91 | }
92 |
93 | private static Node parseWhatever(String code) {
94 | // First try to parse as Compilation Unit
95 | try {
96 | return StaticJavaParser.parse(code);
97 | } catch (ParseProblemException ignored) {
98 | }
99 |
100 | // Then, try several types
101 | List> possibleTypes = Arrays.asList(
102 | MethodDeclaration.class,
103 | ClassOrInterfaceType.class
104 | );
105 | for (Class> t : possibleTypes) {
106 | try {
107 | Method parseMethod = StaticJavaParser.class.getMethod("parse" + t.getSimpleName(), String.class);
108 | Node n = (Node) parseMethod.invoke(null, code);
109 | return n;
110 | } catch (NoSuchMethodException | IllegalAccessException e) {
111 | throw new RuntimeException(e);
112 | } catch (InvocationTargetException ignored) {
113 | }
114 | }
115 |
116 | // If all fails, return null
117 | return null;
118 | }
119 |
120 |
121 | /**
122 | * Main entry point.
123 | *
124 | * @param args expect exactly two arguments: the input file path, the output file path.
125 | */
126 | public static void main(String... args) {
127 | // Load arguments
128 | if (args.length != 2) {
129 | throw new RuntimeException("Args: inputPath outputPath");
130 | }
131 | Path inputPath = Paths.get(args[0]);
132 | Path outputPath = Paths.get(args[1]);
133 |
134 | // Load inputs
135 | List inputList;
136 | try (BufferedReader r = Files.newBufferedReader(inputPath)) {
137 | inputList = GSON.fromJson(
138 | r,
139 | TypeToken.getParameterized(
140 | TypeToken.get(List.class).getType(),
141 | TypeToken.get(InputIndexAndCode.class).getType()
142 | ).getType()
143 | );
144 | } catch (IOException e) {
145 | throw new RuntimeException(e);
146 | }
147 |
148 | // Get tokens
149 | List outputList = new LinkedList<>();
150 | for (InputIndexAndCode input: inputList) {
151 | OutputData output = new OutputData();
152 | Node parsed = parseWhatever(input.code);
153 |
154 | if (parsed == null) {
155 | // Return empty list, representing the code is non-parsable
156 | output.index = input.index;
157 | output.tokens = new LinkedList<>();
158 | } else {
159 | List tokens = new LinkedList<>();
160 | parsed.getTokenRange().get()
161 | .forEach(t -> {
162 | if (!t.getCategory().isWhitespaceOrComment()) {
163 | tokens.add(t.getText());
164 | }
165 | });
166 | output.index = input.index;
167 | output.tokens = tokens;
168 | }
169 |
170 | outputList.add(output);
171 | }
172 |
173 | // Save outputs
174 | try (BufferedWriter w = Files.newBufferedWriter(outputPath)) {
175 | w.write(GSON.toJson(
176 | outputList,
177 | TypeToken.getParameterized(
178 | TypeToken.get(List.class).getType(),
179 | TypeToken.get(OutputData.class).getType()
180 | ).getType()
181 | ));
182 | } catch (IOException e) {
183 | throw new RuntimeException(e);
184 | }
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/MaskMethodNameVisitor.java:
--------------------------------------------------------------------------------
1 | package org.tseval;
2 |
3 | import com.github.javaparser.ast.body.MethodDeclaration;
4 | import com.github.javaparser.ast.expr.MethodCallExpr;
5 | import com.github.javaparser.ast.expr.SimpleName;
6 | import com.github.javaparser.ast.visitor.ModifierVisitor;
7 | import com.github.javaparser.ast.visitor.Visitable;
8 |
9 | public class MaskMethodNameVisitor extends ModifierVisitor {
10 |
11 | public static final MaskMethodNameVisitor sVisitor = new MaskMethodNameVisitor();
12 |
13 | private static final SimpleName MASK = new SimpleName("");
14 |
15 | public static class Context {
16 | public String name;
17 | public Context(String name) {
18 | this.name = name;
19 | }
20 | }
21 |
22 | @Override
23 | public Visitable visit(MethodDeclaration n, Context arg) {
24 | MethodDeclaration ret = (MethodDeclaration) super.visit(n, arg);
25 |
26 | if (n.getNameAsString().equals(arg.name)) {
27 | ret.setName(MASK.clone());
28 | }
29 |
30 | return ret;
31 | }
32 |
33 | @Override
34 | public Visitable visit(MethodCallExpr n, Context arg) {
35 | MethodCallExpr ret = (MethodCallExpr) super.visit(n, arg);
36 |
37 | if (n.getNameAsString().equals(arg.name)) {
38 | ret.setName(MASK.clone());
39 | }
40 |
41 | return ret;
42 | }
43 | }
44 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/MethodDataCollector.java:
--------------------------------------------------------------------------------
1 | package org.tseval;
2 |
3 | import com.github.javaparser.ParseProblemException;
4 | import com.github.javaparser.StaticJavaParser;
5 | import com.github.javaparser.ast.CompilationUnit;
6 | import com.google.gson.Gson;
7 | import com.google.gson.GsonBuilder;
8 | import com.google.gson.stream.JsonWriter;
9 | import org.tseval.MethodDataCollectorVisitor.FilteredReason;
10 | import org.tseval.data.MethodData;
11 | import org.tseval.data.RevisionIds;
12 | import org.tseval.util.BashUtils;
13 | import org.tseval.util.AbstractConfig;
14 | import org.tseval.util.Option;
15 |
16 | import java.io.BufferedWriter;
17 | import java.io.IOException;
18 | import java.nio.file.Files;
19 | import java.nio.file.Path;
20 | import java.nio.file.Paths;
21 | import java.nio.file.StandardOpenOption;
22 | import java.util.Arrays;
23 | import java.util.Comparator;
24 | import java.util.HashMap;
25 | import java.util.LinkedList;
26 | import java.util.List;
27 | import java.util.Map;
28 | import java.util.Objects;
29 | import java.util.stream.Collectors;
30 |
31 | public class MethodDataCollector {
32 |
33 | public static class Config extends AbstractConfig {
34 |
35 | // A list of shas, separated by space
36 | @Option public String shas;
37 | @Option public Path projectDir;
38 | @Option public Path outputDir;
39 | @Option public Path logFile;
40 |
41 | public boolean repOk() {
42 | if (projectDir == null || outputDir == null || logFile == null) {
43 | return false;
44 | }
45 |
46 | if (shas == null || shas.length() == 0) {
47 | return false;
48 | }
49 |
50 | return true;
51 | }
52 | }
53 | public static Config sConfig;
54 |
55 | public static final Gson GSON;
56 | public static final Gson GSON_NO_PPRINT;
57 | static {
58 | GsonBuilder gsonBuilder = new GsonBuilder()
59 | .disableHtmlEscaping()
60 | .serializeNulls()
61 | .registerTypeAdapter(MethodData.class, MethodData.sSerDeser)
62 | .registerTypeAdapter(RevisionIds.class, RevisionIds.sSerializer);
63 | GSON_NO_PPRINT = gsonBuilder.create();
64 | gsonBuilder.setPrettyPrinting();
65 | GSON = gsonBuilder.create();
66 | }
67 |
68 | private static Map sMethodDataIdHashMap = new HashMap<>();
69 | private static int sCurrentMethodDataId = 0;
70 | private static Map> sFileCache = new HashMap<>();
71 | private static Map sFilteredCounters = MethodDataCollectorVisitor.initFilteredCounters();
72 |
73 | private static JsonWriter sMethodDataWriter;
74 | private static JsonWriter sMethodProjectRevisionWriter;
75 |
76 | public static void main(String... args) {
77 | if (args.length != 1) {
78 | System.err.println("Exactly one argument, the path to the json config, is required");
79 | System.exit(-1);
80 | }
81 |
82 | sConfig = AbstractConfig.load(Paths.get(args[0]), Config.class);
83 | collect();
84 | }
85 |
86 | public static void collect() {
87 | try {
88 | // Init the writers for saving
89 | sMethodDataWriter = GSON.newJsonWriter(Files.newBufferedWriter(sConfig.outputDir.resolve("method-data.json")));
90 | sMethodDataWriter.beginArray();
91 | sMethodProjectRevisionWriter = GSON_NO_PPRINT.newJsonWriter(Files.newBufferedWriter(sConfig.outputDir.resolve("revision-ids.json")));
92 | sMethodProjectRevisionWriter.beginArray();
93 |
94 | // Collect for each sha (chronological order)
95 | for (String sha: sConfig.shas.split(" ")) {
96 | log("Sha " + sha);
97 | collectSha(sha);
98 | }
99 |
100 | // Save filtered counters
101 | JsonWriter filteredCountersWriter = GSON.newJsonWriter(Files.newBufferedWriter(sConfig.outputDir.resolve("filtered-counters.json")));
102 | filteredCountersWriter.beginObject();
103 | for (FilteredReason fr : FilteredReason.values()) {
104 | filteredCountersWriter.name(fr.getKey());
105 | filteredCountersWriter.value(sFilteredCounters.get(fr));
106 | }
107 | filteredCountersWriter.endObject();
108 | filteredCountersWriter.close();
109 |
110 | // Close writers
111 | sMethodDataWriter.endArray();
112 | sMethodDataWriter.close();
113 | sMethodProjectRevisionWriter.endArray();
114 | sMethodProjectRevisionWriter.close();
115 | } catch (IOException e) {
116 | throw new RuntimeException(e);
117 | }
118 | }
119 |
120 | public static void log(String msg) {
121 | if (sConfig.logFile != null) {
122 | try (BufferedWriter fw = Files.newBufferedWriter(sConfig.logFile, StandardOpenOption.APPEND, StandardOpenOption.CREATE)) {
123 | fw.write("[" + Thread.currentThread().getId() + "]" + msg + "\n");
124 | // (new Throwable()).printStackTrace(fos);
125 | } catch (IOException e) {
126 | System.err.println("Couldn't log to " + sConfig.logFile);
127 | System.exit(-1);
128 | }
129 | }
130 | }
131 |
132 | private static void collectSha(String sha) throws IOException {
133 | // Check out the sha
134 | BashUtils.run("cd " + sConfig.projectDir + " && git checkout -f " + sha, 0);
135 |
136 | // Find all java files
137 | List javaFiles = Files.walk(sConfig.projectDir)
138 | .filter(Files::isRegularFile)
139 | .filter(p -> p.toString().endsWith(".java"))
140 | .sorted(Comparator.comparing(Object::toString))
141 | .collect(Collectors.toList());
142 | log("In revision " + sha +", got " + javaFiles.size() + " files to parse");
143 |
144 | // For each java file, parse and get methods
145 | MethodDataCollectorVisitor visitor = new MethodDataCollectorVisitor();
146 | List idsRevision = new LinkedList<>();
147 | int parseErrorCount = 0;
148 | int filteredCount = 0;
149 | int reuseFileCount = 0;
150 | int parseFileCount = 0;
151 | for (Path javaFile : javaFiles) {
152 | // Skip parsing identical files, just add the ids
153 | int fileHash = getFileHash(javaFile);
154 | List idsFile = sFileCache.get(fileHash);
155 |
156 | if (idsFile == null) {
157 | // Actually parse this file and collect ids
158 | idsFile = new LinkedList<>();
159 | String path = sConfig.projectDir.relativize(javaFile).toString();
160 |
161 | MethodDataCollectorVisitor.Context context = new MethodDataCollectorVisitor.Context();
162 | try {
163 | CompilationUnit cu = StaticJavaParser.parse(javaFile);
164 | cu.accept(visitor, context);
165 | } catch (ParseProblemException e) {
166 | ++parseErrorCount;
167 | continue;
168 | }
169 |
170 | for (FilteredReason fr : FilteredReason.values()) {
171 | sFilteredCounters.compute(fr, (k, v) -> v + context.filteredCounters.get(k));
172 | filteredCount += context.filteredCounters.get(fr);
173 | }
174 |
175 | for (MethodData methodData : context.methodDataList) {
176 | // Reuse (for duplicate data) or allocate the data id
177 | methodData.path = path;
178 | int methodId = addMethodData(methodData);
179 | idsFile.add(methodId);
180 | }
181 |
182 | // Update file cache
183 | sFileCache.put(fileHash, idsFile);
184 | ++parseFileCount;
185 | } else {
186 | ++reuseFileCount;
187 | }
188 |
189 | idsRevision.addAll(idsFile);
190 | }
191 |
192 | // Create and save MethodProjectRevision
193 | RevisionIds revisionIds = new RevisionIds();
194 | revisionIds.revision = sha;
195 | revisionIds.methodIds = idsRevision;
196 | addRevisionIds(revisionIds);
197 |
198 | log("Parsed " + parseFileCount + " files. " +
199 | "Reused " + reuseFileCount + " files. " +
200 | "Parsing error for " + parseErrorCount + " files. " +
201 | "Ignored " + filteredCount + " methods. " +
202 | "Total collected " + sMethodDataIdHashMap.size() + " methods.");
203 | }
204 |
205 |
206 | private static int getFileHash(Path javaFile) throws IOException {
207 | // Hash both the path and the content
208 | return Objects.hash(javaFile.toString(), Arrays.hashCode(Files.readAllBytes(javaFile)));
209 | }
210 |
211 | private static int addMethodData(MethodData methodData) {
212 | // Don't duplicate previous appeared methods (keys: path, code, comment)
213 | int hash = Objects.hash(methodData.path, methodData.code, methodData.comment);
214 | Integer prevMethodDataId = sMethodDataIdHashMap.get(hash);
215 | if (prevMethodDataId != null) {
216 | // If this method org.csevo.data already existed before, retrieve its id
217 | return prevMethodDataId;
218 | } else {
219 | // Allocate a new id and save this org.csevo.data to the hash map
220 | methodData.id = sCurrentMethodDataId;
221 | ++sCurrentMethodDataId;
222 | sMethodDataIdHashMap.put(hash, methodData.id);
223 |
224 | // Save the method org.csevo.data
225 | GSON.toJson(methodData, MethodData.class, sMethodDataWriter);
226 | return methodData.id;
227 | }
228 | }
229 |
230 | private static void addRevisionIds(RevisionIds revisionIds) {
231 | // Directly write to file
232 | GSON_NO_PPRINT.toJson(revisionIds, RevisionIds.class, sMethodProjectRevisionWriter);
233 | }
234 | }
235 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/MethodDataCollectorVisitor.java:
--------------------------------------------------------------------------------
1 | package org.tseval;
2 |
3 | import com.github.javaparser.ast.PackageDeclaration;
4 | import com.github.javaparser.ast.body.AnnotationDeclaration;
5 | import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
6 | import com.github.javaparser.ast.body.EnumDeclaration;
7 | import com.github.javaparser.ast.body.MethodDeclaration;
8 | import com.github.javaparser.ast.body.Parameter;
9 | import com.github.javaparser.ast.body.TypeDeclaration;
10 | import com.github.javaparser.ast.visitor.VoidVisitorAdapter;
11 | import com.github.javaparser.javadoc.Javadoc;
12 | import com.github.javaparser.javadoc.JavadocBlockTag;
13 | import com.github.javaparser.javadoc.description.JavadocDescription;
14 | import com.github.javaparser.javadoc.description.JavadocDescriptionElement;
15 | import com.github.javaparser.javadoc.description.JavadocInlineTag;
16 | import com.github.javaparser.printer.PrettyPrinterConfiguration;
17 | import org.apache.commons.lang3.tuple.Pair;
18 | import org.tseval.data.MethodData;
19 | import org.tseval.util.NLPUtils;
20 |
21 | import java.util.HashMap;
22 | import java.util.LinkedList;
23 | import java.util.List;
24 | import java.util.Map;
25 |
26 | public class MethodDataCollectorVisitor extends VoidVisitorAdapter {
27 |
28 | private static final int METHOD_LENGTH_MAX = 10_000;
29 |
30 | public static class Context {
31 | String packageName = "";
32 | String className = "";
33 | List methodDataList = new LinkedList<>();
34 | Map filteredCounters = initFilteredCounters();
35 | }
36 |
37 | public enum FilteredReason {
38 | CODE_TOO_LONG("code_too_long"),
39 | CODE_NON_ENGLISH("code_non_english"),
40 | COMMENT_NON_ENGLISH("comment_non_english"),
41 | EMPTY_BODY("empty_body"),
42 | EMPTY_COMMENT_SUMMARY("empty_comment_summary"),
43 | EMPTY_COMMENT("empty_comment");
44 |
45 | private final String key;
46 | FilteredReason(String key) {
47 | this.key = key;
48 | }
49 |
50 | public String getKey() {
51 | return key;
52 | }
53 | }
54 |
55 | public static Map initFilteredCounters() {
56 | Map filteredCounters = new HashMap<>();
57 | for (FilteredReason fr : FilteredReason.values()) {
58 | filteredCounters.put(fr, 0);
59 | }
60 | return filteredCounters;
61 | }
62 |
63 | private static PrettyPrinterConfiguration METHOD_PPRINT_CONFIG = new PrettyPrinterConfiguration();
64 | static {
65 | METHOD_PPRINT_CONFIG
66 | .setPrintJavadoc(false)
67 | .setPrintComments(false);
68 | }
69 |
70 | @Override
71 | public void visit(ClassOrInterfaceDeclaration n, Context context) {
72 | commonVisitTypeDeclaration(n, context);
73 | super.visit(n, context);
74 | }
75 |
76 | @Override
77 | public void visit(AnnotationDeclaration n, Context context) {
78 | commonVisitTypeDeclaration(n, context);
79 | super.visit(n, context);
80 | }
81 |
82 | @Override
83 | public void visit(EnumDeclaration n, Context context) {
84 | commonVisitTypeDeclaration(n, context);
85 | super.visit(n, context);
86 | }
87 |
88 | public void commonVisitTypeDeclaration(TypeDeclaration> n, Context context) {
89 | // Update context class name
90 | context.className = n.getNameAsString();
91 | }
92 |
93 | @Override
94 | public void visit(PackageDeclaration n, Context context) {
95 | // Update context package name
96 | context.packageName = n.getNameAsString();
97 | super.visit(n, context);
98 | }
99 |
100 | private String dollaryClassName(TypeDeclaration> n) {
101 | if (n.isNestedType()) {
102 | return dollaryClassName((TypeDeclaration>) n.getParentNode().get()) + "$" + n.getNameAsString();
103 | }
104 | return n.getNameAsString();
105 | }
106 |
107 | @Override
108 | public void visit(MethodDeclaration n, Context context) {
109 | MethodData methodData = new MethodData();
110 |
111 | if (!n.getBody().isPresent() || n.getBody().get().isEmpty()) {
112 | // Ignore if no/empty method body
113 | context.filteredCounters.compute(FilteredReason.EMPTY_BODY, (k, v) -> v+1);
114 | return;
115 | }
116 |
117 | methodData.name = n.getNameAsString();
118 |
119 | if (n.getJavadoc().isPresent()) {
120 | Javadoc javadoc = n.getJavadoc().get();
121 | methodData.comment = javadoc.toText();
122 |
123 | if (!NLPUtils.isValidISOLatin(methodData.comment)) {
124 | // Ignore if comment is not English
125 | context.filteredCounters.compute(FilteredReason.COMMENT_NON_ENGLISH, (k, v) -> v+1);
126 | return;
127 | }
128 |
129 | methodData.commentSummary = NLPUtils.getFirstSentence(javadocDescToTextNoInlineTags(javadoc.getDescription())).orElse(null);
130 |
131 | if (methodData.commentSummary == null) {
132 | // Ignore if the comment summary is empty
133 | context.filteredCounters.compute(FilteredReason.EMPTY_COMMENT_SUMMARY, (k, v) -> v+1);
134 | return;
135 | }
136 | } else {
137 | // Ignore if comment is empty
138 | context.filteredCounters.compute(FilteredReason.EMPTY_COMMENT, (k, v) -> v+1);
139 | return;
140 | }
141 |
142 | methodData.code = n.toString(METHOD_PPRINT_CONFIG);
143 |
144 | // Ignore if method is too long
145 | if (methodData.code.length() > METHOD_LENGTH_MAX) {
146 | context.filteredCounters.compute(FilteredReason.CODE_TOO_LONG, (k, v) -> v+1);
147 | return;
148 | }
149 |
150 | // Ignore if method is not English
151 | if (!NLPUtils.isValidISOLatin(methodData.code)) {
152 | context.filteredCounters.compute(FilteredReason.CODE_NON_ENGLISH, (k, v) -> v+1);
153 | return;
154 | }
155 |
156 | MethodDeclaration maskedMD = n.clone();
157 | maskedMD.accept(MaskMethodNameVisitor.sVisitor, new MaskMethodNameVisitor.Context(methodData.name));
158 | methodData.codeMasked = maskedMD.toString(METHOD_PPRINT_CONFIG);
159 |
160 | try {
161 | methodData.cname = dollaryClassName((TypeDeclaration>) n.getParentNode().get());
162 | }
163 | catch (Exception e) {
164 | methodData.cname = context.className;
165 | }
166 |
167 | if (!context.packageName.isEmpty()) {
168 | methodData.qcname = context.packageName + "." + methodData.cname;
169 | } else {
170 | methodData.qcname = methodData.cname;
171 | }
172 |
173 | methodData.ret = n.getType().asString();
174 | for (Parameter param : n.getParameters()) {
175 | methodData.params.add(Pair.of(param.getType().asString(), param.getNameAsString()));
176 | }
177 |
178 | context.methodDataList.add(methodData);
179 |
180 | // Note: no recursive visiting the body of this method
181 | }
182 |
183 | static String javadocDescToTextNoInlineTags(JavadocDescription desc) {
184 | StringBuilder sb = new StringBuilder();
185 | for (JavadocDescriptionElement e : desc.getElements()) {
186 | if (e instanceof JavadocInlineTag) {
187 | sb.append(((JavadocInlineTag) e).getContent());
188 | } else {
189 | sb.append(e.toText());
190 | }
191 | }
192 | return sb.toString();
193 | }
194 | }
195 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/data/MethodData.java:
--------------------------------------------------------------------------------
1 | package org.tseval.data;
2 |
3 | import com.google.gson.JsonArray;
4 | import com.google.gson.JsonDeserializationContext;
5 | import com.google.gson.JsonDeserializer;
6 | import com.google.gson.JsonElement;
7 | import com.google.gson.JsonObject;
8 | import com.google.gson.JsonParseException;
9 | import com.google.gson.JsonSerializationContext;
10 | import com.google.gson.JsonSerializer;
11 | import org.apache.commons.lang3.tuple.Pair;
12 |
13 | import java.lang.reflect.Type;
14 | import java.util.ArrayList;
15 | import java.util.List;
16 |
17 | public class MethodData {
18 |
19 | public int id;
20 | public String prj;
21 |
22 | public String name;
23 | public String code;
24 | public String codeMasked;
25 | public String comment;
26 | public String commentSummary;
27 | public String cname;
28 | public String qcname;
29 | public String path;
30 | public String ret;
31 | public List> params = new ArrayList<>();
32 |
33 | // Serialization
34 | public static JsonSerializer sSerializer = getSerializer();
35 |
36 | public static JsonSerializer getSerializer() {
37 | return (obj, type, jsonSerializationContext) -> {
38 | JsonObject jObj = new JsonObject();
39 |
40 | jObj.addProperty("id", obj.id);
41 | jObj.addProperty("prj", obj.prj);
42 | jObj.addProperty("name", obj.name);
43 | jObj.addProperty("code", obj.code);
44 | jObj.addProperty("code_masked", obj.codeMasked);
45 | jObj.addProperty("comment", obj.comment);
46 | jObj.addProperty("comment_summary", obj.commentSummary);
47 | jObj.addProperty("cname", obj.cname);
48 | jObj.addProperty("qcname", obj.qcname);
49 | jObj.addProperty("path", obj.path);
50 | jObj.addProperty("ret", obj.ret);
51 | JsonArray aParams = new JsonArray();
52 | for (Pair param : obj.params) {
53 | JsonArray aParam = new JsonArray();
54 | aParam.add(param.getLeft());
55 | aParam.add(param.getRight());
56 | aParams.add(aParam);
57 | }
58 | jObj.add("params", aParams);
59 |
60 | return jObj;
61 | };
62 | }
63 |
64 | // Deserialization
65 | public static final JsonDeserializer sDeserializer = getDeserializer();
66 |
67 | public static JsonDeserializer getDeserializer() {
68 | return (json, type, context) -> {
69 | try {
70 | MethodData obj = new MethodData();
71 |
72 | JsonObject jObj = json.getAsJsonObject();
73 | obj.id = jObj.get("id").getAsInt();
74 | obj.prj = jObj.get("prj").getAsString();
75 | obj.name = jObj.get("name").getAsString();
76 | obj.code = jObj.get("code").getAsString();
77 | obj.codeMasked = jObj.get("code_masked").getAsString();
78 | obj.comment = jObj.get("comment").getAsString();
79 | obj.commentSummary = jObj.get("comment_summary").getAsString();
80 | obj.cname = jObj.get("cname").getAsString();
81 | obj.qcname = jObj.get("qcname").getAsString();
82 | obj.path = jObj.get("path").getAsString();
83 | obj.ret = jObj.get("ret").getAsString();
84 | JsonArray aParams = jObj.getAsJsonArray("params");
85 | for (JsonElement aParamElem : aParams) {
86 | JsonArray aParam = aParamElem.getAsJsonArray();
87 | obj.params.add(Pair.of(aParam.get(0).getAsString(), aParam.get(1).getAsString()));
88 | }
89 |
90 | return obj;
91 | } catch (IllegalStateException e) {
92 | throw new JsonParseException(e);
93 | }
94 | };
95 | }
96 |
97 | // Ser+Deser
98 | public static SerDeser sSerDeser = new SerDeser();
99 | private static class SerDeser implements JsonSerializer, JsonDeserializer {
100 | @Override
101 | public MethodData deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException {
102 | return sDeserializer.deserialize(json, typeOfT, context);
103 | }
104 |
105 | @Override
106 | public JsonElement serialize(MethodData src, Type typeOfSrc, JsonSerializationContext context) {
107 | return sSerializer.serialize(src, typeOfSrc, context);
108 | }
109 | }
110 | }
111 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/data/ProjectData.java:
--------------------------------------------------------------------------------
1 | package org.tseval.data;
2 |
3 | import com.google.gson.JsonDeserializer;
4 | import com.google.gson.JsonElement;
5 | import com.google.gson.JsonObject;
6 | import com.google.gson.JsonParseException;
7 |
8 | import java.util.HashMap;
9 | import java.util.LinkedList;
10 | import java.util.List;
11 | import java.util.Map;
12 |
13 | public class ProjectData {
14 |
15 | public String name;
16 | public String url;
17 | public List revisions = new LinkedList<>();
18 | public Map> parentRevisions = new HashMap<>();
19 | public Map yearRevisions = new HashMap<>();
20 |
21 | // Deserialization
22 | public static final JsonDeserializer sDeserializer = getDeserializer();
23 |
24 | public static JsonDeserializer getDeserializer() {
25 | return (json, type, context) -> {
26 | try {
27 | ProjectData obj = new ProjectData();
28 |
29 | JsonObject jObj = json.getAsJsonObject();
30 | obj.name = jObj.get("name").getAsString();
31 | obj.url = jObj.get("url").getAsString();
32 |
33 | // revisions
34 | for (JsonElement eRevision : jObj.get("revisions").getAsJsonArray()) {
35 | obj.revisions.add(eRevision.getAsString());
36 | }
37 |
38 | // parent revisions
39 | JsonObject jObjParentRevisions = jObj.get("parent_revisions").getAsJsonObject();
40 | for (Map.Entry entry : jObjParentRevisions.entrySet()) {
41 | List parentRevisions = new LinkedList<>();
42 | for (JsonElement eParent : entry.getValue().getAsJsonArray()) {
43 | parentRevisions.add(eParent.getAsString());
44 | }
45 | obj.parentRevisions.put(entry.getKey(), parentRevisions);
46 | }
47 |
48 | // year revisions
49 | JsonObject jObjYearRevisions = jObj.get("year_revisions").getAsJsonObject();
50 | for (Map.Entry entry : jObjYearRevisions.entrySet()) {
51 | obj.yearRevisions.put(entry.getKey(), entry.getValue().getAsString());
52 | }
53 |
54 |
55 | return obj;
56 | } catch (IllegalStateException e) {
57 | throw new JsonParseException(e);
58 | }
59 | };
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/data/RevisionIds.java:
--------------------------------------------------------------------------------
1 | package org.tseval.data;
2 |
3 | import com.google.gson.JsonArray;
4 | import com.google.gson.JsonObject;
5 | import com.google.gson.JsonSerializer;
6 |
7 | import java.util.LinkedList;
8 | import java.util.List;
9 |
10 | public class RevisionIds {
11 |
12 | public String revision;
13 | public List methodIds = new LinkedList<>();
14 |
15 | // Serialization
16 | public static JsonSerializer sSerializer = getSerializer();
17 |
18 | public static JsonSerializer getSerializer() {
19 | return (obj, type, jsonSerializationContext) -> {
20 | JsonObject jObj = new JsonObject();
21 |
22 | jObj.addProperty("revision", obj.revision);
23 |
24 | JsonArray aMethodIds = new JsonArray();
25 | for (int methodId : obj.methodIds) {
26 | aMethodIds.add(methodId);
27 | }
28 | jObj.add("method_ids", aMethodIds);
29 |
30 | return jObj;
31 | };
32 | }
33 | }
34 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/util/AbstractConfig.java:
--------------------------------------------------------------------------------
1 | package org.tseval.util;
2 |
3 | import com.google.gson.JsonElement;
4 | import com.google.gson.JsonParser;
5 | import org.apache.commons.lang3.tuple.Pair;
6 |
7 | import java.io.FileNotFoundException;
8 | import java.io.FileReader;
9 | import java.io.Reader;
10 | import java.lang.reflect.Field;
11 | import java.nio.file.Path;
12 | import java.nio.file.Paths;
13 | import java.util.Collections;
14 | import java.util.HashMap;
15 | import java.util.Map;
16 | import java.util.function.Function;
17 |
18 | public abstract class AbstractConfig {
19 |
20 | public boolean repOk() {
21 | return true;
22 | }
23 |
24 | public void autoInfer() {}
25 |
26 | private static final Map, Function> SUPPORTED_CLASSES_2_READERS;
27 | static {
28 | Map, Function> supportedClasses2Readers = new HashMap<>();
29 | supportedClasses2Readers.put(Boolean.TYPE, JsonElement::getAsBoolean);
30 | supportedClasses2Readers.put(Boolean.class, JsonElement::getAsBoolean);
31 | supportedClasses2Readers.put(Integer.TYPE, JsonElement::getAsInt);
32 | supportedClasses2Readers.put(Integer.class, JsonElement::getAsInt);
33 | supportedClasses2Readers.put(Double.TYPE, JsonElement::getAsDouble);
34 | supportedClasses2Readers.put(Double.class, JsonElement::getAsDouble);
35 | supportedClasses2Readers.put(String.class, JsonElement::getAsString);
36 | supportedClasses2Readers.put(Path.class, e -> Paths.get(e.getAsString()));
37 | SUPPORTED_CLASSES_2_READERS = Collections.unmodifiableMap(supportedClasses2Readers);
38 | }
39 |
40 | public static T load(Path configPath, Class clz) {
41 | try {
42 | return load(new FileReader(configPath.toFile()), clz);
43 | } catch (FileNotFoundException e) {
44 | throw new RuntimeException("Could not load config " + configPath, e);
45 | }
46 | }
47 |
48 | /**
49 | * Loads the config of type {@code T} from the reader, which should provide a json dict.
50 | * @param reader the reader that provides the config
51 | * @param clz T.class
52 | * @param the type of the config
53 | * @return the loaded config
54 | */
55 | public static T load(Reader reader, Class clz) {
56 | try {
57 | T config = clz.newInstance();
58 |
59 | Map, Field>> options = new HashMap<>();
60 | for (Field f : clz.getFields()) {
61 | if (f.getAnnotation(Option.class) != null) {
62 | Class> fType = f.getType();
63 | if (!SUPPORTED_CLASSES_2_READERS.containsKey(fType)) {
64 | throw new RuntimeException("Unsupported option type " + fType + ", for field " + f.getName() + " in class " + clz);
65 | }
66 | options.put(f.getName(), Pair.of(fType, f));
67 | }
68 | }
69 |
70 | JsonElement configJson = JsonParser.parseReader(reader);
71 | if (configJson.isJsonObject()) {
72 | for (Map.Entry entry : configJson.getAsJsonObject().entrySet()) {
73 | Pair, Field> p = options.get(entry.getKey());
74 | if (p != null) {
75 | p.getRight().set(config, SUPPORTED_CLASSES_2_READERS.get(p.getLeft()).apply(entry.getValue()));
76 | } else {
77 | throw new RuntimeException("Unknown config key " + entry.getKey());
78 | }
79 | }
80 | } else {
81 | throw new RuntimeException("Input should be a json dict");
82 | }
83 |
84 | config.autoInfer();
85 | if (!config.repOk()) {
86 | throw new RuntimeException("Config invalid");
87 | }
88 |
89 | return config;
90 | } catch (RuntimeException e) {
91 | throw new RuntimeException("Malformed config", e);
92 | } catch (IllegalAccessException | InstantiationException e) {
93 | throw new RuntimeException("The config class " + clz + " is not correctly set up", e);
94 | }
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/util/BashUtils.java:
--------------------------------------------------------------------------------
1 | package org.tseval.util;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.IOException;
5 | import java.io.InputStreamReader;
6 | import java.nio.file.Path;
7 | import java.nio.file.Paths;
8 | import java.util.stream.Collectors;
9 |
10 | public class BashUtils {
11 |
12 | public static class RunResult {
13 | public final int exitCode;
14 | public final String stdout;
15 | public final String stderr;
16 |
17 | public RunResult(int exitCode, String stdout, String stderr) {
18 | this.exitCode = exitCode;
19 | this.stdout = stdout;
20 | this.stderr = stderr;
21 | }
22 | }
23 |
24 | public static RunResult run(String cmd) {
25 | return run(cmd, null);
26 | }
27 |
28 | public static RunResult run(String cmd, Integer expectedReturnCode) {
29 | try {
30 | Runtime rt = Runtime.getRuntime();
31 | String[] commands = {"/bin/bash", "-c", cmd};
32 | Process proc = rt.exec(commands);
33 | proc.waitFor();
34 |
35 | String stdout = new BufferedReader(new InputStreamReader(proc.getInputStream())).lines().collect(Collectors.joining("\n"));
36 | String stderr = new BufferedReader(new InputStreamReader(proc.getErrorStream())).lines().collect(Collectors.joining("\n"));
37 | int exitCode = proc.exitValue();
38 |
39 | // Check expected return code
40 | if (expectedReturnCode != null && exitCode != expectedReturnCode) {
41 | // TODO: implement print limit for stdout & stderr, dump to file if they ex
42 | throw new RuntimeException("Expected " + expectedReturnCode + " but returned " + exitCode + " while executing bash command '" + cmd + "'.\n" +
43 | "stdout: " + stdout + "\n" +
44 | "stderr: " + stderr);
45 | }
46 | return new RunResult(exitCode, stdout, stderr);
47 | } catch (IOException | InterruptedException e) {
48 | throw new RuntimeException(e);
49 | }
50 | }
51 |
52 | public static Path getTempDir() {
53 | return Paths.get(run("mktemp -d", 0).stdout.trim());
54 | }
55 |
56 | public static Path getTempFile() {
57 | return Paths.get(run("mktemp", 0).stdout.trim());
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/util/NLPUtils.java:
--------------------------------------------------------------------------------
1 | package org.tseval.util;
2 |
3 | import edu.stanford.nlp.simple.Document;
4 |
5 | import java.nio.charset.Charset;
6 | import java.util.Optional;
7 |
8 | public class NLPUtils {
9 |
10 | public static Optional getFirstSentence(String str) {
11 | if (str.trim().isEmpty()) {
12 | return Optional.empty();
13 | }
14 |
15 | try {
16 | Document document = new Document(str);
17 | return Optional.of(document.sentence(0).toString());
18 | } catch(Exception e) {
19 | System.err.println("Cannot get first sentence of: " + str);
20 | return Optional.empty();
21 | }
22 | }
23 |
24 | public static boolean isValidISOLatin(String s) {
25 | return Charset.forName("US-ASCII").newEncoder().canEncode(s);
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/collector/src/main/java/org/tseval/util/Option.java:
--------------------------------------------------------------------------------
1 | package org.tseval.util;
2 |
3 | import java.lang.annotation.ElementType;
4 | import java.lang.annotation.Retention;
5 | import java.lang.annotation.RetentionPolicy;
6 | import java.lang.annotation.Target;
7 |
8 | @Retention(RetentionPolicy.RUNTIME)
9 | @Target(ElementType.FIELD)
10 | public @interface Option {
11 | }
12 |
--------------------------------------------------------------------------------
/python/.gitignore:
--------------------------------------------------------------------------------
1 | # Log & debug files
2 | /ml-logs/
3 | /debug/
4 |
5 |
6 | # Additional Pycharm
7 | *.iml
8 | .idea/
9 |
10 |
11 | ### Pycharm
12 |
13 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
14 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
15 |
16 | # User-specific stuff
17 | .idea/**/workspace.xml
18 | .idea/**/tasks.xml
19 | .idea/**/usage.statistics.xml
20 | .idea/**/dictionaries
21 | .idea/**/shelf
22 |
23 | # Generated files
24 | .idea/**/contentModel.xml
25 |
26 | # Sensitive or high-churn files
27 | .idea/**/dataSources/
28 | .idea/**/dataSources.ids
29 | .idea/**/dataSources.local.xml
30 | .idea/**/sqlDataSources.xml
31 | .idea/**/dynamic.xml
32 | .idea/**/uiDesigner.xml
33 | .idea/**/dbnavigator.xml
34 |
35 | # Gradle
36 | .idea/**/gradle.xml
37 | .idea/**/libraries
38 |
39 | # Gradle and Maven with auto-import
40 | # When using Gradle or Maven with auto-import, you should exclude module files,
41 | # since they will be recreated, and may cause churn. Uncomment if using
42 | # auto-import.
43 | # .idea/modules.xml
44 | # .idea/*.iml
45 | # .idea/modules
46 | # *.iml
47 | # *.ipr
48 |
49 | # CMake
50 | cmake-build-*/
51 |
52 | # Mongo Explorer plugin
53 | .idea/**/mongoSettings.xml
54 |
55 | # File-based project format
56 | *.iws
57 |
58 | # IntelliJ
59 | out/
60 |
61 | # mpeltonen/sbt-idea plugin
62 | .idea_modules/
63 |
64 | # JIRA plugin
65 | atlassian-ide-plugin.xml
66 |
67 | # Cursive Clojure plugin
68 | .idea/replstate.xml
69 |
70 | # Crashlytics plugin (for Android Studio and IntelliJ)
71 | com_crashlytics_export_strings.xml
72 | crashlytics.properties
73 | crashlytics-build.properties
74 | fabric.properties
75 |
76 | # Editor-based Rest Client
77 | .idea/httpRequests
78 |
79 | # Android studio 3.1+ serialized cache file
80 | .idea/caches/build_file_checksums.ser
81 |
82 |
83 | ### Python
84 |
85 | # Byte-compiled / optimized / DLL files
86 | __pycache__/
87 | *.py[cod]
88 | *$py.class
89 |
90 | # C extensions
91 | *.so
92 |
93 | # Distribution / packaging
94 | .Python
95 | build/
96 | develop-eggs/
97 | dist/
98 | downloads/
99 | eggs/
100 | .eggs/
101 | lib/
102 | lib64/
103 | parts/
104 | sdist/
105 | var/
106 | wheels/
107 | pip-wheel-metadata/
108 | share/python-wheels/
109 | *.egg-info/
110 | .installed.cfg
111 | *.egg
112 | MANIFEST
113 |
114 | # PyInstaller
115 | # Usually these files are written by a python script from a template
116 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
117 | *.manifest
118 | *.spec
119 |
120 | # Installer logs
121 | pip-log.txt
122 | pip-delete-this-directory.txt
123 |
124 | # Unit test / coverage reports
125 | htmlcov/
126 | .tox/
127 | .nox/
128 | .coverage
129 | .coverage.*
130 | .cache
131 | nosetests.xml
132 | coverage.xml
133 | *.cover
134 | *.py,cover
135 | .hypothesis/
136 | .pytest_cache/
137 |
138 | # Translations
139 | *.mo
140 | *.pot
141 |
142 | # Django stuff:
143 | *.log
144 | local_settings.py
145 | db.sqlite3
146 | db.sqlite3-journal
147 |
148 | # Flask stuff:
149 | instance/
150 | .webassets-cache
151 |
152 | # Scrapy stuff:
153 | .scrapy
154 |
155 | # Sphinx documentation
156 | docs/_build/
157 |
158 | # PyBuilder
159 | target/
160 |
161 | # Jupyter Notebook
162 | .ipynb_checkpoints
163 |
164 | # IPython
165 | profile_default/
166 | ipython_config.py
167 |
168 | # pyenv
169 | .python-version
170 |
171 | # pipenv
172 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
173 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
174 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
175 | # install all needed dependencies.
176 | #Pipfile.lock
177 |
178 | # celery beat schedule file
179 | celerybeat-schedule
180 |
181 | # SageMath parsed files
182 | *.sage.py
183 |
184 | # Environments
185 | .env
186 | .venv
187 | env/
188 | venv/
189 | ENV/
190 | env.bak/
191 | venv.bak/
192 |
193 | # Spyder project settings
194 | .spyderproject
195 | .spyproject
196 |
197 | # Rope project settings
198 | .ropeproject
199 |
200 | # mkdocs documentation
201 | /site
202 |
203 | # mypy
204 | .mypy_cache/
205 | .dmypy.json
206 | dmypy.json
207 |
208 | # Pyre type checker
209 | .pyre/
210 |
--------------------------------------------------------------------------------
/python/exps/cg.yaml:
--------------------------------------------------------------------------------
1 | task: CG
2 | setups:
3 | - MP
4 | - CP
5 | - T
6 | num_trials: 3
7 | metrics:
8 | - exact_match
9 | - bleu
10 | - meteor
11 | - rouge_l_f
12 | - rouge_l_p
13 | - rouge_l_r
14 | models:
15 | - name: DeepComHybridESE19
16 | exps:
17 | - DeepComHybridESE19-1
18 | - DeepComHybridESE19-2
19 | - DeepComHybridESE19-3
20 | - name: RNNBaseline
21 | exps:
22 | - RNNBaseline-1
23 | - RNNBaseline-2
24 | - RNNBaseline-3
25 | - name: TransformerACL20
26 | exps:
27 | - TransformerACL20-1
28 | - TransformerACL20-2
29 | - TransformerACL20-3
30 | table_args:
31 | metrics:
32 | - bleu
33 | - meteor
34 | - rouge_l_f
35 | - exact_match
36 | plot_args:
37 | metrics:
38 | bleu: BLEU
39 | meteor: METEOR
40 | rouge_l_f: ROUGE-L
41 | exact_match: EM
42 | metrics_percent:
43 | bleu: True
44 | meteor: True
45 | rouge_l_f: True
46 | exact_match: True
47 | models:
48 | DeepComHybridESE19: DeepComHybrid
49 | RNNBaseline: Seq2Seq
50 | TransformerACL20: Transformer
51 |
--------------------------------------------------------------------------------
/python/exps/mn.yaml:
--------------------------------------------------------------------------------
1 | task: MN
2 | setups:
3 | - MP
4 | - CP
5 | - T
6 | num_trials: 3
7 | metrics:
8 | - exact_match
9 | - set_match_f
10 | - set_match_p
11 | - set_match_r
12 | models:
13 | - name: Code2VecPOPL19
14 | exps:
15 | - Code2VecPOPL19-1
16 | - Code2VecPOPL19-2
17 | - Code2VecPOPL19-3
18 | - name: Code2SeqICLR19
19 | exps:
20 | - Code2SeqICLR19-1
21 | - Code2SeqICLR19-2
22 | - Code2SeqICLR19-3
23 | table_args:
24 | metrics:
25 | - set_match_p
26 | - set_match_r
27 | - set_match_f
28 | - exact_match
29 | plot_args:
30 | metrics:
31 | set_match_p: Precision
32 | set_match_r: Recall
33 | set_match_f: F1
34 | exact_match: EM
35 | metrics_percent:
36 | set_match_p: True
37 | set_match_r: True
38 | set_match_f: True
39 | exact_match: True
40 | models:
41 | Code2VecPOPL19: Code2Vec
42 | Code2SeqICLR19: Code2Seq
43 |
--------------------------------------------------------------------------------
/python/prepare_conda_env.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | _DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )
4 |
5 | # ========== Prepare Conda Environments
6 |
7 | DEFAULT_CONDA_PATH="$HOME/opt/anaconda3/etc/profile.d/conda.sh"
8 | PYTORCH_VERSION=1.7.1
9 | TORCHVISION_VERSION=0.8.2
10 |
11 | function get_conda_path() {
12 | local conda_exe=$(which conda)
13 | if [[ -z ${conda_exe} ]]; then
14 | echo "Fail to detect conda! Have you installed Anaconda/Miniconda?" 1>&2
15 | exit 1
16 | fi
17 |
18 | echo "$(dirname ${conda_exe})/../etc/profile.d/conda.sh"
19 | }
20 |
21 | function get_cuda_version() {
22 | local nvidia_smi_exe=$(which nvidia-smi)
23 | if [[ -z ${nvidia_smi_exe} ]]; then
24 | echo "cpu"
25 | else
26 | local cuda_version_number="$(nvcc -V | grep "release" | sed -E "s/.*release ([^,]+),.*/\1/")"
27 | case $cuda_version_number in
28 | 10.0*)
29 | echo "cu100";;
30 | 10.1*)
31 | echo "cu101";;
32 | 10.2*)
33 | echo "cu102";;
34 | 11.0*)
35 | echo "cu110";;
36 | *)
37 | echo "Unsupported cuda version $cuda_version_number!" 1>&2
38 | exit 1
39 | esac
40 | fi
41 | }
42 |
43 |
44 | function prepare_conda_env() {
45 | ### Preparing the base environment "tseval"
46 | local env_name=${1:-tseval}; shift
47 | local conda_path=${1:-$(get_conda_path)}; shift
48 | local cuda_version=${1:-$(get_cuda_version)}; shift
49 |
50 | echo ">>> Preparing conda environment \"${env_name}\", for cuda version: ${cuda_version}; conda at ${conda_path}"
51 |
52 | # Preparation
53 | set -e
54 | set -x
55 | source ${conda_path}
56 | conda env remove --name $env_name
57 | conda create --name $env_name python=3.8 pip -y
58 | conda activate $env_name
59 |
60 | # PyTorch
61 | local cuda_toolkit="";
62 | case $cuda_version in
63 | cpu)
64 | cuda_toolkit=cpuonly;;
65 | cu100)
66 | cuda_toolkit="cudatoolkit=10.0";;
67 | cu101)
68 | cuda_toolkit="cudatoolkit=10.1";;
69 | cu102)
70 | cuda_toolkit="cudatoolkit=10.2";;
71 | cu110)
72 | cuda_toolkit="cudatoolkit=11.0";;
73 | *)
74 | echo "Unexpected cuda version $cuda_version!" 1>&2
75 | exit 1
76 | esac
77 |
78 | conda install -y pytorch=${PYTORCH_VERSION} torchvision=${TORCHVISION_VERSION} ${cuda_toolkit} -c pytorch
79 |
80 | # Other libraries
81 | pip install -r requirements.txt
82 | }
83 |
84 |
85 | prepare_conda_env "$@"
86 |
--------------------------------------------------------------------------------
/python/requirements.txt:
--------------------------------------------------------------------------------
1 | javalang~=0.13.0
2 | keras~=2.3.1
3 | nltk~=3.5
4 | recordclass==0.13.2
5 | rouge==1.0.0
6 | scikit-learn~=0.22.1
7 | seaborn==0.11.1
8 | seutil==0.5.7
9 | # tensorflow~=2.1.0
10 | torch==1.7.1
11 | torchtext==0.8.1
12 | tqdm~=4.54.1
13 |
--------------------------------------------------------------------------------
/python/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # This script documents the exact procedures we use to get the
4 | # dataset, get the models running, and collect the results.
5 |
6 | # Each function is a group of commands and a later function usually
7 | # requires the execution of all the proceeding functions.
8 |
9 | # The commands within each function should always be executed one
10 | # after one sequentially unless only partial functionality is wanted.
11 |
12 |
13 | _DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )
14 |
15 |
16 | # ========== Metrics, Tables, and Plots
17 |
18 | function metric_dataset() {
19 | python -m tseval.main collect_metrics --which=raw-data-filtered
20 | python -m tseval.main collect_metrics --which=split-dataset --split=Full
21 |
22 | for task in CG MN; do
23 | for setup in CP MP T; do
24 | python -m tseval.main collect_metrics --which=setup-dataset --task=$task --setup=$setup
25 | done
26 | done
27 |
28 | for task in CG MN; do
29 | for setup in CP MP T; do
30 | python -m tseval.main collect_metrics --which=setup-dataset-leak --task=$task --setup=$setup
31 | done
32 | done
33 | }
34 |
35 |
36 | function tables_dataset() {
37 | python -m tseval.main make_tables --which=dataset-metrics
38 | }
39 |
40 |
41 | function plots_dataset() {
42 | for task in CG MN; do
43 | python -m tseval.main make_plots --which=dataset-metrics-dist --task=$task
44 | done
45 | }
46 |
47 |
48 | function analyze_exps() {
49 | for exps in cg mn; do
50 | python -m tseval.main analyze_extract_metrics --exps=./exps/${exps}.yaml
51 | python -m tseval.main analyze_sign_test --exps=./exps/${exps}.yaml
52 | done
53 | }
54 |
55 |
56 | function tables_exps() {
57 | for exps in cg mn; do
58 | python -m tseval.main analyze_make_tables --exps=./exps/${exps}.yaml
59 | done
60 | }
61 |
62 |
63 | function plots_exps() {
64 | for exps in cg mn; do
65 | python -m tseval.main analyze_make_plots --exps=./exps/${exps}.yaml
66 | done
67 | }
68 |
69 |
70 | function extract_data_sim() {
71 | for exps in cg mn; do
72 | python -m tseval.main analyze_extract_data_similarities --exps=./exps/${exps}.yaml
73 | done
74 | }
75 |
76 |
77 | function analyze_near_duplicates() {
78 | for exps in cg mn; do
79 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \
80 | --config=same_code \
81 | --code_sim=1 --nl_sim=1.1
82 |
83 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \
84 | --config=same_nl \
85 | --code_sim=1.1 --nl_sim=1
86 |
87 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \
88 | --config=sim_90 \
89 | --code_sim=0.9 --nl_sim=0.9
90 | done
91 | }
92 |
93 |
94 | function analyze_near_duplicates_only_tables_plots() {
95 | for exps in cg mn; do
96 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \
97 | --config=same_code \
98 | --code_sim=1 --nl_sim=1.1 --only_tables_plots
99 |
100 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \
101 | --config=same_nl \
102 | --code_sim=1.1 --nl_sim=1 --only_tables_plots
103 |
104 | python -m tseval.main analyze_near_duplicates --exps=./exps/${exps}.yaml \
105 | --config=sim_90 \
106 | --code_sim=0.9 --nl_sim=0.9 --only_tables_plots
107 | done
108 | }
109 |
110 |
111 | # ========== Data collection
112 |
113 | function collect_repos() {
114 | python -m tseval.main collect_repos
115 | python -m tseval.main filter_repos
116 | }
117 |
118 |
119 | function collect_raw_data() {
120 | python -m tseval.main collect_raw_data
121 | python -m tseval.main process_raw_data
122 | }
123 |
124 |
125 | # ========== Eval preparation
126 |
127 | function prepare_envs() {
128 | # Require tseval conda env first, prepared using prepare_conda_env
129 | python -m tseval.main prepare_envs
130 | }
131 |
132 |
133 | function prepare_splits() {
134 | python -m tseval.main get_splits --seed=7 --split=Debug --debug
135 | python -m tseval.main get_splits --seed=7 --split=Full
136 | }
137 |
138 |
139 | function cg_debug_workflow() {
140 | python -m tseval.main exp_prepare --task=CG --setup=StandardSetup --setup_name=Debug --split_name=Debug --split_type=MP
141 | python -m tseval.main exp_train --task=CG --setup_name=Debug --model_name=TransformerACL20 --exp_name=TransformerACL20
142 | for action in test_common val test_standard; do
143 | python -m tseval.main exp_eval --task=CG --setup_name=Debug --exp_name=TransformerACL20 --action=$action
144 | done
145 | for action in test_common val test_standard; do
146 | python -m tseval.main exp_compute_metrics --task=CG --setup_name=Debug --exp_name=TransformerACL20 --action=$action
147 | done
148 | }
149 |
150 |
151 | function cg_prepare_setups() {
152 | for split_type in MP CP T; do
153 | python -m tseval.main exp_prepare --task=CG --setup=StandardSetup --setup_name=$split_type --split_name=Full --split_type=$split_type
154 | done
155 | }
156 |
157 |
158 | function cg_debug_model() {
159 | local model=$1; shift
160 | local args="$@"; shift
161 | echo "Arguments to model $model: $args"
162 |
163 | set -e
164 | set -x
165 | python -m tseval.main exp_train --task=CG --setup_name=Debug --exp_name=$model\
166 | --model_name=$model $args
167 | for action in test_common val test_standard; do
168 | python -m tseval.main exp_eval --task=CG --setup_name=Debug --exp_name=$model --action=$action
169 | python -m tseval.main exp_compute_metrics --task=CG --setup_name=Debug --exp_name=$model --action=$action
170 | done
171 | }
172 |
173 |
174 | function cg_run_model() {
175 | local model=$1; shift
176 | local setup=$1; shift
177 | local args="$@"; shift
178 | echo "Arguments to model $model: $args"
179 |
180 | set -e
181 | set -x
182 | python -m tseval.main exp_train --task=CG --setup_name=$setup --exp_name=$model\
183 | --model_name=$model $args
184 | for action in test_common val test_standard; do
185 | python -m tseval.main exp_eval --task=CG --setup_name=$setup --exp_name=$model --action=$action
186 | python -m tseval.main exp_compute_metrics --task=CG --setup_name=$setup --exp_name=$model --action=$action
187 | done
188 | }
189 |
190 |
191 | function mn_debug_workflow() {
192 | python -m tseval.main exp_prepare --task=MN --setup=StandardSetup --setup_name=Debug --split_name=Debug --split_type=MP
193 | python -m tseval.main exp_train --task=MN --setup_name=Debug --model_name=Code2SeqICLR19 --exp_name=Code2SeqICLR19
194 | for action in test_common val test_standard; do
195 | python -m tseval.main exp_eval --task=MN --setup_name=Debug --exp_name=Code2SeqICLR19 --action=$action
196 | done
197 | for action in test_common val test_standard; do
198 | python -m tseval.main exp_compute_metrics --task=MN --setup_name=Debug --exp_name=Code2SeqICLR19 --action=$action
199 | done
200 | }
201 |
202 |
203 | function mn_prepare_setups() {
204 | for split_type in MP CP T; do
205 | python -m tseval.main exp_prepare --task=MN --setup=StandardSetup --setup_name=$split_type --split_name=Full --split_type=$split_type
206 | done
207 | }
208 |
209 |
210 | function mn_debug_model() {
211 | local model=$1; shift
212 | local args="$@"; shift
213 | echo "Arguments to model $model: $args"
214 |
215 | set -e
216 | set -x
217 | python -m tseval.main exp_train --task=MN --setup_name=Debug --exp_name=$model\
218 | --model_name=$model $args
219 | for action in test_common val test_standard; do
220 | python -m tseval.main exp_eval --task=MN --setup_name=Debug --exp_name=$model --action=$action
221 | python -m tseval.main exp_compute_metrics --task=MN --setup_name=Debug --exp_name=$model --action=$action
222 | done
223 | }
224 |
225 |
226 | function mn_run_model() {
227 | local model=$1; shift
228 | local setup=$1; shift
229 | local args="$@"; shift
230 | echo "Arguments to model $model: $args"
231 |
232 | set -e
233 | set -x
234 | python -m tseval.main exp_train --task=MN --setup_name=$setup --exp_name=$model\
235 | --model_name=$model $args
236 | for action in test_common val test_standard; do
237 | python -m tseval.main exp_eval --task=MN --setup_name=$setup --exp_name=$model --action=$action
238 | python -m tseval.main exp_compute_metrics --task=MN --setup_name=$setup --exp_name=$model --action=$action
239 | done
240 | }
241 |
242 |
243 |
244 |
245 | # ==========
246 | # Main function -- program entry point
247 | # This script can be executed as ./run.sh the_function_to_run
248 |
249 | function main() {
250 | local action=${1:?Need Argument}; shift
251 |
252 | ( cd ${_DIR}
253 | $action "$@"
254 | )
255 | }
256 |
257 | main "$@"
258 |
--------------------------------------------------------------------------------
/python/switch-cuda.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Copyright (c) 2018 Patrick Hohenecker
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 |
23 | # author: Patrick Hohenecker
24 | # version: 2018.1
25 | # date: May 15, 2018
26 |
27 |
28 | set -e
29 |
30 |
31 | # ensure that the script has been sourced rather than just executed
32 | if [[ "${BASH_SOURCE[0]}" = "${0}" ]]; then
33 | echo "Please use 'source' to execute switch-cuda.sh!"
34 | exit 1
35 | fi
36 |
37 | INSTALL_FOLDER="/usr/local" # the location to look for CUDA installations at
38 | TARGET_VERSION=${1} # the target CUDA version to switch to (if provided)
39 |
40 | # if no version to switch to has been provided, then just print all available CUDA installations
41 | if [[ -z ${TARGET_VERSION} ]]; then
42 | echo "The following CUDA installations have been found (in '${INSTALL_FOLDER}'):"
43 | ls -l "${INSTALL_FOLDER}" | egrep -o "cuda-[0-9]+\\.[0-9]+$" | while read -r line; do
44 | echo "* ${line}"
45 | done
46 | set +e
47 | return
48 | # otherwise, check whether there is an installation of the requested CUDA version
49 | elif [[ ! -d "${INSTALL_FOLDER}/cuda-${TARGET_VERSION}" ]]; then
50 | echo "No installation of CUDA ${TARGET_VERSION} has been found!"
51 | set +e
52 | return
53 | fi
54 |
55 | # the path of the installation to use
56 | cuda_path="${INSTALL_FOLDER}/cuda-${TARGET_VERSION}"
57 |
58 | # filter out those CUDA entries from the PATH that are not needed anymore
59 | path_elements=(${PATH//:/ })
60 | new_path="${cuda_path}/bin"
61 | for p in "${path_elements[@]}"; do
62 | if [[ ! ${p} =~ ^${INSTALL_FOLDER}/cuda ]]; then
63 | new_path="${new_path}:${p}"
64 | fi
65 | done
66 |
67 | # filter out those CUDA entries from the LD_LIBRARY_PATH that are not needed anymore
68 | ld_path_elements=(${LD_LIBRARY_PATH//:/ })
69 | new_ld_path="${cuda_path}/lib64:${cuda_path}/extras/CUPTI/lib64"
70 | for p in "${ld_path_elements[@]}"; do
71 | if [[ ! ${p} =~ ^${INSTALL_FOLDER}/cuda ]]; then
72 | new_ld_path="${new_ld_path}:${p}"
73 | fi
74 | done
75 |
76 | # update environment variables
77 | export CUDA_HOME="${cuda_path}"
78 | export CUDA_ROOT="${cuda_path}"
79 | export LD_LIBRARY_PATH="${new_ld_path}"
80 | export PATH="${new_path}"
81 |
82 | echo "Switched to CUDA ${TARGET_VERSION}."
83 |
84 | set +e
85 | return
86 |
--------------------------------------------------------------------------------
/python/tseval/Environment.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from seutil import BashUtils, IOUtils, LoggingUtils, MiscUtils
4 |
5 | from tseval.Macros import Macros
6 |
7 |
8 | class Environment:
9 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO)
10 |
11 | # ----------
12 | # Environment variables
13 | # ----------
14 | is_debug: bool = False
15 | random_seed: int = 1
16 | is_parallel: bool = False
17 |
18 | # ----------
19 | # Conda & CUDA
20 | # ----------
21 | conda_env = "tseval"
22 |
23 | @classmethod
24 | def get_conda_init_path(cls) -> str:
25 | which_conda = BashUtils.run("which conda").stdout.strip()
26 | if len(which_conda) == 0:
27 | raise RuntimeError(f"Cannot detect conda environment!")
28 | return str(Path(which_conda).parent.parent/"etc"/"profile.d"/"conda.sh")
29 |
30 | conda_init_path_cached = None
31 |
32 | @MiscUtils.classproperty
33 | def conda_init_path(cls):
34 | if cls.conda_init_path_cached is None:
35 | cls.conda_init_path_cached = cls.get_conda_init_path()
36 | return cls.conda_init_path_cached
37 |
38 | @classmethod
39 | def get_cuda_version(cls) -> str:
40 | which_nvidia_smi = BashUtils.run("which nvidia-smi").stdout.strip()
41 | if len(which_nvidia_smi) == 0:
42 | return "cpu"
43 | else:
44 | cuda_version_number = BashUtils.run(r'nvcc -V | grep "release" | sed -E "s/.*release ([^,]+),.*/\1/"').stdout.strip()
45 | if cuda_version_number.startswith("10.0"):
46 | return "cu100"
47 | elif cuda_version_number.startswith("10.1"):
48 | return "cu101"
49 | elif cuda_version_number.startswith("10.2"):
50 | return "cu102"
51 | elif cuda_version_number.startswith("11.0"):
52 | return "cu110"
53 | else:
54 | raise RuntimeError(f"Unsupported cuda version {cuda_version_number}!")
55 |
56 | cuda_version_cached = None
57 |
58 | @MiscUtils.classproperty
59 | def cuda_version(cls):
60 | if cls.cuda_version_cached is None:
61 | cls.cuda_version_cached = cls.get_cuda_version()
62 | return cls.cuda_version_cached
63 |
64 | @classmethod
65 | def get_cuda_toolkit_spec(cls):
66 | cuda_version = cls.cuda_version
67 | if cuda_version == "cpu":
68 | return "cpuonly"
69 | elif cuda_version == "cu100":
70 | return "cudatoolkit=10.1"
71 | elif cuda_version == "cu101":
72 | return "cudatoolkit=10.1"
73 | elif cuda_version == "cu102":
74 | return "cudatoolkit=10.2"
75 | elif cuda_version == "cu110":
76 | return "cudatoolkit=11.0"
77 | else:
78 | raise RuntimeError(f"Unexpected cuda version {cuda_version}!")
79 |
80 | # ----------
81 | # Tools
82 | # ----------
83 |
84 | collector_installed = False
85 | collector_jar = str(Macros.collector_dir / "target" / f"collector-{Macros.collector_version}.jar")
86 |
87 | @classmethod
88 | def require_collector(cls):
89 | if cls.is_parallel:
90 | return
91 | if not cls.collector_installed:
92 | cls.logger.info("Require collector, installing ...")
93 | with IOUtils.cd(Macros.collector_dir):
94 | BashUtils.run(f"mvn clean install -DskipTests", expected_return_code=0)
95 | cls.collector_installed = True
96 | else:
97 | cls.logger.debug("Require collector, and already installed")
98 | return
99 |
--------------------------------------------------------------------------------
/python/tseval/Macros.py:
--------------------------------------------------------------------------------
1 | from typing import *
2 |
3 | import os
4 | from pathlib import Path
5 |
6 |
7 | class Macros:
8 | this_dir: Path = Path(os.path.dirname(os.path.realpath(__file__)))
9 | python_dir: Path = this_dir.parent
10 | project_dir: Path = python_dir.parent
11 | paper_dir: Path = project_dir / "papers" / "acl22"
12 |
13 | collector_dir: Path = project_dir / "collector"
14 | collector_version = "0.1-dev"
15 |
16 | results_dir: Path = project_dir / "results"
17 | raw_data_dir: Path = project_dir / "_raw_data"
18 | work_dir: Path = project_dir / "_work"
19 | repos_downloads_dir: Path = project_dir / "_downloads"
20 |
21 | train = "train"
22 | val = "val"
23 | test = "test"
24 | test_common = "test_common"
25 | test_standard = "test_standard"
26 |
27 | mixed_prj = "MP"
28 | cross_prj = "CP"
29 | temporally = "T"
30 | split_types = [mixed_prj, cross_prj, temporally]
31 |
32 | @classmethod
33 | def get_pairwise_split_types_with(cls, split: str) -> List[Tuple[str, str]]:
34 | pairs = []
35 | before = True
36 | for s in cls.split_types:
37 | if s == split:
38 | before = False
39 | else:
40 | if before:
41 | pairs.append((s, split))
42 | else:
43 | pairs.append((split, s))
44 | return pairs
45 |
46 | com_gen = "CG"
47 | met_nam = "MN"
48 |
49 | tasks = ["CG", "MN"]
50 |
--------------------------------------------------------------------------------
/python/tseval/Plot.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from pathlib import Path
3 | from typing import *
4 |
5 | import matplotlib as mpl
6 | import pandas as pd
7 | import seaborn as sns
8 | from matplotlib import pyplot as plt
9 | from seutil import IOUtils, latex, LoggingUtils
10 |
11 | from tseval.Environment import Environment
12 | from tseval.Macros import Macros
13 |
14 |
15 | class Plot:
16 |
17 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO)
18 |
19 | @classmethod
20 | def init_plot_libs(cls):
21 | # Initialize seaborn
22 | sns.set()
23 | sns.set_palette("Dark2")
24 | sns.set_context("paper")
25 | # Set matplotlib fonts
26 | mpl.rcParams["axes.titlesize"] = 24
27 | mpl.rcParams["axes.labelsize"] = 24
28 | mpl.rcParams["font.size"] = 18
29 | mpl.rcParams["xtick.labelsize"] = 24
30 | mpl.rcParams["xtick.major.size"] = 14
31 | mpl.rcParams["xtick.minor.size"] = 14
32 | mpl.rcParams["ytick.labelsize"] = 24
33 | mpl.rcParams["ytick.major.size"] = 14
34 | mpl.rcParams["ytick.minor.size"] = 14
35 | mpl.rcParams["legend.fontsize"] = 18
36 | mpl.rcParams["legend.title_fontsize"] = 18
37 | # print(mpl.rcParams)
38 |
39 | def __init__(self):
40 | self.plots_dir: Path = Macros.paper_dir / "figs"
41 | IOUtils.mk_dir(self.plots_dir)
42 | self.init_plot_libs()
43 |
44 | def make_plots(self, options: dict):
45 | which = options.pop("which")
46 |
47 | if which == "dataset-metrics-dist":
48 | self.dataset_metrics_dist(
49 | task=options["task"],
50 | )
51 | else:
52 | self.logger.warning(f"Unknown plot {which}")
53 |
54 | def dataset_metrics_dist(self, task: str):
55 | plots_sub_dir = self.plots_dir / f"dataset-{task}"
56 | plots_sub_dir_rel = str(plots_sub_dir.relative_to(Macros.paper_dir))
57 | IOUtils.rm_dir(plots_sub_dir)
58 | plots_sub_dir.mkdir(parents=True)
59 |
60 | # Load metrics list into a DataFrame
61 | label_x = "code"
62 | max_x = 200
63 | if task == "CG":
64 | label_y = "comment"
65 | max_y = 60
66 | else:
67 | label_y = "name"
68 | max_y = 8
69 | lod: List[dict] = []
70 | seen_split_combination = set()
71 | for setup in Macros.split_types:
72 | metrics_list = IOUtils.load(Macros.results_dir / "metrics" / f"setup-dataset-metrics-list_{task}_{setup}.pkl", IOUtils.Format.pkl)
73 | for sn in [Macros.train, Macros.val, Macros.test_standard]:
74 | for i, (x, y) in enumerate(zip(
75 | metrics_list[f"{sn}_len-{label_x}"], metrics_list[f"{sn}_len-{label_y}"],
76 | )):
77 | lod.append({
78 | "i": i,
79 | "set_name": f"{sn}-{setup}",
80 | label_x: x,
81 | label_y: y,
82 | })
83 |
84 | for s1, s2 in Macros.get_pairwise_split_types_with(setup):
85 | if (s1, s2) in seen_split_combination:
86 | continue
87 | for i, (x, y) in enumerate(zip(
88 | metrics_list[f"{Macros.test_common}-{s1}-{s2}_len-{label_x}"],
89 | metrics_list[f"{Macros.test_common}-{s1}-{s2}_len-{label_y}"],
90 | )):
91 | lod.append({
92 | "i": i,
93 | "set_name": f"{Macros.test_common}-{s1}-{s2}",
94 | label_x: x,
95 | label_y: y,
96 | })
97 | seen_split_combination.add((s1, s2))
98 |
99 | df = pd.DataFrame(lod)
100 |
101 | # Make plots
102 | for sn, df_sn in df.groupby("set_name", as_index=False):
103 | sn: str
104 | if sn in [f"{x}-{Macros.temporally}" for x in [Macros.train, Macros.val, Macros.test_standard]] + [f"{Macros.test_common}-{Macros.cross_prj}-{Macros.temporally}"]:
105 | display_xlabel = "len(" + label_x + ")"
106 | else:
107 | display_xlabel = None
108 | if sn.startswith(Macros.train):
109 | display_ylabel = "len(" + label_y + ")"
110 | else:
111 | display_ylabel = None
112 |
113 | fig = sns.jointplot(
114 | data=df_sn,
115 | x=label_x, y=label_y,
116 | kind="hist",
117 | xlim=(0, max_x),
118 | ylim=(0, max_y),
119 | height=6,
120 | ratio=3,
121 | space=.01,
122 | joint_kws=dict(
123 | bins=(12, min(12, max_y-1)),
124 | binrange=((0, max_x), (0, max_y)),
125 | pmax=.5,
126 | ),
127 | color="royalblue",
128 | )
129 | fig.set_axis_labels(display_xlabel, display_ylabel)
130 | plt.tight_layout()
131 | fig.savefig(plots_sub_dir / f"{sn}.pdf")
132 |
133 | # Generate a tex file that organizes the plots
134 | f = latex.File(plots_sub_dir / f"plot.tex")
135 | f.append(r"\begin{center}")
136 | f.append(r"\begin{footnotesize}")
137 | f.append(r"\begin{tabular}{|l|c|c|c || l|c|}")
138 | f.append(r"\hline")
139 | f.append(r"& \textbf{ATrain} & \textbf{\AVal} & \textbf{\ATestS} & & \textbf{\ATestC} \\")
140 |
141 | for sn_l, (s1_r, s2_r) in zip(Macros.split_types, itertools.combinations(Macros.split_types, 2)):
142 | f.append(r"\hline")
143 | f.append(latex.Macro(f"TH-ds-{sn_l}").use())
144 | for sn in [Macros.train, Macros.val, Macros.test_standard]:
145 | f.append(r" & \begin{minipage}{.18\textwidth}\includegraphics[width=\textwidth]{"
146 | + f"{plots_sub_dir_rel}/{sn}-{sn_l}"
147 | + r"}\end{minipage}")
148 | f.append(" & " + latex.Macro(f"TH-ds-{s1_r}-{s2_r}").use())
149 | f.append(r" & \begin{minipage}{.18\textwidth}\includegraphics[width=\textwidth]{"
150 | + f"{plots_sub_dir_rel}/{Macros.test_common}-{s1_r}-{s2_r}"
151 | + r"}\end{minipage}")
152 | f.append(r"\\")
153 | f.append(r"\hline")
154 | f.append(r"\end{tabular}")
155 | f.append(r"\end{footnotesize}")
156 | f.append(r"\end{center}")
157 | f.save()
158 |
--------------------------------------------------------------------------------
/python/tseval/Table.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import itertools
3 | from pathlib import Path
4 |
5 | from seutil import IOUtils, latex, LoggingUtils
6 | from seutil.latex.Macro import Macro
7 |
8 | from tseval.Environment import Environment
9 | from tseval.Macros import Macros
10 |
11 |
12 | class Table:
13 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO)
14 |
15 | COLSEP = "COLSEP"
16 | ROWSEP = "ROWSEP"
17 |
18 | def __init__(self):
19 | self.tables_dir: Path = Macros.paper_dir / "tables"
20 | IOUtils.mk_dir(self.tables_dir)
21 |
22 | self.metrics_dir: Path = Macros.results_dir / "metrics"
23 | return
24 |
25 | def make_tables(self, options):
26 | which = options.pop("which")
27 | if which == "dataset-metrics":
28 | self.make_numbers_dataset_metrics()
29 | self.make_table_dataset_metrics_small()
30 | for task in [Macros.com_gen, Macros.met_nam]:
31 | self.make_table_dataset_metrics(task)
32 | else:
33 | self.logger.warning(f"Unknown table name {which}")
34 |
35 | def make_numbers_dataset_metrics(self):
36 | file = latex.File(self.tables_dir / f"numbers-dataset-metrics.tex")
37 |
38 | dataset_filtered_metrics = IOUtils.load(Macros.results_dir / "metrics" / f"raw-data-filtered.json", IOUtils.Format.json)
39 | for k, v in dataset_filtered_metrics.items():
40 | file.append_macro(latex.Macro(f"ds-filter-{k}", f"{v:,d}"))
41 |
42 | dataset_metrics = IOUtils.load(Macros.results_dir / "metrics" / f"split-dataset-metrics_Full.json", IOUtils.Format.json)
43 | for k, v in dataset_metrics.items():
44 | fmt = f",d" if type(v) == int else f",.2f"
45 | file.append_macro(latex.Macro(f"ds-{k}", f"{v:{fmt}}"))
46 |
47 | for task in [Macros.com_gen, Macros.met_nam]:
48 | for split in Macros.split_types:
49 | setup_metrics = IOUtils.load(Macros.results_dir / "metrics" / f"setup-dataset-metrics_{task}_{split}.json", IOUtils.Format.json)
50 | for k, v in setup_metrics.items():
51 | fmt = f",d" if type(v) == int else f",.2f"
52 | if k.startswith("all_"):
53 | # Skip all_ metrics
54 | continue
55 | else:
56 | # Replace train/val/test_standard to train-split/val-split/test_standard-split
57 | for x in [Macros.train, Macros.val, Macros.test_standard]:
58 | if k.startswith(f"{x}_"):
59 | k = f"{x}-{split}_" + k[len(f"{x}_"):]
60 | break
61 | if f"ds-{task}-{k}" in file.macros_indexed:
62 | continue
63 | file.append_macro(latex.Macro(f"ds-{task}-{k}", f"{v:{fmt}}"))
64 |
65 | file.save()
66 |
67 | def make_table_dataset_metrics(self, task: str):
68 | f = latex.File(self.tables_dir / f"table-dataset-metrics-{task}.tex")
69 |
70 | metric_2_th = collections.OrderedDict()
71 | # metric_2_th["num-proj"] = r"\multicolumn{2}{c|}{\UseMacro{TH-ds-num-project}}"
72 | metric_2_th["num-data"] = r"\multicolumn{2}{c|}{\UseMacro{TH-ds-num-data}}"
73 | metric_2_th["sep-data"] = self.ROWSEP
74 | metric_2_th["len-code-AVG"] = r"& \UseMacro{TH-ds-len-code-avg}"
75 | # metric_2_th["len-code-MODE"] = r"& \UseMacro{TH-ds-len-code-mode}"
76 | # metric_2_th["len-code-MEDIAN"] = r"& \UseMacro{TH-ds-len-code-median}"
77 | metric_2_th["len-code-le-100"] = r"& \UseMacro{TH-ds-len-code-le100}"
78 | metric_2_th["len-code-le-150"] = r"& \UseMacro{TH-ds-len-code-le150}"
79 | metric_2_th["len-code-le-200"] = r"\multirow{-4}{*}{\UseMacro{TH-ds-len-code}} & \UseMacro{TH-ds-len-code-le200}"
80 | if task == Macros.com_gen:
81 | metric_2_th["sep-cg"] = self.ROWSEP
82 | metric_2_th["len-comment-AVG"] = r"& \UseMacro{TH-ds-len-comment-avg}"
83 | # metric_2_th["len-comment-MODE"] = r"& \UseMacro{TH-ds-len-comment-mode}"
84 | # metric_2_th["len-comment-MEDIAN"] = r"& \UseMacro{TH-ds-len-comment-median}"
85 | metric_2_th["len-comment-le-20"] = r"& \UseMacro{TH-ds-len-comment-le20}"
86 | metric_2_th["len-comment-le-30"] = r"& \UseMacro{TH-ds-len-comment-le30}"
87 | metric_2_th["len-comment-le-50"] = r"\multirow{-4}{*}{\UseMacro{TH-ds-len-comment}} & \UseMacro{TH-ds-len-comment-le50}"
88 | if task == Macros.met_nam:
89 | metric_2_th["sep-mn"] = self.ROWSEP
90 | metric_2_th["len-name-AVG"] = r"& \UseMacro{TH-ds-len-name-avg}"
91 | # metric_2_th["len-name-MODE"] = r"& \UseMacro{TH-ds-len-name-mode}"
92 | # metric_2_th["len-name-MEDIAN"] = r"& \UseMacro{TH-ds-len-name-median}"
93 | metric_2_th["len-name-le-2"] = r"& \UseMacro{TH-ds-len-name-le2}"
94 | metric_2_th["len-name-le-3"] = r"& \UseMacro{TH-ds-len-name-le3}"
95 | metric_2_th["len-name-le-6"] = r"\multirow{-4}{*}{\UseMacro{TH-ds-len-name}} & \UseMacro{TH-ds-len-name-le6}"
96 |
97 | cols = sum(
98 | [
99 | [f"{s1}-{s2}" for s1 in [Macros.train, Macros.val, Macros.test_standard]] + [self.COLSEP]
100 | for s2 in [Macros.mixed_prj, Macros.cross_prj, Macros.temporally]
101 | ]
102 | , [],
103 | ) + [f"{Macros.test_common}-{x}-{y}"
104 | for x, y in itertools.combinations([Macros.mixed_prj, Macros.cross_prj, Macros.temporally], 2)]
105 |
106 | # Header
107 | f.append(r"\begin{table*}[t]")
108 | f.append(r"\begin{footnotesize}")
109 | f.append(r"\begin{center}")
110 | table_name = f"dataset-metrics-{task}"
111 |
112 | f.append(
113 | r"\begin{tabular}{ l@{\hspace{2pt}}c@{\hspace{2pt}} | "
114 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r @{\hspace{3pt}}c@{\hspace{3pt}} "
115 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r @{\hspace{3pt}}c@{\hspace{3pt}} "
116 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r @{\hspace{3pt}}c@{\hspace{3pt}} "
117 | r"r@{\hspace{4pt}}r@{\hspace{4pt}}r }"
118 | )
119 |
120 | f.append(r"\toprule")
121 |
122 | # Line 1
123 | f.append(
124 | r"\multicolumn{2}{c|}{}"
125 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-MP}} &"
126 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-CP}} &"
127 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-T}} &"
128 | r" & \UseMacro{TH-ds-MP-CP} & \UseMacro{TH-ds-MP-T} & \UseMacro{TH-ds-CP-T}"
129 | r" \\\cline{3-5}\cline{7-9}\cline{11-13}\cline{15-17}"
130 | )
131 |
132 | # Line 2
133 | f.append(
134 | r"\multicolumn{2}{c|}{\multirow{-2}{*}{\THDSStat}}"
135 | r" & \UseMacro{TH-ds-train} & \UseMacro{TH-ds-val} & \UseMacro{TH-ds-test_standard} &"
136 | r" & \UseMacro{TH-ds-train} & \UseMacro{TH-ds-val} & \UseMacro{TH-ds-test_standard} &"
137 | r" & \UseMacro{TH-ds-train} & \UseMacro{TH-ds-val} & \UseMacro{TH-ds-test_standard} &"
138 | r" & \multicolumn{3}{c}{\UseMacro{TH-ds-test_common}} \\"
139 | )
140 |
141 | f.append(r"\midrule")
142 |
143 | for metric, row_th in metric_2_th.items():
144 | if row_th == self.ROWSEP:
145 | f.append(r"\midrule")
146 | continue
147 |
148 | f.append(row_th)
149 |
150 | for col in cols:
151 | if col == self.COLSEP:
152 | f.append(" & ")
153 | continue
154 | f.append(" & " + latex.Macro(f"ds-{task}-{col}_{metric}").use())
155 |
156 | f.append(r"\\")
157 |
158 | # Footer
159 | f.append(r"\bottomrule")
160 | f.append(r"\end{tabular}")
161 | f.append(r"\end{center}")
162 | f.append(r"\end{footnotesize}")
163 | f.append(r"\vspace{" + latex.Macro(f"TV-{table_name}").use() + "}")
164 | f.append(r"\caption{" + latex.Macro(f"TC-{table_name}").use() + r"}")
165 | f.append(r"\end{table*}")
166 |
167 | f.save()
168 |
169 | def make_table_dataset_metrics_small(self):
170 | f = latex.File(self.tables_dir / f"table-dataset-metrics-small.tex")
171 |
172 | # Header
173 | f.append(r"\begin{table}[t]")
174 | f.append(r"\begin{footnotesize}")
175 | f.append(r"\begin{center}")
176 | table_name = f"dataset-metrics-small"
177 |
178 | f.append(r"\begin{tabular}{ @{} l | c @{\hspace{5pt}} r @{\hspace{5pt}} r @{\hspace{5pt}} r @{\hspace{3pt}}c@{\hspace{3pt}} c @{\hspace{5pt}} r@{} }")
179 | f.append(r"\toprule")
180 | f.append(r"\textbf{Task} & & \textbf{\ATrain} & \textbf{\AVal} & \textbf{\ATestS} & & & \textbf{\ATestC} \\")
181 |
182 | for task in [Macros.com_gen, Macros.met_nam]:
183 |
184 | f.append(r"\midrule")
185 |
186 | for i, (m, p) in enumerate(zip(
187 | [Macros.mixed_prj, Macros.cross_prj, Macros.temporally],
188 | [f"{x}-{y}" for x, y in itertools.combinations([Macros.mixed_prj, Macros.cross_prj, Macros.temporally], 2)],
189 | )):
190 | if i == 2:
191 | f.append(r"\multirow{-3}{*}{\rotatebox[origin=c]{90}{" + latex.Macro(f"TaskM_{task}").use() + r"}}")
192 |
193 | f.append(" & " + latex.Macro(f"TH-ds-{m}").use())
194 | for sn in [Macros.train, Macros.val, Macros.test_standard]:
195 | f.append(" & " + latex.Macro(f"ds-{task}-{sn}-{m}_num-data").use())
196 | f.append(r" & \tikz[remember picture, baseline] \node[inner sep=2pt, outer sep=0, yshift=1ex] (" + task + m + r"-base) {\phantom{XX}};")
197 | f.append(" & " + latex.Macro(f"TH-ds-{p}").use())
198 | f.append(" & " + latex.Macro(f"ds-{task}-{Macros.test_common}-{p}_num-data").use())
199 | f.append(r"\\")
200 |
201 | f.append(r"\bottomrule")
202 | f.append(r"\end{tabular}")
203 |
204 | f.append(r"\begin{tikzpicture}[remember picture, overlay, thick]")
205 | for task in [Macros.com_gen, Macros.met_nam]:
206 | for r, (l1, l2) in zip(
207 | [Macros.mixed_prj, Macros.cross_prj, Macros.temporally],
208 | [
209 | (Macros.mixed_prj, Macros.cross_prj),
210 | (Macros.mixed_prj, Macros.temporally),
211 | (Macros.cross_prj, Macros.temporally),
212 | ]
213 | ):
214 | f.append(r"\draw[->] (" + task + l1 + r"-base.west) .. controls ($(" + task + r + r"-base.east) - (1em,0)$) .. (" + task + r + r"-base.east);")
215 | f.append(r"\draw (" + task + l2 + r"-base.west) .. controls ($(" + task + r + r"-base.east) - (1em,0)$) .. (" + task + r + r"-base.east);")
216 |
217 | f.append(r"\end{tikzpicture}")
218 |
219 | f.append(r"\end{center}")
220 | f.append(r"\end{footnotesize}")
221 | f.append(r"\vspace{" + latex.Macro(f"TV-{table_name}").use() + "}")
222 | f.append(r"\caption{" + latex.Macro(f"TC-{table_name}").use() + "}")
223 | f.append(r"\end{table}")
224 |
225 | f.save()
226 |
--------------------------------------------------------------------------------
/python/tseval/Utils.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from inspect import Parameter
3 | from pathlib import Path
4 | from typing import *
5 |
6 | import copy
7 | import numpy as np
8 | import typing_inspect
9 | from scipy import stats
10 |
11 |
12 | class Utils:
13 |
14 | @classmethod
15 | def get_option_as_boolean(cls, options, opt, default=False, pop=False) -> bool:
16 | if opt not in options:
17 | return default
18 | else:
19 | # Due to limitations of CliUtils...
20 | value = options.get(opt, "false")
21 | if pop:
22 | del options[opt]
23 | return str(value).lower() != "false"
24 |
25 | @classmethod
26 | def get_option_as_list(cls, options, opt, default=None, pop=False) -> list:
27 | if opt not in options:
28 | return copy.deepcopy(default)
29 | else:
30 | lst = options[opt]
31 | if pop:
32 | del options[opt]
33 | if not isinstance(lst, list):
34 | lst = [lst]
35 | return lst
36 |
37 | @classmethod
38 | def get_option_and_pop(cls, options, opt, default=None) -> any:
39 | if opt in options:
40 | return options.pop(opt)
41 | else:
42 | return copy.deepcopy(default)
43 |
44 | # Summaries
45 | SUMMARIES_FUNCS: Dict[str, Callable[[Union[list, np.ndarray]], Union[int, float]]] = {
46 | "AVG": lambda l: np.mean(l) if len(l) > 0 else np.NaN,
47 | "SUM": lambda l: sum(l) if len(l) > 0 else np.NaN,
48 | "MAX": lambda l: max(l) if len(l) > 0 else np.NaN,
49 | "MIN": lambda l: min(l) if len(l) > 0 else np.NaN,
50 | "MEDIAN": lambda l: np.median(l) if len(l) > 0 and np.NaN not in l else np.NaN,
51 | "STDEV": lambda l: np.std(l) if len(l) > 0 else np.NaN,
52 | "MODE": lambda l: stats.mode(l).mode[0].item() if len(l) > 0 else np.NaN,
53 | "CNT": lambda l: len(l),
54 | }
55 |
56 | SUMMARIES_PRESERVE_INT: Dict[str, bool] = {
57 | "AVG": False,
58 | "SUM": True,
59 | "MAX": True,
60 | "MIN": True,
61 | "MEDIAN": False,
62 | "STDEV": False,
63 | "MODE": True,
64 | "CNT": True,
65 | }
66 |
67 | @classmethod
68 | def parse_cmd_options_for_type(
69 | cls,
70 | options: dict,
71 | typ: type,
72 | excluding_params: List[str] = None,
73 | ) -> Tuple[dict, dict, list]:
74 | """
75 | Parses the commandline options (got from seutil.CliUtils) based on the parameters
76 | and their types specified in typ.__init__.
77 |
78 | :param options: the commandline options got from seutil.CliUtils.
79 | :param typ: the type to initialize.
80 | :param excluding_params: the list of parameters that are not expected to be
81 | passed from commandline, by default ["self"].
82 | :return: two dictionaries and a list:
83 | a dictionary with options that can be sent to typ.__init__;
84 | a dictionary that contains the remaining options;
85 | a list of any missing options required by the typ.__init__.
86 | """
87 | if excluding_params is None:
88 | excluding_params = ["self"]
89 |
90 | accepted_options = {}
91 | unk_options = options
92 | missing_options = []
93 |
94 | signature = inspect.signature(typ.__init__)
95 | for param in signature.parameters.values():
96 | if param.name in excluding_params:
97 | continue
98 |
99 | if param.kind == Parameter.POSITIONAL_ONLY \
100 | or param.kind == Parameter.VAR_KEYWORD \
101 | or param.kind == Parameter.VAR_POSITIONAL:
102 | raise AssertionError(f"Class {typ.__name__} should not have '/', '**' '*'"
103 | f" parameters in order to be configured from commandline")
104 |
105 | if param.name not in unk_options:
106 | if param.default == Parameter.empty:
107 | missing_options.append(param.name)
108 | continue
109 | else:
110 | # No need to insert anything to model_options
111 | continue
112 |
113 | if param.annotation == bool:
114 | accepted_options[param.name] = Utils.get_option_as_boolean(unk_options, param.name, pop=True)
115 | elif typing_inspect.get_origin(param.annotation) == list:
116 | accepted_options[param.name] = Utils.get_option_as_list(unk_options, param.name, pop=True)
117 | elif typing_inspect.get_origin(param.annotation) == set:
118 | accepted_options[param.name] = set(Utils.get_option_as_list(unk_options, param.name, pop=True))
119 | elif typing_inspect.get_origin(param.annotation) == tuple:
120 | accepted_options[param.name] = tuple(Utils.get_option_as_list(unk_options, param.name, pop=True))
121 | else:
122 | accepted_options[param.name] = unk_options.pop(param.name)
123 |
124 | return accepted_options, unk_options, missing_options
125 |
126 | @classmethod
127 | def expect_dir_or_suggest_dvc_pull(cls, path: Path):
128 | if not path.is_dir():
129 | dvc_file = path.parent / (path.name+".dvc")
130 | if dvc_file.exists():
131 | print(f"{path} does not exist, but {dvc_file} exists. You probably want to dvc pull that file first?")
132 | print(f"# DVC command to run:")
133 | print(f" dvc pull {dvc_file}")
134 | raise AssertionError(f"{path} does not exist but can be pulled from dvc.\n dvc pull {dvc_file}")
135 | else:
136 | raise AssertionError(f"{path} does not exist.")
137 |
138 | @classmethod
139 | def suggest_dvc_add(cls, *paths: Path) -> str:
140 | s = f"# DVC commands:\n"
141 | s += f" dvc add "
142 | for path in paths:
143 | s += str(path)
144 | s += "\n"
145 | return s
146 |
--------------------------------------------------------------------------------
/python/tseval/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | module_root = os.path.dirname(os.path.realpath(__file__)) + "/.."
5 | if module_root not in sys.path:
6 | sys.path.insert(0, module_root)
7 |
8 | # Remove temporary names
9 | del os
10 | del sys
11 | del module_root
12 |
--------------------------------------------------------------------------------
/python/tseval/collector/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/collector/__init__.py
--------------------------------------------------------------------------------
/python/tseval/comgen/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/comgen/__init__.py
--------------------------------------------------------------------------------
/python/tseval/comgen/eval/CGEvalHelper.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 |
4 | from seutil import IOUtils, LoggingUtils
5 |
6 | from tseval.comgen.eval import get_setup_cls
7 | from tseval.eval.EvalSetupBase import EvalSetupBase
8 | from tseval.Macros import Macros
9 | from tseval.Utils import Utils
10 |
11 | logger = LoggingUtils.get_logger(__name__)
12 |
13 |
14 | class CGEvalHelper:
15 |
16 | def __init__(self):
17 | self.work_subdir: Path = Macros.work_dir / "CG"
18 |
19 | def exp_prepare(self, setup: str, setup_name: str, **cmd_options):
20 | # Clean the setup dir
21 | setup_dir: Path = self.work_subdir / "setup" / setup_name
22 | IOUtils.rm_dir(setup_dir)
23 | setup_dir.mkdir(parents=True)
24 |
25 | # Initialize setup
26 | setup_cls = get_setup_cls(setup)
27 | setup_options, unk_options, missing_options = Utils.parse_cmd_options_for_type(
28 | cmd_options,
29 | setup_cls,
30 | ["self", "work_dir", "work_subdir", "setup_name"],
31 | )
32 | if len(missing_options) > 0:
33 | raise KeyError(f"Missing options: {missing_options}")
34 | if len(unk_options) > 0:
35 | logger.warning(f"Unrecognized options: {unk_options}")
36 | setup_obj: EvalSetupBase = setup_cls(
37 | work_dir=Macros.work_dir,
38 | work_subdir=self.work_subdir,
39 | setup_name=setup_name,
40 | **setup_options,
41 | )
42 |
43 | # Save setup configs
44 | setup_options["setup"] = setup
45 | IOUtils.dump(setup_dir / "setup_config.json", setup_options, IOUtils.Format.jsonNoSort)
46 |
47 | # Prepare data
48 | setup_obj.prepare()
49 |
50 | # Print dvc commands
51 | print(Utils.suggest_dvc_add(setup_obj.setup_dir))
52 |
53 | def load_setup(self, setup_dir: Path, setup_name: str) -> EvalSetupBase:
54 | """
55 | Loads the setup from its save directory, with updating setup_name.
56 | """
57 | config = IOUtils.load(setup_dir / "setup_config.json", IOUtils.Format.json)
58 | setup_cls = get_setup_cls(config.pop("setup"))
59 | setup_obj = setup_cls(work_dir=Macros.work_dir, work_subdir=self.work_subdir, setup_name=setup_name, **config)
60 | return setup_obj
61 |
62 | def exp_train(
63 | self,
64 | setup_name: str,
65 | exp_name: str,
66 | model_name: str,
67 | cont_train: bool,
68 | no_save: bool,
69 | **cmd_options,
70 | ):
71 | # Load saved setup
72 | setup_dir = self.work_subdir / "setup" / setup_name
73 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir)
74 | setup = self.load_setup(setup_dir, setup_name)
75 |
76 | if not cont_train:
77 | # Delete existing trained model
78 | IOUtils.rm_dir(setup.get_exp_dir(exp_name))
79 |
80 | # Invoke training
81 | setup.train(exp_name, model_name, cont_train, no_save, **cmd_options)
82 |
83 | # Print dvc commands
84 | print(Utils.suggest_dvc_add(setup.get_exp_dir(exp_name)))
85 |
86 | def exp_eval(
87 | self,
88 | setup_name: str,
89 | exp_name: str,
90 | action: Optional[str],
91 | gpu_id: int = 0,
92 | ):
93 | # Load saved setup
94 | setup_dir = self.work_subdir / "setup" / setup_name
95 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir)
96 | setup = self.load_setup(setup_dir, setup_name)
97 |
98 | # Invoke eval
99 | setup.eval(exp_name, action, gpu_id=gpu_id)
100 |
101 | # Print dvc commands
102 | print(Utils.suggest_dvc_add(setup.get_result_dir(exp_name)))
103 |
104 | def exp_compute_metrics(
105 | self,
106 | setup_name: str,
107 | exp_name: str,
108 | action: Optional[str] = None,
109 | ):
110 | # Load saved setup
111 | setup_dir = self.work_subdir / "setup" / setup_name
112 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir)
113 | setup = self.load_setup(setup_dir, setup_name)
114 |
115 | # Invoke eval
116 | setup.compute_metrics(exp_name, action)
117 |
118 | # Print dvc commands
119 | print(Utils.suggest_dvc_add(setup.get_metric_dir(exp_name)))
120 |
--------------------------------------------------------------------------------
/python/tseval/comgen/eval/CGModelLoader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from seutil import IOUtils, LoggingUtils
4 |
5 | from tseval.comgen.model import get_model_cls
6 | from tseval.comgen.model.CGModelBase import CGModelBase
7 | from tseval.Utils import Utils
8 |
9 | logger = LoggingUtils.get_logger(__name__)
10 |
11 |
12 | class CGModelLoader:
13 |
14 | @classmethod
15 | def init_or_load_model(
16 | cls,
17 | model_name: str,
18 | exp_dir: Path,
19 | cont_train: bool,
20 | no_save: bool,
21 | cmd_options: dict,
22 | ) -> CGModelBase:
23 | model_cls = get_model_cls(model_name)
24 | model_work_dir = exp_dir / "model"
25 |
26 | if cont_train and model_work_dir.is_dir() and not no_save:
27 | # Restore model name
28 | loaded_model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt)
29 | if model_name != loaded_model_name:
30 | raise ValueError(f"Contradicting model name (saved: {model_name}; new {loaded_model_name})")
31 |
32 | # Warning about any additional command line arguments
33 | if len(cmd_options) > 0:
34 | logger.warning(f"These options will not be used in cont_train mode: {cmd_options}")
35 |
36 | # Load existing model
37 | model: CGModelBase = model_cls.load(model_work_dir)
38 | else:
39 |
40 | if not no_save:
41 | exp_dir.mkdir(parents=True, exist_ok=True)
42 |
43 | # Save model name
44 | IOUtils.dump(exp_dir / "model_name.txt", model_name, IOUtils.Format.txt)
45 |
46 | # Prepare directory for model
47 | IOUtils.rm(model_work_dir)
48 | model_work_dir.mkdir(parents=True)
49 |
50 | # Initialize the model, using command line arguments
51 | model_options, unk_options, missing_options = Utils.parse_cmd_options_for_type(
52 | cmd_options,
53 | model_cls,
54 | ["self", "model_work_dir"],
55 | )
56 | if len(missing_options) > 0:
57 | raise KeyError(f"Missing options: {missing_options}")
58 | if len(unk_options) > 0:
59 | logger.warning(f"Unrecognized options: {unk_options}")
60 |
61 | model: CGModelBase = model_cls(model_work_dir=model_work_dir, no_save=no_save, **model_options)
62 |
63 | if not no_save:
64 | # Save model configs
65 | IOUtils.dump(exp_dir / "model_config.json", model_options, IOUtils.Format.jsonNoSort)
66 | return model
67 |
68 | @classmethod
69 | def load_model(cls, exp_dir: Path) -> CGModelBase:
70 | """
71 | Loads a trained model from exp_dir. Gets the model name from train_config.json.
72 | """
73 | Utils.expect_dir_or_suggest_dvc_pull(exp_dir)
74 | model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt)
75 | model_cls = get_model_cls(model_name)
76 | model_dir = exp_dir / "model"
77 | return model_cls.load(model_dir)
78 |
--------------------------------------------------------------------------------
/python/tseval/comgen/eval/StandardSetup.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import time
4 | from pathlib import Path
5 | from typing import Dict, List
6 |
7 | import nltk
8 | import numpy as np
9 | from seutil import IOUtils, LoggingUtils
10 | from tqdm import tqdm
11 |
12 | from tseval.comgen.eval.CGModelLoader import CGModelLoader
13 | from tseval.comgen.model.CGModelBase import CGModelBase
14 | from tseval.data.MethodData import MethodData
15 | from tseval.eval.EvalMetrics import EvalMetrics
16 | from tseval.eval.EvalSetupBase import EvalSetupBase
17 | from tseval.Macros import Macros
18 | from tseval.util.ModelUtils import ModelUtils
19 | from tseval.util.TrainConfig import TrainConfig
20 | from tseval.Utils import Utils
21 |
22 | logger = LoggingUtils.get_logger(__name__)
23 |
24 |
25 | class StandardSetup(EvalSetupBase):
26 |
27 | # Validation set on self's split type
28 | EVAL_VAL = Macros.val
29 | # Test_standard set on self's split type
30 | EVAL_TESTS = Macros.test_standard
31 | # Test_common sets, pairwisely between self's split type and other split types
32 | EVAL_TESTC = Macros.test_common
33 |
34 | EVAL_ACTIONS = [EVAL_VAL, EVAL_TESTS, EVAL_TESTC]
35 | DEFAULT_EVAL_ACTION = EVAL_TESTC
36 |
37 | def __init__(
38 | self,
39 | work_dir: Path,
40 | work_subdir: Path,
41 | setup_name: str,
42 | split_name: str,
43 | split_type: str,
44 | ):
45 | super().__init__(work_dir, work_subdir, setup_name)
46 | self.split_name = split_name
47 | self.split_type = split_type
48 |
49 | def prepare(self) -> None:
50 | # Check and prepare directories
51 | split_dir = self.get_split_dir(self.split_name)
52 | Utils.expect_dir_or_suggest_dvc_pull(self.shared_data_dir)
53 | Utils.expect_dir_or_suggest_dvc_pull(split_dir)
54 | IOUtils.rm_dir(self.data_dir)
55 | self.data_dir.mkdir(parents=True)
56 |
57 | # Copy split indexes
58 | all_indexes = []
59 | for sn in [Macros.train, Macros.val, Macros.test_standard]:
60 | ids = IOUtils.load(split_dir / f"{self.split_type}-{sn}.json", IOUtils.Format.json)
61 | all_indexes += ids
62 | IOUtils.dump(self.data_dir / f"split_{sn}.json", ids, IOUtils.Format.json)
63 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type):
64 | ids = IOUtils.load(split_dir / f"{s1}-{s2}-{Macros.test_common}.json", IOUtils.Format.json)
65 | all_indexes += ids
66 | IOUtils.dump(
67 | self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json",
68 | ids,
69 | IOUtils.Format.json,
70 | )
71 | all_indexes = list(sorted(set(all_indexes)))
72 |
73 | # Load raw data
74 | tbar = tqdm()
75 | dataset: List[MethodData] = MethodData.load_dataset(
76 | self.shared_data_dir,
77 | expected_ids=all_indexes,
78 | only=["code", "comment_summary", "misc"],
79 | tbar=tbar,
80 | )
81 | tbar.close()
82 |
83 | # Subtokenize code and comments
84 | tbar = tqdm()
85 |
86 | tbar.set_description("Subtokenizing")
87 | tbar.reset(len(dataset))
88 | orig_code_list = []
89 | tokenized_comment_list = []
90 | for d in dataset:
91 | d.fill_none()
92 | d.misc["orig_code"] = d.code
93 | d.misc["orig_comment_summary"] = d.comment_summary
94 | orig_code_list.append(d.code)
95 | tokenized_comment_list.append(nltk.word_tokenize(d.comment_summary))
96 | tbar.update(1)
97 |
98 | tokenized_code_list = ModelUtils.tokenize_javaparser_batch(orig_code_list, dup_share=False, tbar=tbar)
99 |
100 | tbar.set_description("Subtokenizing")
101 | tbar.reset(len(dataset))
102 | for d, tokenized_code, tokenized_comment in zip(dataset, tokenized_code_list, tokenized_comment_list):
103 | d.code, d.misc["code_src_ids"] = ModelUtils.subtokenize_batch(tokenized_code)
104 | d.comment_summary, d.misc["comment_summary_src_ids"] = ModelUtils.subtokenize_batch(tokenized_comment)
105 |
106 | # convert comment to lower case
107 | d.comment_summary = [t.lower() for t in d.comment_summary]
108 | tbar.update(1)
109 | tbar.close()
110 |
111 | # Clean eval ids
112 | indexed_dataset = {d.id: d for d in dataset}
113 | for sn in [Macros.val, Macros.test_standard] + [f"{Macros.test_common}-{x}-{y}" for x, y in Macros.get_pairwise_split_types_with(self.split_type)]:
114 | eval_ids = IOUtils.load(self.data_dir / f"split_{sn}.json", IOUtils.Format.json)
115 | IOUtils.dump(self.data_dir / f"split_{sn}.json", self.clean_eval_set(indexed_dataset, eval_ids), IOUtils.Format.json)
116 |
117 | # Save dataset
118 | MethodData.save_dataset(dataset, self.data_dir)
119 |
120 | def clean_eval_set(self, indexed_dataset: Dict[int, MethodData], eval_ids: List[int]) -> List[int]:
121 | """
122 | Keeps the eval set absolutely clean by:
123 | - Remove duplicate (comment, code) pairs;
124 | - Remove pseudo-empty comment;
125 | indexed_dataset should already been subtokenized.
126 | """
127 | seen_data = set()
128 | clean_eval_ids = []
129 | for i in eval_ids:
130 | data = indexed_dataset[i]
131 |
132 | # Remove comment like "."
133 | if len(data.comment_summary) == 1 and data.comment_summary[0] == ".":
134 | continue
135 |
136 | # Remove duplicate (comment, code) pairs
137 | data_key = (tuple(data.code), tuple(data.comment_summary))
138 | if data_key in seen_data:
139 | continue
140 | else:
141 | seen_data.add(data_key)
142 |
143 | clean_eval_ids.append(i)
144 | return clean_eval_ids
145 |
146 | def train(self, exp_name: str, model_name: str, cont_train: bool, no_save: bool, **options) -> None:
147 | # Init or load model
148 | exp_dir = self.get_exp_dir(exp_name)
149 | train_config = TrainConfig.get_train_config_from_cmd_options(options)
150 | model = CGModelLoader.init_or_load_model(model_name, exp_dir, cont_train, no_save, options)
151 | if not no_save:
152 | IOUtils.dump(exp_dir / "train_config.jsonl", [IOUtils.jsonfy(train_config)], IOUtils.Format.jsonList, append=True)
153 |
154 | # Load data
155 | tbar = tqdm(desc="Loading data")
156 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar)
157 | indexed_dataset = {d.id: d for d in dataset}
158 |
159 | tbar.set_description("Loading data | take indexes")
160 | tbar.reset(2)
161 |
162 | train_ids = IOUtils.load(self.data_dir / f"split_{Macros.train}.json", IOUtils.Format.json)
163 | train_dataset = [indexed_dataset[i] for i in train_ids]
164 | tbar.update(1)
165 |
166 | val_ids = IOUtils.load(self.data_dir / f"split_{Macros.val}.json", IOUtils.Format.json)
167 | val_dataset = [indexed_dataset[i] for i in val_ids]
168 | tbar.update(1)
169 |
170 | tbar.close()
171 |
172 | # Train model
173 | start = time.time()
174 | model.train(train_dataset, val_dataset, resources_path=self.data_dir, train_config=train_config)
175 | end = time.time()
176 |
177 | if not no_save:
178 | model.save()
179 | IOUtils.dump(exp_dir / "train_time.json", end - start, IOUtils.Format.json)
180 |
181 | def eval_one(self, exp_name: str, eval_ids: List[int], prefix: str, indexed_dataset: Dict[int, MethodData], model: CGModelBase, gpu_id: int = 0):
182 | # Prepare output directory
183 | result_dir = self.get_result_dir(exp_name)
184 | result_dir.mkdir(parents=True, exist_ok=True)
185 |
186 | # Prepare eval data (remove target)
187 | eval_dataset = [indexed_dataset[i] for i in eval_ids]
188 | golds = []
189 | for d in eval_dataset:
190 | golds.append(d.comment_summary)
191 | d.comment_summary = ["dummy"]
192 | d.misc["orig_comment_summary"] = "dummy"
193 | d.misc["comment_summary_src_ids"] = [0]
194 |
195 | # Perform batched queries
196 | tbar = tqdm(desc=f"Predicting | {prefix}")
197 | eval_start = time.time()
198 | predictions = model.batch_predict(eval_dataset, tbar=tbar, gpu_id=gpu_id)
199 | eval_end = time.time()
200 | tbar.close()
201 |
202 | eval_time = eval_end - eval_start
203 |
204 | # Save predictions & golds
205 | IOUtils.dump(result_dir / f"{prefix}_predictions.jsonl", predictions, IOUtils.Format.jsonList)
206 | IOUtils.dump(result_dir / f"{prefix}_golds.jsonl", golds, IOUtils.Format.jsonList)
207 | IOUtils.dump(result_dir / f"{prefix}_eval_time.json", eval_time, IOUtils.Format.json)
208 |
209 | def eval(self, exp_name: str, action: str = None, gpu_id: int = 0) -> None:
210 | if action is None:
211 | action = self.DEFAULT_EVAL_ACTION
212 | if action not in self.EVAL_ACTIONS:
213 | raise RuntimeError(f"Unknown eval action {action}")
214 |
215 | # Load eval data
216 | tbar = tqdm(desc="Loading data")
217 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar)
218 | indexed_dataset = {d.id: d for d in dataset}
219 | tbar.close()
220 |
221 | # Load model
222 | exp_dir = self.get_exp_dir(exp_name)
223 | model: CGModelBase = CGModelLoader.load_model(exp_dir)
224 | if not model.is_train_finished():
225 | logger.warning(f"Model not finished training, at {exp_dir}")
226 |
227 | # Invoke eval_one with specific data ids
228 | if action in [self.EVAL_VAL, self.EVAL_TESTS]:
229 | self.eval_one(
230 | exp_name,
231 | IOUtils.load(self.data_dir / f"split_{action}.json", IOUtils.Format.json),
232 | action,
233 | indexed_dataset,
234 | model,
235 | gpu_id=gpu_id,
236 | )
237 | elif action == self.EVAL_TESTC:
238 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type):
239 | self.eval_one(
240 | exp_name,
241 | IOUtils.load(self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json", IOUtils.Format.json),
242 | f"{Macros.test_common}-{s1}-{s2}",
243 | copy.deepcopy(indexed_dataset),
244 | model,
245 | gpu_id=gpu_id,
246 | )
247 | else:
248 | raise RuntimeError(f"Unknown action {action}")
249 |
250 | def compute_metrics_one(self, exp_name: str, prefix: str):
251 | # Prepare output directory
252 | metric_dir = self.get_metric_dir(exp_name)
253 | metric_dir.mkdir(parents=True, exist_ok=True)
254 |
255 | # Load golds and predictions
256 | result_dir = self.get_result_dir(exp_name)
257 | Utils.expect_dir_or_suggest_dvc_pull(result_dir)
258 | golds = IOUtils.load(result_dir / f"{prefix}_golds.jsonl", IOUtils.Format.jsonList)
259 | predictions = IOUtils.load(result_dir / f"{prefix}_predictions.jsonl", IOUtils.Format.jsonList)
260 |
261 | metrics_list: Dict[str, List] = collections.defaultdict(list)
262 | metrics_list["exact_match"] = EvalMetrics.batch_exact_match(golds, predictions)
263 | metrics_list["token_acc"] = EvalMetrics.batch_token_acc(golds, predictions)
264 | metrics_list["bleu"] = EvalMetrics.batch_bleu(golds, predictions)
265 | rouge_l_res = EvalMetrics.batch_rouge_l(golds, predictions)
266 | metrics_list["rouge_l_f"] = [x["f"] for x in rouge_l_res]
267 | metrics_list["rouge_l_p"] = [x["p"] for x in rouge_l_res]
268 | metrics_list["rouge_l_r"] = [x["r"] for x in rouge_l_res]
269 | metrics_list["meteor"] = EvalMetrics.batch_meteor(golds, predictions)
270 | set_match_res = EvalMetrics.batch_set_match(golds, predictions)
271 | metrics_list["set_match_f"] = [x["f"] for x in set_match_res]
272 | metrics_list["set_match_p"] = [x["p"] for x in set_match_res]
273 | metrics_list["set_match_r"] = [x["r"] for x in set_match_res]
274 |
275 | # Take average
276 | metrics = {}
277 | for k, l in metrics_list.items():
278 | metrics[k] = np.mean(l).item()
279 |
280 | # Save metrics
281 | IOUtils.dump(metric_dir / f"{prefix}_metrics.json", metrics, IOUtils.Format.jsonNoSort)
282 | IOUtils.dump(metric_dir / f"{prefix}_metrics.txt", [f"{k}: {v}" for k, v in metrics.items()], IOUtils.Format.txtList)
283 | IOUtils.dump(metric_dir / f"{prefix}_metrics_list.pkl", metrics_list, IOUtils.Format.pkl)
284 |
285 | def compute_metrics(self, exp_name: str, action: str = None) -> None:
286 | if action is None:
287 | action = self.DEFAULT_EVAL_ACTION
288 | if action not in self.EVAL_ACTIONS:
289 | raise RuntimeError(f"Unknown eval action {action}")
290 |
291 | if action in [self.EVAL_VAL, self.EVAL_TESTS]:
292 | self.compute_metrics_one(
293 | exp_name,
294 | action,
295 | )
296 | elif action == self.EVAL_TESTC:
297 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type):
298 | self.compute_metrics_one(
299 | exp_name,
300 | f"{Macros.test_common}-{s1}-{s2}",
301 | )
302 | else:
303 | raise RuntimeError(f"Unknown action {action}")
304 |
--------------------------------------------------------------------------------
/python/tseval/comgen/eval/__init__.py:
--------------------------------------------------------------------------------
1 | def get_setup_cls(name: str) -> type:
2 | if name == "StandardSetup":
3 | from tseval.comgen.eval.StandardSetup import StandardSetup
4 | return StandardSetup
5 | else:
6 | raise ValueError(f"No setup with name {name}")
7 |
--------------------------------------------------------------------------------
/python/tseval/comgen/model/CGModelBase.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from pathlib import Path
3 | from typing import List, Optional, Tuple
4 |
5 | from seutil import IOUtils
6 | from tqdm import tqdm
7 |
8 | from tseval.data.MethodData import MethodData
9 | from tseval.util.TrainConfig import TrainConfig
10 |
11 |
12 | class CGModelBase:
13 |
14 | def __init__(self, model_work_dir: Path, no_save: bool = False):
15 | self.model_work_dir = model_work_dir
16 | self.no_save = no_save
17 |
18 | @abc.abstractmethod
19 | def train(
20 | self,
21 | train_dataset: List[MethodData],
22 | val_dataset: List[MethodData],
23 | resources_path: Optional[Path] = None,
24 | train_config: Optional[TrainConfig] = None,
25 | ):
26 | """
27 | Trains the model.
28 |
29 | :param train_dataset: training set.
30 | :param val_dataset: validation set.
31 | :param resources_path: path to resources that could be shared by multiple model's training process,
32 | e.g., pre-trained embeddings.
33 | """
34 | raise NotImplementedError
35 |
36 | @abc.abstractmethod
37 | def is_train_finished(self) -> bool:
38 | raise NotImplementedError
39 |
40 | @abc.abstractmethod
41 | def predict(
42 | self,
43 | data: MethodData,
44 | gpu_id: int = 0,
45 | ) -> List[str]:
46 | """
47 | Predicts the comment summary given the context in data. The model should output
48 | results with a confidence score in [0, 1].
49 | :param data: the data, with its statements partially filled.
50 | :return: a list of predicted comment summary tokens.
51 | """
52 | raise NotImplementedError
53 |
54 | def batch_predict(
55 | self,
56 | dataset: List[MethodData],
57 | tbar: Optional[tqdm] = None,
58 | gpu_id: int = 0,
59 | ) -> List[List[str]]:
60 | """
61 | Performs batched predictions using given dataset as inputs.
62 |
63 | The default implementation invokes #predict multiple times. Subclass can override
64 | this method to speed up the prediction by using batching.
65 |
66 | :param dataset: a list of inputs.
67 | :param tbar: an optional tqdm progress bar to show prediction progress.
68 | :return: a list of the return value of #predict.
69 | """
70 | if tbar is not None:
71 | tbar.reset(len(dataset))
72 |
73 | results = []
74 | for data in dataset:
75 | results.append(self.predict(data, gpu_id=gpu_id))
76 | if tbar is not None:
77 | tbar.update(1)
78 |
79 | return results
80 |
81 | def save(self) -> None:
82 | """
83 | Saves the current model at the work_dir.
84 | Default behavior is to serialize the entire object in model.pkl.
85 | """
86 | if not self.no_save:
87 | IOUtils.dump(self.model_work_dir / "model.pkl", self, IOUtils.Format.pkl)
88 |
89 | @classmethod
90 | def load(cls, work_dir) -> "CGModelBase":
91 | """
92 | Loads a model from the work_dir.
93 | Default behavior is to deserialize the object from model.pkl, with resetting its work_dir.
94 | """
95 | obj = IOUtils.load(work_dir / "model.pkl", IOUtils.Format.pkl)
96 | obj.model_work_dir = work_dir
97 | return obj
98 |
--------------------------------------------------------------------------------
/python/tseval/comgen/model/TransformerACL20.py:
--------------------------------------------------------------------------------
1 | import stat
2 | import tempfile
3 | from pathlib import Path
4 | from subprocess import TimeoutExpired
5 | from typing import List, Optional
6 |
7 | from recordclass import RecordClass
8 | from seutil import BashUtils, IOUtils, LoggingUtils
9 | from tqdm import tqdm
10 |
11 | from tseval.comgen.model.CGModelBase import CGModelBase
12 | from tseval.data.MethodData import MethodData
13 | from tseval.Environment import Environment
14 | from tseval.Macros import Macros
15 | from tseval.util.ModelUtils import ModelUtils
16 | from tseval.util.TrainConfig import TrainConfig
17 | from tseval.Utils import Utils
18 |
19 | logger = LoggingUtils.get_logger(__name__)
20 |
21 |
22 | class TransformerACL20Config(RecordClass):
23 | max_src_len: int = 150
24 | max_tgt_len: int = 50
25 | use_rnn: bool = False
26 | seed: int = None
27 |
28 |
29 | class TransformerACL20(CGModelBase):
30 | ENV_NAME = "tseval-CG-TransformerACL20"
31 | SRC_DIR = Macros.work_dir / "src" / "CG-TransformerACL20"
32 |
33 | @classmethod
34 | def prepare_env(cls):
35 | Utils.expect_dir_or_suggest_dvc_pull(cls.SRC_DIR)
36 | s = "#!/bin/bash\n"
37 | s += "set -e\n"
38 | s += f"source {Environment.conda_init_path}\n"
39 | s += f"conda env remove --name {cls.ENV_NAME}\n"
40 | s += f"conda create --name {cls.ENV_NAME} python=3.7 pip -y\n"
41 | s += f"conda activate {cls.ENV_NAME}\n"
42 | s += "set -x\n"
43 | s += f"cd {cls.SRC_DIR}\n"
44 | # Pytorch 1.4.0 -> max cuda version 10.1
45 | cuda_toolkit_spec = Environment.get_cuda_toolkit_spec()
46 | if cuda_toolkit_spec in ["cudatoolkit=11.0", "cudatoolkit=10.2"]:
47 | cuda_toolkit_spec = "cudatoolkit=10.1"
48 | s += f"conda install -y pytorch==1.4.0 torchvision==0.5.0 {cuda_toolkit_spec} -c pytorch\n"
49 | s += f"pip install -r requirements.txt\n"
50 | t = Path(tempfile.mktemp(prefix="tseval"))
51 | IOUtils.dump(t, s, IOUtils.Format.txt)
52 | t.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
53 | print(f"Preparing env for {__name__}...")
54 | rr = BashUtils.run(t)
55 | if rr.return_code == 0:
56 | print("Success!")
57 | else:
58 | print("Failed!")
59 | print(f"STDOUT:\n{rr.stdout}")
60 | print(f"STDERR:\n{rr.stderr}")
61 | print(f"^^^ Preparing env for {__name__} failed!")
62 |
63 | def __init__(
64 | self,
65 | model_work_dir: Path,
66 | no_save: bool = False,
67 | max_src_len: int = 150,
68 | max_tgt_len: int = 50,
69 | use_rnn: bool = False,
70 | seed: int = ModelUtils.get_random_seed(),
71 | ):
72 | super().__init__(model_work_dir, no_save)
73 | if not self.SRC_DIR.is_dir():
74 | raise RuntimeError(f"Environment missing (expected at {self.SRC_DIR})")
75 |
76 | self.config = TransformerACL20Config(
77 | max_src_len=max_src_len,
78 | max_tgt_len=max_tgt_len,
79 | use_rnn=use_rnn,
80 | seed=seed,
81 | )
82 | self.train_finished = False
83 |
84 | def prepare_data(self, dataset: List[MethodData], data_dir: Path, sn: str):
85 | sn_dir = data_dir / "java" / sn
86 | IOUtils.rm_dir(sn_dir)
87 | sn_dir.mkdir(parents=True)
88 |
89 | # subtokenized code
90 | IOUtils.dump(
91 | sn_dir / "code.original_subtoken",
92 | [" ".join(d.code[:self.config.max_src_len]) for d in dataset],
93 | IOUtils.Format.txtList,
94 | )
95 |
96 | # subtokenized comment
97 | IOUtils.dump(
98 | sn_dir / "javadoc.original",
99 | [" ".join(d.comment_summary[:self.config.max_tgt_len]) for d in dataset],
100 | IOUtils.Format.txtList,
101 | )
102 |
103 | # tokenized code
104 | with open(sn_dir / "code.original", "w") as f:
105 | for d in dataset:
106 | tokens = ModelUtils.regroup_subtokens(d.code, d.misc["code_src_ids"])
107 | f.write(" ".join(tokens) + "\n")
108 |
109 | def train(
110 | self,
111 | train_dataset: List[MethodData],
112 | val_dataset: List[MethodData],
113 | resources_path: Optional[Path] = None,
114 | train_config: TrainConfig = None,
115 | ):
116 | if train_config is None:
117 | train_config = TrainConfig()
118 |
119 | # Prepare data
120 | data_dir = self.model_work_dir / "data"
121 | if not data_dir.is_dir():
122 | data_dir.mkdir(parents=True)
123 | self.prepare_data(train_dataset, data_dir, "train")
124 | self.prepare_data(val_dataset, data_dir, "dev")
125 |
126 | # Prepare script
127 | model_dir = self.model_work_dir / "model"
128 | s = "#!/bin/bash\n"
129 | s += "set -e\n"
130 | s += f"source {Environment.conda_init_path}\n"
131 | s += f"conda activate {self.ENV_NAME}\n"
132 | s += "set -x\n"
133 | s += f"cd {self.SRC_DIR}\n"
134 | s += f"""MKL_THREADING_LAYER=GNU CUDA_VISIBLE_DEVICES={train_config.gpu_id} timeout {train_config.train_session_time} python -W ignore '{self.SRC_DIR}/main/train.py' \\
135 | --data_workers 5 \\
136 | --dataset_name java \\
137 | --data_dir '{data_dir}/' \\
138 | --model_dir '{model_dir}' \\
139 | --model_name model \\
140 | --train_src train/code.original_subtoken \\
141 | --train_tgt train/javadoc.original \\
142 | --dev_src dev/code.original_subtoken \\
143 | --dev_tgt dev/javadoc.original \\
144 | --uncase True \\
145 | --use_src_word True \\
146 | --use_src_char False \\
147 | --use_tgt_word True \\
148 | --use_tgt_char False \\
149 | --max_src_len {self.config.max_src_len} \\
150 | --max_tgt_len {self.config.max_tgt_len} \\
151 | --emsize 512 \\
152 | --fix_embeddings False \\
153 | --src_vocab_size 50000 \\
154 | --tgt_vocab_size 30000 \\
155 | --share_decoder_embeddings True \\
156 | --max_examples -1 \\
157 | --batch_size 32 \\
158 | --test_batch_size 64 \\
159 | --num_epochs 200 \\
160 | --dropout_emb 0.2 \\
161 | --dropout 0.2 \\
162 | --copy_attn True \\
163 | --early_stop 20 \\
164 | --optimizer adam \\
165 | --lr_decay 0.99 \\
166 | --valid_metric bleu \\
167 | --checkpoint True \\
168 | --random_seed {self.config.seed} \\
169 | """
170 |
171 | if not self.config.use_rnn:
172 | s += """--model_type transformer \\
173 | --num_head 8 \\
174 | --d_k 64 \\
175 | --d_v 64 \\
176 | --d_ff 2048 \\
177 | --src_pos_emb False \\
178 | --tgt_pos_emb True \\
179 | --max_relative_pos 32 \\
180 | --use_neg_dist True \\
181 | --nlayers 6 \\
182 | --trans_drop 0.2 \\
183 | --warmup_steps 2000 \\
184 | --learning_rate 0.0001
185 | """
186 | else:
187 | s += """--model_type rnn \\
188 | --conditional_decoding False \\
189 | --nhid 512 \\
190 | --nlayers 2 \\
191 | --use_all_enc_layers False \\
192 | --dropout_rnn 0.2 \\
193 | --reuse_copy_attn True \\
194 | --learning_rate 0.002 \\
195 | --grad_clipping 5.0
196 | """
197 |
198 | script_path = self.model_work_dir / "train.sh"
199 | stdout_path = self.model_work_dir / "train.stdout"
200 | stderr_path = self.model_work_dir / "train.stderr"
201 | IOUtils.dump(script_path, s, IOUtils.Format.txt)
202 | script_path.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
203 |
204 | with IOUtils.cd(self.model_work_dir):
205 | try:
206 | logger.info(f"=====Starting train\nScript: {script_path}\nSTDOUT: {stdout_path}\nSTDERR: {stderr_path}\n=====")
207 | rr = BashUtils.run(f"{script_path} 1>{stdout_path} 2>{stderr_path}", timeout=train_config.train_session_time)
208 | if rr.return_code != 0:
209 | raise RuntimeError(f"Train returned {rr.return_code}; check STDERR at {stderr_path}")
210 | except (TimeoutExpired, KeyboardInterrupt):
211 | logger.warning("Training not finished")
212 | self.train_finished = False
213 | return
214 | except:
215 | logger.warning(f"Error during training!")
216 | raise
217 |
218 | # If we can reach here, training should be finished
219 | self.train_finished = True
220 | # Remove the big checkpoint file
221 | IOUtils.rm(model_dir / "model.mdl.checkpoint")
222 | return
223 |
224 | def is_train_finished(self) -> bool:
225 | return self.train_finished
226 |
227 | def predict(self, data: MethodData, gpu_id: int = 0) -> List[str]:
228 | return self.batch_predict([data])[0]
229 |
230 | def batch_predict(
231 | self,
232 | dataset: List[MethodData],
233 | tbar: Optional[tqdm] = None,
234 | gpu_id: int = 0,
235 | ) -> List[List[str]]:
236 | # Prepare data
237 | data_dir = Path(tempfile.mkdtemp(prefix="tseval"))
238 |
239 | # Use the dummy comment_summary field to carry id information, so that we know what ids are deleted
240 | for i, d in enumerate(dataset):
241 | d.comment_summary = [str(i)]
242 | self.prepare_data(dataset, data_dir, "test")
243 |
244 | # Prepare script
245 | model_dir = self.model_work_dir / "model"
246 | s = "#!/bin/bash\n"
247 | s += "set -e\n"
248 | s += f"source {Environment.conda_init_path}\n"
249 | s += f"conda activate {self.ENV_NAME}\n"
250 | s += "set -x\n"
251 | s += f"cd {self.SRC_DIR}\n"
252 | # Reducing test_batch_size to 4, otherwise it will delete some test data due to some bug
253 | s += f"""MKL_THREADING_LAYER=GNU CUDA_VISIBLE_DEVICES={gpu_id} python -W ignore '{self.SRC_DIR}/main/test.py' \\
254 | --data_workers 5 \\
255 | --dataset_name java \\
256 | --data_dir '{data_dir}/' \\
257 | --model_dir '{model_dir}' \\
258 | --model_name model \\
259 | --dev_src test/code.original_subtoken \\
260 | --dev_tgt test/javadoc.original \\
261 | --uncase True \\
262 | --max_examples -1 \\
263 | --max_src_len {self.config.max_src_len} \\
264 | --max_tgt_len {self.config.max_tgt_len} \\
265 | --test_batch_size 4 \\
266 | --beam_size 4 \\
267 | --n_best 1 \\
268 | --block_ngram_repeat 3 \\
269 | --stepwise_penalty False \\
270 | --coverage_penalty none \\
271 | --length_penalty none \\
272 | --beta 0 \\
273 | --gamma 0 \\
274 | --replace_unk
275 | """
276 |
277 | script_path = Path(tempfile.mktemp(prefix="tseval.test.sh-"))
278 | stdout_path = Path(tempfile.mktemp(prefix="tseval.test.stdout-"))
279 | stderr_path = Path(tempfile.mktemp(prefix="tseval.test.stderr-"))
280 | IOUtils.dump(script_path, s, IOUtils.Format.txt)
281 | script_path.chmod(stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
282 |
283 | # Run predictions
284 | with IOUtils.cd(self.model_work_dir):
285 | logger.info(f"=====Starting eval\nScript: {script_path}\nSTDOUT: {stdout_path}\nSTDERR: {stderr_path}\n=====")
286 | rr = BashUtils.run(f"{script_path} 1>{stdout_path} 2>{stderr_path}")
287 | if rr.return_code != 0:
288 | raise RuntimeError(f"Eval returned {rr.return_code}; check STDERR at {stderr_path}")
289 |
290 | # Load predictions
291 | beam_res = IOUtils.load(model_dir / "model_beam.json", IOUtils.Format.jsonList)
292 | predictions = [[] for _ in range(len(dataset))]
293 | for x in beam_res:
294 | predictions[int(x["references"][0])] = x["predictions"][0].split(" ")
295 |
296 | # Delete temp files
297 | IOUtils.rm_dir(data_dir)
298 | IOUtils.rm(script_path)
299 | IOUtils.rm(stdout_path)
300 | IOUtils.rm(stderr_path)
301 |
302 | return predictions
303 |
304 | def save(self) -> None:
305 | # Save config and training status
306 | IOUtils.dump(self.model_work_dir / "config.json", IOUtils.jsonfy(self.config), IOUtils.Format.jsonPretty)
307 | IOUtils.dump(self.model_work_dir / "train_finished.json", self.train_finished)
308 |
309 | # Model should already be saved/checkpointed to the correct path
310 | return
311 |
312 | @classmethod
313 | def load(cls, model_work_dir) -> "CGModelBase":
314 | obj = TransformerACL20(model_work_dir)
315 | obj.config = IOUtils.dejsonfy(IOUtils.load(model_work_dir / "config.json"), TransformerACL20Config)
316 | obj.train_finished = IOUtils.load(model_work_dir / "train_finished.json")
317 | return obj
318 |
--------------------------------------------------------------------------------
/python/tseval/comgen/model/__init__.py:
--------------------------------------------------------------------------------
1 | def get_model_cls(name: str) -> type:
2 | if name == "TransformerACL20":
3 | from tseval.comgen.model.TransformerACL20 import TransformerACL20
4 | return TransformerACL20
5 | elif name == "DeepComHybridESE19":
6 | from tseval.comgen.model.DeepComHybridESE19 import DeepComHybridESE19
7 | return DeepComHybridESE19
8 | else:
9 | raise ValueError(f"No model with name {name}")
10 |
--------------------------------------------------------------------------------
/python/tseval/data/MethodData.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import *
4 |
5 | from recordclass import RecordClass
6 | from seutil import IOUtils
7 | from tqdm import tqdm
8 |
9 |
10 | class MethodData(RecordClass):
11 | id: int = -1
12 |
13 | # Project name
14 | prj: str = None
15 | # Years that has this data
16 | years: List[int] = None
17 |
18 | # Method name
19 | name: str = None
20 | # Code (subtokenized version after preprocessing)
21 | code: Union[str, List[str]] = None
22 | # Code with masking its name
23 | code_masked: str = None
24 | # Comment (full, including tags)
25 | comment: str = None
26 | # Comment (summary first sentence) (subtokenized version after preprocessing)
27 | comment_summary: Union[str, List[str]] = None
28 | # Class name
29 | cname: str = None
30 | # Qualified class name
31 | qcname: str = None
32 | # File relative path
33 | path: str = None
34 | # Return type
35 | ret: str = None
36 | # Parameter types
37 | params: List[Tuple[str, str]] = None
38 |
39 | misc: dict = None
40 |
41 | def init(self):
42 | self.years = []
43 | self.params = []
44 | self.misc = {}
45 | return
46 |
47 | def fill_none(self):
48 | if self.years is None:
49 | self.years = []
50 | if self.params is None:
51 | self.params = []
52 | if self.misc is None:
53 | self.misc = {}
54 |
55 | @classmethod
56 | def save_dataset(
57 | cls,
58 | dataset: List["MethodData"],
59 | save_dir: Path,
60 | exist_ok: bool = True,
61 | append: bool = False,
62 | only: Optional[Iterable[str]] = None,
63 | ):
64 | """
65 | Saves dataset to save_dir. Different fields are saved in different files in the
66 | directory. Call graphs are shared for data from one project.
67 | :param dataset: the list of data to save.
68 | :param save_dir: the path to save.
69 | :param exist_ok: if False, requires that save_dir doesn't exist; otherwise,
70 | existing files in save_dir will be modified.
71 | :param append: if True, append to current saved data (requires exist_ok=True);
72 | otherwise, wipes out existing data at save_dir.
73 | :param only: only save certain fields; the files corresponding to the other fields
74 | are not touched; id are always saved.
75 | """
76 | save_dir.mkdir(parents=True, exist_ok=exist_ok)
77 |
78 | IOUtils.dump(save_dir / "id.jsonl", [d.id for d in dataset], IOUtils.Format.jsonList, append=append)
79 |
80 | if only is None or "prj" in only:
81 | IOUtils.dump(save_dir / "prj.jsonl", [d.prj for d in dataset], IOUtils.Format.jsonList, append=append)
82 |
83 | if only is None or "years" in only:
84 | IOUtils.dump(save_dir / "years.jsonl", [d.years for d in dataset], IOUtils.Format.jsonList, append=append)
85 |
86 | if only is None or "name" in only:
87 | IOUtils.dump(save_dir / "name.jsonl", [d.name for d in dataset], IOUtils.Format.jsonList, append=append)
88 |
89 | if only is None or "code" in only:
90 | IOUtils.dump(save_dir / "code.jsonl", [d.code for d in dataset], IOUtils.Format.jsonList, append=append)
91 |
92 | if only is None or "code_masked" in only:
93 | IOUtils.dump(save_dir / "code_masked.jsonl", [d.code_masked for d in dataset], IOUtils.Format.jsonList, append=append)
94 |
95 | if only is None or "comment" in only:
96 | IOUtils.dump(save_dir / "comment.jsonl", [d.comment for d in dataset], IOUtils.Format.jsonList, append=append)
97 |
98 | if only is None or "comment_summary" in only:
99 | IOUtils.dump(save_dir / "comment_summary.jsonl", [d.comment_summary for d in dataset], IOUtils.Format.jsonList, append=append)
100 |
101 | if only is None or "cname" in only:
102 | IOUtils.dump(save_dir / "cname.jsonl", [d.cname for d in dataset], IOUtils.Format.jsonList, append=append)
103 |
104 | if only is None or "qcname" in only:
105 | IOUtils.dump(save_dir / "qcname.jsonl", [d.qcname for d in dataset], IOUtils.Format.jsonList, append=append)
106 |
107 | if only is None or "path" in only:
108 | IOUtils.dump(save_dir / "path.jsonl", [d.path for d in dataset], IOUtils.Format.jsonList, append=append)
109 |
110 | if only is None or "ret" in only:
111 | IOUtils.dump(save_dir / "ret.jsonl", [d.ret for d in dataset], IOUtils.Format.jsonList, append=append)
112 |
113 | if only is None or "params" in only:
114 | IOUtils.dump(save_dir / "params.jsonl", [d.params for d in dataset], IOUtils.Format.jsonList, append=append)
115 |
116 | if only is None or "misc" in only:
117 | IOUtils.dump(save_dir / "misc.jsonl", [d.misc for d in dataset], IOUtils.Format.jsonList, append=append)
118 |
119 | @classmethod
120 | def iter_load_dataset(
121 | cls,
122 | save_dir: Path,
123 | only: Optional[Iterable[str]] = None,
124 | ) -> Generator["MethodData", None, None]:
125 | """
126 | Iteratively loads dataset from the save directory.
127 | :param save_dir: the directory to load data from.
128 | :param only: only load certain fields; the other fields are not filled in the
129 | loaded data; id is always loaded.
130 | :return: a generator iteratively loading the dataset.
131 | """
132 | if not save_dir.is_dir():
133 | raise FileNotFoundError(f"Not found saved data at {save_dir}")
134 |
135 | # First, load all ids
136 | ids = IOUtils.load(save_dir / "id.jsonl", IOUtils.Format.jsonList)
137 |
138 | # The types of some line-by-line loaded fields
139 | f2type = {}
140 | f2file = {}
141 | for f in ["prj", "years", "name", "code", "code_masked", "comment", "comment_summary", "cname", "qcname", "path", "ret", "params", "misc"]:
142 | if only is None or f in only:
143 | f2file[f] = open(save_dir / f"{f}.jsonl", "r")
144 |
145 | for i in ids:
146 | d = MethodData(id=i)
147 |
148 | # Load line-by-line fields
149 | for f in f2file.keys():
150 | o = json.loads(f2file[f].readline())
151 | if f in f2type:
152 | o = IOUtils.dejsonfy(o, f2type[f])
153 | setattr(d, f, o)
154 |
155 | yield d
156 |
157 | # Close all files
158 | for file in f2file.values():
159 | file.close()
160 |
161 | @classmethod
162 | def load_dataset(
163 | cls,
164 | save_dir: Path,
165 | only: Optional[List[str]] = None,
166 | expected_ids: List[int] = None,
167 | tbar: Optional[tqdm] = None,
168 | ) -> List["MethodData"]:
169 | """
170 | Loads the dataset from save_dir.
171 |
172 | :param expected_ids: if provided, the list of data ids to load; the returned dataset
173 | will only contain these data.
174 | :param tbar: an optional progress bar.
175 | Other parameters are the same as #iter_load_dataset.
176 | """
177 | dataset = []
178 |
179 | # Load all data by default
180 | if expected_ids is None:
181 | expected_ids = IOUtils.load(save_dir / "id.jsonl", IOUtils.Format.jsonList)
182 |
183 | # Convert to set to speed up checking "has" relation
184 | expected_ids = set(expected_ids)
185 |
186 | if tbar is not None:
187 | tbar.set_description("Loading dataset")
188 | tbar.reset(len(expected_ids))
189 |
190 | for d in cls.iter_load_dataset(save_dir=save_dir, only=only):
191 | if d.id in expected_ids:
192 | dataset.append(d)
193 | if tbar is not None:
194 | tbar.update(1)
195 |
196 | # Early stop loading if all data have been loaded
197 | if len(dataset) == len(expected_ids):
198 | break
199 |
200 | return dataset
201 |
--------------------------------------------------------------------------------
/python/tseval/data/RevisionIds.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from recordclass import RecordClass
4 |
5 |
6 | class RevisionIds(RecordClass):
7 |
8 | revision: str = None
9 | method_ids: List[int] = None
10 |
11 | def init(self):
12 | self.method_ids = []
13 |
--------------------------------------------------------------------------------
/python/tseval/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/data/__init__.py
--------------------------------------------------------------------------------
/python/tseval/eval/EvalHelper.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import math
4 | import random
5 | from typing import Dict, List, Tuple
6 |
7 | from seutil import IOUtils
8 |
9 | from tseval.data.MethodData import MethodData
10 | from tseval.Macros import Macros
11 | from tseval.Utils import Utils
12 |
13 |
14 | class EvalHelper:
15 |
16 | def get_splits(
17 | self,
18 | split_name: str,
19 | seed: int,
20 | prj_val_ratio: float = 0.1,
21 | prj_test_ratio: float = 0.2,
22 | inprj_val_ratio: float = 0.1,
23 | inprj_test_ratio: float = 0.2,
24 | train_year: int = 2019,
25 | val_year: int = 2020,
26 | test_year: int = 2021,
27 | debug: bool = False,
28 | ):
29 | """
30 | Gets {mixed-project, cross-project, temporally} splits for the given seed and configurations.
31 |
32 | :param debug: take maximum of 500/100/100 data in the result train/val/test sets.
33 | """
34 | split_dir = Macros.work_dir / "split" / split_name
35 | IOUtils.rm_dir(split_dir)
36 | split_dir.mkdir(parents=True)
37 |
38 | # Save configs
39 | IOUtils.dump(
40 | split_dir / "config.json",
41 | {
42 | "seed": seed,
43 | "prj_val_ratio": prj_val_ratio,
44 | "prj_test_ratio": prj_test_ratio,
45 | "inprj_val_ratio": inprj_val_ratio,
46 | "inprj_test_ratio": inprj_test_ratio,
47 | "train_year": train_year,
48 | "val_year": val_year,
49 | "test_year": test_year,
50 | "debug": debug,
51 | },
52 | IOUtils.Format.jsonNoSort,
53 | )
54 | stats = {}
55 |
56 | # Load shared data
57 | dataset: List[MethodData] = MethodData.load_dataset(Macros.work_dir / "shared")
58 |
59 | # Initialize random state
60 | random.seed(seed)
61 |
62 | all_prjs = list(sorted(set([d.prj for d in dataset])))
63 | prj2ids = collections.defaultdict(list)
64 | for d in dataset:
65 | prj2ids[d.prj].append(d.id)
66 |
67 | # Get project split
68 | prj_split_names: Dict[str, List[str]] = {sn: l for sn, l in zip(
69 | [Macros.train, Macros.val, Macros.test],
70 | self.split(all_prjs, prj_val_ratio, prj_test_ratio),
71 | )}
72 | prj_split: Dict[str, List[int]] = {sn: sum([prj2ids[n] for n in names], [])
73 | for sn, names in prj_split_names.items()}
74 | for sn in [Macros.train, Macros.val, Macros.test]:
75 | IOUtils.dump(split_dir / f"prj-{sn}.jsonl", prj_split_names[sn], IOUtils.Format.jsonList)
76 | IOUtils.dump(split_dir / f"prj-split-{sn}.jsonl", prj_split[sn], IOUtils.Format.jsonList)
77 | stats[f"num_prj_{sn}"] = len(prj_split_names[sn])
78 | stats[f"num_prj_split_{sn}"] = len(prj_split[sn])
79 |
80 | # Get in-project splits
81 | prj2inprj_splits: Dict[str, Dict[str, List[str]]] = {}
82 |
83 | for prj in all_prjs:
84 | prj2inprj_splits[prj] = {sn: l for sn, l in zip(
85 | [Macros.train, Macros.val, Macros.test],
86 | self.split(prj2ids[prj], inprj_val_ratio, inprj_test_ratio),
87 | )}
88 | inprj_split: Dict[str, List[int]] = {sn: sum([prj2inprj_splits[prj][sn] for prj in all_prjs], [])
89 | for sn in [Macros.train, Macros.val, Macros.test]}
90 | for sn in [Macros.train, Macros.val, Macros.test]:
91 | IOUtils.dump(split_dir / f"inprj-split-{sn}.jsonl", inprj_split[sn], IOUtils.Format.jsonList)
92 | stats[f"num_inprj_split_{sn}"] = len(inprj_split[sn])
93 |
94 | # Get year splits
95 | year_split: Dict[str, List[int]] = {sn: [] for sn in [Macros.train, Macros.val, Macros.test]}
96 | for d in dataset:
97 | min_year = min(d.years)
98 | if min_year <= train_year:
99 | year_split[Macros.train].append(d.id)
100 | elif min_year <= val_year:
101 | year_split[Macros.val].append(d.id)
102 | elif min_year <= test_year:
103 | year_split[Macros.test].append(d.id)
104 | for sn in [Macros.train, Macros.val, Macros.test]:
105 | IOUtils.dump(split_dir / f"year-split-{sn}.jsonl", year_split[sn], IOUtils.Format.jsonList)
106 | stats[f"num_year_split_{sn}"] = len(year_split[sn])
107 |
108 | # Get actual mixed-prj/cross-prj/temporally splits
109 | train_size = min(len(inprj_split[Macros.train]), len(prj_split[Macros.train]), len(year_split[Macros.train]))
110 | split_sn2split_ids: Dict[Tuple[str, str], List[int]] = {}
111 | for sn in [Macros.train, Macros.val, Macros.test]:
112 | split_sn2split_ids[(Macros.mixed_prj, sn)] = list(sorted(inprj_split[sn]))
113 | split_sn2split_ids[(Macros.cross_prj, sn)] = list(sorted(prj_split[sn]))
114 | split_sn2split_ids[(Macros.temporally, sn)] = list(sorted(year_split[sn]))
115 |
116 | for s1_i, s1 in enumerate(Macros.split_types):
117 | # train/eval/test_standard
118 | for sn in [Macros.train, Macros.val, Macros.test]:
119 | split_ids = split_sn2split_ids[(s1, sn)]
120 |
121 | # Downsample train set
122 | if sn == Macros.train:
123 | IOUtils.dump(split_dir / f"{s1}-{sn}_full.json", split_ids, IOUtils.Format.json)
124 | random.shuffle(split_ids)
125 | split_ids = list(sorted(split_ids[:train_size]))
126 |
127 | # Debugging
128 | if debug:
129 | if sn == Macros.train:
130 | split_ids = split_ids[:500]
131 | else:
132 | split_ids = split_ids[:100]
133 |
134 | sn_oname = sn if sn != Macros.test else Macros.test_standard
135 | IOUtils.dump(split_dir / f"{s1}-{sn_oname}.json", split_ids, IOUtils.Format.json)
136 | stats[f"num_{s1}_{sn_oname}"] = len(split_ids)
137 |
138 | # test_common set
139 | for s2_i in range(s1_i+1, len(Macros.split_types)):
140 | s2 = Macros.split_types[s2_i]
141 | split_ids = self.intersect(
142 | split_sn2split_ids[(s1, Macros.test)],
143 | split_sn2split_ids[(s2, Macros.test)],
144 | )
145 |
146 | # Debugging
147 | if debug:
148 | split_ids = split_ids[:100]
149 |
150 | IOUtils.dump(split_dir / f"{s1}-{s2}-{Macros.test_common}.json", split_ids, IOUtils.Format.json)
151 | stats[f"num_{s1}-{s2}_{Macros.test_common}"] = len(split_ids)
152 |
153 | # Save stats
154 | IOUtils.dump(split_dir / "stats.json", stats, IOUtils.Format.jsonNoSort)
155 |
156 | # Suggest dvc command
157 | print(Utils.suggest_dvc_add(split_dir))
158 | return
159 |
160 | @classmethod
161 | def split(
162 | cls,
163 | l: List,
164 | val_ratio: float,
165 | test_ratio: float,
166 | ) -> Tuple[List, List, List]:
167 | assert val_ratio > 0 and test_ratio > 0 and val_ratio + test_ratio < 1
168 | lcopy = copy.copy(l)
169 | random.shuffle(lcopy)
170 | test_val_split = int(math.ceil(len(lcopy) * test_ratio))
171 | val_train_split = int(math.ceil(len(lcopy) * (test_ratio + val_ratio)))
172 | return (
173 | lcopy[val_train_split:],
174 | lcopy[test_val_split:val_train_split],
175 | lcopy[:test_val_split],
176 | )
177 |
178 | @classmethod
179 | def intersect(cls, *lists: List[int])-> List[int]:
180 | return list(sorted(set.intersection(*[set(l) for l in lists])))
181 |
182 | @classmethod
183 | def get_task_specific_eval_helper(cls, task):
184 | if task == Macros.com_gen:
185 | from tseval.comgen.eval.CGEvalHelper import CGEvalHelper
186 | return CGEvalHelper()
187 | elif task == Macros.met_nam:
188 | from tseval.metnam.eval.MNEvalHelper import MNEvalHelper
189 | return MNEvalHelper()
190 | else:
191 | raise KeyError(f"Invalid task {task}")
192 |
193 | def exp_prepare(self, **options):
194 | task = options.pop("task")
195 | eh = self.get_task_specific_eval_helper(task)
196 | setup_cls_name = options.pop("setup")
197 | setup_name = options.pop("setup_name")
198 | eh.exp_prepare(setup_cls_name, setup_name, **options)
199 |
200 | def exp_train(self, **options):
201 | task = options.pop("task")
202 | eh = self.get_task_specific_eval_helper(task)
203 | setup_name = options.pop("setup_name")
204 | exp_name = options.pop("exp_name")
205 | model_name = options.pop("model_name")
206 | cont_train = Utils.get_option_as_boolean(options, "cont_train", pop=True)
207 | no_save = Utils.get_option_as_boolean(options, "no_save", pop=True)
208 | eh.exp_train(setup_name, exp_name, model_name, cont_train, no_save, **options)
209 |
210 | def exp_eval(self, **options):
211 | task = options.pop("task")
212 | eh = self.get_task_specific_eval_helper(task)
213 | setup_name = options.pop("setup_name")
214 | exp_name = options.pop("exp_name")
215 | action = options.pop("action")
216 | gpu_id = Utils.get_option_and_pop(options, "gpu_id", 0)
217 | eh.exp_eval(setup_name, exp_name, action, gpu_id=gpu_id)
218 |
219 | def exp_compute_metrics(self, **options):
220 | task = options.pop("task")
221 | eh = self.get_task_specific_eval_helper(task)
222 | setup_name = options.pop("setup_name")
223 | exp_name = options.pop("exp_name")
224 | action = options.pop("action")
225 | eh.exp_compute_metrics(setup_name, exp_name, action)
226 |
--------------------------------------------------------------------------------
/python/tseval/eval/EvalMetrics.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import math
3 | from typing import *
4 |
5 | import nltk
6 | from rouge import Rouge
7 |
8 |
9 | @functools.lru_cache(maxsize=128_000)
10 | def bleu_cached(gold: Tuple[str], pred: Tuple[str]) -> float:
11 | if len(pred) == 0:
12 | return 0
13 | return nltk.translate.bleu_score.sentence_bleu(
14 | [list(gold)],
15 | list(pred),
16 | smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2,
17 | auto_reweigh=True,
18 | )
19 |
20 |
21 | @functools.lru_cache(maxsize=128_000)
22 | def token_acc_cached(gold: Tuple[str], pred: Tuple[str]) -> float:
23 | matches = len([
24 | i for i in range(min(len(gold), len(pred)))
25 | if gold[i] == pred[i]
26 | ])
27 | return matches / max(len(gold), len(pred))
28 |
29 |
30 | @functools.lru_cache(maxsize=128_000)
31 | def rouge_l_cached(gold: Tuple[str], pred: Tuple[str]) -> Dict[str, float]:
32 | # Replace the "." characters (e.g., in identifier names), otherwise they'll always be considered as sentence boundaries
33 | hyp = " ".join(pred).replace(".", "")
34 | ref = " ".join(gold).replace(".", "")
35 |
36 | if len(hyp) == 0 or len(ref) == 0:
37 | return {'r': 0, 'p': 0, 'f': 0}
38 |
39 | rouge = Rouge()
40 | scores = rouge.get_scores(hyps=hyp, refs=ref, avg=True)
41 | return scores['rouge-l']
42 |
43 |
44 | @functools.lru_cache(maxsize=128_000)
45 | def set_match_cached(gold: Tuple[str], pred: Tuple[str]) -> Dict[str, float]:
46 | if len(gold) == 0 or len(pred) == 0:
47 | return {"r": 0, "p": 0, "f": 0}
48 |
49 | gold_unique_tokens = set(gold)
50 | pred_unique_tokens = set(pred)
51 | match_tokens = gold_unique_tokens & pred_unique_tokens
52 | precision = min(len(match_tokens) / len(pred_unique_tokens), 1)
53 | recall = min(len(match_tokens) / len(gold_unique_tokens), 1)
54 | if precision == 0 or recall == 0:
55 | f1 = 0
56 | else:
57 | f1 = 2 / (1 / precision + 1 / recall)
58 | return {"r": recall, "p": precision, "f": f1}
59 |
60 |
61 | @functools.lru_cache(maxsize=128_000)
62 | def meteor_cached(gold: Tuple[str], pred: Tuple[str]) -> float:
63 | if len(gold) == 0 or len(pred) == 0:
64 | return 0
65 |
66 | return nltk.translate.meteor_score.single_meteor_score(
67 | " ".join(gold),
68 | " ".join(pred)
69 | )
70 |
71 |
72 | @functools.lru_cache(maxsize=128_000)
73 | def near_duplicate_similarity_cached(
74 | gold: Tuple[str],
75 | pred: Tuple[str],
76 | threshold: float = 0.1,
77 | ) -> float:
78 | """
79 | Computes the approximate token-level accuracy between gold and pred.
80 |
81 | Returns:
82 | token-level accuracy - if not exact match and mismatching tokens <= threshold;
83 | 0 - otherwise.
84 | """
85 | mismatch_allowed = int(math.ceil(threshold * min(len(gold), len(pred))))
86 |
87 | # Check length difference
88 | if abs(len(gold) - len(pred)) >= mismatch_allowed:
89 | return 0
90 |
91 | # Count number of mismatches
92 | mismatch_count = 0
93 | max_len = max(len(gold), len(pred))
94 | for i in range(max_len):
95 | if i >= len(gold):
96 | mismatch_count += len(pred) - i
97 | break
98 | if i >= len(pred):
99 | mismatch_count += len(gold) - i
100 | break
101 | if gold[i] != pred[i]:
102 | mismatch_count += 1
103 | if mismatch_count >= mismatch_allowed:
104 | return 0
105 | if mismatch_count >= mismatch_allowed:
106 | return 0
107 | else:
108 | return 1 - mismatch_count / max_len
109 |
110 |
111 | class EvalMetrics:
112 |
113 | @classmethod
114 | def batch_exact_match(cls, golds: List[Any], preds: List[Any]) -> List[float]:
115 | """
116 | return[i] = (golds[i] == preds[i]) ? 1 : 0
117 | """
118 | assert len(golds) == len(preds)
119 |
120 | results = []
121 | for gold, pred in zip(golds, preds):
122 | if gold == pred:
123 | results.append(1)
124 | else:
125 | results.append(0)
126 | return results
127 |
128 | @classmethod
129 | def token_acc(cls, gold: List[str], pred: List[str]) -> float:
130 | return token_acc_cached(tuple(gold), tuple(pred))
131 |
132 | @classmethod
133 | def batch_token_acc(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]:
134 | assert len(golds) == len(preds)
135 | return [
136 | token_acc_cached(tuple(gold), tuple(pred))
137 | for gold, pred in zip(golds, preds)
138 | ]
139 |
140 | @classmethod
141 | def bleu(cls, gold: List[str], pred: List[str]) -> float:
142 | """
143 | return = BLEU([gold], pred)
144 | """
145 | return bleu_cached(tuple(gold), tuple(pred))
146 |
147 | @classmethod
148 | def batch_bleu(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]:
149 | """
150 | return[i] = #bleu(golds[i], preds[i])
151 | """
152 | assert len(golds) == len(preds)
153 | return [
154 | bleu_cached(tuple(gold), tuple(pred))
155 | for gold, pred in zip(golds, preds)
156 | ]
157 |
158 | @classmethod
159 | def rouge_l(cls, gold: List[str], pred: List[str]) -> Dict[str, float]:
160 | """
161 | return = rouge l metric computed for given sequences
162 | """
163 | return rouge_l_cached(tuple(gold), tuple(pred))
164 |
165 | @classmethod
166 | def batch_rouge_l(cls, golds: List[List[str]], preds: List[List[str]]) -> List[Dict[str, float]]:
167 | """
168 | return[i] = #rouge_l(golds[i], preds[i])
169 | """
170 | assert len(golds) == len(preds)
171 | return [
172 | rouge_l_cached(tuple(gold), tuple(pred))
173 | for gold, pred in zip(golds, preds)
174 | ]
175 |
176 | @classmethod
177 | def set_match(cls, gold: List[str], pred: List[str]) -> Dict[str, float]:
178 | return set_match_cached(tuple(gold), tuple(pred))
179 |
180 | @classmethod
181 | def batch_set_match(cls, golds: List[List[str]], preds: List[List[str]]) -> List[Dict[str, float]]:
182 | assert len(golds) == len(preds)
183 | return [
184 | set_match_cached(tuple(gold), tuple(pred))
185 | for gold, pred in zip(golds, preds)
186 | ]
187 |
188 | @classmethod
189 | def batch_set_match_f1(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]:
190 | assert len(golds) == len(preds)
191 | return [
192 | set_match_cached(tuple(gold), tuple(pred))["f"]
193 | for gold, pred in zip(golds, preds)
194 | ]
195 |
196 | @classmethod
197 | def meteor(cls, gold: List[str], pred: List[str]) -> float:
198 | return meteor_cached(tuple(gold), tuple(pred))
199 |
200 | @classmethod
201 | def batch_meteor(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]:
202 | assert len(golds) == len(preds)
203 | return [
204 | meteor_cached(tuple(gold), tuple(pred))
205 | for gold, pred in zip(golds, preds)
206 | ]
207 |
208 | @classmethod
209 | def near_duplicate_similarity(cls, gold: List[str], pred: List[str]) -> float:
210 | return near_duplicate_similarity_cached(tuple(gold), tuple(pred))
211 |
212 | @classmethod
213 | def batch_near_duplicate_similarity(cls, golds: List[List[str]], preds: List[List[str]]) -> List[float]:
214 | assert len(golds) == len(preds)
215 | return [
216 | near_duplicate_similarity_cached(tuple(gold), tuple(pred))
217 | for gold, pred in zip(golds, preds)
218 | ]
219 |
--------------------------------------------------------------------------------
/python/tseval/eval/EvalSetupBase.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from pathlib import Path
3 | from typing import List
4 |
5 |
6 | class EvalSetupBase:
7 |
8 | def __init__(self, work_dir: Path, work_subdir: Path, setup_name: str):
9 | """
10 | All setups require the work_dir parameter, which is configured by EvalHelper.
11 | Any other parameters for the implemented setup should be passed in via constructor.
12 |
13 | data_dir is an unit directory managed by dvc.
14 | """
15 | self.work_dir = work_dir
16 | self.work_subdir = work_subdir
17 | self.setup_name = setup_name
18 | return
19 |
20 | @property
21 | def setup_dir(self):
22 | return self.work_subdir / "setup" / self.setup_name
23 |
24 | @property
25 | def shared_data_dir(self):
26 | return self.work_dir / "shared"
27 |
28 | def get_split_dir(self, split_name: str):
29 | return self.work_dir / "split" / split_name
30 |
31 | @property
32 | def data_dir(self):
33 | return self.setup_dir / "data"
34 |
35 | def get_exp_dir(self, exp_name: str):
36 | return self.work_subdir / "exp" / self.setup_name / exp_name
37 |
38 | def get_result_dir(self, exp_name: str):
39 | return self.work_subdir / "result" / self.setup_name / exp_name
40 |
41 | def get_metric_dir(self, exp_name: str):
42 | return self.work_subdir / "metric" / self.setup_name / exp_name
43 |
44 | @abc.abstractmethod
45 | def prepare(self) -> None:
46 | """
47 | Prepares this eval setup, primarily obtains the training/validation/testing sets,
48 | reads data from shared_data_dir, saves any processed data to data_dir.
49 |
50 | setup_dir (which includes data_dir) is managed by dvc.
51 | """
52 | raise NotImplementedError
53 |
54 | @abc.abstractmethod
55 | def train(self, exp_name: str, model_name: str, cont_train: bool, no_save: bool, **options) -> None:
56 | """
57 | Trains the model, loads data from data_dir, saves the model to
58 | get_exp_dir(exp_name) (intermediate files, e.g., logs, can also be saved there).
59 |
60 | get_exp_dir(exp_name) is managed by dvc.
61 |
62 | :param exp_name: name given to this experiment.
63 | :param model_name: the model's name.
64 | :param cont_train: if True and if there is already a partially trained model in
65 | the save directory, load that model and continue training; otherwise, ignore
66 | any possible partially trained models.
67 | :param no_save: if True, avoids saving anything during training.
68 | :param options: options for initializing the models.
69 | """
70 | raise NotImplementedError
71 |
72 | @abc.abstractmethod
73 | def eval(self, exp_name: str, actions: List[str] = None, gpu_id: int = 0) -> None:
74 | """
75 | Evaluates the model (usually, on both validation and testing set),
76 | loads data from data_dir,
77 | loads trained model from get_exp_dir(exp_name),
78 | saves results to get_result_dir(exp_name).
79 |
80 | get_result_dir(exp_name) is managed by dvc.
81 |
82 | :param exp_name: name of the experiment.
83 | :param actions: a list of eval actions requested.
84 | """
85 | raise NotImplementedError
86 |
87 | @abc.abstractmethod
88 | def compute_metrics(self, exp_name: str, actions: List[str] = None) -> None:
89 | """
90 | Computes metrics on the prediction results,
91 | loads data from data_dir,
92 | loads results from get_result_dir(exp_name),
93 | saves metrics to get_metric_dir(exp_name).
94 |
95 | get_metric_dir(exp_name) is managed by git.
96 |
97 | :param exp_name: name of the experiment.
98 | :param actions: a list of eval actions requested (to compute metrics for).
99 | """
100 | raise NotImplementedError
101 |
--------------------------------------------------------------------------------
/python/tseval/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/eval/__init__.py
--------------------------------------------------------------------------------
/python/tseval/main.py:
--------------------------------------------------------------------------------
1 | import random
2 | import sys
3 | import time
4 | from pathlib import Path
5 |
6 | import pkg_resources
7 | from seutil import CliUtils, IOUtils, LoggingUtils
8 |
9 | from tseval.Environment import Environment
10 | from tseval.Macros import Macros
11 | from tseval.Utils import Utils
12 |
13 | # Check seutil version
14 | EXPECTED_SEUTIL_VERSION = "0.5.6"
15 | if pkg_resources.get_distribution("seutil").version < EXPECTED_SEUTIL_VERSION:
16 | print(f"seutil version does not meet expectation! Expected version: {EXPECTED_SEUTIL_VERSION}, current installed version: {pkg_resources.get_distribution('seutil').version}", file=sys.stderr)
17 | print(f"Hint: either upgrade seutil, or modify the expected version (after confirmation that the version will work)", file=sys.stderr)
18 | sys.exit(-1)
19 |
20 |
21 | logging_file = Macros.python_dir / "experiment.log"
22 | LoggingUtils.setup(filename=str(logging_file))
23 |
24 | logger = LoggingUtils.get_logger(__name__)
25 |
26 |
27 | # ==========
28 | # Data collection, sample
29 |
30 | def collect_repos(**options):
31 | from tseval.collector.DataCollector import DataCollector
32 | DataCollector().search_github_java_repos()
33 |
34 |
35 | def filter_repos(**options):
36 | from tseval.collector.DataCollector import DataCollector
37 | DataCollector().filter_repos(
38 | year_end=options.get("year_end", 2021),
39 | year_cnt=options.get("year_cnt", 3),
40 | loc_min=options.get("loc_min", 1e6),
41 | loc_max=options.get("loc_max", 2e6),
42 | star_min=options.get("star_min", 20),
43 | )
44 |
45 |
46 | def collect_raw_data(**options):
47 | from tseval.collector.DataCollector import DataCollector
48 | DataCollector().collect_raw_data_projects(
49 | year_end=options.get("year_end", 2021),
50 | year_cnt=options.get("year_cnt", 3),
51 | skip_collected=Utils.get_option_as_boolean(options, "skip_collected"),
52 | project_names=Utils.get_option_as_list(options, "projects"),
53 | )
54 |
55 |
56 | def process_raw_data(**options):
57 | from tseval.collector.DataCollector import DataCollector
58 | DataCollector().process_raw_data(
59 | year_end=options.get("year_end", 2021),
60 | year_cnt=options.get("year_cnt", 3),
61 | )
62 |
63 |
64 | def get_splits(**options):
65 | from tseval.eval.EvalHelper import EvalHelper
66 | EvalHelper().get_splits(
67 | split_name=options["split"],
68 | seed=options.get("seed", 7),
69 | prj_val_ratio=options.get("prj_val_ratio", 0.1),
70 | prj_test_ratio=options.get("prj_test_ratio", 0.2),
71 | inprj_val_ratio=options.get("inprj_val_ratio", 0.1),
72 | inprj_test_ratio=options.get("inprj_test_ratio", 0.2),
73 | train_year=options.get("train_year", 2019),
74 | val_year=options.get("val_year", 2020),
75 | test_year=options.get("test_year", 2021),
76 | debug=Utils.get_option_as_boolean(options, "debug"),
77 | )
78 |
79 |
80 | # ==========
81 | # Machine learning
82 |
83 | def prepare_envs(**options):
84 | which = Utils.get_option_as_list(options, "which")
85 | if which is None or "TransformerACL20" in which:
86 | from tseval.comgen.model.TransformerACL20 import TransformerACL20
87 | TransformerACL20.prepare_env()
88 | if which is None or "DeepComHybridESE19" in which:
89 | from tseval.comgen.model.DeepComHybridESE19 import DeepComHybridESE19
90 | DeepComHybridESE19.prepare_env()
91 | if which is None or "Code2SeqICLR19" in which:
92 | from tseval.metnam.model.Code2SeqICLR19 import Code2SeqICLR19
93 | Code2SeqICLR19.prepare_env()
94 | if which is None or "Code2VecPOPL19" in which:
95 | from tseval.metnam.model.Code2VecPOPL19 import Code2VecPOPL19
96 | Code2VecPOPL19.prepare_env()
97 |
98 |
99 | def exp_prepare(**options):
100 | from tseval.eval.EvalHelper import EvalHelper
101 | EvalHelper().exp_prepare(**options)
102 |
103 |
104 | def exp_train(**options):
105 | from tseval.eval.EvalHelper import EvalHelper
106 | EvalHelper().exp_train(**options)
107 |
108 |
109 | def exp_eval(**options):
110 | from tseval.eval.EvalHelper import EvalHelper
111 | EvalHelper().exp_eval(**options)
112 |
113 |
114 | def exp_compute_metrics(**options):
115 | from tseval.eval.EvalHelper import EvalHelper
116 | EvalHelper().exp_compute_metrics(**options)
117 |
118 |
119 | # ==========
120 | # Table & Plot
121 |
122 | def make_tables(**options):
123 | from tseval.Table import Table
124 | Table().make_tables(options)
125 |
126 |
127 | def make_plots(**options):
128 | from tseval.Plot import Plot
129 | Plot().make_plots(options)
130 |
131 |
132 | # ==========
133 | # Metrics collection
134 |
135 | def collect_metrics(**options):
136 | from tseval.collector.MetricsCollector import MetricsCollector
137 |
138 | mc = MetricsCollector()
139 | mc.collect_metrics(**options)
140 | return
141 |
142 |
143 | # ==========
144 | # Collect and analyze results
145 |
146 | def analyze_check_files(**options):
147 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
148 | ExperimentsAnalyzer(
149 | exps_spec_path=Path(options["exps"]),
150 | output_prefix=options.get("output"),
151 | ).check_files()
152 |
153 |
154 | def analyze_recompute_metrics(**options):
155 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
156 | ExperimentsAnalyzer(
157 | exps_spec_path=Path(options["exps"]),
158 | output_prefix=options.get("output"),
159 | ).recompute_metrics()
160 |
161 |
162 | def analyze_extract_metrics(**options):
163 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
164 | ExperimentsAnalyzer(
165 | exps_spec_path=Path(options["exps"]),
166 | output_prefix=options.get("output"),
167 | ).extract_metrics()
168 |
169 |
170 | def analyze_sign_test(**options):
171 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
172 | ExperimentsAnalyzer(
173 | exps_spec_path=Path(options["exps"]),
174 | output_prefix=options.get("output"),
175 | ).sign_test_default()
176 |
177 |
178 | def analyze_make_tables(**options):
179 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
180 | ExperimentsAnalyzer(
181 | exps_spec_path=Path(options["exps"]),
182 | output_prefix=options.get("output"),
183 | ).make_tables_default()
184 |
185 |
186 | def analyze_make_plots(**options):
187 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
188 | ExperimentsAnalyzer(
189 | exps_spec_path=Path(options["exps"]),
190 | output_prefix=options.get("output"),
191 | ).make_plots_default()
192 |
193 |
194 | def analyze_sample_results(**options):
195 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
196 | ExperimentsAnalyzer(
197 | exps_spec_path=Path(options["exps"]),
198 | output_prefix=options.get("output"),
199 | ).sample_results(
200 | seed=options.get("seed", 7),
201 | count=options.get("count", 100),
202 | )
203 |
204 |
205 | def analyze_extract_data_similarities(**options):
206 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
207 | ExperimentsAnalyzer(
208 | exps_spec_path=Path(options["exps"]),
209 | output_prefix=options.get("output"),
210 | ).extract_data_similarities()
211 |
212 |
213 | def analyze_near_duplicates(**options):
214 | from tseval.collector.ExperimentsAnalyzer import ExperimentsAnalyzer
215 | ExperimentsAnalyzer(
216 | exps_spec_path=Path(options["exps"]),
217 | output_prefix=options.get("output"),
218 | ).filter_near_duplicates_and_analyze(
219 | code_sim_threshold=options["code_sim"],
220 | nl_sim_threshold=options["nl_sim"],
221 | config_name=options["config"],
222 | only_tables_plots=Utils.get_option_as_boolean(options, "only_tables_plots", default=False),
223 | )
224 |
225 |
226 | # ==========
227 | # Main
228 |
229 | def normalize_options(opts: dict) -> dict:
230 | # Set a different log file
231 | if "log_path" in opts:
232 | logger.info(f"Switching to log file {opts['log_path']}")
233 | LoggingUtils.setup(filename=opts['log_path'])
234 |
235 | # Set debug mode
236 | if "debug" in opts and str(opts["debug"]).lower() != "false":
237 | Environment.is_debug = True
238 | logger.debug("Debug mode on")
239 | logger.debug(f"Command line options: {opts}")
240 |
241 | # Set parallel mode - all automatic installations are disabled
242 | if "parallel" in opts and str(opts["parallel"]).lower() != "false":
243 | Environment.is_parallel = True
244 | logger.warning(f"Parallel mode on")
245 |
246 | # Set/report random seed
247 | if "random_seed" in opts:
248 | Environment.random_seed = int(opts["random_seed"])
249 | else:
250 | Environment.random_seed = time.time_ns()
251 |
252 | random.seed(Environment.random_seed)
253 | logger.info(f"Random seed is {Environment.random_seed}")
254 | return opts
255 |
256 |
257 | if __name__ == "__main__":
258 | CliUtils.main(sys.argv[1:], globals(), normalize_options)
259 |
--------------------------------------------------------------------------------
/python/tseval/metnam/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/metnam/__init__.py
--------------------------------------------------------------------------------
/python/tseval/metnam/eval/MNEvalHelper.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Optional
3 |
4 | from seutil import IOUtils, LoggingUtils
5 |
6 | from tseval.eval.EvalSetupBase import EvalSetupBase
7 | from tseval.Macros import Macros
8 | from tseval.metnam.eval import get_setup_cls
9 | from tseval.Utils import Utils
10 |
11 | logger = LoggingUtils.get_logger(__name__)
12 |
13 |
14 | class MNEvalHelper:
15 |
16 | def __init__(self):
17 | self.work_subdir: Path = Macros.work_dir / "MN"
18 |
19 | def exp_prepare(self, setup: str, setup_name: str, **cmd_options):
20 | # Clean the setup dir
21 | setup_dir: Path = self.work_subdir / "setup" / setup_name
22 | IOUtils.rm_dir(setup_dir)
23 | setup_dir.mkdir(parents=True)
24 |
25 | # Initialize setup
26 | setup_cls = get_setup_cls(setup)
27 | setup_options, unk_options, missing_options = Utils.parse_cmd_options_for_type(
28 | cmd_options,
29 | setup_cls,
30 | ["self", "work_dir", "work_subdir", "setup_name"],
31 | )
32 | if len(missing_options) > 0:
33 | raise KeyError(f"Missing options: {missing_options}")
34 | if len(unk_options) > 0:
35 | logger.warning(f"Unrecognized options: {unk_options}")
36 | setup_obj: EvalSetupBase = setup_cls(
37 | work_dir=Macros.work_dir,
38 | work_subdir=self.work_subdir,
39 | setup_name=setup_name,
40 | **setup_options,
41 | )
42 |
43 | # Save setup configs
44 | setup_options["setup"] = setup
45 | IOUtils.dump(setup_dir / "setup_config.json", setup_options, IOUtils.Format.jsonNoSort)
46 |
47 | # Prepare data
48 | setup_obj.prepare()
49 |
50 | # Print dvc commands
51 | print(Utils.suggest_dvc_add(setup_obj.setup_dir))
52 |
53 | def load_setup(self, setup_dir: Path, setup_name: str) -> EvalSetupBase:
54 | """
55 | Loads the setup from its save directory, with updating setup_name.
56 | """
57 | config = IOUtils.load(setup_dir / "setup_config.json", IOUtils.Format.json)
58 | setup_cls = get_setup_cls(config.pop("setup"))
59 | setup_obj = setup_cls(work_dir=Macros.work_dir, work_subdir=self.work_subdir, setup_name=setup_name, **config)
60 | return setup_obj
61 |
62 | def exp_train(
63 | self,
64 | setup_name: str,
65 | exp_name: str,
66 | model_name: str,
67 | cont_train: bool,
68 | no_save: bool,
69 | **cmd_options,
70 | ):
71 | # Load saved setup
72 | setup_dir = self.work_subdir / "setup" / setup_name
73 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir)
74 | setup = self.load_setup(setup_dir, setup_name)
75 |
76 | if not cont_train:
77 | # Delete existing trained model
78 | IOUtils.rm_dir(setup.get_exp_dir(exp_name))
79 |
80 | # Invoke training
81 | setup.train(exp_name, model_name, cont_train, no_save, **cmd_options)
82 |
83 | # Print dvc commands
84 | print(Utils.suggest_dvc_add(setup.get_exp_dir(exp_name)))
85 |
86 | def exp_eval(
87 | self,
88 | setup_name: str,
89 | exp_name: str,
90 | action: Optional[str],
91 | gpu_id: int = 0,
92 | ):
93 | # Load saved setup
94 | setup_dir = self.work_subdir / "setup" / setup_name
95 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir)
96 | setup = self.load_setup(setup_dir, setup_name)
97 |
98 | # Invoke eval
99 | setup.eval(exp_name, action, gpu_id=gpu_id)
100 |
101 | # Print dvc commands
102 | print(Utils.suggest_dvc_add(setup.get_result_dir(exp_name)))
103 |
104 | def exp_compute_metrics(
105 | self,
106 | setup_name: str,
107 | exp_name: str,
108 | action: Optional[str] = None,
109 | ):
110 | # Load saved setup
111 | setup_dir = self.work_subdir / "setup" / setup_name
112 | Utils.expect_dir_or_suggest_dvc_pull(setup_dir)
113 | setup = self.load_setup(setup_dir, setup_name)
114 |
115 | # Invoke eval
116 | setup.compute_metrics(exp_name, action)
117 |
118 | # Print dvc commands
119 | print(Utils.suggest_dvc_add(setup.get_metric_dir(exp_name)))
120 |
--------------------------------------------------------------------------------
/python/tseval/metnam/eval/MNModelLoader.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from seutil import IOUtils, LoggingUtils
4 |
5 | from tseval.metnam.model import get_model_cls
6 | from tseval.metnam.model.MNModelBase import MNModelBase
7 | from tseval.Utils import Utils
8 |
9 | logger = LoggingUtils.get_logger(__name__)
10 |
11 |
12 | class MNModelLoader:
13 |
14 | @classmethod
15 | def init_or_load_model(
16 | cls,
17 | model_name: str,
18 | exp_dir: Path,
19 | cont_train: bool,
20 | no_save: bool,
21 | cmd_options: dict,
22 | ) -> MNModelBase:
23 | model_cls = get_model_cls(model_name)
24 | model_work_dir = exp_dir / "model"
25 |
26 | if cont_train and model_work_dir.is_dir() and not no_save:
27 | # Restore model name
28 | loaded_model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt)
29 | if model_name != loaded_model_name:
30 | raise ValueError(f"Contradicting model name (saved: {model_name}; new {loaded_model_name})")
31 |
32 | # Warning about any additional command line arguments
33 | if len(cmd_options) > 0:
34 | logger.warning(f"These options will not be used in cont_train mode: {cmd_options}")
35 |
36 | # Load existing model
37 | model: MNModelBase = model_cls.load(model_work_dir)
38 | else:
39 |
40 | if not no_save:
41 | exp_dir.mkdir(parents=True, exist_ok=True)
42 |
43 | # Save model name
44 | IOUtils.dump(exp_dir / "model_name.txt", model_name, IOUtils.Format.txt)
45 |
46 | # Prepare directory for model
47 | IOUtils.rm(model_work_dir)
48 | model_work_dir.mkdir(parents=True)
49 |
50 | # Initialize the model, using command line arguments
51 | model_options, unk_options, missing_options = Utils.parse_cmd_options_for_type(
52 | cmd_options,
53 | model_cls,
54 | ["self", "model_work_dir"],
55 | )
56 | if len(missing_options) > 0:
57 | raise KeyError(f"Missing options: {missing_options}")
58 | if len(unk_options) > 0:
59 | logger.warning(f"Unrecognized options: {unk_options}")
60 |
61 | model: MNModelBase = model_cls(model_work_dir=model_work_dir, no_save=no_save, **model_options)
62 |
63 | if not no_save:
64 | # Save model configs
65 | IOUtils.dump(exp_dir / "model_config.json", model_options, IOUtils.Format.jsonNoSort)
66 | return model
67 |
68 | @classmethod
69 | def load_model(cls, exp_dir: Path) -> MNModelBase:
70 | """
71 | Loads a trained model from exp_dir. Gets the model name from train_config.json.
72 | """
73 | Utils.expect_dir_or_suggest_dvc_pull(exp_dir)
74 | model_name = IOUtils.load(exp_dir / "model_name.txt", IOUtils.Format.txt)
75 | model_cls = get_model_cls(model_name)
76 | model_dir = exp_dir / "model"
77 | return model_cls.load(model_dir)
78 |
--------------------------------------------------------------------------------
/python/tseval/metnam/eval/StandardSetup.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import time
4 | from pathlib import Path
5 | from typing import Dict, List
6 |
7 | import numpy as np
8 | from seutil import IOUtils, LoggingUtils
9 | from tqdm import tqdm
10 |
11 | from tseval.data.MethodData import MethodData
12 | from tseval.eval.EvalMetrics import EvalMetrics
13 | from tseval.eval.EvalSetupBase import EvalSetupBase
14 | from tseval.Macros import Macros
15 | from tseval.metnam.eval.MNModelLoader import MNModelLoader
16 | from tseval.metnam.model.MNModelBase import MNModelBase
17 | from tseval.util.ModelUtils import ModelUtils
18 | from tseval.util.TrainConfig import TrainConfig
19 | from tseval.Utils import Utils
20 |
21 | logger = LoggingUtils.get_logger(__name__)
22 |
23 |
24 | class StandardSetup(EvalSetupBase):
25 |
26 | # Validation set on self's split type
27 | EVAL_VAL = Macros.val
28 | # Test_standard set on self's split type
29 | EVAL_TESTS = Macros.test_standard
30 | # Test_common sets, pairwisely between self's split type and other split types
31 | EVAL_TESTC = Macros.test_common
32 |
33 | EVAL_ACTIONS = [EVAL_VAL, EVAL_TESTS, EVAL_TESTC]
34 | DEFAULT_EVAL_ACTION = EVAL_TESTC
35 |
36 | def __init__(
37 | self,
38 | work_dir: Path,
39 | work_subdir: Path,
40 | setup_name: str,
41 | split_name: str,
42 | split_type: str,
43 | ):
44 | super().__init__(work_dir, work_subdir, setup_name)
45 | self.split_name = split_name
46 | self.split_type = split_type
47 |
48 | def prepare(self) -> None:
49 | # Check and prepare directories
50 | split_dir = self.get_split_dir(self.split_name)
51 | Utils.expect_dir_or_suggest_dvc_pull(self.shared_data_dir)
52 | Utils.expect_dir_or_suggest_dvc_pull(split_dir)
53 | IOUtils.rm_dir(self.data_dir)
54 | self.data_dir.mkdir(parents=True)
55 |
56 | # Copy split indexes
57 | all_indexes = []
58 | for sn in [Macros.train, Macros.val, Macros.test_standard]:
59 | ids = IOUtils.load(split_dir / f"{self.split_type}-{sn}.json", IOUtils.Format.json)
60 | all_indexes += ids
61 | IOUtils.dump(self.data_dir / f"split_{sn}.json", ids, IOUtils.Format.json)
62 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type):
63 | ids = IOUtils.load(split_dir / f"{s1}-{s2}-{Macros.test_common}.json", IOUtils.Format.json)
64 | all_indexes += ids
65 | IOUtils.dump(
66 | self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json",
67 | ids,
68 | IOUtils.Format.json,
69 | )
70 | all_indexes = list(sorted(set(all_indexes)))
71 |
72 | # Load raw data
73 | tbar = tqdm()
74 | dataset: List[MethodData] = MethodData.load_dataset(
75 | self.shared_data_dir,
76 | expected_ids=all_indexes,
77 | only=["code", "code_masked", "name", "misc"],
78 | tbar=tbar,
79 | )
80 | tbar.close()
81 |
82 | # Subtokenize code and comments
83 | tbar = tqdm()
84 |
85 | tbar.set_description("Subtokenizing")
86 | tbar.reset(len(dataset))
87 | orig_code_list = []
88 | name_list = []
89 | for d in dataset:
90 | # Replace the mask token such that the code remains parsable
91 | code_masked = d.code_masked.replace("", "METHODNAMEMASK")
92 | d.fill_none()
93 | d.misc["orig_code"] = d.code
94 | d.misc["orig_code_masked"] = code_masked
95 | d.misc["orig_name"] = d.name
96 | orig_code_list.append(code_masked)
97 | name_list.append(d.name)
98 | tbar.update(1)
99 |
100 | tokenized_code_list = ModelUtils.tokenize_javaparser_batch(orig_code_list, dup_share=False, tbar=tbar)
101 |
102 | tbar.set_description("Subtokenizing")
103 | tbar.reset(len(dataset))
104 | for d, tokenized_code, name in zip(dataset, tokenized_code_list, name_list):
105 | d.code, d.misc["code_src_ids"] = ModelUtils.subtokenize_batch(tokenized_code)
106 | d.code_masked = None
107 | d.name = ModelUtils.subtokenize(name)
108 |
109 | # convert name to lower case
110 | d.name = [t.lower() for t in d.name]
111 | tbar.update(1)
112 | tbar.close()
113 |
114 | # Clean eval ids
115 | indexed_dataset = {d.id: d for d in dataset}
116 | for sn in [Macros.val, Macros.test_standard] + [f"{Macros.test_common}-{x}-{y}" for x, y in Macros.get_pairwise_split_types_with(self.split_type)]:
117 | eval_ids = IOUtils.load(self.data_dir / f"split_{sn}.json", IOUtils.Format.json)
118 | IOUtils.dump(self.data_dir / f"split_{sn}.json", self.clean_eval_set(indexed_dataset, eval_ids), IOUtils.Format.json)
119 |
120 | # Save dataset
121 | MethodData.save_dataset(dataset, self.data_dir)
122 |
123 | def clean_eval_set(self, indexed_dataset: Dict[int, MethodData], eval_ids: List[int]) -> List[int]:
124 | """
125 | Keeps the eval set absolutely clean by:
126 | - Remove duplicate (name, code) pairs;
127 | indexed_dataset should already been subtokenized.
128 | """
129 | seen_data = set()
130 | clean_eval_ids = []
131 | for i in eval_ids:
132 | data = indexed_dataset[i]
133 |
134 | # Remove duplicate (name, code) pairs
135 | data_key = (tuple(data.code), tuple(data.name))
136 | if data_key in seen_data:
137 | continue
138 | else:
139 | seen_data.add(data_key)
140 |
141 | clean_eval_ids.append(i)
142 | return clean_eval_ids
143 |
144 | def train(self, exp_name: str, model_name: str, cont_train: bool, no_save: bool, **options) -> None:
145 | # Init or load model
146 | exp_dir = self.get_exp_dir(exp_name)
147 | train_config = TrainConfig.get_train_config_from_cmd_options(options)
148 | model = MNModelLoader.init_or_load_model(model_name, exp_dir, cont_train, no_save, options)
149 | if not no_save:
150 | IOUtils.dump(exp_dir / "train_config.jsonl", [IOUtils.jsonfy(train_config)], IOUtils.Format.jsonList, append=True)
151 |
152 | # Load data
153 | tbar = tqdm(desc="Loading data")
154 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar)
155 | indexed_dataset = {d.id: d for d in dataset}
156 |
157 | tbar.set_description("Loading data | take indexes")
158 | tbar.reset(2)
159 |
160 | train_ids = IOUtils.load(self.data_dir / f"split_{Macros.train}.json", IOUtils.Format.json)
161 | train_dataset = [indexed_dataset[i] for i in train_ids]
162 | tbar.update(1)
163 |
164 | val_ids = IOUtils.load(self.data_dir / f"split_{Macros.val}.json", IOUtils.Format.json)
165 | val_dataset = [indexed_dataset[i] for i in val_ids]
166 | tbar.update(1)
167 |
168 | tbar.close()
169 |
170 | # Train model
171 | start = time.time()
172 | model.train(train_dataset, val_dataset, resources_path=self.data_dir, train_config=train_config)
173 | end = time.time()
174 |
175 | if not no_save:
176 | model.save()
177 | IOUtils.dump(exp_dir / "train_time.json", end - start, IOUtils.Format.json)
178 |
179 | def eval_one(self, exp_name: str, eval_ids: List[int], prefix: str, indexed_dataset: Dict[int, MethodData], model: MNModelBase, gpu_id: int = 0):
180 | # Prepare output directory
181 | result_dir = self.get_result_dir(exp_name)
182 | result_dir.mkdir(parents=True, exist_ok=True)
183 |
184 | # Prepare eval data (remove target)
185 | eval_dataset = [indexed_dataset[i] for i in eval_ids]
186 | golds = []
187 | for d in eval_dataset:
188 | golds.append(d.name)
189 | d.name = ["dummy"]
190 | d.misc["orig_name"] = "dummy"
191 |
192 | # Perform batched queries
193 | tbar = tqdm(desc=f"Predicting | {prefix}")
194 | eval_start = time.time()
195 | predictions = model.batch_predict(eval_dataset, tbar=tbar, gpu_id=gpu_id)
196 | eval_end = time.time()
197 | tbar.close()
198 |
199 | eval_time = eval_end - eval_start
200 |
201 | # Save predictions & golds
202 | IOUtils.dump(result_dir / f"{prefix}_predictions.jsonl", predictions, IOUtils.Format.jsonList)
203 | IOUtils.dump(result_dir / f"{prefix}_golds.jsonl", golds, IOUtils.Format.jsonList)
204 | IOUtils.dump(result_dir / f"{prefix}_eval_time.json", eval_time, IOUtils.Format.json)
205 |
206 | def eval(self, exp_name: str, action: str = None, gpu_id: int = 0) -> None:
207 | if action is None:
208 | action = self.DEFAULT_EVAL_ACTION
209 | if action not in self.EVAL_ACTIONS:
210 | raise RuntimeError(f"Unknown eval action {action}")
211 |
212 | # Load eval data
213 | tbar = tqdm(desc="Loading data")
214 | dataset = MethodData.load_dataset(self.data_dir, tbar=tbar)
215 | indexed_dataset = {d.id: d for d in dataset}
216 | tbar.close()
217 |
218 | # Load model
219 | exp_dir = self.get_exp_dir(exp_name)
220 | model: MNModelBase = MNModelLoader.load_model(exp_dir)
221 | if not model.is_train_finished():
222 | logger.warning(f"Model not finished training, at {exp_dir}")
223 |
224 | # Invoke eval_one with specific data ids
225 | if action in [self.EVAL_VAL, self.EVAL_TESTS]:
226 | self.eval_one(
227 | exp_name,
228 | IOUtils.load(self.data_dir / f"split_{action}.json", IOUtils.Format.json),
229 | action,
230 | indexed_dataset,
231 | model,
232 | gpu_id=gpu_id,
233 | )
234 | elif action == self.EVAL_TESTC:
235 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type):
236 | self.eval_one(
237 | exp_name,
238 | IOUtils.load(self.data_dir / f"split_{Macros.test_common}-{s1}-{s2}.json", IOUtils.Format.json),
239 | f"{Macros.test_common}-{s1}-{s2}",
240 | copy.deepcopy(indexed_dataset),
241 | model,
242 | gpu_id=gpu_id,
243 | )
244 | else:
245 | raise RuntimeError(f"Unknown action {action}")
246 |
247 | def compute_metrics_one(self, exp_name: str, prefix: str):
248 | # Prepare output directory
249 | metric_dir = self.get_metric_dir(exp_name)
250 | metric_dir.mkdir(parents=True, exist_ok=True)
251 |
252 | # Load golds and predictions
253 | result_dir = self.get_result_dir(exp_name)
254 | Utils.expect_dir_or_suggest_dvc_pull(result_dir)
255 | golds = IOUtils.load(result_dir / f"{prefix}_golds.jsonl", IOUtils.Format.jsonList)
256 | predictions = IOUtils.load(result_dir / f"{prefix}_predictions.jsonl", IOUtils.Format.jsonList)
257 |
258 | metrics_list: Dict[str, List] = collections.defaultdict(list)
259 | metrics_list["exact_match"] = EvalMetrics.batch_exact_match(golds, predictions)
260 | metrics_list["token_acc"] = EvalMetrics.batch_token_acc(golds, predictions)
261 | metrics_list["bleu"] = EvalMetrics.batch_bleu(golds, predictions)
262 | rouge_l_res = EvalMetrics.batch_rouge_l(golds, predictions)
263 | metrics_list["rouge_l_f"] = [x["f"] for x in rouge_l_res]
264 | metrics_list["rouge_l_p"] = [x["p"] for x in rouge_l_res]
265 | metrics_list["rouge_l_r"] = [x["r"] for x in rouge_l_res]
266 | metrics_list["meteor"] = EvalMetrics.batch_meteor(golds, predictions)
267 | set_match_res = EvalMetrics.batch_set_match(golds, predictions)
268 | metrics_list["set_match_f"] = [x["f"] for x in set_match_res]
269 | metrics_list["set_match_p"] = [x["p"] for x in set_match_res]
270 | metrics_list["set_match_r"] = [x["r"] for x in set_match_res]
271 |
272 | # Take average
273 | metrics = {}
274 | for k, l in metrics_list.items():
275 | metrics[k] = np.mean(l).item()
276 |
277 | # Save metrics
278 | IOUtils.dump(metric_dir / f"{prefix}_metrics.json", metrics, IOUtils.Format.jsonNoSort)
279 | IOUtils.dump(metric_dir / f"{prefix}_metrics.txt", [f"{k}: {v}" for k, v in metrics.items()], IOUtils.Format.txtList)
280 | IOUtils.dump(metric_dir / f"{prefix}_metrics_list.pkl", metrics_list, IOUtils.Format.pkl)
281 |
282 | def compute_metrics(self, exp_name: str, action: str = None) -> None:
283 | if action is None:
284 | action = self.DEFAULT_EVAL_ACTION
285 | if action not in self.EVAL_ACTIONS:
286 | raise RuntimeError(f"Unknown eval action {action}")
287 |
288 | if action in [self.EVAL_VAL, self.EVAL_TESTS]:
289 | self.compute_metrics_one(
290 | exp_name,
291 | action,
292 | )
293 | elif action == self.EVAL_TESTC:
294 | for s1, s2 in Macros.get_pairwise_split_types_with(self.split_type):
295 | self.compute_metrics_one(
296 | exp_name,
297 | f"{Macros.test_common}-{s1}-{s2}",
298 | )
299 | else:
300 | raise RuntimeError(f"Unknown action {action}")
301 |
--------------------------------------------------------------------------------
/python/tseval/metnam/eval/__init__.py:
--------------------------------------------------------------------------------
1 | def get_setup_cls(name: str) -> type:
2 | if name == "StandardSetup":
3 | from tseval.metnam.eval.StandardSetup import StandardSetup
4 | return StandardSetup
5 | else:
6 | raise ValueError(f"No setup with name {name}")
7 |
--------------------------------------------------------------------------------
/python/tseval/metnam/model/MNModelBase.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from pathlib import Path
3 | from typing import List, Optional
4 |
5 | from seutil import IOUtils
6 | from tqdm import tqdm
7 |
8 | from tseval.data.MethodData import MethodData
9 | from tseval.util.TrainConfig import TrainConfig
10 |
11 |
12 | class MNModelBase:
13 |
14 | def __init__(self, model_work_dir: Path, no_save: bool = False):
15 | self.model_work_dir = model_work_dir
16 | self.no_save = no_save
17 |
18 | @abc.abstractmethod
19 | def train(
20 | self,
21 | train_dataset: List[MethodData],
22 | val_dataset: List[MethodData],
23 | resources_path: Optional[Path] = None,
24 | train_config: Optional[TrainConfig] = None,
25 | ):
26 | """
27 | Trains the model.
28 |
29 | :param train_dataset: training set.
30 | :param val_dataset: validation set.
31 | :param resources_path: path to resources that could be shared by multiple model's training process,
32 | e.g., pre-trained embeddings.
33 | """
34 | raise NotImplementedError
35 |
36 | @abc.abstractmethod
37 | def is_train_finished(self) -> bool:
38 | raise NotImplementedError
39 |
40 | @abc.abstractmethod
41 | def predict(
42 | self,
43 | data: MethodData,
44 | gpu_id: int = 0,
45 | ) -> List[str]:
46 | """
47 | Predicts the comment summary given the context in data. The model should output
48 | results with a confidence score in [0, 1].
49 | :param data: the data, with its statements partially filled.
50 | :return: a list of predicted comment summary tokens.
51 | """
52 | raise NotImplementedError
53 |
54 | def batch_predict(
55 | self,
56 | dataset: List[MethodData],
57 | tbar: Optional[tqdm] = None,
58 | gpu_id: int = 0,
59 | ) -> List[List[str]]:
60 | """
61 | Performs batched predictions using given dataset as inputs.
62 |
63 | The default implementation invokes #predict multiple times. Subclass can override
64 | this method to speed up the prediction by using batching.
65 |
66 | :param dataset: a list of inputs.
67 | :param tbar: an optional tqdm progress bar to show prediction progress.
68 | :return: a list of the return value of #predict.
69 | """
70 | if tbar is not None:
71 | tbar.reset(len(dataset))
72 |
73 | results = []
74 | for data in dataset:
75 | results.append(self.predict(data, gpu_id=gpu_id))
76 | if tbar is not None:
77 | tbar.update(1)
78 |
79 | return results
80 |
81 | def save(self) -> None:
82 | """
83 | Saves the current model at the work_dir.
84 | Default behavior is to serialize the entire object in model.pkl.
85 | """
86 | if not self.no_save:
87 | IOUtils.dump(self.model_work_dir / "model.pkl", self, IOUtils.Format.pkl)
88 |
89 | @classmethod
90 | def load(cls, work_dir) -> "MNModelBase":
91 | """
92 | Loads a model from the work_dir.
93 | Default behavior is to deserialize the object from model.pkl, with resetting its work_dir.
94 | """
95 | obj = IOUtils.load(work_dir / "model.pkl", IOUtils.Format.pkl)
96 | obj.model_work_dir = work_dir
97 | return obj
98 |
--------------------------------------------------------------------------------
/python/tseval/metnam/model/__init__.py:
--------------------------------------------------------------------------------
1 | def get_model_cls(name: str) -> type:
2 | if name == "Code2SeqICLR19":
3 | from tseval.metnam.model.Code2SeqICLR19 import Code2SeqICLR19
4 | return Code2SeqICLR19
5 | elif name == "Code2VecPOPL19":
6 | from tseval.metnam.model.Code2VecPOPL19 import Code2VecPOPL19
7 | return Code2VecPOPL19
8 | else:
9 | raise ValueError(f"No model with name {name}")
10 |
--------------------------------------------------------------------------------
/python/tseval/util/ModelUtils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import copy
3 | import re
4 | import tempfile
5 | import time
6 | from pathlib import Path
7 | from typing import Dict, List, Optional, OrderedDict, Tuple
8 |
9 | from seutil import BashUtils, IOUtils, LoggingUtils
10 | from tqdm import tqdm
11 |
12 | from tseval.Environment import Environment
13 |
14 | logger = LoggingUtils.get_logger(__name__)
15 |
16 |
17 | class ModelUtils:
18 |
19 | TOKENIZE_JAVAPARSER_BATCH_SIZE = 10000
20 |
21 | @classmethod
22 | def get_random_seed(cls) -> int:
23 | """
24 | Generates a random int as seed, within the range of [0, 2**32)
25 | The seed is generated based on current time
26 | """
27 | return time.time_ns() % (2**32)
28 |
29 | @classmethod
30 | def tokenize_javaparser(cls, code: str) -> List[str]:
31 | return cls.tokenize_javaparser_batch([code])[0]
32 |
33 | @classmethod
34 | def tokenize_javaparser_batch(
35 | cls,
36 | code_list: List[str],
37 | dup_share: bool = True,
38 | tbar: Optional[tqdm] = None,
39 | ) -> List[List[str]]:
40 | """
41 | Tokenizes a list of code using JavaParser.
42 | :param code_list: a list of code to be tokenized.
43 | :param dup_share: if True (default), the returned lists of tokens will be shared across duplicate code
44 | (thus modifying one of them will affect others).
45 | :param tbar: optional tqdm progress bar.
46 | :returns a list of tokenized code.
47 | """
48 | # get an unique list of code to tokenize, maintain a backward mapping
49 | code_2_id: OrderedDict[str, int] = collections.OrderedDict()
50 | ids: List[int] = []
51 | for c in code_list:
52 | ids.append(code_2_id.setdefault(c, len(code_2_id)))
53 |
54 | unique_code_list = list(code_2_id.keys())
55 | if tbar is not None:
56 | tbar.set_description(f"JavaParser Tokenize ({len(unique_code_list)}U/{len(code_list)})")
57 | tbar.reset(len(unique_code_list))
58 |
59 | # Tokenize (with batching)
60 | unique_tokens_list = []
61 | for beg in range(0, len(unique_code_list), cls.TOKENIZE_JAVAPARSER_BATCH_SIZE):
62 | unique_tokens_list += cls.tokenize_javaparser_batch_(
63 | unique_code_list[beg:beg + cls.TOKENIZE_JAVAPARSER_BATCH_SIZE], tbar=tbar,
64 | )
65 |
66 | if dup_share:
67 | return [unique_tokens_list[i] for i in ids]
68 | else:
69 | return [copy.copy(unique_tokens_list[i]) for i in ids]
70 |
71 | @classmethod
72 | def tokenize_javaparser_batch_(
73 | cls,
74 | code_list: List[str],
75 | tbar: Optional[tqdm] = None,
76 | ):
77 | # Use JavaParser to tokenize
78 | Environment.require_collector()
79 |
80 | tokenizer_inputs = []
81 | for code in code_list:
82 | tokenizer_inputs.append({
83 | "index": len(tokenizer_inputs),
84 | "code": code,
85 | })
86 |
87 | inputs_file = Path(tempfile.mktemp())
88 | IOUtils.dump(inputs_file, tokenizer_inputs, IOUtils.Format.json)
89 | outputs_file = Path(tempfile.mktemp())
90 |
91 | BashUtils.run(
92 | f"java -cp {Environment.collector_jar} org.tseval.ExtractToken '{inputs_file}' '{outputs_file}'",
93 | expected_return_code=0,
94 | )
95 |
96 | tokenizer_outputs = IOUtils.load(outputs_file, IOUtils.Format.json)
97 | IOUtils.rm(inputs_file)
98 | IOUtils.rm(outputs_file)
99 |
100 | # Check for tokenizer failures
101 | for code, output in zip(code_list, tokenizer_outputs):
102 | if len(code.strip()) == 0:
103 | logger.warning(f"Empty code: {code}")
104 | continue
105 | if len(output["tokens"]) == 0:
106 | logger.warning(f"Tokenizer failed: {code}")
107 |
108 | if tbar is not None:
109 | tbar.update(len(code_list))
110 |
111 | return [d["tokens"] for d in tokenizer_outputs]
112 |
113 | RE_SUBTOKENIZE = re.compile(r"(?<=[_$])(?!$)|(?"
115 |
116 | @classmethod
117 | def is_identifier(cls, token: str) -> bool:
118 | return len(token) > 0 and \
119 | (token[0].isalpha() or token[0] in "_$") and \
120 | all([c.isalnum() or c in "_$" for c in token])
121 |
122 | @classmethod
123 | def subtokenize(cls, token: str) -> List[str]:
124 | """
125 | Subtokenizes an identifier name into subtokens, by CamelCase and snake_case.
126 | """
127 | # Only subtokenize identifier words (starts with letter _$, contains only alnum and _$)
128 | if cls.is_identifier(token):
129 | return cls.RE_SUBTOKENIZE.split(token)
130 | else:
131 | return [token]
132 |
133 | @classmethod
134 | def subtokenize_batch(cls, tokens: List[str]) -> Tuple[List[str], List[int]]:
135 | """
136 | Subtokenizes list of tokens.
137 | :return a list of subtokens, and a list of pointers to the original token indices.
138 | """
139 | sub_tokens = []
140 | src_indices = []
141 | for i, token in enumerate(tokens):
142 | new_sub_tokens = cls.subtokenize(token)
143 | sub_tokens += new_sub_tokens
144 | src_indices += [i] * len(new_sub_tokens)
145 | return sub_tokens, src_indices
146 |
147 | @classmethod
148 | def subtokenize_space_batch(cls, tokens: List[str]) -> Tuple[List[str], List[int]]:
149 | """
150 | Subtokenizes list of tokens, and inserts special token when necessary
151 | (between two identifiers).
152 | :return a list of subtokens, and a list of pointers to the original token indices.
153 | """
154 | sub_tokens = []
155 | src_indices = []
156 | last_is_identifier = False
157 | for i, token in enumerate(tokens):
158 | is_identifier = cls.is_identifier(token)
159 | if last_is_identifier and is_identifier:
160 | sub_tokens.append(cls.SPACE_TOKEN)
161 | src_indices.append(-1)
162 | new_sub_tokens = cls.subtokenize(token)
163 | sub_tokens += new_sub_tokens
164 | src_indices += [i] * len(new_sub_tokens)
165 | last_is_identifier = is_identifier
166 | return sub_tokens, src_indices
167 |
168 | @classmethod
169 | def regroup_subtokens(cls, subtokens: List[str], src_indices: List[int]) -> List[str]:
170 | """
171 | Given a list of subtokens and the original token indices, groups them back to tokens.
172 | :param subtokens: a list of subtokens.
173 | :param src_indices: the i-th indice should point to the original token that
174 | sub_tokens[i] belongs to; -1 means it is a special sub_token.
175 | :return: a list of tokens after regrouping.
176 | """
177 | id2tokens: Dict[int, str] = collections.defaultdict(str)
178 | for subtoken, i in zip(subtokens, src_indices):
179 | if i >= 0:
180 | id2tokens[i] += subtoken
181 | return [id2tokens[i] for i in sorted(id2tokens.keys())]
182 |
--------------------------------------------------------------------------------
/python/tseval/util/TrainConfig.py:
--------------------------------------------------------------------------------
1 | from typing import get_type_hints
2 |
3 | from recordclass import RecordClass
4 |
5 |
6 | class TrainConfig(RecordClass):
7 | train_session_time: int = 20 * 3600
8 | gpu_id: int = 0
9 |
10 | @classmethod
11 | def get_train_config_from_cmd_options(cls, options: dict) -> "TrainConfig":
12 | """
13 | Gets a TrainConfig from the command line options (the options will be modified
14 | in place to remove the parsed fields).
15 | """
16 | field_values = {}
17 | for f, t in get_type_hints(cls).items():
18 | if f in options:
19 | field_values[f] = t(options.pop(f))
20 | return cls(**field_values)
21 |
--------------------------------------------------------------------------------
/python/tseval/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EngineeringSoftware/time-segmented-evaluation/df052dbed791b39dc95dab6d7e6e0e8fb6b76946/python/tseval/util/__init__.py
--------------------------------------------------------------------------------