├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── appendix ├── README.md ├── dataset_percentiles.md ├── filters │ ├── README.md │ ├── first_sentence.py │ ├── length.py │ └── verb_direct_object.py ├── predictions │ ├── codereviewer_with-history_preds │ │ ├── context-ratio_0.0_with-history.jsonl.gz │ │ ├── context-ratio_0.25_with-history.jsonl.gz │ │ └── context-ratio_0.5_with-history.jsonl.gz │ ├── codereviewer_without-history_preds │ │ ├── context-ratio_0.0_without-history.jsonl.gz │ │ ├── context-ratio_0.25_without-history.jsonl.gz │ │ └── context-ratio_0.5_without-history.jsonl.gz │ ├── codet5_with-history_preds │ │ ├── context-ratio_0.0_with-history.jsonl.gz │ │ ├── context-ratio_0.25_with-history.jsonl.gz │ │ └── context-ratio_0.5_with-history.jsonl.gz │ ├── codet5_without-history_preds │ │ ├── context-ratio_0.0_without-history.jsonl.gz │ │ ├── context-ratio_0.25_without-history.jsonl.gz │ │ └── context-ratio_0.5_without-history.jsonl.gz │ ├── gpt-3.5-turbo_with-history_preds │ │ ├── context-ratio_0.0_with-history.jsonl.gz │ │ ├── context-ratio_0.25_with-history.jsonl.gz │ │ └── context-ratio_0.5_with-history.jsonl.gz │ ├── gpt-3.5-turbo_without-history_preds │ │ ├── context-ratio_0.0_without-history.jsonl.gz │ │ ├── context-ratio_0.25_without-history.jsonl.gz │ │ └── context-ratio_0.5_without-history.jsonl.gz │ ├── race_codet5_with-history_preds │ │ ├── context-ratio_0.0_with-history.jsonl.gz │ │ ├── context-ratio_0.25_with-history.jsonl.gz │ │ └── context-ratio_0.5_with-history.jsonl.gz │ └── race_codet5_without-history_preds │ │ ├── context-ratio_0.0_without-history.jsonl.gz │ │ ├── context-ratio_0.25_without-history.jsonl.gz │ │ └── context-ratio_0.5_without-history.jsonl.gz └── results │ ├── README.md │ ├── metrics │ ├── README.md │ ├── cmg_approaches │ │ ├── full_metrics.jsonl │ │ └── prefix_metrics.jsonl │ ├── filters │ │ ├── filtered │ │ │ ├── full_metrics.jsonl │ │ │ └── prefix_metrics.jsonl │ │ ├── out_of_filters │ │ │ ├── full_metrics.jsonl │ │ │ └── prefix_metrics.jsonl │ │ └── random │ │ │ ├── full_metrics.jsonl │ │ │ └── prefix_metrics.jsonl │ └── llm │ │ ├── full_metrics.jsonl │ │ └── prefix_metrics.jsonl │ └── plots │ ├── README.md │ ├── cmg_approaches │ ├── context_ratio_0.pdf │ ├── context_ratio_25.pdf │ └── context_ratio_50.pdf │ ├── filters │ ├── filtered │ │ ├── context_ratio_0.pdf │ │ ├── context_ratio_25.pdf │ │ └── context_ratio_50.pdf │ ├── out-of-filters │ │ ├── context_ratio_0.pdf │ │ ├── context_ratio_25.pdf │ │ └── context_ratio_50.pdf │ └── random │ │ ├── context_ratio_0.pdf │ │ ├── context_ratio_25.pdf │ │ └── context_ratio_50.pdf │ └── llm │ ├── context_ratio_0.pdf │ ├── context_ratio_25.pdf │ └── context_ratio_50.pdf ├── compute_metrics.py ├── conf ├── __init__.py ├── data │ ├── dataset_config.py │ └── input_config.py ├── eval_config.py ├── metrics_config.py ├── model │ ├── base_configs.py │ └── configs.py ├── retrieval_config.py ├── sweep.yaml └── train_config.py ├── eval.py ├── mypy.ini ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── retrieve.py ├── src ├── __init__.py ├── data_utils │ ├── __init__.py │ ├── cmc_data_module.py │ ├── cmc_dataset_w_history.py │ ├── data_collators │ │ ├── __init__.py │ │ ├── base_collator_utils.py │ │ ├── data_collator_retrieval.py │ │ ├── data_collator_test.py │ │ └── data_collator_train.py │ └── preprocessors │ │ ├── __init__.py │ │ ├── base_preprocessor.py │ │ ├── codereviewer_preprocessor.py │ │ ├── default_preprocessor.py │ │ ├── race_preprocessor.py │ │ └── reused_implementations │ │ ├── __init__.py │ │ └── race.py ├── metrics │ ├── __init__.py │ ├── accuracy.py │ ├── bleu_norm.py │ ├── edit_similarity.py │ ├── exact_match.py │ ├── log_mnext.py │ ├── mrr.py │ └── reused_implementations │ │ ├── __init__.py │ │ ├── b_norm.py │ │ └── log_mnext.py ├── model │ ├── __init__.py │ ├── cmc_module.py │ └── configurations │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── decoder_wrapper.py │ │ ├── encoder_decoder_wrapper.py │ │ ├── race_wrapper.py │ │ ├── seq2seq_wrapper.py │ │ └── utils │ │ └── race.py ├── retrieval │ ├── __init__.py │ ├── embedders │ │ ├── __init__.py │ │ └── transformer.py │ ├── search │ │ ├── __init__.py │ │ └── diff.py │ └── utils │ │ ├── __init__.py │ │ └── typing_utils.py └── utils │ ├── __init__.py │ ├── evaluation_metrics.py │ ├── model_utils.py │ ├── prefix_utils.py │ ├── typing_utils.py │ └── wandb_organize_utils.py ├── tests ├── __init__.py ├── test_accuracy.py ├── test_codereviewer_preprocessor.py ├── test_data_collator_base.py ├── test_data_collator_test.py ├── test_data_collator_train.py ├── test_default_preprocessor.py ├── test_diff_search.py ├── test_edit_similarity.py ├── test_exact_match.py ├── test_full_pipeline.sh ├── test_mrr.py ├── test_pipeline.py ├── test_prefix_utils.py ├── test_race.py ├── test_race_preprocessor.py └── test_transformer_embedder.py └── train.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Test & Lint 2 | 3 | on: push 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.10 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.10.4 14 | - name: Install dependencies 15 | run: | 16 | python -m pip install --upgrade pip 17 | pip install -r requirements.txt 18 | - name: Lint with isort 19 | run: | 20 | isort --profile black . 21 | - name: Lint with Black 22 | run: | 23 | black . --check -l 120 24 | - name: Lint with mypy 25 | run: | 26 | mypy . 27 | - name: Run unit tests with pytest 28 | run: | 29 | pytest -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /raw_data/ 2 | __pycache__/ 3 | .idea 4 | /data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Aleksandra Eliseeva 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /appendix/README.md: -------------------------------------------------------------------------------- 1 | # Supplementary Materials 2 | 3 | > :bulb: Note: each folder has its own README with further details! 4 | 5 | This folder contains: 6 | 7 | * [`results`](results): comprehensive metrics for all our experiments, stored as JSONLines files; 8 | * [`predictions`](predictions): model predictions for all our experiments, stored as JSONLines files; 9 | * [`filters`](filters): the implementations of frequent filters from CMG research we used in our experiments. 10 | 11 | ### Other details referenced in the paper 12 | * The prompts for LLMs are available under [`appendix_llm` tag](https://github.com/JetBrains-Research/commit_message_generation/tree/appendix_llm) in [`src/data_utils/cmg_prompts.py`](https://github.com/JetBrains-Research/commit_message_generation/blob/appendix_llm/src/data_utils/cmg_prompts.py). 13 | * The regular expressions we used for commit messages processing are available in [another repo](https://github.com/saridormi/commit_chronicle) in [`src/processing/message_processor.py`](https://github.com/saridormi/commit_chronicle/blob/appendix/src/processing/message_processor.py). 14 | * Specific percentiles that we used to drop outliers from our dataset are available in [`dataset_percentiles.md`](dataset_percentiles.md). 15 | -------------------------------------------------------------------------------- /appendix/dataset_percentiles.md: -------------------------------------------------------------------------------- 1 | # Specific percentiles values 2 | 3 | During "Filtering outliers" step in our dataset processing pipeline, we dropped examples out of [5% percentile, 95% percentile] range. 4 | In this file, we provide specific values for these percentiles. 5 | 6 | | Feature |5% percentile| 95% percentile | 7 | |:-----------------------:|:-----------:|:--------------:| 8 | | Messages: # characters | 12 | 491 | 9 | | Messages: # tokens | 2 | 53 | 10 | | Diffs: # characterss | 240 | 42 785 | 11 | | Diffs: # tokens | 20 | 3740 | 12 | | Diffs: # modified files | 1 | 16 | -------------------------------------------------------------------------------- /appendix/filters/README.md: -------------------------------------------------------------------------------- 1 | # Filters 2 | 3 | In this directory, we provide the implementations for the following filters: 4 | 5 | * **First Sentence** in [`first_sentence.py`](first_sentence.py) 6 | * **Verb-Direct Object** in [`verb_direct_object.py`](verb_direct_object.py) 7 | * **Message Length**/**Diff Length** in [`length.py`](length.py) 8 | -------------------------------------------------------------------------------- /appendix/filters/first_sentence.py: -------------------------------------------------------------------------------- 1 | def is_one_sentence(text: str, nl_char: str = "\n") -> bool: 2 | """Implements single sentence filter. 3 | 4 | Args: 5 | text: Input string. 6 | 7 | Returns: 8 | True if input string has only one sentence, False otherwise. 9 | 10 | Notes: 11 | Determines the number of sentences based on newline characters. 12 | """ 13 | lines = text.split(nl_char) 14 | return len(lines) == 1 15 | -------------------------------------------------------------------------------- /appendix/filters/length.py: -------------------------------------------------------------------------------- 1 | from nltk import wordpunct_tokenize 2 | 3 | 4 | def is_shorter_than_n_tokens(text: str, n: int) -> bool: 5 | """Implements filter by length. 6 | 7 | Args: 8 | text: Input string. 9 | n: Number of tokens to consider. 10 | 11 | Returns: 12 | True if input string has <= n tokens, False otherwise. 13 | 14 | Notes: 15 | Performs tokenization by whitespaces and punctuation. 16 | """ 17 | num_tokens = len(wordpunct_tokenize(text)) 18 | return num_tokens <= n 19 | -------------------------------------------------------------------------------- /appendix/filters/verb_direct_object.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | spacy_nlp = spacy.load("en_core_web_sm") 4 | 5 | 6 | def is_verb_direct_object(text: str): 7 | """Implements filter by Verb-Direct Object grammar structure via spaCy package. 8 | 9 | Args: 10 | text: Input string. 11 | 12 | Returns: 13 | True if input string starts with V-DO grammar structure, False otherwise. 14 | 15 | Notes: 16 | * Only the first sentence is considered. 17 | * Since past forms (e.g. fixed) and gerunds (e.g. fixing) are often not tagged as verbs, 18 | there is an extra preprocessing step: lemmatization of first word if it is a verb. 19 | * Current implementation supports not only Direct Objects consisting of single noun, 20 | but also clauses/phrases. 21 | """ 22 | first_word = text.split(" ")[0] 23 | processed_first_word = spacy_nlp(first_word)[0] 24 | if processed_first_word.pos_ == "VERB": 25 | text = " ".join([processed_first_word.lemma_] + text.split(" ")[1:]) 26 | 27 | doc = spacy_nlp(text) 28 | 29 | token = doc[0] 30 | if ( 31 | token.pos_ == "VERB" 32 | and token.dep_ == "ROOT" 33 | and len([t.dep_ for t in token.children]) 34 | and [t.dep_ for t in token.children][0] == "dobj" 35 | ): 36 | return True 37 | return False 38 | -------------------------------------------------------------------------------- /appendix/predictions/codereviewer_with-history_preds/context-ratio_0.0_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codereviewer_with-history_preds/context-ratio_0.0_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codereviewer_with-history_preds/context-ratio_0.25_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codereviewer_with-history_preds/context-ratio_0.25_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codereviewer_with-history_preds/context-ratio_0.5_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codereviewer_with-history_preds/context-ratio_0.5_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codereviewer_without-history_preds/context-ratio_0.0_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codereviewer_without-history_preds/context-ratio_0.0_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codereviewer_without-history_preds/context-ratio_0.25_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codereviewer_without-history_preds/context-ratio_0.25_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codereviewer_without-history_preds/context-ratio_0.5_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codereviewer_without-history_preds/context-ratio_0.5_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codet5_with-history_preds/context-ratio_0.0_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codet5_with-history_preds/context-ratio_0.0_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codet5_with-history_preds/context-ratio_0.25_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codet5_with-history_preds/context-ratio_0.25_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codet5_with-history_preds/context-ratio_0.5_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codet5_with-history_preds/context-ratio_0.5_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codet5_without-history_preds/context-ratio_0.0_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codet5_without-history_preds/context-ratio_0.0_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codet5_without-history_preds/context-ratio_0.25_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codet5_without-history_preds/context-ratio_0.25_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/codet5_without-history_preds/context-ratio_0.5_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/codet5_without-history_preds/context-ratio_0.5_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/gpt-3.5-turbo_with-history_preds/context-ratio_0.0_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/gpt-3.5-turbo_with-history_preds/context-ratio_0.0_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/gpt-3.5-turbo_with-history_preds/context-ratio_0.25_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/gpt-3.5-turbo_with-history_preds/context-ratio_0.25_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/gpt-3.5-turbo_with-history_preds/context-ratio_0.5_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/gpt-3.5-turbo_with-history_preds/context-ratio_0.5_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/gpt-3.5-turbo_without-history_preds/context-ratio_0.0_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/gpt-3.5-turbo_without-history_preds/context-ratio_0.0_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/gpt-3.5-turbo_without-history_preds/context-ratio_0.25_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/gpt-3.5-turbo_without-history_preds/context-ratio_0.25_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/gpt-3.5-turbo_without-history_preds/context-ratio_0.5_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/gpt-3.5-turbo_without-history_preds/context-ratio_0.5_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/race_codet5_with-history_preds/context-ratio_0.0_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/race_codet5_with-history_preds/context-ratio_0.0_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/race_codet5_with-history_preds/context-ratio_0.25_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/race_codet5_with-history_preds/context-ratio_0.25_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/race_codet5_with-history_preds/context-ratio_0.5_with-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/race_codet5_with-history_preds/context-ratio_0.5_with-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/race_codet5_without-history_preds/context-ratio_0.0_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/race_codet5_without-history_preds/context-ratio_0.0_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/race_codet5_without-history_preds/context-ratio_0.25_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/race_codet5_without-history_preds/context-ratio_0.25_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/predictions/race_codet5_without-history_preds/context-ratio_0.5_without-history.jsonl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/predictions/race_codet5_without-history_preds/context-ratio_0.5_without-history.jsonl.gz -------------------------------------------------------------------------------- /appendix/results/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | In this directory, we provide comprehensive results of our experiments. 4 | 5 | * Plots with prefix-level metrics are available in [`plots`](plots) directory. 6 | * Metric values are available in [`metrics`](metrics) directory. 7 | 8 | **Note:** each directory has its own README with further details. 9 | -------------------------------------------------------------------------------- /appendix/results/metrics/README.md: -------------------------------------------------------------------------------- 1 | # Metrics 2 | 3 | In this directory, we provide metric values for all our experiments. 4 | 5 | * [`cmg_approaches`](cmg_approaches) – results for CMG approaches on $CMG_{test}$. 6 | * [`llm`](llm) – results for CMG approaches and LLM GPT-3.5-turbo on $LLM_{test}$. 7 | * [`filters`](filters) – results for CMG approaches on Filtered, Out-of-Filters and Random subsets of $CMG_{test}$ with 10,385 examples. 8 | 9 | For each setting, we provide: 10 | * `full_metrics.jsonl` – metrics between full predictions and targets, stored in JSONLines format. 11 | * `prefix_metrics.jsonl` – metrics between prefixes of predictions and targets, stored in JSONLines format. 12 | -------------------------------------------------------------------------------- /appendix/results/metrics/cmg_approaches/full_metrics.jsonl: -------------------------------------------------------------------------------- 1 | {"context_ratio":"0%","model":"RACE","history_mode":"without-history","b-norm":15.32,"edit_similarity":29.02,"exact_match@1":11.37,"exact_match@2":3.07} 2 | {"context_ratio":"0%","model":"RACE","history_mode":"with-history","b-norm":16.91,"edit_similarity":31.15,"exact_match@1":17.95,"exact_match@2":4.36} 3 | {"context_ratio":"0%","model":"CodeT5","history_mode":"with-history","b-norm":16.8,"edit_similarity":30.91,"exact_match@1":17.68,"exact_match@2":4.27} 4 | {"context_ratio":"0%","model":"CodeT5","history_mode":"without-history","b-norm":15.12,"edit_similarity":28.71,"exact_match@1":10.9,"exact_match@2":3.03} 5 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"with-history","b-norm":16.78,"edit_similarity":30.74,"exact_match@1":17.83,"exact_match@2":4.38} 6 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"without-history","b-norm":15.15,"edit_similarity":28.76,"exact_match@1":10.87,"exact_match@2":3.05} 7 | {"context_ratio":"25%","model":"RACE","history_mode":"without-history","b-norm":18.38,"edit_similarity":30.91,"exact_match@1":46.62,"exact_match@2":13.45} 8 | {"context_ratio":"25%","model":"RACE","history_mode":"with-history","b-norm":22.16,"edit_similarity":33.78,"exact_match@1":45.36,"exact_match@2":13.4} 9 | {"context_ratio":"25%","model":"CodeT5","history_mode":"with-history","b-norm":21.94,"edit_similarity":33.31,"exact_match@1":44.98,"exact_match@2":13.1} 10 | {"context_ratio":"25%","model":"CodeT5","history_mode":"without-history","b-norm":17.91,"edit_similarity":30.54,"exact_match@1":45.35,"exact_match@2":12.92} 11 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"with-history","b-norm":21.84,"edit_similarity":32.9,"exact_match@1":45.58,"exact_match@2":13.28} 12 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"without-history","b-norm":18.1,"edit_similarity":30.93,"exact_match@1":46.05,"exact_match@2":13.35} 13 | {"context_ratio":"50%","model":"RACE","history_mode":"without-history","b-norm":24.74,"edit_similarity":33.22,"exact_match@1":50.68,"exact_match@2":14.38} 14 | {"context_ratio":"50%","model":"RACE","history_mode":"with-history","b-norm":27.28,"edit_similarity":34.69,"exact_match@1":47.84,"exact_match@2":13.26} 15 | {"context_ratio":"50%","model":"CodeT5","history_mode":"with-history","b-norm":26.9,"edit_similarity":33.95,"exact_match@1":47.45,"exact_match@2":12.75} 16 | {"context_ratio":"50%","model":"CodeT5","history_mode":"without-history","b-norm":24.13,"edit_similarity":32.74,"exact_match@1":49.94,"exact_match@2":14.03} 17 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"with-history","b-norm":26.94,"edit_similarity":33.76,"exact_match@1":48.1,"exact_match@2":12.9} 18 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"without-history","b-norm":24.35,"edit_similarity":33.2,"exact_match@1":50.9,"exact_match@2":14.59} 19 | -------------------------------------------------------------------------------- /appendix/results/metrics/filters/filtered/full_metrics.jsonl: -------------------------------------------------------------------------------- 1 | {"context_ratio":"0%","model":"RACE","history_mode":"with-history","b-norm":22.3,"edit_similarity":36.14,"exact_match@1":32.2,"exact_match@2":7.62} 2 | {"context_ratio":"0%","model":"RACE","history_mode":"without-history","b-norm":19.5,"edit_similarity":33.03,"exact_match@1":17.03,"exact_match@2":4.56} 3 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"with-history","b-norm":22.14,"edit_similarity":35.54,"exact_match@1":31.8,"exact_match@2":7.61} 4 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"without-history","b-norm":18.88,"edit_similarity":32.26,"exact_match@1":16.1,"exact_match@2":4.24} 5 | {"context_ratio":"0%","model":"CodeT5","history_mode":"with-history","b-norm":22.21,"edit_similarity":35.36,"exact_match@1":31.58,"exact_match@2":7.54} 6 | {"context_ratio":"0%","model":"CodeT5","history_mode":"without-history","b-norm":19.17,"edit_similarity":32.28,"exact_match@1":16.08,"exact_match@2":4.36} 7 | {"context_ratio":"25%","model":"RACE","history_mode":"with-history","b-norm":29.39,"edit_similarity":40.25,"exact_match@1":52.35,"exact_match@2":17.0} 8 | {"context_ratio":"25%","model":"RACE","history_mode":"without-history","b-norm":23.48,"edit_similarity":35.85,"exact_match@1":49.91,"exact_match@2":16.28} 9 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"with-history","b-norm":29.22,"edit_similarity":39.35,"exact_match@1":51.99,"exact_match@2":16.89} 10 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"without-history","b-norm":22.88,"edit_similarity":35.56,"exact_match@1":49.04,"exact_match@2":15.67} 11 | {"context_ratio":"25%","model":"CodeT5","history_mode":"with-history","b-norm":29.21,"edit_similarity":39.36,"exact_match@1":51.55,"exact_match@2":16.29} 12 | {"context_ratio":"25%","model":"CodeT5","history_mode":"without-history","b-norm":22.77,"edit_similarity":35.11,"exact_match@1":48.51,"exact_match@2":15.55} 13 | {"context_ratio":"50%","model":"RACE","history_mode":"with-history","b-norm":35.78,"edit_similarity":41.36,"exact_match@1":54.34,"exact_match@2":17.25} 14 | {"context_ratio":"50%","model":"RACE","history_mode":"without-history","b-norm":31.46,"edit_similarity":39.02,"exact_match@1":58.19,"exact_match@2":17.89} 15 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"with-history","b-norm":35.28,"edit_similarity":40.29,"exact_match@1":53.98,"exact_match@2":16.5} 16 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"without-history","b-norm":30.82,"edit_similarity":38.9,"exact_match@1":57.92,"exact_match@2":17.2} 17 | {"context_ratio":"50%","model":"CodeT5","history_mode":"with-history","b-norm":35.2,"edit_similarity":40.17,"exact_match@1":53.4,"exact_match@2":16.25} 18 | {"context_ratio":"50%","model":"CodeT5","history_mode":"without-history","b-norm":30.76,"edit_similarity":38.39,"exact_match@1":57.3,"exact_match@2":16.9} 19 | -------------------------------------------------------------------------------- /appendix/results/metrics/filters/out_of_filters/full_metrics.jsonl: -------------------------------------------------------------------------------- 1 | {"context_ratio":"0%","model":"RACE","history_mode":"with-history","b-norm":5.74,"edit_similarity":20.83,"exact_match@1":10.81,"exact_match@2":3.04} 2 | {"context_ratio":"0%","model":"RACE","history_mode":"without-history","b-norm":5.79,"edit_similarity":20.15,"exact_match@1":9.55,"exact_match@2":2.79} 3 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"with-history","b-norm":5.63,"edit_similarity":20.75,"exact_match@1":10.47,"exact_match@2":3.17} 4 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"without-history","b-norm":6.22,"edit_similarity":20.51,"exact_match@1":9.86,"exact_match@2":3.16} 5 | {"context_ratio":"0%","model":"CodeT5","history_mode":"with-history","b-norm":5.81,"edit_similarity":21.11,"exact_match@1":10.3,"exact_match@2":3.23} 6 | {"context_ratio":"0%","model":"CodeT5","history_mode":"without-history","b-norm":6.09,"edit_similarity":20.34,"exact_match@1":9.98,"exact_match@2":3.11} 7 | {"context_ratio":"25%","model":"RACE","history_mode":"with-history","b-norm":5.52,"edit_similarity":15.54,"exact_match@1":44.48,"exact_match@2":10.02} 8 | {"context_ratio":"25%","model":"RACE","history_mode":"without-history","b-norm":6.16,"edit_similarity":16.21,"exact_match@1":47.68,"exact_match@2":11.47} 9 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"with-history","b-norm":5.41,"edit_similarity":15.42,"exact_match@1":45.38,"exact_match@2":10.4} 10 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"without-history","b-norm":6.37,"edit_similarity":16.71,"exact_match@1":48.79,"exact_match@2":11.64} 11 | {"context_ratio":"25%","model":"CodeT5","history_mode":"with-history","b-norm":5.61,"edit_similarity":16.0,"exact_match@1":44.55,"exact_match@2":10.52} 12 | {"context_ratio":"25%","model":"CodeT5","history_mode":"without-history","b-norm":6.21,"edit_similarity":16.44,"exact_match@1":47.41,"exact_match@2":10.98} 13 | {"context_ratio":"50%","model":"RACE","history_mode":"with-history","b-norm":6.89,"edit_similarity":15.9,"exact_match@1":39.09,"exact_match@2":7.4} 14 | {"context_ratio":"50%","model":"RACE","history_mode":"without-history","b-norm":7.74,"edit_similarity":16.97,"exact_match@1":40.8,"exact_match@2":8.88} 15 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"with-history","b-norm":6.73,"edit_similarity":15.58,"exact_match@1":40.42,"exact_match@2":7.59} 16 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"without-history","b-norm":8.04,"edit_similarity":17.78,"exact_match@1":42.67,"exact_match@2":9.32} 17 | {"context_ratio":"50%","model":"CodeT5","history_mode":"with-history","b-norm":6.79,"edit_similarity":15.62,"exact_match@1":39.33,"exact_match@2":7.15} 18 | {"context_ratio":"50%","model":"CodeT5","history_mode":"without-history","b-norm":7.84,"edit_similarity":17.32,"exact_match@1":40.4,"exact_match@2":8.8} 19 | -------------------------------------------------------------------------------- /appendix/results/metrics/filters/random/full_metrics.jsonl: -------------------------------------------------------------------------------- 1 | {"context_ratio":"0%","model":"CodeT5","history_mode":"with-history","b-norm":16.48,"edit_similarity":30.96,"exact_match@1":17.25,"exact_match@2":3.89} 2 | {"context_ratio":"0%","model":"RACE","history_mode":"with-history","b-norm":16.63,"edit_similarity":31.3,"exact_match@1":17.85,"exact_match@2":4.19} 3 | {"context_ratio":"0%","model":"RACE","history_mode":"without-history","b-norm":15.12,"edit_similarity":29.09,"exact_match@1":11.19,"exact_match@2":3.37} 4 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"with-history","b-norm":16.51,"edit_similarity":30.66,"exact_match@1":17.67,"exact_match@2":3.94} 5 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"without-history","b-norm":14.85,"edit_similarity":28.65,"exact_match@1":10.58,"exact_match@2":2.96} 6 | {"context_ratio":"0%","model":"CodeT5","history_mode":"without-history","b-norm":14.91,"edit_similarity":28.74,"exact_match@1":10.89,"exact_match@2":3.07} 7 | {"context_ratio":"25%","model":"RACE","history_mode":"with-history","b-norm":22.16,"edit_similarity":33.81,"exact_match@1":45.12,"exact_match@2":13.72} 8 | {"context_ratio":"25%","model":"RACE","history_mode":"without-history","b-norm":18.28,"edit_similarity":30.95,"exact_match@1":46.24,"exact_match@2":14.05} 9 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"with-history","b-norm":21.71,"edit_similarity":32.92,"exact_match@1":45.11,"exact_match@2":13.42} 10 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"without-history","b-norm":18.04,"edit_similarity":30.94,"exact_match@1":45.51,"exact_match@2":13.82} 11 | {"context_ratio":"25%","model":"CodeT5","history_mode":"with-history","b-norm":21.74,"edit_similarity":33.24,"exact_match@1":44.26,"exact_match@2":13.43} 12 | {"context_ratio":"25%","model":"CodeT5","history_mode":"without-history","b-norm":17.66,"edit_similarity":30.48,"exact_match@1":44.88,"exact_match@2":12.93} 13 | {"context_ratio":"50%","model":"RACE","history_mode":"with-history","b-norm":27.21,"edit_similarity":34.81,"exact_match@1":47.93,"exact_match@2":13.06} 14 | {"context_ratio":"50%","model":"RACE","history_mode":"without-history","b-norm":24.77,"edit_similarity":33.25,"exact_match@1":50.86,"exact_match@2":14.61} 15 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"with-history","b-norm":26.95,"edit_similarity":33.93,"exact_match@1":48.15,"exact_match@2":13.42} 16 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"without-history","b-norm":24.43,"edit_similarity":33.31,"exact_match@1":51.16,"exact_match@2":14.92} 17 | {"context_ratio":"50%","model":"CodeT5","history_mode":"with-history","b-norm":26.91,"edit_similarity":33.97,"exact_match@1":47.97,"exact_match@2":12.94} 18 | {"context_ratio":"50%","model":"CodeT5","history_mode":"without-history","b-norm":24.09,"edit_similarity":32.75,"exact_match@1":50.29,"exact_match@2":14.29} 19 | -------------------------------------------------------------------------------- /appendix/results/metrics/llm/full_metrics.jsonl: -------------------------------------------------------------------------------- 1 | {"context_ratio":"0%","model":"RACE","history_mode":"with-history","b-norm":15.62,"edit_similarity":30.14,"exact_match@1":18.48,"exact_match@2":3.85} 2 | {"context_ratio":"0%","model":"RACE","history_mode":"without-history","b-norm":14.48,"edit_similarity":28.45,"exact_match@1":12.67,"exact_match@2":3.26} 3 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"with-history","b-norm":16.02,"edit_similarity":30.29,"exact_match@1":17.79,"exact_match@2":4.1} 4 | {"context_ratio":"0%","model":"CodeReviewer","history_mode":"without-history","b-norm":14.36,"edit_similarity":28.23,"exact_match@1":12.12,"exact_match@2":3.16} 5 | {"context_ratio":"0%","model":"CodeT5","history_mode":"with-history","b-norm":15.91,"edit_similarity":30.25,"exact_match@1":17.91,"exact_match@2":3.8} 6 | {"context_ratio":"0%","model":"CodeT5","history_mode":"without-history","b-norm":14.05,"edit_similarity":27.92,"exact_match@1":11.85,"exact_match@2":2.64} 7 | {"context_ratio":"0%","model":"gpt-3.5-turbo","history_mode":"history","b-norm":11.14,"edit_similarity":25.86,"exact_match@1":12.55,"exact_match@2":2.81} 8 | {"context_ratio":"0%","model":"gpt-3.5-turbo","history_mode":"simple","b-norm":9.26,"edit_similarity":24.58,"exact_match@1":8.65,"exact_match@2":1.47} 9 | {"context_ratio":"25%","model":"RACE","history_mode":"with-history","b-norm":21.14,"edit_similarity":32.35,"exact_match@1":44.92,"exact_match@2":12.93} 10 | {"context_ratio":"25%","model":"RACE","history_mode":"without-history","b-norm":17.54,"edit_similarity":30.13,"exact_match@1":46.68,"exact_match@2":13.33} 11 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"with-history","b-norm":21.35,"edit_similarity":32.68,"exact_match@1":45.69,"exact_match@2":13.15} 12 | {"context_ratio":"25%","model":"CodeReviewer","history_mode":"without-history","b-norm":17.34,"edit_similarity":30.45,"exact_match@1":45.96,"exact_match@2":13.03} 13 | {"context_ratio":"25%","model":"CodeT5","history_mode":"with-history","b-norm":21.11,"edit_similarity":32.31,"exact_match@1":43.68,"exact_match@2":12.45} 14 | {"context_ratio":"25%","model":"CodeT5","history_mode":"without-history","b-norm":17.16,"edit_similarity":30.02,"exact_match@1":45.19,"exact_match@2":12.85} 15 | {"context_ratio":"25%","model":"gpt-3.5-turbo","history_mode":"simple","b-norm":11.48,"edit_similarity":26.35,"exact_match@1":21.84,"exact_match@2":5.99} 16 | {"context_ratio":"25%","model":"gpt-3.5-turbo","history_mode":"history","b-norm":13.24,"edit_similarity":27.83,"exact_match@1":34.34,"exact_match@2":10.09} 17 | {"context_ratio":"50%","model":"RACE","history_mode":"with-history","b-norm":27.08,"edit_similarity":34.12,"exact_match@1":49.19,"exact_match@2":13.24} 18 | {"context_ratio":"50%","model":"RACE","history_mode":"without-history","b-norm":23.64,"edit_similarity":32.13,"exact_match@1":50.39,"exact_match@2":14.12} 19 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"with-history","b-norm":27.36,"edit_similarity":34.5,"exact_match@1":49.76,"exact_match@2":13.69} 20 | {"context_ratio":"50%","model":"CodeReviewer","history_mode":"without-history","b-norm":23.65,"edit_similarity":32.69,"exact_match@1":50.76,"exact_match@2":14.81} 21 | {"context_ratio":"50%","model":"CodeT5","history_mode":"with-history","b-norm":26.72,"edit_similarity":34.05,"exact_match@1":48.55,"exact_match@2":13.24} 22 | {"context_ratio":"50%","model":"CodeT5","history_mode":"without-history","b-norm":23.06,"edit_similarity":32.09,"exact_match@1":49.69,"exact_match@2":13.72} 23 | {"context_ratio":"50%","model":"gpt-3.5-turbo","history_mode":"simple","b-norm":10.93,"edit_similarity":23.51,"exact_match@1":24.15,"exact_match@2":7.86} 24 | {"context_ratio":"50%","model":"gpt-3.5-turbo","history_mode":"history","b-norm":12.35,"edit_similarity":24.6,"exact_match@1":32.47,"exact_match@2":11.29} 25 | -------------------------------------------------------------------------------- /appendix/results/plots/README.md: -------------------------------------------------------------------------------- 1 | # Plots 2 | 3 | In this directory, we provide plots for prefix metrics for all our experiments. 4 | 5 | * [`cmg_approaches`](cmg_approaches) – plots for CMG approaches on $CMG_{test}$. 6 | * [`llm`](llm) – plots for CMG approaches and LLM GPT-3.5-turbo on $LLM_{test}$. 7 | * [`filters`](filters) – plots for CMG approaches on Filtered, Out-of-Filters and Random subsets of $CMG_{test}$ with 10,385 examples. 8 | -------------------------------------------------------------------------------- /appendix/results/plots/cmg_approaches/context_ratio_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/cmg_approaches/context_ratio_0.pdf -------------------------------------------------------------------------------- /appendix/results/plots/cmg_approaches/context_ratio_25.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/cmg_approaches/context_ratio_25.pdf -------------------------------------------------------------------------------- /appendix/results/plots/cmg_approaches/context_ratio_50.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/cmg_approaches/context_ratio_50.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/filtered/context_ratio_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/filtered/context_ratio_0.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/filtered/context_ratio_25.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/filtered/context_ratio_25.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/filtered/context_ratio_50.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/filtered/context_ratio_50.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/out-of-filters/context_ratio_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/out-of-filters/context_ratio_0.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/out-of-filters/context_ratio_25.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/out-of-filters/context_ratio_25.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/out-of-filters/context_ratio_50.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/out-of-filters/context_ratio_50.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/random/context_ratio_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/random/context_ratio_0.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/random/context_ratio_25.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/random/context_ratio_25.pdf -------------------------------------------------------------------------------- /appendix/results/plots/filters/random/context_ratio_50.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/filters/random/context_ratio_50.pdf -------------------------------------------------------------------------------- /appendix/results/plots/llm/context_ratio_0.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/llm/context_ratio_0.pdf -------------------------------------------------------------------------------- /appendix/results/plots/llm/context_ratio_25.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/llm/context_ratio_25.pdf -------------------------------------------------------------------------------- /appendix/results/plots/llm/context_ratio_50.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/appendix/results/plots/llm/context_ratio_50.pdf -------------------------------------------------------------------------------- /conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .data.dataset_config import DatasetConfig 2 | from .data.input_config import InputConfig 3 | from .eval_config import EvalConfig 4 | from .metrics_config import MetricsConfig 5 | from .model.base_configs import ( 6 | BaseDecoderConfig, 7 | BaseEncoderDecoderConfig, 8 | BaseModelConfig, 9 | BaseRACEConfig, 10 | BaseSeq2SeqConfig, 11 | ) 12 | from .retrieval_config import RetrievalConfig 13 | from .train_config import TrainConfig 14 | 15 | __all__ = [ 16 | "BaseDecoderConfig", 17 | "BaseEncoderDecoderConfig", 18 | "BaseModelConfig", 19 | "BaseRACEConfig", 20 | "BaseSeq2SeqConfig", 21 | "EvalConfig", 22 | "TrainConfig", 23 | "InputConfig", 24 | "DatasetConfig", 25 | "MetricsConfig", 26 | "RetrievalConfig", 27 | ] 28 | -------------------------------------------------------------------------------- /conf/data/dataset_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from omegaconf import MISSING 5 | 6 | 7 | @dataclass 8 | class DataLoaderConfig: 9 | """ 10 | DataLoader configuration. 11 | """ 12 | 13 | batch_size: int = MISSING 14 | num_workers: int = MISSING 15 | 16 | 17 | @dataclass 18 | class DatasetConfig: 19 | """ 20 | Basic data-related configuration. 21 | 22 | Attributes: 23 | dataset_root: Directory with data, should contain files `train.jsonl`, `val.jsonl`, `test.jsonl`. 24 | preprocessor_chunksize: When data is preprocessed, how many examples should be in single chunk. 25 | stage: Name of current stage, set to "sweep" to use correct logic for W&B sweep. 26 | add_history_to_inputs: True to save history for each input example, 27 | False to load history in RAM and build inputs on the fly. 28 | use_train_downsample: True to use downsampled version of train set. 29 | use_eval_downsample: True to use downsampled versions of validation and test sets. 30 | testing: True to generate random numbers instead of actual data (used for tuning batch size). 31 | use_cache: True to look for preprocessed files, False to relaunch preprocessing even if preprocessed files are present. 32 | line_sep: Newline separator used in data (should be the same for diffs and messages). 33 | train_dataloader_conf: Configuration for train dataloader. 34 | val_dataloader_conf: Configuration for val dataloader. 35 | test_dataloader_conf: Configuration for test dataloader. 36 | """ 37 | 38 | dataset_root: str = "raw_data/multilang" 39 | preprocessor_chunksize: int = 4096 40 | stage: Optional[str] = None 41 | add_history_to_inputs: bool = True 42 | use_train_downsample: bool = False 43 | use_eval_downsample: bool = True 44 | testing: bool = False 45 | use_cache: bool = False 46 | line_sep: str = "\n" 47 | train_dataloader_conf: DataLoaderConfig = DataLoaderConfig(batch_size=16, num_workers=4) 48 | val_dataloader_conf: DataLoaderConfig = DataLoaderConfig(batch_size=16, num_workers=4) 49 | test_dataloader_conf: DataLoaderConfig = DataLoaderConfig(batch_size=1, num_workers=1) 50 | -------------------------------------------------------------------------------- /conf/data/input_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from omegaconf import MISSING 4 | 5 | 6 | @dataclass 7 | class InputConfig: 8 | """ 9 | Input configuration. 10 | 11 | Attributes: 12 | generate_with_history: `True` to concatenate commit message history with current commit message in decoder context during generation, `False` otherwise (ignored when `encoder_input_type` is `history`). 13 | train_with_history: `True` to concatenate commit message history with current commit message in decoder context during training, `False` otherwise (ignored when `encoder_input_type` is `history`). 14 | encoder_input_type: What type of input will be passed to encoder. Currently, `history` and `diff` are supported. 15 | context_ratio: A ratio of characters from input message to pass to model context during generation (should be in [0, 1] range). 16 | """ 17 | 18 | generate_with_history: bool = True 19 | train_with_history: bool = MISSING 20 | encoder_input_type: str = MISSING 21 | context_ratio: float = 0.0 22 | -------------------------------------------------------------------------------- /conf/eval_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, List, Optional 3 | 4 | from hydra.core.config_store import ConfigStore 5 | from omegaconf import MISSING 6 | 7 | from .data.dataset_config import DatasetConfig 8 | from .data.input_config import InputConfig 9 | from .model.base_configs import BaseModelConfig 10 | from .model.configs import ( 11 | CodeReviewerConfig, 12 | CodeT5Config, 13 | DistilGPT2Config, 14 | RACEConfig, 15 | RandomTransformerConfig, 16 | ) 17 | 18 | 19 | @dataclass 20 | class ArtifactEvalConfig: 21 | """ 22 | Configuration for W&B artifact with model checkpoint. 23 | 24 | Artifact name is not provided, because it's automatically retrieved from model and input configuration. 25 | 26 | Attributes: 27 | project: W&B project. 28 | version: Version tag of W&B artifact. 29 | artifact_path: Path to model checkpoint in artifact. 30 | local_path: Path to save artifact locally. 31 | """ 32 | 33 | project: str = "saridormi/commit_message_completion" 34 | version: str = "latest" 35 | artifact_path: str = "last.ckpt" 36 | local_path: str = "artifacts" 37 | 38 | 39 | @dataclass 40 | class WandbEvalConfig: 41 | """ 42 | Configuration for W&B logging. 43 | 44 | What's logged during evaluation: 45 | * (optional) load model checkpoint from W&B artifact 46 | * model predictions 47 | 48 | Attributes: 49 | use_wandb: Whether W&B will be used for logging or not. 50 | project: Name of project this run will appear in. 51 | load_artifact: Whether model checkpoint should be loaded from W&B artifact or not. 52 | use_api_key: True to read an API key from a local file (expected to be stored in `wandb_api_key.txt`). 53 | """ 54 | 55 | use_wandb: bool = True 56 | project: str = "commit_message_completion" 57 | load_artifact: bool = True 58 | use_api_key: bool = False 59 | artifact_config: ArtifactEvalConfig = field(default_factory=ArtifactEvalConfig) 60 | 61 | 62 | @dataclass 63 | class TrainerEvalConfig: 64 | """ 65 | Configuration for pytorch_lightning.Trainer. All options will be passes to Trainer as kwargs. 66 | (refer to docs: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) 67 | """ 68 | 69 | accelerator: str = "gpu" 70 | devices: Any = 1 71 | limit_test_batches: Optional[int] = None 72 | 73 | 74 | @dataclass 75 | class GenerationConfig: 76 | """ 77 | Configuration for generation. 78 | 79 | All options will be passed to HuggingFace's generate() as kwargs. 80 | (refer to docs: https://huggingface.co/docs/transformers/main_classes/text_generation) 81 | """ 82 | 83 | num_beams: int = 10 84 | repetition_penalty: float = 1.0 85 | length_penalty: float = 1.0 86 | no_repeat_ngram_size: int = 3 87 | max_new_tokens: int = 15 88 | 89 | 90 | @dataclass 91 | class EvalConfig: 92 | """ 93 | Configuration for evaluation. 94 | 95 | Args: 96 | stage: Set to "sweep" if you want to use validation data for tuning hyperparameters. 97 | ckpt_path: Local path to model checkpoint. Instead of this, you can also define configuration for loading artifact at WandbEvalConfig. 98 | """ 99 | 100 | defaults: List[Any] = field(default_factory=lambda: ["_self_", {"dataset": "multilang"}]) 101 | 102 | stage: str = "test" 103 | ckpt_path: str = "" 104 | dataset: DatasetConfig = MISSING 105 | model: BaseModelConfig = MISSING 106 | input: InputConfig = field(default_factory=InputConfig) 107 | logger: WandbEvalConfig = field(default_factory=WandbEvalConfig) 108 | generation: GenerationConfig = field(default_factory=GenerationConfig) 109 | trainer: TrainerEvalConfig = field(default_factory=TrainerEvalConfig) 110 | 111 | 112 | cs = ConfigStore.instance() 113 | cs.store(name="eval_config", node=EvalConfig) 114 | cs.store(name="distilgpt2", group="model", node=DistilGPT2Config) 115 | cs.store(name="random_roberta_2_random_gpt2_2", group="model", node=RandomTransformerConfig) 116 | cs.store(name="codet5", group="model", node=CodeT5Config) 117 | cs.store(name="codereviewer", group="model", node=CodeReviewerConfig) 118 | cs.store(name="race", group="model", node=RACEConfig) 119 | cs.store(name="multilang", group="dataset", node=DatasetConfig) 120 | -------------------------------------------------------------------------------- /conf/metrics_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List, Optional 3 | 4 | from hydra.core.config_store import ConfigStore 5 | from omegaconf import MISSING 6 | 7 | 8 | @dataclass 9 | class FilterConfig: 10 | """ 11 | Configuration for additional data filtering when calculating metrics. 12 | 13 | Attributes: 14 | path: Path to file with filters metadata for a test set. 15 | use_filtering: True to use additional data filtering, False otherwise. 16 | filters_to_include: List of column names to consider. Each column should be boolean. 17 | logic: A logic to follow when multiple columns are given (`and` for logical and, `or` for logical or). 18 | fit_filters: If True, will consider examples that fit given columns with given logic. 19 | If False, will consider examples that DON'T FIT given columns with given logic. 20 | use_pos_in_file_filtering: True to use `pos_in_file` column and only consider lines present in a given file, 21 | False to use boolean filters logic. 22 | 23 | """ 24 | 25 | path: str = "raw_data/multilang/downsample/filters/test.jsonl" 26 | use_filtering: bool = False 27 | filters_to_include: List[str] = field( 28 | default_factory=lambda: ["is_vdo", "one_sentence_newline", "message_30_tokens", "diff_100_tokens"] 29 | ) 30 | logic: str = "and" 31 | fit_filters: bool = True 32 | use_pos_in_file_filtering: bool = False 33 | use_subset: bool = False 34 | subset_num_examples: Optional[int] = None 35 | 36 | 37 | @dataclass 38 | class ArtifactMetricConfig: 39 | """ 40 | Configuration for W&B artifact with model predictions. 41 | 42 | Attributes: 43 | project: W&B project. 44 | name: Name of W&B artifact. 45 | version: Version tag of W&B artifact. 46 | artifact_path: Path to model predictions in artifact. 47 | local_path: Path to save artifact locally. 48 | """ 49 | 50 | project: str = "saridormi/commit_message_completion" 51 | name: str = MISSING 52 | version: str = "latest" 53 | artifact_path: str = MISSING 54 | local_path: str = "artifacts" 55 | 56 | 57 | @dataclass 58 | class WandbMetricConfig: 59 | """ 60 | Configuration for W&B logging. 61 | 62 | What's logged during metrics calculation: 63 | * (optional) load model predictions from W&B artifact 64 | * metrics 65 | 66 | Attributes: 67 | use_wandb: Whether W&B will be used for logging or not. 68 | project: Name of project this run will appear in. 69 | load_artifact: Whether model predictions should be loaded from W&B artifact or not. 70 | use_api_key: True to read an API key from a local file (expected to be stored in `wandb_api_key.txt`). 71 | """ 72 | 73 | use_wandb: bool = True 74 | project: str = "commit_message_completion" 75 | load_artifact: bool = True 76 | use_api_key: bool = False 77 | artifact_config: ArtifactMetricConfig = field(default_factory=ArtifactMetricConfig) 78 | 79 | 80 | @dataclass 81 | class MetricsConfig: 82 | """ 83 | Configuration for metrics calculation. 84 | 85 | Metrics are calculated: 86 | * between full predictions and targets 87 | * between all prefixes of N tokens of predictions of targets 88 | 89 | Attributes: 90 | preds_path: Local path to model predictions. Instead of this, you can also define configuration for loading artifact at WandbMetricConfig. 91 | include_short: False to only consider messages with >= i tokens when computing metrics for prefixes of i tokens, 92 | True to include all messages. 93 | max_n_tokens: Maximum number of tokens (for prefix-level metrics). 94 | """ 95 | 96 | preds_path: Optional[str] = None 97 | include_short: bool = False 98 | max_n_tokens: int = 15 99 | filter: FilterConfig = field(default_factory=FilterConfig) 100 | logger: WandbMetricConfig = field(default_factory=WandbMetricConfig) 101 | 102 | 103 | cs = ConfigStore.instance() 104 | cs.store(name="metrics_config", node=MetricsConfig) 105 | -------------------------------------------------------------------------------- /conf/model/base_configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | 4 | from omegaconf import MISSING 5 | 6 | 7 | @dataclass 8 | class BaseModelConfig: 9 | """ 10 | Basic model configuration. 11 | 12 | Attributes: 13 | configuration: What model architecture to use. Should be one of `decoder`, `encoder_decoder`, `seq2seq`, `race`. 14 | preprocessor_configuration: What diff processing strategy to use. Should be one of `default`, `codereviewer`, `race`. 15 | diff_tokenizer_name_or_path: Local path or name on HuggingFace Hub for diff tokenizer. 16 | msg_tokenizer_name_or_path: Local path or name on HuggingFace Hub for message tokenizer. 17 | encoder_context_max_length: Maximum allowed number of tokens for encoder context. 18 | decoder_context_max_length: Maximum allowed number of tokens for decoder context. 19 | """ 20 | 21 | configuration: str = MISSING 22 | preprocessor_configuration: str = "default" 23 | diff_tokenizer_name_or_path: str = MISSING 24 | msg_tokenizer_name_or_path: str = MISSING 25 | encoder_context_max_len: int = MISSING 26 | decoder_context_max_len: int = MISSING 27 | 28 | 29 | @dataclass 30 | class BaseDecoderConfig(BaseModelConfig): 31 | """ 32 | Base configuration for Transformer Decoder. 33 | """ 34 | 35 | configuration: str = "decoder" 36 | decoder_name_or_path: str = MISSING 37 | 38 | 39 | @dataclass 40 | class BaseEncoderDecoderConfig(BaseModelConfig): 41 | """ 42 | Base configuration for Transformer initialized with pretrained encoder/decoder. 43 | """ 44 | 45 | configuration: str = "encoder_decoder" 46 | num_layers_encoder: Optional[int] = None 47 | encoder_model_type: Optional[str] = None 48 | encoder_name_or_path: Optional[str] = None 49 | num_layers_decoder: Optional[int] = None 50 | decoder_model_type: Optional[str] = None 51 | decoder_name_or_path: Optional[str] = None 52 | tie_encoder_decoder: bool = MISSING 53 | tie_word_embeddings: bool = MISSING 54 | 55 | 56 | @dataclass 57 | class BaseSeq2SeqConfig(BaseModelConfig): 58 | """ 59 | Base configuration for pretrained seq2seq Transformer. 60 | """ 61 | 62 | configuration: str = "seq2seq" 63 | name_or_path: str = MISSING 64 | 65 | 66 | @dataclass 67 | class BaseRACEConfig(BaseModelConfig): 68 | """ 69 | Base configuration for RACE model. 70 | """ 71 | 72 | configuration: str = "race" 73 | preprocessor_configuration: str = "race" 74 | name_or_path: str = MISSING 75 | -------------------------------------------------------------------------------- /conf/model/configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from .base_configs import ( 4 | BaseDecoderConfig, 5 | BaseEncoderDecoderConfig, 6 | BaseRACEConfig, 7 | BaseSeq2SeqConfig, 8 | ) 9 | 10 | 11 | @dataclass 12 | class DistilGPT2Config(BaseDecoderConfig): 13 | diff_tokenizer_name_or_path: str = "distilgpt2" 14 | msg_tokenizer_name_or_path: str = "distilgpt2" 15 | encoder_context_max_len: int = 512 16 | decoder_context_max_len: int = 512 17 | decoder_name_or_path: str = "distilgpt2" 18 | 19 | 20 | @dataclass 21 | class RandomTransformerConfig(BaseEncoderDecoderConfig): 22 | diff_tokenizer_name_or_path: str = "raw_data/multilang/byte_level" 23 | msg_tokenizer_name_or_path: str = "raw_data/multilang/byte_level" 24 | encoder_context_max_len: int = 512 25 | decoder_context_max_len: int = 256 26 | 27 | num_layers_encoder: int = 2 28 | encoder_model_type: str = "roberta" 29 | 30 | num_layers_decoder: int = 2 31 | decoder_model_type: str = "gpt2" 32 | 33 | tie_encoder_decoder: bool = False 34 | tie_word_embeddings: bool = False 35 | 36 | 37 | @dataclass 38 | class CodeT5Config(BaseSeq2SeqConfig): 39 | diff_tokenizer_name_or_path: str = "Salesforce/codet5-base" 40 | msg_tokenizer_name_or_path: str = "Salesforce/codet5-base" 41 | encoder_context_max_len: int = 512 42 | decoder_context_max_len: int = 512 43 | 44 | name_or_path: str = "Salesforce/codet5-base" 45 | 46 | 47 | @dataclass 48 | class CodeReviewerConfig(BaseSeq2SeqConfig): 49 | preprocessor_configuration: str = "codereviewer" 50 | diff_tokenizer_name_or_path: str = "microsoft/codereviewer" 51 | msg_tokenizer_name_or_path: str = "microsoft/codereviewer" 52 | encoder_context_max_len: int = 512 53 | decoder_context_max_len: int = 512 54 | 55 | name_or_path: str = "microsoft/codereviewer" 56 | 57 | 58 | @dataclass 59 | class RACEConfig(BaseRACEConfig): 60 | diff_tokenizer_name_or_path: str = "Salesforce/codet5-base" 61 | msg_tokenizer_name_or_path: str = "Salesforce/codet5-base" 62 | encoder_context_max_len: int = 512 63 | decoder_context_max_len: int = 512 64 | 65 | name_or_path: str = "Salesforce/codet5-base" 66 | -------------------------------------------------------------------------------- /conf/retrieval_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, List 3 | 4 | from hydra.core.config_store import ConfigStore 5 | from omegaconf import MISSING 6 | 7 | from .data.dataset_config import DatasetConfig 8 | from .data.input_config import InputConfig 9 | from .model.base_configs import BaseModelConfig 10 | from .model.configs import ( 11 | CodeReviewerConfig, 12 | CodeT5Config, 13 | DistilGPT2Config, 14 | RACEConfig, 15 | RandomTransformerConfig, 16 | ) 17 | 18 | 19 | @dataclass 20 | class SearchConfig: 21 | num_neighbors: int = 1 22 | num_trees: int = 100 23 | load_index: bool = False 24 | load_index_path: str = "" 25 | index_root_dir: str = "ann_indices" 26 | 27 | 28 | @dataclass 29 | class ArtifactRetrievalConfig: 30 | """ 31 | Configuration for W&B artifact with model checkpoint. 32 | 33 | Artifact name is not provided because it's automatically retrieved from model and input configuration. 34 | 35 | Attributes: 36 | project: W&B project. 37 | version: Version tag of W&B artifact. 38 | artifact_path: Path to model checkpoint in artifact. 39 | local_path: Path to save artifact locally. 40 | """ 41 | 42 | project: str = "saridormi/commit_message_completion" 43 | version: str = "latest" 44 | artifact_path: str = "last.ckpt" 45 | local_path: str = "artifacts" 46 | 47 | 48 | @dataclass 49 | class WandbRetrievalConfig: 50 | """ 51 | Configuration for W&B logging. 52 | 53 | What's logged during evaluation: 54 | * (optional) load model checkpoint from W&B artifact 55 | * model predictions 56 | 57 | Attributes: 58 | use_wandb: Whether W&B will be used for logging or not. 59 | project: Name of a project this run will appear in. 60 | use_api_key: True to read an API key from a local file (expected to be stored in `wandb_api_key.txt`). 61 | download_artifact: Whether model checkpoint should be downloaded from W&B artifact or not. 62 | input_artifact: Configuration for W&B artifact with model checkpoint. 63 | upload_artifact: Whether retrieved predictions should be uploaded to W&B artifact or not. 64 | """ 65 | 66 | use_wandb: bool = True 67 | project: str = "commit_message_completion" 68 | use_api_key: bool = False 69 | download_artifact: bool = True 70 | input_artifact: ArtifactRetrievalConfig = field(default_factory=ArtifactRetrievalConfig) 71 | upload_artifact: bool = True 72 | 73 | 74 | @dataclass 75 | class EmbedderConfig: 76 | """ 77 | Configuration for Transformer encoder that is used to construct embeddings. 78 | 79 | Args: 80 | device: Set to `cpu` to run model on CPU and `cuda` to run model on GPU. Currently, only single-GPU setting is supported; if your system has more than 1 GPU, make sure to set CUDA_VISIBLE_DEVICES enviromental variable to a single GPU. 81 | precision: Set to 16 to use native mixed precision from PyTorch. 82 | normalize_embeddings: Set to True to normalize embeddings, so that L2-norm is equal to 1. 83 | """ 84 | 85 | device: str = "cpu" 86 | precision: int = 16 87 | normalize_embeddings: bool = True 88 | 89 | 90 | @dataclass 91 | class RetrievalConfig: 92 | """ 93 | Configuration for retrieval. 94 | 95 | Args: 96 | ckpt_path: Local path to model checkpoint. Instead of this, you can also define a configuration for loading artifact at WandbEvalConfig. 97 | """ 98 | 99 | defaults: List[Any] = field(default_factory=lambda: ["_self_", {"dataset": "multilang"}]) 100 | 101 | ckpt_path: str = "" 102 | dataset: DatasetConfig = MISSING 103 | model: BaseModelConfig = MISSING 104 | input: InputConfig = field(default_factory=InputConfig) 105 | search: SearchConfig = field(default_factory=SearchConfig) 106 | embedder: EmbedderConfig = field(default_factory=EmbedderConfig) 107 | logger: WandbRetrievalConfig = field(default_factory=WandbRetrievalConfig) 108 | 109 | 110 | cs = ConfigStore.instance() 111 | cs.store(name="retrieval_config", node=RetrievalConfig) 112 | cs.store(name="distilgpt2", group="model", node=DistilGPT2Config) 113 | cs.store(name="random_roberta_2_random_gpt2_2", group="model", node=RandomTransformerConfig) 114 | cs.store(name="codet5", group="model", node=CodeT5Config) 115 | cs.store(name="codereviewer", group="model", node=CodeReviewerConfig) 116 | cs.store(name="race", group="model", node=RACEConfig) 117 | cs.store(name="multilang", group="dataset", node=DatasetConfig) 118 | -------------------------------------------------------------------------------- /conf/sweep.yaml: -------------------------------------------------------------------------------- 1 | method: bayes 2 | 3 | metric: 4 | goal: minimize 5 | name: val_loss 6 | 7 | parameters: 8 | stage: 9 | distribution: constant 10 | value: sweep 11 | optimizer.learning_rate: 12 | min: !!float 1e-5 13 | max: !!float 1e-4 14 | optimizer.weight_decay: 15 | min: 0.0 16 | max: 0.1 17 | optimizer.ratio_warmup_steps: 18 | min: 0.0 19 | max: 0.05 20 | 21 | early_terminate: 22 | type: "hyperband" 23 | min_iter: 2 24 | 25 | command: 26 | - ${env} 27 | - ${interpreter} 28 | - ${program} 29 | - ${args_no_hyphens} 30 | 31 | program: train.py 32 | -------------------------------------------------------------------------------- /conf/train_config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, List, Optional 3 | 4 | from hydra.core.config_store import ConfigStore 5 | from omegaconf import MISSING 6 | 7 | from .data.dataset_config import DatasetConfig 8 | from .data.input_config import InputConfig 9 | from .model.base_configs import BaseModelConfig 10 | from .model.configs import ( 11 | CodeReviewerConfig, 12 | CodeT5Config, 13 | DistilGPT2Config, 14 | RACEConfig, 15 | RandomTransformerConfig, 16 | ) 17 | 18 | 19 | @dataclass 20 | class OptimizerConfig: 21 | """ 22 | Configuration for optimizer. 23 | 24 | Attributes: 25 | learning_rate: Learning rate for AdamW. 26 | initial_batch_size: If given, learning rate will be recalculated as (given lr) * (actual bs) / (initial bs). 27 | weight_decay: Weight decay for AdamW. 28 | num_warmup_steps: Number of warmup steps for linear scheduler with warmup. 29 | ratio_warmup_steps: Ratio of warmup steps for linear scheduler with warmup (so ratio_warmup_steps * total_steps will be used). 30 | """ 31 | 32 | learning_rate: float = 1e-5 33 | initial_batch_size: Optional[int] = None 34 | weight_decay: float = 0.1 35 | num_warmup_steps: Optional[int] = None 36 | ratio_warmup_steps: Optional[float] = None 37 | 38 | 39 | @dataclass 40 | class ArtifactTrainConfig: 41 | """ 42 | Configuration for W&B artifact. 43 | 44 | Artifact name is not configurable because it's automatically retrieved from model and input configuration. 45 | 46 | Attributes: 47 | load_artifact: True to download artifact from W&B, False otherwise. 48 | project: W&B project. 49 | version: Version tag of W&B artifact. 50 | artifact_path: Path to download in artifact. 51 | """ 52 | 53 | load_artifact: bool = True 54 | project: str = "saridormi/commit_message_completion" 55 | version: str = "latest" 56 | artifact_path: str = "last.ckpt" 57 | 58 | 59 | @dataclass 60 | class WandbTrainConfig: 61 | """ 62 | Configuration for W&B logging. 63 | 64 | What's logged during training: 65 | * loss & validation metrics 66 | * gradients 67 | * (optionally) model checkpoints 68 | 69 | Attributes: 70 | use_wandb: Whether W&B will be used for logging or not. 71 | use_api_key: True to read an API key from a local file (expected to be stored in `wandb_api_key.txt`). 72 | project: Name of a project this run will appear in. 73 | save_artifact: True to upload model checkpoints to W&B as artifacts, False otherwise. 74 | checkpoint: Artifact configuration for fine-tuned model checkpoint (option for RACE). 75 | retrieval: Artifact configuration for retrieved predictions (option for RACE). 76 | 77 | """ 78 | 79 | use_wandb: bool = True 80 | use_api_key: bool = False 81 | project: str = "commit_message_completion" 82 | save_artifact: bool = True 83 | checkpoint: ArtifactTrainConfig = field(default_factory=ArtifactTrainConfig) 84 | retrieval: ArtifactTrainConfig = field(default_factory=ArtifactTrainConfig) 85 | 86 | 87 | @dataclass 88 | class TrainerTrainConfig: 89 | """ 90 | Configuration for pytorch_lightning.Trainer. All options will be passed to Trainer as kwargs. 91 | 92 | Refer to docs: https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html 93 | 94 | Note: 95 | Defined fields are just the most frequently use arguments. You can easily add new ones using Hydra's 96 | override logic. E.g. `python train.py ++trainer.devices=4 ++trainer.strategy=ddp` 97 | """ 98 | 99 | max_epochs: int = 5 100 | precision: int = 16 101 | amp_backend: str = "native" 102 | accumulate_grad_batches: int = 1 103 | num_sanity_val_steps: int = 100 104 | gradient_clip_val: float = 1.0 105 | accelerator: str = "gpu" 106 | devices: Any = 1 107 | val_check_interval: Any = 1.0 108 | limit_train_batches: Optional[int] = None 109 | limit_val_batches: Optional[int] = None 110 | 111 | 112 | @dataclass 113 | class TrainConfig: 114 | """ 115 | Configuration for training. For further information, refer to corresponding subconfig classes. 116 | """ 117 | 118 | defaults: List[Any] = field(default_factory=lambda: ["_self_", {"dataset": "multilang"}]) 119 | dataset: DatasetConfig = MISSING 120 | model: BaseModelConfig = MISSING 121 | input: InputConfig = field(default_factory=InputConfig) 122 | logger: WandbTrainConfig = field(default_factory=WandbTrainConfig) 123 | optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) 124 | trainer: TrainerTrainConfig = field(default_factory=TrainerTrainConfig) 125 | 126 | 127 | cs = ConfigStore.instance() 128 | cs.store(name="train_config", node=TrainConfig) 129 | cs.store(name="distilgpt2", group="model", node=DistilGPT2Config) 130 | cs.store(name="random_roberta_2_random_gpt2_2", group="model", node=RandomTransformerConfig) 131 | cs.store(name="codet5", group="model", node=CodeT5Config) 132 | cs.store(name="codereviewer", group="model", node=CodeReviewerConfig) 133 | cs.store(name="race", group="model", node=RACEConfig) 134 | cs.store(name="multilang", group="dataset", node=DatasetConfig) 135 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import hydra 5 | import nltk 6 | import pytorch_lightning as pl 7 | import wandb 8 | from omegaconf import OmegaConf 9 | 10 | from conf import EvalConfig 11 | from src.data_utils import CMCDataModule 12 | from src.model import CMCModule 13 | from src.utils import WandbOrganizer 14 | 15 | nltk.download("omw-1.4") 16 | nltk.download("wordnet") 17 | 18 | 19 | @hydra.main(version_base="1.1", config_path="conf", config_name="eval_config") 20 | def main(cfg: EvalConfig) -> None: 21 | # ----------------------- 22 | # init - 23 | # ----------------------- 24 | pl.seed_everything(42) 25 | 26 | if cfg.model.diff_tokenizer_name_or_path == cfg.model.msg_tokenizer_name_or_path: 27 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 28 | 29 | dm = CMCDataModule( 30 | dataset_cfg=cfg.dataset, 31 | model_cfg=cfg.model, 32 | input_cfg=cfg.input, 33 | local_rank=int(os.environ.get("LOCAL_RANK", 0)), 34 | world_size=1, 35 | shift_labels=cfg.model.configuration != "decoder", 36 | process_retrieved=cfg.model.configuration == "race", 37 | ) 38 | 39 | if cfg.logger.use_wandb: 40 | if cfg.logger.use_api_key: 41 | with open(hydra.utils.to_absolute_path("wandb_api_key.txt"), "r") as f: 42 | os.environ["WANDB_API_KEY"] = f.read().strip() 43 | trainer_logger = pl.loggers.WandbLogger( 44 | name=f"context_ratio_{cfg.input.context_ratio}_{('with-history' if cfg.input.generate_with_history else 'without-history')}", 45 | project=cfg.logger.project, 46 | config=OmegaConf.to_container(cfg, resolve=True), 47 | job_type="eval", 48 | ) 49 | 50 | if cfg.model.configuration == "race": 51 | # download retrieved examples 52 | artifact = wandb.use_artifact( 53 | "codet5" 54 | + ("_with-history" if cfg.input.train_with_history else "_without-history") 55 | + "_retrieval:latest", 56 | type="retrieval", 57 | ) 58 | 59 | for part in ["train", "val", "test"]: 60 | artifact.get_path(f"{part}_predictions.jsonl").download( 61 | root=os.path.join( 62 | hydra.utils.to_absolute_path(dm.get_root_dir_for_part(cfg.dataset.dataset_root, part)), 63 | "retrieval" + ("_with_history" if cfg.input.train_with_history else "_without_history"), 64 | ) 65 | ) 66 | 67 | dm.prepare_data(stage="test") 68 | dm.setup(stage=cfg.stage) 69 | 70 | run_name = WandbOrganizer.get_run_name( 71 | cfg.model, 72 | encoder_input_type=cfg.input.encoder_input_type, 73 | train_with_history=cfg.input.train_with_history, 74 | ) 75 | 76 | if cfg.logger.use_wandb and cfg.logger.load_artifact: 77 | artifact_name = f"{cfg.logger.artifact_config.project}/{run_name}:{cfg.logger.artifact_config.version}" 78 | artifact = trainer_logger.experiment.use_artifact(artifact_name) 79 | if "tags" in artifact.metadata: 80 | trainer_logger.experiment.tags = artifact.metadata["tags"] + WandbOrganizer.get_tags_generate( 81 | generate_with_history=cfg.input.generate_with_history, context_ratio=cfg.input.context_ratio 82 | ) 83 | 84 | artifact.get_path(cfg.logger.artifact_config.artifact_path).download( 85 | root=hydra.utils.to_absolute_path(f"{cfg.logger.artifact_config.local_path}/{run_name}") 86 | ) 87 | 88 | cfg.ckpt_path = os.path.join( 89 | hydra.utils.to_absolute_path(f"{cfg.logger.artifact_config.local_path}/{run_name}"), 90 | cfg.logger.artifact_config.artifact_path, 91 | ) 92 | 93 | preds_table_tags = [f"context-ratio_{cfg.input.context_ratio}"] 94 | if cfg.input.encoder_input_type == "diff": 95 | if cfg.input.generate_with_history: 96 | preds_table_tags.append("with-history") 97 | else: 98 | preds_table_tags.append("without-history") 99 | preds_table_name = "_".join(preds_table_tags) 100 | 101 | if cfg.ckpt_path: 102 | # initialize from fine-tuned checkpoint 103 | PATH = os.path.join(hydra.utils.get_original_cwd(), cfg.ckpt_path) 104 | print("Checkpoint path\n", PATH, "\n") 105 | 106 | model = CMCModule.load_from_checkpoint( 107 | PATH, 108 | model_cfg=cfg.model, 109 | diff_tokenizer=dm.diff_tokenizer, 110 | msg_tokenizer=dm.msg_tokenizer, 111 | generation_kwargs=cfg.generation, # type: ignore[arg-type] 112 | preds_artifact_name=f"{run_name}_preds", 113 | preds_artifact_type="multilang preds", 114 | preds_table_name=preds_table_name, 115 | ) 116 | else: 117 | logging.info("Using zero-shot model") 118 | # use zero-shot pretrained model or even random model 119 | model = CMCModule( 120 | model_cfg=cfg.model, 121 | diff_tokenizer=dm.diff_tokenizer, 122 | msg_tokenizer=dm.msg_tokenizer, 123 | generation_kwargs=cfg.generation, # type: ignore[arg-type] 124 | preds_artifact_name=f"{run_name}_preds", 125 | preds_artifact_type="multilang preds", 126 | preds_table_name=preds_table_name, 127 | ) 128 | 129 | trainer = pl.Trainer(**cfg.trainer, logger=trainer_logger if cfg.logger.use_wandb else True) # type: ignore[arg-type] 130 | 131 | # ----------------------- 132 | # test - 133 | # ----------------------- 134 | trainer.test(datamodule=dm, model=model) 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | show_error_codes = True 4 | exclude = src/metrics/reused_implementations/ 5 | 6 | [mypy-src.metrics.reused_implementations.*] 7 | follow_imports = skip -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "commit-message-generation" 3 | version = "0.0.1" 4 | description = "" 5 | authors = ["Alexandra Eliseeva "] 6 | readme = "README.md" 7 | packages = [] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | torch = "1.12.1" 12 | pytorch-lightning = "1.7.7" 13 | transformers = "4.21.3" 14 | wandb = "0.13.9" 15 | hydra-core = "1.2.0" 16 | pandas = "*" 17 | jsonlines = "*" 18 | torchmetrics = "0.9.3" 19 | datasets = "2.4.0" 20 | nltk = "3.6.4" 21 | rouge_score = "0.0.4" 22 | sacrebleu = "2.0.0" 23 | rapidfuzz = "2.0.11" 24 | marisa-trie = "^0.8.0" 25 | 26 | [tool.poetry.group.dev.dependencies] 27 | black = "^22.6.0" 28 | isort = "^5.10.1" 29 | mypy = "^0.981" 30 | pytest = "^7.1.2" 31 | 32 | 33 | [tool.poetry.group.retrieval.dependencies] 34 | annoy = "^1.17.1" 35 | 36 | [build-system] 37 | requires = ["poetry-core"] 38 | build-backend = "poetry.core.masonry.api" 39 | -------------------------------------------------------------------------------- /retrieve.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import List, Optional 4 | 5 | import hydra 6 | import jsonlines 7 | import wandb 8 | from omegaconf import OmegaConf 9 | from tqdm import tqdm 10 | 11 | from conf import RetrievalConfig 12 | from src.data_utils import CMCDataModule 13 | from src.model import CMCModule 14 | from src.retrieval import DiffSearch, TransformerEmbedder 15 | from src.retrieval.utils import CommitEmbeddingExample, RetrievalPrediction 16 | from src.utils import WandbOrganizer 17 | 18 | 19 | def download_artifact(cfg: RetrievalConfig, run: wandb.wandb_sdk.wandb_run.Run, artifact_name: str) -> str: 20 | """Helper function to download relevant artifact from W&B. 21 | 22 | Args: 23 | cfg: Current configuration, necessary to find relevant artifact. 24 | run: Current W&B run. 25 | 26 | Returns: 27 | A local path to the artifact. 28 | """ 29 | full_artifact_name = f"{cfg.logger.input_artifact.project}/{artifact_name}:{cfg.logger.input_artifact.version}" 30 | artifact = run.use_artifact(full_artifact_name) 31 | if "tags" in artifact.metadata: 32 | run.tags = artifact.metadata["tags"] 33 | 34 | artifact.get_path(cfg.logger.input_artifact.artifact_path).download( 35 | root=hydra.utils.to_absolute_path(f"{cfg.logger.input_artifact.local_path}/{artifact_name}") 36 | ) 37 | 38 | return os.path.join( 39 | hydra.utils.to_absolute_path(f"{cfg.logger.input_artifact.local_path}/{artifact_name}"), 40 | cfg.logger.input_artifact.artifact_path, 41 | ) 42 | 43 | 44 | def export_model_checkpoint(cfg: RetrievalConfig) -> str: 45 | """Helper function to export model weights in a Transformers format from Lightning checkpoint. 46 | 47 | Returns: 48 | A local path to directory with checkpoint in a Transformers format. 49 | """ 50 | logging.info(f"Checkpoint path: {cfg.ckpt_path}") 51 | 52 | module = CMCModule.load_from_checkpoint( 53 | cfg.ckpt_path, 54 | model_cfg=cfg.model, 55 | ) 56 | 57 | transformers_ckpt_path = os.path.join(cfg.ckpt_path.split("/")[-1], "transformers_format") 58 | os.makedirs(transformers_ckpt_path, exist_ok=True) 59 | module.save_pretrained(transformers_ckpt_path) 60 | return transformers_ckpt_path 61 | 62 | 63 | @hydra.main(version_base="1.1", config_path="conf", config_name="retrieval_config") 64 | def main(cfg: RetrievalConfig) -> None: 65 | run_name = WandbOrganizer.get_run_name( 66 | cfg.model, 67 | encoder_input_type=cfg.input.encoder_input_type, 68 | train_with_history=cfg.input.train_with_history, 69 | ) 70 | 71 | # -------------------- 72 | # - init W&B - 73 | # -------------------- 74 | run: Optional[wandb.wandb_sdk.wandb_run.Run] 75 | if cfg.logger.use_wandb: 76 | if cfg.logger.use_api_key: 77 | with open(hydra.utils.to_absolute_path("wandb_api_key.txt"), "r") as f: 78 | os.environ["WANDB_API_KEY"] = f.read().strip() 79 | 80 | run = wandb.init( # type: ignore[assignment] 81 | project=cfg.logger.project, 82 | name=f"{run_name}_retrieval", 83 | config=OmegaConf.to_container(cfg, resolve=True), # type: ignore[arg-type] 84 | job_type="retrieval", 85 | ) 86 | assert run is not None 87 | 88 | if cfg.logger.download_artifact: 89 | logging.info("Downloading artifact from W&B") 90 | cfg.ckpt_path = download_artifact(run=run, cfg=cfg, artifact_name=run_name) 91 | else: 92 | run = None 93 | 94 | # ------------------------------ 95 | # - extract model weights - 96 | # ------------------------------ 97 | assert cfg.ckpt_path 98 | cfg.ckpt_path = hydra.utils.to_absolute_path(cfg.ckpt_path) 99 | transformers_ckpt_path = export_model_checkpoint(cfg) 100 | 101 | # ---------------------------- 102 | # - preprocess data - 103 | # ---------------------------- 104 | dm = CMCDataModule( 105 | dataset_cfg=cfg.dataset, 106 | model_cfg=cfg.model, 107 | input_cfg=cfg.input, 108 | local_rank=int(os.environ.get("LOCAL_RANK", 0)), 109 | world_size=1, 110 | shift_labels=False, 111 | process_retrieved=False, 112 | ) 113 | dm.prepare_data(stage="retrieve") 114 | dm.setup() 115 | 116 | # ----------------------------- 117 | # - build embeddings index - 118 | # ----------------------------- 119 | embedder = TransformerEmbedder( 120 | name_or_path=transformers_ckpt_path, 121 | device=cfg.embedder.device, 122 | precision=cfg.embedder.precision, 123 | normalize_embeddings=cfg.embedder.normalize_embeddings, 124 | ) 125 | 126 | os.makedirs(hydra.utils.to_absolute_path(cfg.search.index_root_dir), exist_ok=True) 127 | search = DiffSearch( 128 | num_trees=cfg.search.num_trees, 129 | embeddings_dim=embedder.embeddings_dim, 130 | load_index=cfg.search.load_index, 131 | index_root_dir=hydra.utils.to_absolute_path(cfg.search.index_root_dir), 132 | load_index_path=hydra.utils.to_absolute_path(cfg.search.load_index_path), 133 | ) 134 | 135 | if not cfg.search.load_index: 136 | for batch in tqdm(dm.retrieval_dataloader(part="train"), desc="Building embeddings index"): 137 | search.add_batch(embedder.transform(batch)) 138 | search.finalize() 139 | 140 | # ------------------------------ 141 | # - retrieve NNs - 142 | # ------------------------------ 143 | 144 | logging.info(f"Start processing train") 145 | 146 | open(f"train_predictions.jsonl", "w").close() 147 | predictions: List[RetrievalPrediction] = [] 148 | for batch in tqdm(dm.retrieval_dataloader(part="train"), desc="Retrieving predictions for train"): 149 | if len(predictions) > 10000: 150 | with jsonlines.open("train_predictions.jsonl", "a") as writer: 151 | writer.write_all( 152 | [{"pos_in_file": pred["pos_in_file"], "distance": pred["distance"]} for pred in predictions] 153 | ) 154 | predictions = [] 155 | 156 | predictions.extend(search.predict_batch_train([idx for idx in batch.pos_in_file])) 157 | 158 | if len(predictions) > 0: 159 | with jsonlines.open("train_predictions.jsonl", "a") as writer: 160 | writer.write_all( 161 | [{"pos_in_file": pred["pos_in_file"], "distance": pred["distance"]} for pred in predictions] 162 | ) 163 | 164 | logging.info(f"Finish processing train") 165 | 166 | for part in ["val", "test"]: 167 | logging.info(f"Start processing {part}") 168 | 169 | open(f"{part}_predictions.jsonl", "w").close() 170 | predictions: List[RetrievalPrediction] = [] # type: ignore[no-redef] 171 | for batch in tqdm(dm.retrieval_dataloader(part=part), desc=f"Retrieving predictions for {part}"): 172 | if len(predictions) > 10000: 173 | with jsonlines.open(f"{part}_predictions.jsonl", "a") as writer: 174 | writer.write_all( 175 | [{"pos_in_file": pred["pos_in_file"], "distance": pred["distance"]} for pred in predictions] 176 | ) 177 | predictions = [] 178 | 179 | predictions.extend(search.predict_batch(embedder.transform(batch))) 180 | 181 | if len(predictions) > 0: 182 | with jsonlines.open(f"{part}_predictions.jsonl", "a") as writer: 183 | writer.write_all( 184 | [{"pos_in_file": pred["pos_in_file"], "distance": pred["distance"]} for pred in predictions] 185 | ) 186 | 187 | logging.info(f"Finish processing {part}") 188 | 189 | # ------------------- 190 | # - log predictions - 191 | # ------------------- 192 | if run and cfg.logger.upload_artifact: 193 | logging.info("Uploading artifact to W&B") 194 | artifact = wandb.Artifact(f"{run_name}_retrieval", type="retrieval") 195 | artifact.add_file("train_predictions.jsonl") 196 | artifact.add_file("val_predictions.jsonl") 197 | artifact.add_file("test_predictions.jsonl") 198 | run.log_artifact(artifact) 199 | 200 | 201 | if __name__ == "__main__": 202 | main() 203 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/src/__init__.py -------------------------------------------------------------------------------- /src/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from src.data_utils.cmc_data_module import CMCDataModule 2 | 3 | __all__ = ["CMCDataModule"] 4 | -------------------------------------------------------------------------------- /src/data_utils/cmc_dataset_w_history.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import Any, Callable, Dict, Generator, Iterator, List, Literal, Optional 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, IterableDataset 7 | 8 | from src.utils import SingleExample 9 | 10 | 11 | class CMCDatasetWithHistory(IterableDataset): 12 | def __init__( 13 | self, 14 | filename: str, 15 | history_path: str, 16 | history_mode: Literal["ram", "io"], 17 | rank: int, 18 | world_size: int, 19 | retrieval_filename: Optional[str] = None, 20 | ): 21 | """ 22 | Defines an iterable-style dataset for a commit message completion task. 23 | This version expects input to be already tokenized and provides history for each commit. 24 | 25 | Args: 26 | filename: File to read diff, author ids and positions in history from. 27 | rank: Rank of the process in DDP (must be 0 if you have a single process). 28 | world_size: Number of processes in DDP (must be 1 if you have a single process). 29 | retrieval_filename: File to read retrieved diffs and messages from (optional). 30 | history_path: Path to JSON with full message history for each author. 31 | history_mode: If set to `io`, the history is expected to be already processed for each example in input file. 32 | If set to `ram`, will work by loading it into memory and providing correct slices for each example based 33 | on its author and its position in history. 34 | """ 35 | 36 | self._filename = filename 37 | self._retrieval_filename = retrieval_filename 38 | 39 | self._history_mode = history_mode 40 | self._history: Optional[Dict[str, List[List[int]]]] = None 41 | if history_mode == "ram": 42 | with open(history_path, "r") as infile: 43 | self._history = json.load(infile) 44 | 45 | self._len = None 46 | 47 | self._gpu_rank: int = rank 48 | self._gpu_world_size: int = world_size 49 | 50 | self._num_workers: int 51 | self._world_size: int 52 | self._process_rank: int 53 | 54 | def __len__(self): 55 | if self._len is None: 56 | logging.info("Calculating length of input file") 57 | with open(self._filename, "r") as f: 58 | self._len = sum(1 for _ in f) 59 | return self._len 60 | 61 | @staticmethod 62 | def _init_worker_fn(worker_id: int) -> None: 63 | """Init each worker for DataLoader in a proper way.""" 64 | worker_info = torch.utils.data.get_worker_info() 65 | assert worker_id == worker_info.id # type: ignore[union-attr] 66 | dataset: CMCDatasetWithHistory = worker_info.dataset # type: ignore[assignment, union-attr] 67 | dataset._process_rank = dataset._gpu_rank * dataset._num_workers + worker_info.id # type: ignore[union-attr] 68 | dataset._world_size = dataset._gpu_world_size * dataset._num_workers 69 | 70 | def _process_single_example(self, example: Dict[str, Any], pos_in_file: int) -> SingleExample: 71 | """Process a single row from input file.""" 72 | diff_input_ids: List[int] = example["diff_input_ids"] 73 | msg_input_ids: List[int] = example["msg_input_ids"] 74 | 75 | history_input_ids = [] 76 | if self._history_mode == "ram": 77 | assert self._history, "Configured to load history into memory, but it wasn't defined." 78 | author: str = str(example["author"]) 79 | pos_in_history: int = example["pos_in_history"] 80 | history_input_ids = self._history[str(author)][:pos_in_history] 81 | elif self._history_mode == "io": 82 | assert ( 83 | "history_input_ids" in example 84 | ), "Configured to read history inputs from input file, but they aren't present." 85 | history_input_ids = example["history_input_ids"] 86 | 87 | return SingleExample( 88 | diff_input_ids=diff_input_ids, 89 | msg_input_ids=msg_input_ids, 90 | history_input_ids=history_input_ids, 91 | pos_in_file=pos_in_file, 92 | ) 93 | 94 | def _process_single_example_retrieval( 95 | self, original_example: Dict[str, Any], retrieved_example: Dict[str, Any], pos_in_file: int 96 | ) -> SingleExample: 97 | """Process a single row from input file + a single row from retrieval file.""" 98 | processed_example = self._process_single_example(example=original_example, pos_in_file=pos_in_file) 99 | 100 | retrieved_diff_input_ids: List[int] = retrieved_example["diff_input_ids"] 101 | retrieved_msg_input_ids: List[int] = retrieved_example["msg_input_ids"] 102 | 103 | processed_example.retrieved_diff_input_ids = retrieved_diff_input_ids 104 | processed_example.retrieved_msg_input_ids = retrieved_msg_input_ids 105 | return processed_example 106 | 107 | def _get_examples_generator(self) -> Generator[SingleExample, None, None]: 108 | """ 109 | For multiprocessing support: 110 | 111 | process_rank = current process id 112 | world_size = # of processes 113 | 114 | This function yields local_rank'th row from every world_size rows. 115 | """ 116 | if self._retrieval_filename is None: 117 | with open(self._filename) as f: 118 | for i, line in enumerate(f): 119 | if i % self._world_size == self._process_rank: 120 | example: Dict[str, Any] = json.loads(line.strip()) 121 | yield self._process_single_example(example, pos_in_file=i) 122 | else: 123 | with open(self._filename) as f: 124 | with open(self._retrieval_filename) as f_retrieval: 125 | for (i, line), (i_retrieval, line_retrieval) in zip(enumerate(f), enumerate(f_retrieval)): 126 | 127 | assert i == i_retrieval 128 | 129 | if i % self._world_size == self._process_rank: 130 | original_example: Dict[str, Any] = json.loads(line.strip()) 131 | retrieved_example: Dict[str, Any] = json.loads(line_retrieval.strip()) 132 | yield self._process_single_example_retrieval( 133 | original_example=original_example, retrieved_example=retrieved_example, pos_in_file=i 134 | ) 135 | 136 | def __iter__(self) -> Iterator[SingleExample]: 137 | assert self._num_workers is not None, f"You must access __iter__ through DataLoader" 138 | return iter(self._get_examples_generator()) 139 | 140 | def get_dataloader(self, batch_size: int, num_workers: int, collate_fn: Callable) -> DataLoader: 141 | """Creates DataLoader in a proper way.""" 142 | assert num_workers >= 0, "num_workers must be at least 0" 143 | if num_workers == 0: 144 | # We need to initialize at least 1 worker in order to call worker_init_fn 145 | num_workers = 1 146 | self._num_workers = num_workers 147 | 148 | return DataLoader( 149 | dataset=self, 150 | batch_size=batch_size, # TODO: https://pytorch.org/docs/stable/data.html#disable-automatic-batching (?) 151 | num_workers=num_workers, 152 | collate_fn=collate_fn, 153 | pin_memory=torch.cuda.is_available(), 154 | worker_init_fn=CMCDatasetWithHistory._init_worker_fn, 155 | ) 156 | 157 | @staticmethod 158 | def load_data( 159 | history_path: str, 160 | data_path: str, 161 | rank: int, 162 | world_size: int, 163 | history_mode: Literal["ram", "io"], 164 | retrieved_data_path: Optional[str] = None, 165 | ): 166 | """ 167 | Load dataset from files on disk. 168 | 169 | Args: 170 | history_path: Path to history file. 171 | data_path: Path to data file. 172 | rank: Rank of the process in DDP (must be 0 if you have a single process). 173 | world_size: Number of processes in DDP (must be 1 if you have a single process). 174 | use_history: True to use history in a dataset, False to stick with current messages only. 175 | retrieved_data_path: Path to retrieved file (optional). 176 | """ 177 | 178 | return CMCDatasetWithHistory( 179 | filename=data_path, 180 | retrieval_filename=retrieved_data_path, 181 | history_path=history_path, 182 | rank=rank, 183 | world_size=world_size, 184 | history_mode=history_mode, 185 | ) 186 | -------------------------------------------------------------------------------- /src/data_utils/data_collators/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_collator_retrieval import DataCollatorRetrieval 2 | from .data_collator_test import DataCollatorTest 3 | from .data_collator_train import DataCollatorTrain 4 | 5 | __all__ = ["DataCollatorTrain", "DataCollatorTest", "DataCollatorRetrieval"] 6 | -------------------------------------------------------------------------------- /src/data_utils/data_collators/data_collator_retrieval.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import torch 5 | 6 | from src.utils import BatchRetrieval, SingleExample 7 | 8 | from .base_collator_utils import BaseCollatorUtils 9 | 10 | 11 | @dataclass 12 | class DataCollatorRetrieval(BaseCollatorUtils): 13 | def __call__(self, examples: List[SingleExample]): 14 | if not self.testing: 15 | (encoder_input_ids, encoder_attention_mask), _, _ = self._process_encoder_input(examples=examples) 16 | return BatchRetrieval( 17 | encoder_input_ids=encoder_input_ids, 18 | encoder_attention_mask=encoder_attention_mask, 19 | pos_in_file=[example.pos_in_file for example in examples], 20 | ) 21 | else: 22 | batch_size = len(examples) 23 | return BatchRetrieval( 24 | encoder_input_ids=torch.randint(1000, (batch_size, self.encoder_context_max_len), dtype=torch.int64), 25 | encoder_attention_mask=torch.ones(batch_size, self.encoder_context_max_len, dtype=torch.int64), 26 | pos_in_file=[i for i in range(batch_size)], 27 | ) 28 | -------------------------------------------------------------------------------- /src/data_utils/data_collators/data_collator_train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | 6 | from src.utils import BatchTrain, SingleExample 7 | 8 | from .base_collator_utils import BaseCollatorUtils 9 | 10 | 11 | @dataclass 12 | class DataCollatorTrain(BaseCollatorUtils): 13 | """This class is used to construct batches out of lists of examples in training/validation setting. 14 | 15 | There is an option to add message history to decoder context 16 | (but if history is used as encoder input, it will be ignored). 17 | 18 | - Format with history: `[BOS] history_1 [SEP] ... history_k [SEP] message [EOS]` 19 | 20 | - Format without history: `[BOS] message [EOS]` 21 | 22 | Attributes: 23 | shift_labels: True to mimic transformers' seq2seq models ids/labels construction logic, False otherwise 24 | (pass False for decoder class). 25 | """ 26 | 27 | shift_labels: bool 28 | decoder_start_token_id: Optional[int] = None 29 | 30 | def _shift_for_encoder_decoder( 31 | self, ids: List[List[int]], labels: List[List[int]] 32 | ) -> Tuple[List[List[int]], List[List[int]]]: 33 | """This method mimics transformers logic of ids and labels for EncoderDecoderModel 34 | (or T5ForConditionalGeneration). 35 | 36 | Starting from transformers v4.12, loss is now calculated in EncoderDecoderModel, not in decoder class. 37 | Also, decoder input ids are created automatically based on labels: labels are shifted and -100 is replaced 38 | with pad token. In our case, history ids are masked -100 in labels, but they are still 39 | meaningful ids. Therefore, we can't use the default approach. 40 | """ 41 | if self.decoder_start_token_id is None: 42 | ids = [[self.msg_bos_token_id]] + ids[:-1] 43 | else: 44 | ids = [[self.decoder_start_token_id]] + ids[:-1] 45 | return ids, labels 46 | 47 | def _process_decoder_input(self, examples: List[SingleExample]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 48 | """ 49 | Prepares decoder input for train/validation: 50 | * aggregates messages from history when configured accordingly 51 | * concatenates history with current message 52 | * constructs labels 53 | * pads, converts to tensors 54 | 55 | Args: 56 | examples: A list of inputs for current batch. 57 | 58 | Returns: 59 | Tuple of three tensors: input ids, attention masks, labels. 60 | """ 61 | message_inputs: List[List[int]] = [example.msg_input_ids for example in examples] 62 | history_inputs: List[List[List[int]]] = [example.history_input_ids for example in examples] 63 | 64 | all_msg_ids: List[torch.Tensor] = [] 65 | all_msg_masks: List[torch.Tensor] = [] 66 | all_msg_labels: List[torch.Tensor] = [] 67 | 68 | for message_ids, history_ids in zip(message_inputs, history_inputs): 69 | message_ids = message_ids[: self.decoder_context_max_len - 2] 70 | 71 | cur_history_ids = [] 72 | cur_history_labels = [] 73 | 74 | if self.encoder_input_type != "history" and self.with_history: 75 | cur_history_ids = self._get_history( 76 | cur_len=len(message_ids) + 2, 77 | history_ids=history_ids, 78 | ) 79 | cur_history_labels = [[-100 for _ in message] for message in cur_history_ids] 80 | 81 | cur_ids = [[self.msg_bos_token_id]] + cur_history_ids + [message_ids] + [[self.msg_eos_token_id]] 82 | cur_labels = [[self.msg_bos_token_id]] + cur_history_labels + [message_ids] + [[self.msg_eos_token_id]] 83 | 84 | if self.shift_labels: 85 | cur_ids, cur_labels = self._shift_for_encoder_decoder(cur_ids, cur_labels) 86 | 87 | cur_ids_tensor = torch.tensor([ex for sublist in cur_ids for ex in sublist], dtype=torch.int64) 88 | cur_labels_tensor = torch.tensor([ex for sublist in cur_labels for ex in sublist], dtype=torch.int64) 89 | cur_mask_tensor = torch.ones_like(cur_ids_tensor) 90 | 91 | all_msg_ids.append(cur_ids_tensor) 92 | all_msg_masks.append(cur_mask_tensor) 93 | all_msg_labels.append(cur_labels_tensor) 94 | 95 | msg_max_len = max(len(tensor) for tensor in all_msg_ids) 96 | all_msg_ids = [ 97 | self._pad_tensor( 98 | tensor, 99 | pad_len=msg_max_len - tensor.numel(), 100 | value=self.msg_pad_token_id, 101 | left=False, 102 | ) 103 | for tensor in all_msg_ids 104 | ] 105 | all_msg_masks = [ 106 | self._pad_tensor( 107 | tensor, 108 | pad_len=msg_max_len - tensor.numel(), 109 | value=0, 110 | left=False, 111 | ) 112 | for tensor in all_msg_masks 113 | ] 114 | all_msg_labels = [ 115 | self._pad_tensor( 116 | tensor, 117 | pad_len=msg_max_len - tensor.numel(), 118 | value=-100, 119 | left=False, 120 | ) 121 | for tensor in all_msg_labels 122 | ] 123 | 124 | return torch.stack(all_msg_ids), torch.stack(all_msg_masks), torch.stack(all_msg_labels) 125 | 126 | def __call__(self, examples: List[SingleExample]) -> BatchTrain: 127 | if not self.testing: 128 | ( 129 | (encoder_input_ids, encoder_attention_mask), 130 | (retrieved_diff_input_ids, retrieved_diff_attention_mask), 131 | (retrieved_msg_input_ids, retrieved_msg_attention_mask), 132 | ) = self._process_encoder_input(examples=examples) 133 | 134 | decoder_input_ids, decoder_attention_mask, labels = self._process_decoder_input(examples=examples) 135 | 136 | return BatchTrain( 137 | encoder_input_ids=encoder_input_ids, 138 | encoder_attention_mask=encoder_attention_mask, 139 | decoder_input_ids=decoder_input_ids, 140 | decoder_attention_mask=decoder_attention_mask, 141 | labels=labels, 142 | retrieved_diff_input_ids=retrieved_diff_input_ids, 143 | retrieved_diff_attention_mask=retrieved_diff_attention_mask, 144 | retrieved_msg_input_ids=retrieved_msg_input_ids, 145 | retrieved_msg_attention_mask=retrieved_msg_attention_mask, 146 | ) 147 | else: 148 | batch_size = len(examples) 149 | return BatchTrain( 150 | encoder_input_ids=torch.randint(1000, (batch_size, self.encoder_context_max_len), dtype=torch.int64), 151 | encoder_attention_mask=torch.ones(batch_size, self.encoder_context_max_len, dtype=torch.int64), 152 | decoder_input_ids=torch.randint(1000, (batch_size, self.decoder_context_max_len), dtype=torch.int64), 153 | decoder_attention_mask=torch.ones(batch_size, self.decoder_context_max_len, dtype=torch.int64), 154 | labels=torch.randint(1000, (batch_size, self.decoder_context_max_len), dtype=torch.int64), 155 | retrieved_diff_input_ids=torch.randint( 156 | 1000, (batch_size, self.encoder_context_max_len), dtype=torch.int64 157 | ), 158 | retrieved_diff_attention_mask=torch.ones(batch_size, self.encoder_context_max_len, dtype=torch.int64), 159 | retrieved_msg_input_ids=torch.randint( 160 | 1000, (batch_size, self.encoder_context_max_len), dtype=torch.int64 161 | ), 162 | retrieved_msg_attention_mask=torch.ones(batch_size, self.encoder_context_max_len, dtype=torch.int64), 163 | ) 164 | -------------------------------------------------------------------------------- /src/data_utils/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_preprocessor import BasePreprocessor 2 | from .codereviewer_preprocessor import CodeReviewerPreprocessor 3 | from .default_preprocessor import DefaultPreprocessor 4 | from .race_preprocessor import RACEPreprocessor 5 | 6 | __all__ = ["BasePreprocessor", "DefaultPreprocessor", "RACEPreprocessor", "CodeReviewerPreprocessor"] 7 | -------------------------------------------------------------------------------- /src/data_utils/preprocessors/codereviewer_preprocessor.py: -------------------------------------------------------------------------------- 1 | from .default_preprocessor import DefaultPreprocessor 2 | 3 | 4 | class CodeReviewerPreprocessor(DefaultPreprocessor): 5 | def _preprocess_diff(self, diff: str, line_sep: str, **kwargs) -> str: 6 | """Helper method: add tags from CodeReviewer to single file diff.""" 7 | processed_lines = [] 8 | for line in diff.split(line_sep): 9 | line = line.strip() 10 | if not line: 11 | continue 12 | if line.startswith("+"): 13 | processed_lines.append("" + line[1:]) 14 | elif line.startswith("-"): 15 | processed_lines.append("" + line[1:]) 16 | else: 17 | processed_lines.append("" + line) 18 | return line_sep.join(processed_lines) + line_sep 19 | -------------------------------------------------------------------------------- /src/data_utils/preprocessors/default_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from .base_preprocessor import BasePreprocessor 4 | 5 | 6 | class DefaultPreprocessor(BasePreprocessor): 7 | def _preprocess_diff(self, diff: str, line_sep: str, **kwargs) -> str: 8 | """Return given diff without any changes.""" 9 | return diff 10 | 11 | def _preprocess_mods(self, mods: List[Dict[str, str]], line_sep: str = "[NL]", **kwargs) -> str: 12 | """ 13 | Transforms a list of all files modifications made in a commit into a single string representation. 14 | 15 | Specifically, adds a header to each file diff (https://git-scm.com/docs/git-diff#_generating_patch_text_with_p) 16 | and concatenates the results. 17 | 18 | Args: 19 | mods: A list of files modifications made in a commit. 20 | line_sep: Line separator in diffs. 21 | 22 | Returns: 23 | A single string representation of all files modifications made in a commit. 24 | """ 25 | diff = "" 26 | 27 | for mod in mods: 28 | if mod["change_type"] == "UNKNOWN": 29 | continue 30 | elif mod["change_type"] == "ADD": 31 | file_diff = f"new file {mod['new_path']}" 32 | elif mod["change_type"] == "DELETE": 33 | file_diff = f"deleted file {mod['old_path']}" 34 | elif mod["change_type"] == "RENAME": 35 | file_diff = f"rename from {mod['old_path']}{line_sep}rename to {mod['new_path']}" 36 | elif mod["change_type"] == "COPY": 37 | file_diff = f"copy from {mod['old_path']}{line_sep}copy to {mod['new_path']}" 38 | else: 39 | file_diff = f"{mod['new_path']}" 40 | diff += file_diff + line_sep + self._preprocess_diff(mod["diff"], line_sep=line_sep) 41 | 42 | return diff 43 | 44 | def _preprocess_message(self, message: str, **kwargs) -> str: 45 | """Returns given message without any changes.""" 46 | return message 47 | -------------------------------------------------------------------------------- /src/data_utils/preprocessors/race_preprocessor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from .base_preprocessor import BasePreprocessor 4 | from .reused_implementations import compute_code_diffs 5 | 6 | 7 | class RACEPreprocessor(BasePreprocessor): 8 | def _preprocess_diff(self, header: List[str], diff: str, line_sep: str) -> str: 9 | """Helper method: transforms single file diff to representation from RACE paper.""" 10 | old_lines = [header[0].strip()] 11 | new_lines = [header[1].strip() if len(header) == 2 else header[0]] 12 | 13 | for line in diff.split(line_sep): 14 | line = line.strip() 15 | if not line: 16 | continue 17 | if line.startswith("+"): 18 | new_lines.extend(line.split(" ")) 19 | elif line.startswith("-"): 20 | old_lines.extend(line.split(" ")) 21 | else: 22 | new_lines.extend(line.split(" ")) 23 | old_lines.extend(line.split(" ")) 24 | resulting_tokens: List[str] = compute_code_diffs(old_tokens=old_lines, new_tokens=new_lines) 25 | return " ".join(resulting_tokens) 26 | 27 | def _preprocess_mods(self, mods: List[Dict[str, str]], line_sep: str = "[NL]", *args, **kwargs) -> str: 28 | """Transforms a list of file modification made in a commit to a single diff representation from RACE paper.""" 29 | diff = [] 30 | 31 | for i, mod in enumerate(mods): 32 | if mod["change_type"] == "UNKNOWN": 33 | continue 34 | elif mod["change_type"] == "ADD": 35 | header = [f"new file {mod['new_path']}"] 36 | elif mod["change_type"] == "DELETE": 37 | header = [f"deleted file {mod['old_path']}"] 38 | elif mod["change_type"] == "RENAME": 39 | header = [f"rename from {mod['old_path']}", f"rename to {mod['new_path']}"] 40 | elif mod["change_type"] == "COPY": 41 | header = [f"copy from {mod['old_path']}", f"copy to {mod['new_path']}"] 42 | else: 43 | header = [f"{mod['new_path']}"] 44 | diff.append(self._preprocess_diff(header, mod["diff"], line_sep=line_sep)) 45 | return line_sep.join(diff) 46 | 47 | def _preprocess_message(self, message: str, **kwargs) -> str: 48 | """Returns given message without any changes.""" 49 | return message 50 | -------------------------------------------------------------------------------- /src/data_utils/preprocessors/reused_implementations/__init__.py: -------------------------------------------------------------------------------- 1 | from .race import compute_code_diffs 2 | 3 | __all__ = ["compute_code_diffs"] 4 | -------------------------------------------------------------------------------- /src/data_utils/preprocessors/reused_implementations/race.py: -------------------------------------------------------------------------------- 1 | # This code is taken from replication package of "RACE: Retrieval-Augmented Commit Message Generation", EMNLP, 2022. 2 | # https://github.com/DeepSoftwareAnalytics/RACE 3 | 4 | 5 | import difflib 6 | 7 | REPLACE_OLD = "" 8 | REPLACE_NEW = "" 9 | REPLACE_END = "" 10 | 11 | INSERT = "" 12 | INSERT_OLD = "" 13 | INSERT_NEW = "" 14 | INSERT_END = "" 15 | 16 | DELETE = "" 17 | DELETE_END = "" 18 | 19 | KEEP = "" 20 | KEEP_END = "" 21 | 22 | 23 | def compute_code_diffs(old_tokens, new_tokens): 24 | spans = [] 25 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher( 26 | None, old_tokens, new_tokens 27 | ).get_opcodes(): 28 | if edit_type == "equal": 29 | spans.extend([KEEP] + old_tokens[o_start:o_end] + [KEEP_END]) 30 | elif edit_type == "replace": 31 | spans.extend( 32 | [REPLACE_OLD] + old_tokens[o_start:o_end] + [REPLACE_NEW] + new_tokens[n_start:n_end] + [REPLACE_END] 33 | ) 34 | elif edit_type == "insert": 35 | spans.extend([INSERT] + new_tokens[n_start:n_end] + [INSERT_END]) 36 | else: 37 | spans.extend([DELETE] + old_tokens[o_start:o_end] + [DELETE_END]) 38 | 39 | return spans 40 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .accuracy import Accuracy 2 | from .bleu_norm import BLEUNorm 3 | from .edit_similarity import EditSimilarity 4 | from .exact_match import ExactMatch 5 | from .log_mnext import LogMNEXT 6 | from .mrr import MRR 7 | 8 | __all__ = ["EditSimilarity", "ExactMatch", "BLEUNorm", "LogMNEXT", "Accuracy", "MRR"] 9 | -------------------------------------------------------------------------------- /src/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class Accuracy(Metric): 6 | """Accuracy@k metric. Returns a ratio of examples where reference is present among top k predictions.""" 7 | 8 | # https://devblog.pytorchlightning.ai/torchmetrics-v0-9-faster-forward-d595bb321e6d 9 | full_state_update: bool = False 10 | 11 | def __init__( 12 | self, top_k: int = 5, ignore_index: int = -100, shift: bool = True, dist_sync_on_step: bool = False 13 | ) -> None: 14 | super().__init__(dist_sync_on_step=dist_sync_on_step) 15 | 16 | self.top_k = top_k 17 | self.ignore_index = ignore_index 18 | self.shift = shift 19 | 20 | self.accuracy: torch.Tensor 21 | self.total: torch.Tensor 22 | 23 | self.add_state("accuracy", default=torch.tensor(0).float(), dist_reduce_fx="sum") 24 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 25 | 26 | def update(self, predictions: torch.Tensor, references: torch.Tensor) -> None: # type: ignore[override] 27 | assert predictions.ndimension() == references.ndimension() + 1 28 | assert predictions.size()[:-1] == references.size() 29 | assert predictions.size()[-1] >= self.top_k 30 | 31 | # for support of batches of size 1 32 | if len(references.shape) == 1: 33 | references = references.unsqueeze(0) 34 | predictions = predictions.unsqueeze(0) 35 | 36 | # shift scores and labels 37 | if self.shift: 38 | predictions = predictions[..., :-1, :] 39 | references = references[..., 1:] 40 | 41 | # labels = [batch_size x seq_len - 1] 42 | # scores = [batch_size x seq_len - 1 x vocab_size] 43 | # top_k_predictions = [batch_size x seq_len -1 x top_k] 44 | _, top_k_predictions = torch.topk(predictions, self.top_k) 45 | expanded_labels = references.unsqueeze(-1).expand_as(top_k_predictions) 46 | true_pos = torch.logical_and(expanded_labels == top_k_predictions, expanded_labels != self.ignore_index) 47 | 48 | acc_top_k_list = ( 49 | true_pos.sum(dim=-1).float() / (references != self.ignore_index).sum(dim=1).unsqueeze(1).float() 50 | ).sum(dim=1) 51 | 52 | try: 53 | self.accuracy += acc_top_k_list.sum() 54 | self.total += references.shape[0] 55 | except RuntimeError: 56 | self.accuracy = self.accuracy.to(acc_top_k_list.device) 57 | self.total = self.total.to(self.accuracy.device) 58 | 59 | def compute(self) -> torch.Tensor: 60 | return self.accuracy.float() / self.total 61 | -------------------------------------------------------------------------------- /src/metrics/bleu_norm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import datasets 4 | 5 | from src.metrics.reused_implementations import bleuFromMaps, splitPuncts 6 | 7 | _CITATION = """\ 8 | @inproceedings{tao2021evaluation, 9 | title={On the Evaluation of Commit Message Generation Models: An Experimental Study}, 10 | author={Tao, Wei and Wang, Yanlin and Shi, Ensheng and Du, Lun and Han, Shi and Zhang, Hongyu and Zhang, Dongmei and Zhang, Wenqiang}, 11 | booktitle={2021 IEEE International Conference on Software Maintenance and Evolution (ICSME)}, 12 | pages={126--136}, 13 | year={2021}, 14 | organization={IEEE} 15 | } 16 | @inproceedings{Papineni02bleu:a, 17 | author = {Kishore Papineni and Salim Roukos and Todd Ward and Wei-jing Zhu}, 18 | title = {BLEU: a Method for Automatic Evaluation of Machine Translation}, 19 | booktitle = {}, 20 | year = {2002}, 21 | pages = {311--318} 22 | } 23 | @inproceedings{lin-och-2004-orange, 24 | title = "{ORANGE}: a Method for Evaluating Automatic Evaluation Metrics for Machine Translation", 25 | author = "Lin, Chin-Yew and 26 | Och, Franz Josef", 27 | booktitle = "{COLING} 2004: Proceedings of the 20th International Conference on Computational Linguistics", 28 | month = "aug 23{--}aug 27", 29 | year = "2004", 30 | address = "Geneva, Switzerland", 31 | publisher = "COLING", 32 | url = "https://www.aclweb.org/anthology/C04-1072", 33 | pages = "501--507", 34 | } 35 | """ 36 | 37 | _DESCRIPTION = """\ 38 | B-Norm is a variation of BLEU. It uses smoothing by Lin and Och, 2004 and does some additional preprocessing steps. 39 | It was recommended for evaluation of commit message generation approaches in the 40 | "On the Evaluation of Commit Message Generation Models: An Experimental Study" paper accepted to ICSME 2021. 41 | This class uses implementation provided in the replication package. 42 | """ 43 | 44 | 45 | class BLEUNorm(datasets.Metric): 46 | def _info(self): 47 | return datasets.MetricInfo( 48 | description=_DESCRIPTION, 49 | citation=_CITATION, 50 | features=datasets.Features( 51 | { 52 | "predictions": datasets.Value("string", id="sequence"), 53 | "references": datasets.Value("string", id="sequence"), 54 | } 55 | ), 56 | codebase_urls=["https://github.com/DeepSoftwareAnalytics/CommitMsgEmpirical/blob/main/metrics/B-Norm.py"], 57 | ) 58 | 59 | def _compute(self, predictions: List[str], references: List[str]) -> Dict[str, float]: # type: ignore[override] 60 | prediction_map = {i: [splitPuncts(pred.strip().lower())] for i, pred in enumerate(predictions)} 61 | gold_map = {i: [splitPuncts(ref.strip().lower())] for i, ref in enumerate(references)} 62 | return {"b_norm": bleuFromMaps(gold_map, prediction_map)[0] / 100.0} 63 | -------------------------------------------------------------------------------- /src/metrics/edit_similarity.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from rapidfuzz.distance.Levenshtein import normalized_similarity 5 | from torchmetrics import Metric 6 | 7 | 8 | class EditSimilarity(Metric): 9 | """Edit Similarity metric. It is a string similarity metric based on Levenshtein distance: 10 | 1 - edit_distance/max_len 11 | 12 | Final metric value is calculated as average sentence-level edit similarity. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | insertion_cost: int = 1, 18 | deletion_cost: int = 1, 19 | substitution_cost: int = 1, 20 | dist_sync_on_step: bool = False, 21 | ) -> None: 22 | super().__init__(dist_sync_on_step=dist_sync_on_step) 23 | 24 | self.weights = (insertion_cost, deletion_cost, substitution_cost) 25 | 26 | self.scores: torch.Tensor 27 | self.total: torch.Tensor 28 | self.add_state("scores", default=torch.tensor(0.0), dist_reduce_fx="sum") 29 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 30 | 31 | def update(self, predictions: List[str], references: List[str]) -> None: # type: ignore[override] 32 | for pred, ref in zip(predictions, references): 33 | e_sim = normalized_similarity( 34 | pred, 35 | ref, 36 | weights=self.weights, 37 | ) 38 | 39 | if not ref: 40 | self.scores = torch.tensor(float("nan")) 41 | else: 42 | self.scores += torch.tensor(e_sim) 43 | self.total += 1 44 | 45 | def compute(self) -> torch.Tensor: 46 | return self.scores / self.total 47 | -------------------------------------------------------------------------------- /src/metrics/exact_match.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class ExactMatch(Metric): 8 | """ 9 | ExactMatch@N metric. Given N, calculates the ratio of examples 10 | where first N generated words exactly match first N words from corresponding target. 11 | 12 | Words are obtained by splitting by whitespaces. Cases where target contains less than N words are skipped. 13 | 14 | Args: 15 | n: Number of words to compare. Optional, full sequences will be considered if n is not given. 16 | """ 17 | 18 | def __init__(self, n: Optional[int] = None, dist_sync_on_step: bool = False) -> None: 19 | super().__init__(dist_sync_on_step=dist_sync_on_step) 20 | 21 | self.n = n 22 | 23 | self.correct: torch.Tensor 24 | self.total: torch.Tensor 25 | self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") 26 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 27 | 28 | def update(self, predictions: List[str], references: List[str]) -> None: # type: ignore[override] 29 | for pred, ref in zip(predictions, references): 30 | pred_words, ref_words = pred.strip().split(), ref.strip().split() 31 | 32 | if self.n: 33 | # compare first n words 34 | if len(ref_words) >= self.n: 35 | if len(pred_words) >= self.n and all( 36 | pred_word == target_word 37 | for pred_word, target_word in zip(pred_words[: self.n], ref_words[: self.n]) 38 | ): 39 | self.correct += 1 40 | 41 | self.total += 1 42 | else: 43 | # compare full sequences 44 | if len(pred_words) == len(ref_words) and all( 45 | pred_word == target_word for pred_word, target_word in zip(pred_words, ref_words) 46 | ): 47 | self.correct += 1 48 | self.total += 1 49 | 50 | def compute(self) -> torch.Tensor: 51 | if not self.total.item(): 52 | return self.total 53 | 54 | return self.correct.float() / self.total 55 | -------------------------------------------------------------------------------- /src/metrics/log_mnext.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | from src.metrics.reused_implementations import log_mnext_score 7 | 8 | 9 | class LogMNEXT(Metric): 10 | """Log-MNEXT metric. It is a string similarity metric based on METEOR-NEXT. 11 | 12 | It was proposed in the "Evaluating Commit Message Generation: To BLEU Or Not To BLEU?" paper 13 | accepted to ICSE NIER 2022. This class uses original implementation from replication package. 14 | 15 | Final metric value is calculated as average sentence-level Log-MNEXT (replication package includes only 16 | sentence-level implementation). 17 | """ 18 | 19 | def __init__(self, dist_sync_on_step: Optional[bool] = False): 20 | super().__init__(dist_sync_on_step=dist_sync_on_step) 21 | self.scores: torch.Tensor 22 | self.total: torch.Tensor 23 | self.add_state("scores", default=torch.tensor(0.0), dist_reduce_fx="sum") 24 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 25 | 26 | def update(self, predictions: List[str], references: List[str]) -> None: # type: ignore[override] 27 | for pred, ref in zip(predictions, references): 28 | self.scores += torch.tensor(log_mnext_score([ref], pred)) 29 | self.total += 1 30 | 31 | def compute(self) -> torch.Tensor: 32 | return self.scores / self.total 33 | -------------------------------------------------------------------------------- /src/metrics/mrr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics import Metric 3 | 4 | 5 | class MRR(Metric): 6 | """Mean Reciprocal Rank (MRR)@k metric. In contrast with accuracy, it takes a position of correct prediction among 7 | top k into account.""" 8 | 9 | # https://devblog.pytorchlightning.ai/torchmetrics-v0-9-faster-forward-d595bb321e6d 10 | full_state_update: bool = False 11 | 12 | def __init__( 13 | self, top_k: int = 5, ignore_index: int = -100, shift: bool = True, dist_sync_on_step: bool = False 14 | ) -> None: 15 | super().__init__(dist_sync_on_step=dist_sync_on_step) 16 | 17 | self.top_k = top_k 18 | self.ignore_index = ignore_index 19 | self.shift = shift 20 | 21 | self.mrr: torch.Tensor 22 | self.total: torch.Tensor 23 | self.add_state("mrr", default=torch.tensor(0).float(), dist_reduce_fx="sum") 24 | self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") 25 | 26 | def update(self, predictions: torch.Tensor, references: torch.Tensor) -> None: # type: ignore[override] 27 | 28 | assert predictions.ndimension() == references.ndimension() + 1 29 | assert predictions.size()[:-1] == references.size() 30 | assert predictions.size()[-1] >= self.top_k 31 | 32 | # for support of batches of size 1 33 | if len(references.shape) == 1: 34 | references = references.unsqueeze(0) 35 | predictions = predictions.unsqueeze(0) 36 | 37 | # shift scores and labels 38 | if self.shift: 39 | predictions = predictions[..., :-1, :] 40 | references = references[..., 1:] 41 | 42 | # labels = [batch_size x seq_len - 1] 43 | # scores = [batch_size x seq_len - 1 x vocab_size] 44 | # top_k_predictions = [batch_size x seq_len - 1 x top_k] 45 | _, top_k_predictions = torch.topk(predictions, self.top_k) 46 | expanded_labels = references.unsqueeze(-1).expand_as(top_k_predictions) 47 | true_pos = torch.logical_and(expanded_labels == top_k_predictions, expanded_labels != self.ignore_index) 48 | # mrr depends on position of correct label in top k generated outputs 49 | true_pos_for_mrr = true_pos / torch.arange(1, true_pos.size(-1) + 1, dtype=torch.float, device=true_pos.device) 50 | mrr_top_k_list = ( 51 | true_pos_for_mrr.max(dim=-1)[0].sum(dim=-1) / (references != self.ignore_index).sum(dim=1).float() 52 | ) 53 | 54 | try: 55 | self.mrr += mrr_top_k_list.sum() 56 | self.total += references.shape[0] 57 | except RuntimeError: 58 | self.mrr = self.mrr.to(mrr_top_k_list.device) 59 | self.total = self.total.to(self.mrr.device) 60 | 61 | def compute(self) -> torch.Tensor: # type: ignore[override] 62 | return self.mrr.float() / self.total 63 | -------------------------------------------------------------------------------- /src/metrics/reused_implementations/__init__.py: -------------------------------------------------------------------------------- 1 | from .b_norm import bleuFromMaps, splitPuncts 2 | from .log_mnext import log_mnext_score 3 | 4 | __all__ = ["splitPuncts", "bleuFromMaps", "log_mnext_score"] 5 | -------------------------------------------------------------------------------- /src/metrics/reused_implementations/b_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is copied from https://github.com/DeepSoftwareAnalytics/CommitMsgEmpirical, 3 | the replication package for "On the Evaluation of Commit Message Generation Models: An Experimental Study" 4 | accepted to ICSME 2021. 5 | """ 6 | 7 | #!/usr/bin/python 8 | 9 | """ 10 | This script was adapted from the original version by hieuhoang1972 which is part of MOSES. 11 | """ 12 | 13 | # $Id: bleu.py 1307 2007-03-14 22:22:36Z hieuhoang1972 $ 14 | 15 | """Provides: 16 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 17 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 18 | score_cooked(alltest, n=4): Score a list of cooked test sentences. 19 | score_set(s, testid, refids, n=4): Interface with dataset.py; calculate BLEU score of testid against refids. 20 | The reason for breaking the BLEU computation into three phases cook_refs(), cook_test(), and score_cooked() is to allow the caller to calculate BLEU scores for multiple test sets as efficiently as possible. 21 | """ 22 | 23 | import math 24 | import os 25 | import re 26 | import subprocess 27 | import sys 28 | import xml.sax.saxutils 29 | 30 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 31 | nonorm = 0 32 | 33 | preserve_case = False 34 | eff_ref_len = "shortest" 35 | 36 | normalize1 = [ 37 | ("", ""), # strip "skipped" tags 38 | (r"-\n", ""), # strip end-of-line hyphenation and join lines 39 | (r"\n", " "), # join lines 40 | # (r'(\d)\s+(?=\d)', r'\1'), # join digits 41 | ] 42 | normalize1 = [(re.compile(pattern), replace) for (pattern, replace) in normalize1] 43 | 44 | normalize2 = [ 45 | (r"([\{-\~\[-\` -\&\(-\+\:-\@\/])", r" \1 "), # tokenize punctuation. apostrophe is missing 46 | (r"([^0-9])([\.,])", r"\1 \2 "), # tokenize period and comma unless preceded by a digit 47 | (r"([\.,])([^0-9])", r" \1 \2"), # tokenize period and comma unless followed by a digit 48 | (r"([0-9])(-)", r"\1 \2 "), # tokenize dash when preceded by a digit 49 | ] 50 | normalize2 = [(re.compile(pattern), replace) for (pattern, replace) in normalize2] 51 | 52 | 53 | def normalize(s): 54 | """Normalize and tokenize text. This is lifted from NIST mteval-v11a.pl.""" 55 | # Added to bypass NIST-style pre-processing of hyp and ref files -- wade 56 | if nonorm: 57 | return s.split() 58 | if type(s) is not str: 59 | s = " ".join(s) 60 | # language-independent part: 61 | for (pattern, replace) in normalize1: 62 | s = re.sub(pattern, replace, s) 63 | s = xml.sax.saxutils.unescape(s, {""": '"'}) 64 | # language-dependent part (assuming Western languages): 65 | s = " %s " % s 66 | if not preserve_case: 67 | s = s.lower() # this might not be identical to the original 68 | for (pattern, replace) in normalize2: 69 | s = re.sub(pattern, replace, s) 70 | return s.split() 71 | 72 | 73 | def count_ngrams(words, n=4): 74 | counts = {} 75 | for k in range(1, n + 1): 76 | for i in range(len(words) - k + 1): 77 | ngram = tuple(words[i : i + k]) 78 | counts[ngram] = counts.get(ngram, 0) + 1 79 | return counts 80 | 81 | 82 | def cook_refs(refs, n=4): 83 | """Takes a list of reference sentences for a single segment 84 | and returns an object that encapsulates everything that BLEU 85 | needs to know about them.""" 86 | 87 | refs = [normalize(ref) for ref in refs] 88 | maxcounts = {} 89 | for ref in refs: 90 | counts = count_ngrams(ref, n) 91 | for (ngram, count) in counts.items(): 92 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 93 | return ([len(ref) for ref in refs], maxcounts) 94 | 95 | 96 | def cook_test(test, item, n=4): 97 | """Takes a test sentence and returns an object that 98 | encapsulates everything that BLEU needs to know about it.""" 99 | (reflens, refmaxcounts) = item 100 | test = normalize(test) 101 | result = {} 102 | result["testlen"] = len(test) 103 | 104 | # Calculate effective reference sentence length. 105 | 106 | if eff_ref_len == "shortest": 107 | result["reflen"] = min(reflens) 108 | elif eff_ref_len == "average": 109 | result["reflen"] = float(sum(reflens)) / len(reflens) 110 | elif eff_ref_len == "closest": 111 | min_diff = None 112 | for reflen in reflens: 113 | if min_diff is None or abs(reflen - len(test)) < min_diff: 114 | min_diff = abs(reflen - len(test)) 115 | result["reflen"] = reflen 116 | 117 | result["guess"] = [max(len(test) - k + 1, 0) for k in range(1, n + 1)] 118 | 119 | result["correct"] = [0] * n 120 | counts = count_ngrams(test, n) 121 | for (ngram, count) in counts.items(): 122 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 123 | 124 | return result 125 | 126 | 127 | def score_cooked(allcomps, n=4, ground=0, smooth=1): 128 | totalcomps = {"testlen": 0, "reflen": 0, "guess": [0] * n, "correct": [0] * n} 129 | for comps in allcomps: 130 | for key in ["testlen", "reflen"]: 131 | totalcomps[key] += comps[key] 132 | for key in ["guess", "correct"]: 133 | for k in range(n): 134 | totalcomps[key][k] += comps[key][k] 135 | logbleu = 0.0 136 | all_bleus = [] 137 | for k in range(n): 138 | correct = totalcomps["correct"][k] 139 | guess = totalcomps["guess"][k] 140 | addsmooth = 0 141 | if smooth == 1 and k > 0: 142 | addsmooth = 1 143 | logbleu += math.log(correct + addsmooth + sys.float_info.min) - math.log(guess + addsmooth + sys.float_info.min) 144 | if guess == 0: 145 | all_bleus.append(-10000000) 146 | else: 147 | all_bleus.append(math.log(correct + sys.float_info.min) - math.log(guess)) 148 | 149 | logbleu /= float(n) 150 | all_bleus.insert(0, logbleu) 151 | 152 | brevPenalty = min(0, 1 - float(totalcomps["reflen"] + 1) / (totalcomps["testlen"] + 1)) 153 | for i in range(len(all_bleus)): 154 | if i == 0: 155 | all_bleus[i] += brevPenalty 156 | all_bleus[i] = math.exp(all_bleus[i]) 157 | return all_bleus 158 | 159 | 160 | def bleu(refs, candidate, ground=0, smooth=1): 161 | refs = cook_refs(refs) 162 | test = cook_test(candidate, refs) 163 | return score_cooked([test], ground=ground, smooth=smooth) 164 | 165 | 166 | def splitPuncts(line): 167 | return " ".join(re.findall(r"[\w]+|[^\s\w]", line)) 168 | 169 | 170 | def computeMaps(predictions, goldfile): 171 | predictionMap = {} 172 | goldMap = {} 173 | gf = open(goldfile, "r") 174 | 175 | for row in predictions: 176 | cols = row.strip().split("\t") 177 | if len(cols) == 1: 178 | (rid, pred) = (cols[0], "") 179 | else: 180 | (rid, pred) = (cols[0], cols[1]) 181 | predictionMap[rid] = [splitPuncts(pred.strip().lower())] 182 | 183 | for row in gf: 184 | (rid, pred) = row.split("\t") 185 | if rid in predictionMap: # Only insert if the id exists for the method 186 | if rid not in goldMap: 187 | goldMap[rid] = [] 188 | goldMap[rid].append(splitPuncts(pred.strip().lower())) 189 | 190 | return (goldMap, predictionMap) 191 | 192 | 193 | # m1 is the reference map 194 | # m2 is the prediction map 195 | def bleuFromMaps(m1, m2): 196 | score = [0] * 5 197 | num = 0.0 198 | 199 | for key in m1: 200 | if key in m2: 201 | bl = bleu(m1[key], m2[key][0]) 202 | score = [score[i] + bl[i] for i in range(0, len(bl))] 203 | num += 1 204 | return [s * 100.0 / num for s in score] 205 | 206 | 207 | if __name__ == "__main__": 208 | ref_sentence_lst = open(sys.argv[1]).read().split("\n") 209 | with open("tmp_ref.txt", "w") as f: 210 | for idx, ref_sentence in enumerate(ref_sentence_lst): 211 | f.write("{}\t{}\n".format(idx, ref_sentence)) 212 | 213 | reference_file = "tmp_ref.txt" 214 | predictions = [] 215 | for idx, row in enumerate(sys.stdin): 216 | predictions.append("{}\t{}".format(idx, row)) 217 | if len(predictions) == len(ref_sentence_lst) - 1: 218 | predictions.append("{}\t{}".format(len(predictions), "")) 219 | (goldMap, predictionMap) = computeMaps(predictions, reference_file) 220 | print(bleuFromMaps(goldMap, predictionMap)[0]) 221 | os.remove("tmp_ref.txt") 222 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.cmc_module import CMCModule 2 | 3 | __all__ = ["CMCModule"] 4 | -------------------------------------------------------------------------------- /src/model/configurations/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from .decoder_wrapper import DecoderWrapper 3 | from .encoder_decoder_wrapper import EncoderDecoderWrapper 4 | from .race_wrapper import RACEWrapper 5 | from .seq2seq_wrapper import Seq2SeqWrapper 6 | 7 | __all__ = ["BaseModel", "DecoderWrapper", "EncoderDecoderWrapper", "Seq2SeqWrapper", "RACEWrapper"] 8 | -------------------------------------------------------------------------------- /src/model/configurations/base_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from torch import nn 4 | 5 | from src.utils import Batch, BatchTest 6 | 7 | 8 | class BaseModel(nn.Module): 9 | def forward(self, batch: Batch) -> Any: 10 | raise NotImplementedError() 11 | 12 | def generate(self, batch: BatchTest, **kwargs) -> Any: 13 | raise NotImplementedError() 14 | 15 | def num_parameters(self, exclude_embeddings: bool): 16 | raise NotImplementedError() 17 | 18 | def save_pretrained(self, path: str) -> None: 19 | raise NotImplementedError() 20 | -------------------------------------------------------------------------------- /src/model/configurations/decoder_wrapper.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast 2 | 3 | from src.model.configurations.base_model import BaseModel 4 | from src.utils import Batch, BatchTest 5 | 6 | 7 | class DecoderWrapper(BaseModel): 8 | """This class serves as a GPT-2 wrapper for commit message completion task. 9 | 10 | Args: 11 | tokenizer: tokenizer for target sequences (messages) 12 | decoder_name_or_path: name or path for pretrained GPT-2 checkpoint 13 | """ 14 | 15 | def __init__( 16 | self, 17 | tokenizer: PreTrainedTokenizerFast, 18 | decoder_name_or_path: str, 19 | **kwargs, 20 | ): 21 | super().__init__() 22 | self._tokenizer = tokenizer 23 | self.model = AutoModelForCausalLM.from_pretrained(decoder_name_or_path) 24 | self.model.resize_token_embeddings(len(self._tokenizer)) # type: ignore[arg-type] 25 | 26 | def forward(self, batch: Batch): 27 | return self.model( 28 | input_ids=batch.decoder_input_ids, attention_mask=batch.decoder_attention_mask, labels=batch.labels 29 | ) 30 | 31 | def generate(self, batch: BatchTest, **generation_kwargs): 32 | return self.model.generate( 33 | input_ids=batch.decoder_input_ids, 34 | attention_mask=batch.decoder_attention_mask, 35 | **generation_kwargs, 36 | ) 37 | 38 | def num_parameters(self, exclude_embeddings: bool): 39 | return self.model.num_parameters(exclude_embeddings=exclude_embeddings) 40 | 41 | def save_pretrained(self, path: str) -> None: 42 | self.model.save_pretrained(path) 43 | -------------------------------------------------------------------------------- /src/model/configurations/race_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from src.model.configurations.base_model import BaseModel 4 | from src.utils import Batch, BatchTest 5 | 6 | from .utils.race import RACE 7 | 8 | 9 | class RACEWrapper(BaseModel): 10 | """This class serves as a wrapper of RACE model for commit message completion task. 11 | 12 | Args: 13 | name_or_path: Name on HuggingFace hub or path to pretrained checkpoint. 14 | tokenizer: Tokenizer for the checkpoint (it's initialized earlier to add special tokens when necessary). 15 | """ 16 | 17 | def __init__(self, tokenizer, name_or_path, **kwargs): 18 | super().__init__() 19 | self._tokenizer = tokenizer 20 | self.model = RACE.from_pretrained(name_or_path) 21 | self.model.resize_token_embeddings(len(self._tokenizer)) 22 | 23 | def forward(self, batch: Batch) -> Any: 24 | return self.model( 25 | input_ids=batch.encoder_input_ids, 26 | attention_mask=batch.encoder_attention_mask, 27 | decoder_input_ids=batch.decoder_input_ids, 28 | decoder_attention_mask=batch.decoder_attention_mask, 29 | retrieved_diff_input_ids=batch.retrieved_diff_input_ids, 30 | retrieved_diff_attention_mask=batch.retrieved_diff_attention_mask, 31 | retrieved_msg_input_ids=batch.retrieved_msg_input_ids, 32 | retrieved_msg_attention_mask=batch.retrieved_msg_attention_mask, 33 | labels=batch.labels, 34 | ) 35 | 36 | def generate(self, batch: BatchTest, **generation_kwargs) -> Any: 37 | return self.model.generate( 38 | input_ids=batch.encoder_input_ids, 39 | attention_mask=batch.encoder_attention_mask, 40 | decoder_input_ids=batch.decoder_input_ids, 41 | decoder_attention_mask=batch.decoder_attention_mask, 42 | retrieved_diff_input_ids=batch.retrieved_diff_input_ids, 43 | retrieved_diff_attention_mask=batch.retrieved_diff_attention_mask, 44 | retrieved_msg_input_ids=batch.retrieved_msg_input_ids, 45 | retrieved_msg_attention_mask=batch.retrieved_msg_attention_mask, 46 | **generation_kwargs, 47 | ) 48 | 49 | def num_parameters(self, exclude_embeddings: bool): 50 | return self.model.num_parameters(exclude_embeddings=exclude_embeddings) 51 | 52 | def save_pretrained(self, path: str) -> None: 53 | self.model.save_pretrained(path) 54 | -------------------------------------------------------------------------------- /src/model/configurations/seq2seq_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from transformers import AutoModelForSeq2SeqLM 4 | 5 | from src.model.configurations.base_model import BaseModel 6 | from src.utils import Batch, BatchTest 7 | 8 | 9 | class Seq2SeqWrapper(BaseModel): 10 | """This class serves as a wrapper of Transformer-based models for commit message completion task. 11 | 12 | More specifically, this class relies on pretrained seq2seq models from HuggingFace Transformers. 13 | 14 | Args: 15 | name_or_path: Name on HuggingFace hub or path to pretrained checkpoint. 16 | tokenizer: Tokenizer for the checkpoint (it's initialized earlier to add special tokens when necessary). 17 | """ 18 | 19 | def __init__(self, tokenizer, name_or_path, **kwargs): 20 | super().__init__() 21 | self._tokenizer = tokenizer 22 | self.model = AutoModelForSeq2SeqLM.from_pretrained(name_or_path) 23 | self.model.resize_token_embeddings(len(self._tokenizer)) 24 | 25 | def forward(self, batch: Batch) -> Any: 26 | return self.model( 27 | input_ids=batch.encoder_input_ids, 28 | attention_mask=batch.encoder_attention_mask, 29 | decoder_input_ids=batch.decoder_input_ids, 30 | decoder_attention_mask=batch.decoder_attention_mask, 31 | labels=batch.labels, 32 | ) 33 | 34 | def generate(self, batch: BatchTest, **generation_kwargs) -> Any: 35 | return self.model.generate( 36 | input_ids=batch.encoder_input_ids, 37 | attention_mask=batch.encoder_attention_mask, 38 | decoder_input_ids=batch.decoder_input_ids, 39 | decoder_attention_mask=batch.decoder_attention_mask, 40 | **generation_kwargs, 41 | ) 42 | 43 | def num_parameters(self, exclude_embeddings: bool): 44 | return self.model.num_parameters(exclude_embeddings=exclude_embeddings) 45 | 46 | def save_pretrained(self, path: str) -> None: 47 | self.model.save_pretrained(path) 48 | -------------------------------------------------------------------------------- /src/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | from .embedders import TransformerEmbedder 2 | from .search import DiffSearch 3 | 4 | __all__ = ["DiffSearch", "TransformerEmbedder"] 5 | -------------------------------------------------------------------------------- /src/retrieval/embedders/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import TransformerEmbedder 2 | 3 | __all__ = ["TransformerEmbedder"] 4 | -------------------------------------------------------------------------------- /src/retrieval/embedders/transformer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from transformers import AutoModel, AutoTokenizer 7 | 8 | from src.utils import BatchRetrieval 9 | 10 | from ..utils import CommitEmbeddingExample 11 | 12 | 13 | class TransformerEmbedder: 14 | """This class utilizes Transformer encoder to produce embeddings. 15 | 16 | Currently, the following architectures are supported: 17 | * BERT/RoBERTa 18 | * T5 (in this case, its encoder is used) 19 | """ 20 | 21 | def __init__(self, name_or_path: str, precision: int, device: str, normalize_embeddings: bool): 22 | assert device in ["cpu", "cuda"] 23 | if device == "cuda" and torch.cuda.device_count() > 1: 24 | raise ValueError("Please, specify GPU by setting CUDA_VISIBLE_DEVICES environment variable.") 25 | 26 | self._device = device 27 | self._normalize_embeddings = normalize_embeddings 28 | self._precision = precision 29 | 30 | self.model = AutoModel.from_pretrained(name_or_path) 31 | if self.model.config.model_type == "t5": 32 | logging.info("T5 model is passed, extracting encoder") 33 | self.model = self.model.encoder 34 | self.model.to(self._device) 35 | self.model.eval() 36 | 37 | def _transform(self, batch: BatchRetrieval) -> torch.Tensor: 38 | outputs = self.model( 39 | input_ids=batch.encoder_input_ids.to(self._device), 40 | attention_mask=batch.encoder_attention_mask.to(self._device), 41 | ) 42 | embeddings = torch.mean(outputs.last_hidden_state, dim=1) 43 | if self._normalize_embeddings: 44 | embeddings = F.normalize(embeddings, p=2, dim=1) 45 | return embeddings 46 | 47 | @torch.no_grad() 48 | def transform(self, batch: BatchRetrieval) -> List[CommitEmbeddingExample]: 49 | """Return embeddings for given list of strings. 50 | 51 | It includes the following steps: 52 | * run through model, obtain last_hidden_state of shape (batch_size, sequence_length, hidden_size) 53 | * compute mean by sequence_length dimension and obtain embeddings of shape (batch_size, hidden_size) 54 | * (optional) normalize embeddings so that each embedding's L2 norm is equal to 1 55 | """ 56 | if self._precision == 16 and self._device == "cuda": 57 | with torch.autocast(device_type="cuda", dtype=torch.float16): 58 | embeddings = self._transform(batch) 59 | else: 60 | embeddings = self._transform(batch) 61 | 62 | np_embeddings = embeddings.cpu().numpy() 63 | return [ 64 | CommitEmbeddingExample(diff_embedding=np_embedding, pos_in_file=pos_in_file) 65 | for np_embedding, pos_in_file in zip(np_embeddings, batch.pos_in_file) 66 | ] 67 | 68 | @property 69 | def embeddings_dim(self): 70 | if self.model.config.model_type == "t5": 71 | return self.model.config.d_model 72 | 73 | if self.model.config.model_type in ["bert", "roberta"]: 74 | return self.model.config.hidden_size 75 | -------------------------------------------------------------------------------- /src/retrieval/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .diff import DiffSearch 2 | 3 | __all__ = ["DiffSearch"] 4 | -------------------------------------------------------------------------------- /src/retrieval/search/diff.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Literal 3 | 4 | import annoy 5 | import numpy.typing as npt 6 | 7 | from ..utils import CommitEmbeddingExample, RetrievalPrediction 8 | 9 | 10 | class DiffSearch: 11 | """This class is used to retrieve the nearest neighbor via the Annoy library.""" 12 | 13 | def __init__( 14 | self, 15 | num_trees: int, 16 | embeddings_dim: int, 17 | load_index: bool, 18 | load_index_path: str = ".", 19 | index_root_dir: str = ".", 20 | metric: Literal["angular", "euclidean", "manhattan", "hamming", "dot"] = "angular", 21 | ) -> None: 22 | self._num_trees = num_trees 23 | 24 | self._index = annoy.AnnoyIndex(embeddings_dim, metric) 25 | self._index.set_seed(42) 26 | 27 | if load_index: 28 | self._index.load(load_index_path) 29 | else: 30 | self._index.on_disk_build(os.path.join(index_root_dir, f"index_{num_trees}.ann")) 31 | 32 | def add(self, example: CommitEmbeddingExample) -> None: 33 | """Adds a single item to the index.""" 34 | self._index.add_item(example["pos_in_file"], example["diff_embedding"]) 35 | 36 | def add_batch(self, batch: List[CommitEmbeddingExample]) -> None: 37 | """Adds a batch of items to the index. 38 | 39 | Note: Simply iterates over batch, because annoy doesn't support batch processing. 40 | """ 41 | for example in batch: 42 | if len(example["diff_embedding"].shape) > 1: 43 | assert example["diff_embedding"].shape[0] == 1 44 | example["diff_embedding"] = example["diff_embedding"].flatten() 45 | self.add(example) 46 | 47 | def finalize(self) -> None: 48 | self._index.build(self._num_trees) 49 | 50 | def predict_train(self, idx: int) -> RetrievalPrediction: 51 | """Retrieves the closest neighbor for given idx of embedding already present in index.""" 52 | # we are interested in the nearest neighbor, but for vectors from index it will always be themselves 53 | # so, we search for 2 neighbors and skip the first one 54 | retrieved_idxs, retrieved_distances = self._index.get_nns_by_item(idx, 2, include_distances=True) 55 | retrieved_idxs, retrieved_distances = retrieved_idxs[1:], retrieved_distances[1:] 56 | return RetrievalPrediction( 57 | distance=float(retrieved_distances[0]), 58 | pos_in_file=retrieved_idxs[0], 59 | ) 60 | 61 | def predict(self, diff_embedding: npt.NDArray) -> RetrievalPrediction: 62 | """Retrieves the closest neighbor from index for given embedding.""" 63 | 64 | if len(diff_embedding.shape) > 1: 65 | assert ( 66 | diff_embedding.shape[0] == 1 67 | ), "This method is used to process single example. Use `predict_batch` to process several examples." 68 | diff_embedding = diff_embedding.flatten() 69 | 70 | retrieved_idxs, retrieved_distances = self._index.get_nns_by_vector(diff_embedding, 1, include_distances=True) 71 | 72 | return RetrievalPrediction( 73 | distance=float(retrieved_distances[0]), 74 | pos_in_file=retrieved_idxs[0], 75 | ) 76 | 77 | def predict_batch(self, batch: List[CommitEmbeddingExample]) -> List[RetrievalPrediction]: 78 | """Retrieves the closest neighbors for each example in a batch. 79 | 80 | Note: Simply iterates over batch, because annoy doesn't support batch processing. 81 | """ 82 | return [self.predict(diff_embedding=example["diff_embedding"]) for example in batch] 83 | 84 | def predict_batch_train(self, batch_idxs: List[int]) -> List[RetrievalPrediction]: 85 | """Retrieves the closest neighbors for each example in a batch. Intended for examples present in index. 86 | 87 | Note: Simply iterates over batch, because annoy doesn't support batch processing. 88 | """ 89 | return [self.predict_train(idx=example) for example in batch_idxs] 90 | -------------------------------------------------------------------------------- /src/retrieval/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .typing_utils import CommitEmbeddingExample, RetrievalPrediction 2 | 3 | __all__ = ["CommitEmbeddingExample", "RetrievalPrediction"] 4 | -------------------------------------------------------------------------------- /src/retrieval/utils/typing_utils.py: -------------------------------------------------------------------------------- 1 | import numpy.typing as npt 2 | from typing_extensions import TypedDict 3 | 4 | 5 | class CommitEmbeddingExample(TypedDict): 6 | diff_embedding: npt.NDArray 7 | pos_in_file: int 8 | 9 | 10 | class RetrievalPrediction(TypedDict): 11 | pos_in_file: int 12 | distance: float 13 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation_metrics import EvaluationMetrics 2 | from .model_utils import get_decoder_start_token_id, remove_layers_from_model 3 | from .prefix_utils import PrefixAllowedTokens, VocabPrefixTree 4 | from .typing_utils import Batch, BatchRetrieval, BatchTest, BatchTrain, SingleExample 5 | from .wandb_organize_utils import WandbOrganizer 6 | 7 | __all__ = [ 8 | "SingleExample", 9 | "BatchTrain", 10 | "BatchTest", 11 | "BatchRetrieval", 12 | "EvaluationMetrics", 13 | "PrefixAllowedTokens", 14 | "VocabPrefixTree", 15 | "WandbOrganizer", 16 | "Batch", 17 | "remove_layers_from_model", 18 | "get_decoder_start_token_id", 19 | ] 20 | -------------------------------------------------------------------------------- /src/utils/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import torch 4 | from datasets import Metric, load_metric 5 | from torch import Tensor 6 | from torchmetrics import MetricCollection 7 | 8 | from src.metrics import MRR, Accuracy, BLEUNorm, EditSimilarity, ExactMatch, LogMNEXT 9 | 10 | 11 | class EvaluationMetrics: 12 | """This class is used to compute all evaluation metrics for commit message completion task. 13 | 14 | Currently, it includes the following: 15 | 16 | * string similarity metrics: BLEU, B-NORM, ROUGE, METEOR, LogM-Next, ChrF 17 | * completion metrics: Accuracy@k, MRR@k, Exact Match@k, Edit Similarity 18 | 19 | Accuracy@k and MRR@k are calculated on raw model output (tensors), all other metrics are calculated on 20 | generated and decoded strings. 21 | 22 | Args: 23 | do_tensors: True to compute Accuracy@k and MRR@k and False otherwise. 24 | do_strings: True to compute string similarity metrics and False otherwise. 25 | n: if an integer is given, ExactMatch metrics will be computed for first n tokens. Otherwise, it is computed for the whole sequences. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | do_strings: bool, 31 | do_tensors: bool, 32 | shift: bool, 33 | n: Optional[int] = None, 34 | prefix: Optional[str] = None, 35 | ): 36 | 37 | self.tensors_metrics: MetricCollection = MetricCollection({}) 38 | self.datasets_metrics: Dict[str, Metric] = {} 39 | self.strings_metrics: MetricCollection = MetricCollection({}) 40 | 41 | if do_tensors: 42 | self.tensors_metrics = MetricCollection( 43 | { 44 | "acc_top1": Accuracy(top_k=1, shift=shift), 45 | "acc_top5": Accuracy(top_k=5, shift=shift), 46 | "MRR_top5": MRR(top_k=5, shift=shift), 47 | }, 48 | ) 49 | 50 | if do_strings: 51 | self.datasets_metrics = { 52 | "b_norm": BLEUNorm(), 53 | "bleu": load_metric("bleu"), 54 | "rouge": load_metric("rouge"), 55 | "meteor": load_metric("meteor"), 56 | "chrf": load_metric("chrf"), 57 | } 58 | self.strings_metrics = MetricCollection( 59 | { 60 | "exact_match": ExactMatch(n=n), 61 | "edit_similarity": EditSimilarity(), 62 | "log_mnext": LogMNEXT(), 63 | } 64 | ) 65 | 66 | self.prefix = prefix 67 | 68 | def add_batch( 69 | self, 70 | predictions: Optional[List[str]] = None, 71 | references: Optional[List[str]] = None, 72 | predictions_tensor: Optional[Tensor] = None, 73 | references_tensor: Optional[Tensor] = None, 74 | ) -> Dict[str, torch.Tensor]: 75 | cur_metrics: Dict[str, torch.Tensor] = {} 76 | 77 | if self.datasets_metrics: 78 | assert predictions is not None and references is not None 79 | for key in self.datasets_metrics: 80 | if key == "bleu": 81 | self.datasets_metrics[key].add_batch( 82 | predictions=[[token.lower() for token in line.split()] for line in predictions], 83 | references=[[[token.lower() for token in line.split()]] for line in references], 84 | ) 85 | elif key == "chrf": 86 | self.datasets_metrics[key].add_batch( 87 | predictions=predictions, references=[[line] for line in references] 88 | ) 89 | else: 90 | self.datasets_metrics[key].add_batch(predictions=predictions, references=references) 91 | 92 | if self.tensors_metrics: 93 | assert predictions_tensor is not None and references_tensor is not None 94 | cur_metrics = self.tensors_metrics(predictions_tensor, references_tensor) 95 | 96 | if self.strings_metrics: 97 | assert predictions is not None and references is not None 98 | cur_string_metrics = self.strings_metrics(predictions, references) 99 | if cur_metrics: 100 | cur_metrics.update(cur_string_metrics) 101 | else: 102 | cur_metrics = cur_string_metrics 103 | 104 | return cur_metrics 105 | 106 | def compute(self) -> Dict[str, torch.Tensor]: 107 | results: Dict[str, torch.Tensor] = {} 108 | if self.datasets_metrics: 109 | for key in self.datasets_metrics: 110 | if key == "bleu": 111 | results[key] = self.datasets_metrics[key].compute(smooth=True)["bleu"] # type: ignore[index] 112 | elif key == "rouge": 113 | rouge = self.datasets_metrics[key].compute() 114 | results["rouge1"] = rouge["rouge1"].mid.fmeasure # type: ignore[index] 115 | results["rouge2"] = rouge["rouge2"].mid.fmeasure # type: ignore[index] 116 | results["rougeL"] = rouge["rougeL"].mid.fmeasure # type: ignore[index] 117 | elif key == "meteor": 118 | results[key] = self.datasets_metrics[key].compute()["meteor"] # type: ignore[index] 119 | elif key == "b_norm": 120 | results[key] = self.datasets_metrics[key].compute()["b_norm"] # type: ignore[index] 121 | elif key == "chrf": 122 | results[key] = self.datasets_metrics[key].compute()["score"] / 100 # type: ignore[index] 123 | 124 | for metrics in (self.tensors_metrics, self.strings_metrics): 125 | if metrics: 126 | metrics_results = metrics.compute() 127 | results.update({key: metrics_results[key] for key in metrics_results}) 128 | 129 | if self.prefix: 130 | results = {f"{self.prefix}_{key}": results[key] for key in results} 131 | return results 132 | -------------------------------------------------------------------------------- /src/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from copy import copy 2 | from typing import Optional, Union, no_type_check 3 | 4 | from transformers import AutoConfig, GPT2LMHeadModel, RobertaForCausalLM, RobertaModel 5 | 6 | from conf import BaseModelConfig, BaseRACEConfig, BaseSeq2SeqConfig 7 | 8 | 9 | def get_decoder_start_token_id(model_cfg: BaseModelConfig) -> Optional[int]: 10 | if model_cfg.configuration == "encoder_decoder": 11 | return None 12 | elif model_cfg.configuration == "decoder": 13 | return None 14 | elif model_cfg.configuration == "seq2seq": 15 | seq2seq_cfg = BaseSeq2SeqConfig(**model_cfg) # type: ignore[arg-type] 16 | name_or_path = seq2seq_cfg.name_or_path 17 | elif model_cfg.configuration == "race": 18 | race_cfg = BaseRACEConfig(**model_cfg) # type: ignore[arg-type] 19 | name_or_path = race_cfg.name_or_path 20 | else: 21 | return None 22 | 23 | config = AutoConfig.from_pretrained(name_or_path) 24 | return config.decoder_start_token_id 25 | 26 | 27 | @no_type_check # a lot of attr-defined errors from transformers 28 | def remove_layers_from_model( 29 | teacher: Union[RobertaModel, RobertaForCausalLM, GPT2LMHeadModel], num_layers: int 30 | ) -> Union[RobertaModel, RobertaForCausalLM, GPT2LMHeadModel]: 31 | if isinstance(teacher, RobertaForCausalLM): 32 | student_config = copy(teacher.config) 33 | student_config.num_hidden_layers = num_layers 34 | roberta_lm = RobertaForCausalLM(config=student_config) 35 | 36 | # copy all embeddings 37 | roberta_lm.roberta.embeddings.word_embeddings = teacher.roberta.embeddings.word_embeddings 38 | roberta_lm.roberta.embeddings.position_embeddings = teacher.roberta.embeddings.position_embeddings 39 | roberta_lm.roberta.embeddings.token_type_embeddings = teacher.roberta.embeddings.token_type_embeddings 40 | roberta_lm.roberta.embeddings.LayerNorm = teacher.roberta.embeddings.LayerNorm 41 | roberta_lm.roberta.embeddings.dropout = teacher.roberta.embeddings.dropout 42 | 43 | # uniformly pick from middle layers from teacher 44 | # it is basically np.linspace(0, teacher_config.num_hidden_layers, 45 | # num=student_config.num_hidden_layers, endpoint=True) 46 | step = (teacher.config.num_hidden_layers - 1) / (student_config.num_hidden_layers - 1) 47 | for student_layer, teacher_layer in enumerate(int(i * step) for i in range(student_config.num_hidden_layers)): 48 | roberta_lm.roberta.encoder.layer[student_layer] = teacher.roberta.encoder.layer[teacher_layer] 49 | return roberta_lm 50 | elif isinstance(teacher, RobertaModel): 51 | student_config = copy(teacher.config) 52 | student_config.num_hidden_layers = num_layers 53 | roberta = RobertaModel(config=student_config) # type: ignore[assignment] 54 | 55 | # copy all embeddings 56 | roberta.embeddings.word_embeddings = teacher.embeddings.word_embeddings 57 | roberta.embeddings.position_embeddings = teacher.embeddings.position_embeddings 58 | roberta.embeddings.token_type_embeddings = teacher.embeddings.token_type_embeddings 59 | roberta.embeddings.LayerNorm = teacher.embeddings.LayerNorm 60 | roberta.embeddings.dropout = teacher.embeddings.dropout 61 | 62 | # uniformly pick from middle layers from teacher 63 | # it is basically np.linspace(0, teacher_config.num_hidden_layers, 64 | # num=student_config.num_hidden_layers, endpoint=True) 65 | step = (teacher.config.num_hidden_layers - 1) / (student_config.num_hidden_layers - 1) 66 | for student_layer, teacher_layer in enumerate(int(i * step) for i in range(student_config.num_hidden_layers)): 67 | roberta.encoder.layer[student_layer] = teacher.encoder.layer[teacher_layer] 68 | return roberta 69 | elif isinstance(teacher, GPT2LMHeadModel): 70 | student_config = copy(teacher.config) 71 | student_config.n_layer = num_layers 72 | gpt2_lm = GPT2LMHeadModel(config=student_config) 73 | 74 | # Copying all embeddings 75 | gpt2_lm.transformer.wte = teacher.transformer.wte 76 | gpt2_lm.transformer.wpe = teacher.transformer.wpe 77 | gpt2_lm.transformer.drop = teacher.transformer.drop 78 | 79 | # Specific thing for GPT2LMHead 80 | gpt2_lm.tie_weights() 81 | # Uniformly pick from middle layers from teacher 82 | # It is basically np.linspace(0, teacher_config.n_layer, num=student_config.n_layer, endpoint=True) 83 | step = (teacher.config.n_layer - 1) / (student_config.n_layer - 1) 84 | for student_layer, teacher_layer in enumerate(int(i * step) for i in range(student_config.n_layer)): 85 | gpt2_lm.transformer.h[student_layer] = teacher.transformer.h[teacher_layer] 86 | return gpt2_lm 87 | -------------------------------------------------------------------------------- /src/utils/prefix_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import marisa_trie 4 | import torch 5 | from transformers import PreTrainedTokenizerFast 6 | 7 | 8 | class VocabPrefixTree: 9 | def __init__(self, tokenizer: PreTrainedTokenizerFast): 10 | self._vocab = tokenizer.get_vocab() # type: ignore[attr-defined] 11 | self._trie = marisa_trie.Trie([key.replace("Ġ", " ") for key in self._vocab]) 12 | 13 | def get_tokens(self, prefix: str) -> List[int]: 14 | """Uses the trie to find all the tokens that either: 15 | 16 | * start with the given prefix 17 | * are prefixes for the given prefix 18 | 19 | Args: 20 | prefix: Current prefix. 21 | 22 | Returns: 23 | A list of tokens ids. 24 | """ 25 | tokens = [] 26 | for token in self._trie.prefixes(prefix): 27 | try: 28 | tokens.append(self._vocab[token]) 29 | except KeyError: 30 | tokens.append(self._vocab[token.replace(" ", "Ġ")]) 31 | 32 | for token in self._trie.keys(prefix): 33 | try: 34 | tokens.append(self._vocab[token]) 35 | except KeyError: 36 | tokens.append(self._vocab[token.replace(" ", "Ġ")]) 37 | 38 | return tokens 39 | 40 | @property 41 | def vocab(self): 42 | return self._vocab 43 | 44 | 45 | class PrefixAllowedTokens: 46 | def __init__( 47 | self, 48 | context_len: Dict[int, int], 49 | prefix: Dict[int, str], 50 | tokenizer: PreTrainedTokenizerFast, 51 | trie: Optional[VocabPrefixTree] = None, 52 | ): 53 | self._context_len = context_len 54 | self._prefix = prefix 55 | 56 | self._tokenizer = tokenizer 57 | if not trie: 58 | trie = VocabPrefixTree(tokenizer) 59 | self._trie = trie 60 | 61 | def __call__(self, batch_id: int, sentence: torch.Tensor) -> List[int]: 62 | decoded_sentence = self._tokenizer.decode(sentence[self._context_len[batch_id] :]) # type: ignore[attr-defined] 63 | 64 | # when given prefix is empty, we can generate any token 65 | if not self._prefix[batch_id]: 66 | return list(self._trie.vocab.values()) # type: ignore[attr-defined] 67 | 68 | # if we haven't generated prefix or its part yet, we can: 69 | # 1) generate tokens starting with the prefix 70 | # 2) generate tokens which are prefixes for the prefix 71 | if len(decoded_sentence) == 0: 72 | res = self._trie.get_tokens(self._prefix[batch_id]) 73 | # if we've already generated the prefix, we can generate any token 74 | elif decoded_sentence.startswith(self._prefix[batch_id]): 75 | res = list(self._trie.vocab.values()) # type: ignore[attr-defined] 76 | # if we've generated only part of the prefix, we can: 77 | # 1) generate tokens starting with its remaining part 78 | # 2) generate tokens which are prefixes for its remaining part 79 | else: 80 | res = self._trie.get_tokens(self._prefix[batch_id][len(decoded_sentence) :]) 81 | 82 | return res 83 | -------------------------------------------------------------------------------- /src/utils/typing_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Optional 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class SingleExample: 9 | diff_input_ids: List[int] 10 | msg_input_ids: List[int] 11 | history_input_ids: List[List[int]] 12 | pos_in_file: int 13 | retrieved_diff_input_ids: Optional[List[int]] = None 14 | retrieved_msg_input_ids: Optional[List[int]] = None 15 | 16 | 17 | @dataclass 18 | class BatchRetrieval: 19 | encoder_input_ids: torch.Tensor 20 | encoder_attention_mask: torch.Tensor 21 | pos_in_file: List[int] 22 | 23 | def pin_memory(self): 24 | self.encoder_input_ids = self.encoder_input_ids.pin_memory() 25 | self.encoder_attention_mask = self.encoder_attention_mask.pin_memory() 26 | return self 27 | 28 | 29 | @dataclass 30 | class Batch: 31 | encoder_input_ids: torch.Tensor 32 | encoder_attention_mask: torch.Tensor 33 | decoder_input_ids: torch.Tensor 34 | decoder_attention_mask: torch.Tensor 35 | labels: Optional[torch.Tensor] 36 | retrieved_diff_input_ids: Optional[torch.Tensor] 37 | retrieved_diff_attention_mask: Optional[torch.Tensor] 38 | retrieved_msg_input_ids: Optional[torch.Tensor] 39 | retrieved_msg_attention_mask: Optional[torch.Tensor] 40 | 41 | def pin_memory(self): 42 | self.encoder_input_ids = self.encoder_input_ids.pin_memory() 43 | self.encoder_attention_mask = self.encoder_attention_mask.pin_memory() 44 | self.decoder_input_ids = self.decoder_input_ids.pin_memory() 45 | self.decoder_attention_mask = self.decoder_attention_mask.pin_memory() 46 | if self.labels is not None: 47 | self.labels = self.labels.pin_memory() 48 | if self.retrieved_diff_input_ids is not None: 49 | self.retrieved_diff_input_ids = self.retrieved_diff_input_ids.pin_memory() 50 | if self.retrieved_diff_attention_mask is not None: 51 | self.retrieved_diff_attention_mask = self.retrieved_diff_attention_mask.pin_memory() 52 | if self.retrieved_msg_input_ids is not None: 53 | self.retrieved_msg_input_ids = self.retrieved_msg_input_ids.pin_memory() 54 | if self.retrieved_msg_attention_mask is not None: 55 | self.retrieved_msg_attention_mask = self.retrieved_msg_attention_mask.pin_memory() 56 | return self 57 | 58 | 59 | @dataclass 60 | class BatchTrain(Batch): 61 | labels: torch.Tensor 62 | 63 | 64 | @dataclass 65 | class BatchTest(Batch): 66 | targets: List[str] 67 | prefixes: List[str] 68 | -------------------------------------------------------------------------------- /src/utils/wandb_organize_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from conf import ( 4 | BaseDecoderConfig, 5 | BaseEncoderDecoderConfig, 6 | BaseModelConfig, 7 | BaseRACEConfig, 8 | BaseSeq2SeqConfig, 9 | ) 10 | 11 | 12 | class WandbOrganizer: 13 | @staticmethod 14 | def _prepare_pretrained_name(name_or_path: str) -> str: 15 | name_or_path = name_or_path.split("/")[-1] 16 | if "-" in name_or_path: 17 | name_or_path = name_or_path.split("-")[0] 18 | if "_" in name_or_path: 19 | name_or_path = name_or_path.split("_")[0] 20 | return name_or_path 21 | 22 | @staticmethod 23 | def _get_model_tags(model_cfg: BaseModelConfig) -> List[str]: 24 | tags = [model_cfg.configuration] 25 | 26 | if model_cfg.configuration == "seq2seq": 27 | model_cfg = BaseSeq2SeqConfig(**model_cfg) # type: ignore[arg-type] 28 | tags.append(WandbOrganizer._prepare_pretrained_name(model_cfg.name_or_path)) 29 | elif model_cfg.configuration == "race": 30 | model_cfg = BaseRACEConfig(**model_cfg) # type: ignore[arg-type] 31 | tags.append(WandbOrganizer._prepare_pretrained_name(model_cfg.name_or_path)) 32 | elif model_cfg.configuration == "decoder": 33 | model_cfg = BaseDecoderConfig(**model_cfg) # type: ignore[arg-type] 34 | tags.append(WandbOrganizer._prepare_pretrained_name(model_cfg.decoder_name_or_path)) 35 | elif model_cfg.configuration == "encoder_decoder": 36 | model_cfg = BaseEncoderDecoderConfig(**model_cfg) # type: ignore[arg-type] 37 | 38 | if model_cfg.encoder_name_or_path: 39 | tags.append(f"[encoder]: {WandbOrganizer._prepare_pretrained_name(model_cfg.encoder_name_or_path)}") 40 | if model_cfg.encoder_model_type: 41 | tags.append(f"[encoder]: random_{model_cfg.encoder_model_type}") 42 | if model_cfg.num_layers_encoder: 43 | tags.append(f"[encoder]: {model_cfg.num_layers_encoder} layers") 44 | 45 | if model_cfg.decoder_name_or_path: 46 | tags.append(f"[decoder]: {WandbOrganizer._prepare_pretrained_name(model_cfg.decoder_name_or_path)}") 47 | if model_cfg.decoder_model_type: 48 | tags.append(f"[decoder]: random_{model_cfg.decoder_model_type}") 49 | if model_cfg.num_layers_decoder: 50 | tags.append(f"[decoder]: {model_cfg.num_layers_decoder} layers") 51 | 52 | if model_cfg.tie_encoder_decoder: 53 | tags.append("shared weights") 54 | elif model_cfg.tie_word_embeddings: 55 | tags.append("shared embeddings") 56 | 57 | return tags 58 | 59 | @staticmethod 60 | def get_run_name(model_cfg: BaseModelConfig, encoder_input_type: str, train_with_history: bool) -> str: 61 | name = [] 62 | 63 | if model_cfg.configuration == "seq2seq": 64 | model_cfg = BaseSeq2SeqConfig(**model_cfg) # type: ignore[arg-type] 65 | name.append(WandbOrganizer._prepare_pretrained_name(model_cfg.name_or_path)) 66 | elif model_cfg.configuration == "race": 67 | model_cfg = BaseRACEConfig(**model_cfg) # type: ignore[arg-type] 68 | name.append("race_" + WandbOrganizer._prepare_pretrained_name(model_cfg.name_or_path)) 69 | elif model_cfg.configuration == "decoder": 70 | model_cfg = BaseDecoderConfig(**model_cfg) # type: ignore[arg-type] 71 | name.append(WandbOrganizer._prepare_pretrained_name(model_cfg.decoder_name_or_path)) 72 | elif model_cfg.configuration == "encoder_decoder": 73 | model_cfg = BaseEncoderDecoderConfig(**model_cfg) # type: ignore[arg-type] 74 | if model_cfg.encoder_name_or_path: 75 | name.append(WandbOrganizer._prepare_pretrained_name(model_cfg.encoder_name_or_path)) 76 | if model_cfg.encoder_model_type: 77 | name.append(f"random_{model_cfg.encoder_model_type}") 78 | if model_cfg.num_layers_encoder: 79 | name.append(str(model_cfg.num_layers_encoder)) 80 | 81 | if model_cfg.decoder_name_or_path: 82 | name.append(WandbOrganizer._prepare_pretrained_name(model_cfg.decoder_name_or_path)) 83 | if model_cfg.decoder_model_type: 84 | name.append(f"random_{model_cfg.decoder_model_type}") 85 | if model_cfg.num_layers_decoder: 86 | name.append(str(model_cfg.num_layers_decoder)) 87 | 88 | if model_cfg.tie_encoder_decoder: 89 | name.append("shared-weights") 90 | elif model_cfg.tie_word_embeddings: 91 | name.append("shared-embeddings") 92 | 93 | if encoder_input_type == "diff": 94 | name.append("with-history" if train_with_history else "without-history") 95 | elif encoder_input_type == "history": 96 | name.append("history-input") 97 | 98 | return "_".join(name) 99 | 100 | @staticmethod 101 | def get_tags_train(model_cfg: BaseModelConfig, encoder_input_type: str, train_with_history: bool) -> List[str]: 102 | tags = WandbOrganizer._get_model_tags(model_cfg) 103 | tags.append("train with history" if train_with_history else "train without history") 104 | tags.append(encoder_input_type) 105 | return tags 106 | 107 | @staticmethod 108 | def get_tags_generate(generate_with_history: bool, context_ratio: float) -> List[str]: 109 | tags = [ 110 | "generate with history" if generate_with_history else "generate without history", 111 | f"context ratio = {context_ratio}", 112 | ] 113 | return tags 114 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/commit_message_generation/073f8e9d501fb4e35876b73743906fd2b258a7ba/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import pytorch_lightning as pl 4 | import torch 5 | from torchmetrics import MetricCollection 6 | from torchmetrics.utilities import check_forward_full_state_property 7 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer 8 | 9 | from src.metrics import Accuracy 10 | 11 | pl.seed_everything(123) 12 | 13 | 14 | @pytest.fixture 15 | def metrics_collection(): 16 | return MetricCollection( 17 | { 18 | "accuracy@1": Accuracy(top_k=1, ignore_index=-100), 19 | "accuracy@2": Accuracy(top_k=2, ignore_index=-100), 20 | "accuracy@3": Accuracy(top_k=3, ignore_index=-100), 21 | } 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def metrics_collection_no_shift(): 27 | return MetricCollection( 28 | { 29 | "accuracy@1": Accuracy(top_k=1, shift=False, ignore_index=-100), 30 | "accuracy@2": Accuracy(top_k=2, shift=False, ignore_index=-100), 31 | "accuracy@3": Accuracy(top_k=3, shift=False, ignore_index=-100), 32 | } 33 | ) 34 | 35 | 36 | def test_full_top1_match(metrics_collection): 37 | scores = torch.tensor( 38 | [[0.5, 0, 0.3, 0.2, 0.4], [1.0, -100, -200, -101, 0], [-1, -1, -1, -1, -1]], dtype=torch.float 39 | ) 40 | labels = torch.tensor([-1, 0, 0], dtype=torch.long) 41 | results = metrics_collection(scores, labels) 42 | assert np.allclose([results["accuracy@1"].item(), results["accuracy@3"].item()], [1.0, 1.0], rtol=1e-05, atol=1e-08) 43 | 44 | 45 | def test_full_top3_match(metrics_collection): 46 | scores = torch.tensor( 47 | [[0.5, 0, 0.3, 0.2, 0.4], [1.0, -100, -200, -101, 0], [-1, -1, -1, -1, -1]], dtype=torch.float 48 | ) 49 | labels = torch.tensor([-1, 2, 1], dtype=torch.long) 50 | results = metrics_collection(scores, labels) 51 | assert np.allclose([results["accuracy@1"].item(), results["accuracy@3"].item()], [0.0, 1.0], rtol=1e-05, atol=1e-08) 52 | 53 | 54 | def test_different_batch_sizes(metrics_collection): 55 | # batch size 4 56 | scores = torch.tensor( 57 | [ 58 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 59 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 60 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 61 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 62 | ], 63 | dtype=torch.float, 64 | ) 65 | labels = torch.tensor([[-1, 0, 2, -100], [-1, 1, 1, -100], [-1, 2, 1, 3], [-1, 1, 2, -100]], dtype=torch.long) 66 | results = metrics_collection(scores, labels) 67 | assert np.allclose( 68 | [results["accuracy@1"].item(), results["accuracy@2"].item()], [0.5 / 4, (1 + 2 / 3) / 4], rtol=1e-05, atol=1e-08 69 | ) 70 | 71 | # batch size 2 72 | metrics_collection.reset() 73 | first_half = metrics_collection(scores[:2], labels[:2]) 74 | second_half = metrics_collection(scores[2:], labels[2:]) 75 | total = metrics_collection.compute() 76 | assert np.allclose( 77 | [ 78 | (first_half["accuracy@1"].item() + second_half["accuracy@1"].item()) / 2, 79 | (first_half["accuracy@2"].item() + second_half["accuracy@2"].item()) / 2, 80 | ], 81 | [0.5 / 4, (1 + 2 / 3) / 4], 82 | rtol=1e-05, 83 | atol=1e-08, 84 | ) 85 | assert np.allclose( 86 | [total["accuracy@1"].item(), total["accuracy@2"].item()], [0.5 / 4, (1 + 2 / 3) / 4], rtol=1e-05, atol=1e-08 87 | ) 88 | 89 | # batch size 1 90 | metrics_collection.reset() 91 | results = [] 92 | for i in range(4): 93 | results.append(metrics_collection(scores[i], labels[i])) 94 | total = metrics_collection.compute() 95 | mean_acc1 = np.mean([res["accuracy@1"].item() for res in results]) 96 | mean_acc2 = np.mean([res["accuracy@2"].item() for res in results]) 97 | assert np.allclose([mean_acc1, mean_acc2], [0.5 / 4, (1 + 2 / 3) / 4], rtol=1e-05, atol=1e-08) 98 | assert np.allclose( 99 | [total["accuracy@1"].item(), total["accuracy@2"].item()], [0.5 / 4, (1 + 2 / 3) / 4], rtol=1e-05, atol=1e-08 100 | ) 101 | 102 | 103 | def test_gpt2_shift(metrics_collection, metrics_collection_no_shift): 104 | model = AutoModelForCausalLM.from_pretrained("gpt2") 105 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 106 | 107 | input = "My name is John" 108 | tokenized_input = tokenizer(input, add_special_tokens=True, padding=False, return_tensors="pt").input_ids 109 | logits = model(input_ids=tokenized_input, labels=tokenized_input).logits 110 | preds_tokens = tokenizer.convert_ids_to_tokens(torch.topk(logits, 1, dim=-1)[1].squeeze(-1).squeeze(0)) 111 | input_tokens = tokenizer.convert_ids_to_tokens(tokenized_input.squeeze(0)) 112 | 113 | assert input_tokens[1:][-1] == preds_tokens[:-1][-1] 114 | 115 | results = metrics_collection(logits, tokenized_input) 116 | results_no_shift = metrics_collection_no_shift(logits, tokenized_input) 117 | 118 | for key in results: 119 | assert results[key] > results_no_shift[key] 120 | 121 | 122 | def test_t5_shift(metrics_collection, metrics_collection_no_shift): 123 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") 124 | tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") 125 | 126 | input = "Q: Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering." 127 | target = "George Washington is a fictional character. Geoffrey Hinton is a fictional character." 128 | tokenized_input = tokenizer(input, add_special_tokens=True, padding=False, return_tensors="pt").input_ids 129 | tokenized_target = tokenizer(target, add_special_tokens=True, padding=False, return_tensors="pt").input_ids 130 | 131 | logits = model(input_ids=tokenized_input, labels=tokenized_target).logits 132 | preds_tokens = tokenizer.convert_ids_to_tokens(torch.topk(logits, 1, dim=-1)[1].squeeze(-1).squeeze(0)) 133 | target_tokens = tokenizer.convert_ids_to_tokens(tokenized_target.squeeze(0)) 134 | 135 | assert preds_tokens[:-1] == target_tokens[:-1] 136 | 137 | results = metrics_collection(logits, tokenized_target) 138 | results_no_shift = metrics_collection_no_shift(logits, tokenized_target) 139 | 140 | for key in results: 141 | assert results[key] < results_no_shift[key] 142 | 143 | 144 | def test_full_state_update(): 145 | scores = torch.tensor( 146 | [ 147 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 148 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 149 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 150 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 151 | ], 152 | dtype=torch.float, 153 | ) 154 | labels = torch.tensor([[-1, 0, 2, -100], [-1, 1, 1, -100], [-1, 2, 1, 3], [-1, 1, 2, -100]], dtype=torch.long) 155 | 156 | for top_k in list(range(1, 4)): 157 | check_forward_full_state_property( 158 | Accuracy, 159 | init_args=dict(top_k=top_k, ignore_index=-100), 160 | input_args={"predictions": scores, "references": labels}, 161 | ) 162 | -------------------------------------------------------------------------------- /tests/test_codereviewer_preprocessor.py: -------------------------------------------------------------------------------- 1 | from src.data_utils.preprocessors import CodeReviewerPreprocessor 2 | 3 | 4 | def test_preprocess_diff(): 5 | preprocessor = CodeReviewerPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 6 | 7 | assert ( 8 | preprocessor._preprocess_diff( 9 | "context 1[NL]context 2[NL]context 3[NL]-old line[NL]+new line[NL]", line_sep="[NL]" 10 | ) 11 | == "context 1[NL]context 2[NL]context 3[NL]old line[NL]new line[NL]" 12 | ) 13 | -------------------------------------------------------------------------------- /tests/test_default_preprocessor.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import jsonlines 4 | 5 | from src.data_utils.preprocessors import DefaultPreprocessor 6 | 7 | 8 | def test_preprocess_mods(): 9 | preprocessor = DefaultPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 10 | 11 | # check that all mods types work correctly 12 | modify_mod = { 13 | "change_type": "MODIFY", 14 | "old_path": "fname", 15 | "new_path": "fname", 16 | "diff": "context 1[NL]context 2[NL]context 3[NL]-old line[NL]+new line[NL]", 17 | } 18 | assert preprocessor._preprocess_mods([modify_mod], line_sep="[NL]") == "fname[NL]" + modify_mod["diff"] 19 | 20 | add_mod = { 21 | "change_type": "ADD", 22 | "old_path": None, 23 | "new_path": "fname", 24 | "diff": "context 1[NL]context 2[NL]context 3[NL]-old line[NL]+new line[NL]", 25 | } 26 | assert preprocessor._preprocess_mods([add_mod], line_sep="[NL]") == "new file fname[NL]" + add_mod["diff"] 27 | 28 | delete_mod = { 29 | "change_type": "DELETE", 30 | "old_path": "fname", 31 | "new_path": None, 32 | "diff": "context 1[NL]context 2[NL]context 3[NL]-old line[NL]+new line[NL]", 33 | } 34 | assert preprocessor._preprocess_mods([delete_mod], line_sep="[NL]") == "deleted file fname[NL]" + delete_mod["diff"] 35 | 36 | rename_mod = { 37 | "change_type": "RENAME", 38 | "old_path": "fname1", 39 | "new_path": "fname2", 40 | "diff": "context 1[NL]context 2[NL]context 3[NL]-old line[NL]+new line[NL]", 41 | } 42 | assert ( 43 | preprocessor._preprocess_mods([rename_mod], line_sep="[NL]") 44 | == "rename from fname1[NL]rename to fname2[NL]" + rename_mod["diff"] 45 | ) 46 | 47 | copy_mod = { 48 | "change_type": "COPY", 49 | "old_path": "fname1", 50 | "new_path": "fname2", 51 | "diff": "context 1[NL]context 2[NL]context 3[NL]-old line[NL]+new line[NL]", 52 | } 53 | assert ( 54 | preprocessor._preprocess_mods([copy_mod], line_sep="[NL]") 55 | == "copy from fname1[NL]copy to fname2[NL]" + copy_mod["diff"] 56 | ) 57 | 58 | # check some mods together 59 | assert preprocessor._preprocess_mods([modify_mod, modify_mod, add_mod], line_sep="[NL]") == ( 60 | "fname[NL]" + modify_mod["diff"] + "fname[NL]" + modify_mod["diff"] + "new file fname[NL]" + add_mod["diff"] 61 | ) 62 | 63 | 64 | def test_shuffle(tmp_path): 65 | preprocessor = DefaultPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 66 | 67 | data = [{"hash": f"hash{i}", "data": f"row{i}"} for i in range(10)] 68 | with jsonlines.open(f"{tmp_path}/test_file.jsonl", "w") as writer: 69 | writer.write_all(data) 70 | 71 | retrieved_data = [{"sim": f"sim{i}", "data": f"row{i}"} for i in range(10)] 72 | with jsonlines.open(f"{tmp_path}/test_retrieved_file.jsonl", "w") as writer: 73 | writer.write_all(retrieved_data) 74 | 75 | for i in range(5): 76 | preprocessor._shuffle(f"{tmp_path}/test_file.jsonl", f"{tmp_path}/test_file_shuffled_{i}.jsonl") 77 | preprocessor._shuffle( 78 | f"{tmp_path}/test_retrieved_file.jsonl", f"{tmp_path}/test_retrieved_file_shuffled_{i}.jsonl" 79 | ) 80 | with jsonlines.open(f"{tmp_path}/test_file_shuffled_{i}.jsonl", "r") as reader: 81 | shuffled_data = [line for line in reader] 82 | 83 | with jsonlines.open(f"{tmp_path}/test_retrieved_file_shuffled_{i}.jsonl", "r") as reader: 84 | retrieved_shuffled_data = [line for line in reader] 85 | 86 | assert shuffled_data != data 87 | assert [row["data"] for row in shuffled_data] == [row["data"] for row in retrieved_shuffled_data] 88 | 89 | 90 | def test_get_pos_in_history(): 91 | preprocessor = DefaultPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 92 | positions = preprocessor._get_pos_in_history([1, 1, 2, 2, 3]) 93 | assert positions == [0, 1, 0, 1, 0] 94 | assert preprocessor._num_commits == {1: 2, 2: 2, 3: 1} 95 | 96 | positions = preprocessor._get_pos_in_history([2, 1, 2, 55]) 97 | assert positions == [2, 2, 3, 0] 98 | assert preprocessor._num_commits == {1: 3, 2: 4, 3: 1, 55: 1} 99 | 100 | 101 | def test_process_history(tmp_path): 102 | preprocessor = DefaultPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 103 | 104 | with jsonlines.open(f"{tmp_path}/test_file.jsonl", "w") as writer: 105 | writer.write_all( 106 | [{"author": i, "msg_input_ids": [i]} for i in range(10)] 107 | + [{"author": i, "msg_input_ids": [i + 100]} for i in range(5, 15)] 108 | ) 109 | 110 | preprocessor._process_history(input_path=f"{tmp_path}/test_file.jsonl", output_path=f"{tmp_path}/test_history.json") 111 | with open(f"{tmp_path}/test_history.json", "r") as f: 112 | history = json.load(f) 113 | 114 | assert set(history.keys()) == set([f"{i}" for i in range(15)]) 115 | for i in range(5): 116 | assert history[f"{i}"] == [[i]] 117 | for i in range(5, 10): 118 | assert history[f"{i}"] == [[i], [i + 100]] 119 | for i in range(10, 15): 120 | assert history[f"{i}"] == [[i + 100]] 121 | 122 | 123 | def test_add_history_to_inputs(tmp_path): 124 | preprocessor = DefaultPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 125 | 126 | data = [{"msg_input_ids": f"msg{i}", "author": 0, "pos_in_history": i} for i in range(10)] 127 | data += [{"msg_input_ids": f"msg{i + 100}", "author": 1, "pos_in_history": i} for i in range(10)] 128 | with jsonlines.open(f"{tmp_path}/test.jsonl", "w") as writer: 129 | writer.write_all(data) 130 | 131 | history = {0: [f"msg{i}" for i in range(10)], 1: [f"msg{i + 100}" for i in range(10)]} 132 | with open(f"{tmp_path}/test_history.json", "w") as f: 133 | json.dump(history, f) 134 | 135 | preprocessor._add_history_to_inputs( 136 | input_path=f"{tmp_path}/test.jsonl", 137 | history_path=f"{tmp_path}/test_history.json", 138 | output_path=f"{tmp_path}/test_w_history.jsonl", 139 | part="test", 140 | decoder_context_max_length=100, 141 | ) 142 | with jsonlines.open(f"{tmp_path}/test_w_history.jsonl", "r") as reader: 143 | results = [line for line in reader] 144 | assert all( 145 | [ 146 | line["history_input_ids"] == [f"msg{i}" for i in range(line["pos_in_history"])] 147 | for line in results 148 | if line["author"] == 0 149 | ] 150 | ) 151 | assert all( 152 | [ 153 | line["history_input_ids"] == [f"msg{i + 100}" for i in range(line["pos_in_history"])] 154 | for line in results 155 | if line["author"] == 1 156 | ] 157 | ) 158 | 159 | preprocessor._add_history_to_inputs( 160 | input_path=f"{tmp_path}/test.jsonl", 161 | history_path=f"{tmp_path}/test_history.json", 162 | output_path=f"{tmp_path}/test_w_history.jsonl", 163 | part="test", 164 | decoder_context_max_length=16, 165 | ) 166 | with jsonlines.open(f"{tmp_path}/test_w_history.jsonl", "r") as reader: 167 | results = [line for line in reader] 168 | assert all( 169 | [ 170 | line["history_input_ids"] 171 | == [f"msg{i}" for i in range(line["pos_in_history"] - 2, line["pos_in_history"]) if i >= 0] 172 | for line in results 173 | if line["author"] == 0 174 | ] 175 | ) 176 | assert all( 177 | [ 178 | line["history_input_ids"] 179 | == [f"msg{i + 100}" for i in range(line["pos_in_history"] - 1, line["pos_in_history"]) if i >= 0] 180 | for line in results 181 | if line["author"] == 1 182 | ] 183 | ) 184 | -------------------------------------------------------------------------------- /tests/test_diff_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from src.retrieval import DiffSearch 8 | from src.retrieval.utils import CommitEmbeddingExample, RetrievalPrediction 9 | 10 | 11 | def cosine_sim(x: List[float], y: List[float]) -> float: 12 | """A simple helper function to compute cosine similarity between two 1D lists.""" 13 | assert len(x) == len(y) 14 | xy = sum(x_item * y_item for x_item, y_item in zip(x, y)) 15 | x_norm = sum(x_item**2 for x_item in x) ** 0.5 16 | y_norm = sum(y_item**2 for y_item in y) ** 0.5 17 | return xy / (x_norm * y_norm) 18 | 19 | 20 | def angular_dist(x: List[float], y: List[float]) -> float: 21 | """A simple helper function to compute angular distance between two 1D lists.""" 22 | assert len(x) == len(y) 23 | return (2 * (1 - cosine_sim(x, y))) ** 0.5 24 | 25 | 26 | def test_idxs_out_of_order(tmp_path): 27 | search = DiffSearch(embeddings_dim=3, num_trees=3, metric="angular", load_index=False) 28 | search.add(CommitEmbeddingExample(diff_embedding=np.array([0, 0, 1]), pos_in_file=1)) 29 | search.add(CommitEmbeddingExample(diff_embedding=np.array([1, 1, 1]), pos_in_file=2)) 30 | search.add(CommitEmbeddingExample(diff_embedding=np.array([1, 0, 1]), pos_in_file=0)) 31 | search.finalize() 32 | 33 | assert search._index.get_item_vector(0) == [1, 0, 1] 34 | assert search._index.get_item_vector(1) == [0, 0, 1] 35 | assert search._index.get_item_vector(2) == [1, 1, 1] 36 | 37 | 38 | def test_train_logic(tmp_path): 39 | search = DiffSearch(embeddings_dim=3, num_trees=3, metric="angular", load_index=False) 40 | search.add(CommitEmbeddingExample(diff_embedding=np.array([1, 0, 1]), pos_in_file=0)) 41 | search.add(CommitEmbeddingExample(diff_embedding=np.array([0, 0, 1]), pos_in_file=1)) 42 | search.add(CommitEmbeddingExample(diff_embedding=np.array([1, 1, 1]), pos_in_file=2)) 43 | search.finalize() 44 | 45 | # searching for the nearest neighbor for embedding present in index returns itself 46 | prediction = search.predict(np.array([1, 0, 1])) 47 | assert prediction == RetrievalPrediction( 48 | distance=pytest.approx(angular_dist([1, 0, 1], [1, 0, 1]), abs=1e-7), 49 | pos_in_file=0, 50 | ) 51 | 52 | # the expected way: use `predict_train` and pass idx of vector in index 53 | prediction = search.predict_train(0) 54 | assert prediction == RetrievalPrediction( 55 | distance=pytest.approx(angular_dist([1, 0, 1], [1, 1, 1]), abs=1e-7), 56 | pos_in_file=2, 57 | ) 58 | 59 | 60 | def test_nn_search(tmp_path): 61 | search = DiffSearch(embeddings_dim=3, num_trees=3, metric="angular", load_index=False) 62 | search.add(CommitEmbeddingExample(diff_embedding=np.array([0, 0, 1]), pos_in_file=1)) 63 | search.add(CommitEmbeddingExample(diff_embedding=np.array([1, 1, 1]), pos_in_file=2)) 64 | search.add(CommitEmbeddingExample(diff_embedding=np.array([1, 0, 1]), pos_in_file=0)) 65 | search.finalize() 66 | 67 | prediction = search.predict(np.array([1, 0, 1])) 68 | assert prediction == RetrievalPrediction( 69 | distance=pytest.approx(angular_dist([1, 0, 1], [1, 0, 1]), abs=1e-7), 70 | pos_in_file=0, 71 | ) 72 | -------------------------------------------------------------------------------- /tests/test_edit_similarity.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.metrics import EditSimilarity 4 | 5 | 6 | @pytest.fixture 7 | def edit_similarity(): 8 | return EditSimilarity() 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "input_str", 13 | [("hello"), ('@pytest.mark.parametrize(\n "input_str",'), ("def test_same_string(edit_similarity, input_str):")], 14 | ) 15 | def test_same_string(edit_similarity, input_str): 16 | assert edit_similarity([input_str], [input_str]) == pytest.approx(1.0) 17 | 18 | 19 | def test_empty_pred(edit_similarity): 20 | assert edit_similarity([""], ["hello"]) == pytest.approx(0.0) 21 | 22 | 23 | def test_empty_ref(edit_similarity): 24 | assert edit_similarity([""], [""]).isnan().item() 25 | assert edit_similarity(["hello"], [""]).isnan().item() 26 | -------------------------------------------------------------------------------- /tests/test_exact_match.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from src.metrics import ExactMatch 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "predictions,references,n", 8 | [ 9 | (["hello\n\n\n"], ["hello world"], 1), 10 | (["full match"], ["full match"], 2), 11 | (["a a a a a x y z"], ["a a a a a b c d"], 5), 12 | ], 13 | ) 14 | def test_full_match(predictions, references, n): 15 | assert ExactMatch(n=n)(predictions, references) == pytest.approx(1.0) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "predictions,references,n", 20 | [ 21 | (["random words"], ["something else then random words again"], 2), 22 | ([""], ["anything"], 1), 23 | ], 24 | ) 25 | def test_no_match(predictions, references, n): 26 | assert ExactMatch(n=n)(predictions, references) == pytest.approx(0.0) 27 | 28 | 29 | def test_different_lengths(): 30 | predictions = ["a b c", "d f"] 31 | references = ["a b c", "d e"] 32 | assert ExactMatch(n=1)(predictions, references) == pytest.approx(1.0) 33 | assert ExactMatch(n=2)(predictions, references) == pytest.approx(0.5) 34 | assert ExactMatch(n=3)(predictions, references) == pytest.approx(1.0) 35 | 36 | for n in range(4, 100): 37 | assert ExactMatch(n=n)(predictions, references) == pytest.approx(0.0) 38 | -------------------------------------------------------------------------------- /tests/test_full_pipeline.sh: -------------------------------------------------------------------------------- 1 | # a simple script that launches pipeline for all kinds of models 2 | # run to make sure that artifacts & everything else in W&B project is setup as intended 3 | echo "W&B username: $1" 4 | echo "Accelerator: $2 (should be one of 'gpu','cpu')" 5 | 6 | # run codet5 pipeline with diffs & history on first 10 examples 7 | python train.py +model=codet5 ++input.train_with_history=true ++logger.project=test ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_train_batches=10 ++trainer.limit_val_batches=10 ++dataset.use_cache=true ++dataset.train_dataloader_conf.batch_size=4 ++dataset.val_dataloader_conf.batch_size=4 ++input.encoder_input_type=diff 8 | python eval.py +model=codet5 ++input.train_with_history=true ++input.context_ratio=0.25 ++input.generate_with_history=true ++logger.project=test ++logger.artifact_config.project="$1/test" ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_test_batches=10 ++dataset.use_cache=true ++input.encoder_input_type=diff 9 | python compute_metrics.py ++logger.project=test ++logger.artifact_config.name=codet5_with-history_preds ++logger.artifact_config.project="$1/test" ++logger.artifact_config.version=context-ratio_0.25_with-history 10 | 11 | # run codereviewer pipeline with diffs & history on first 10 examples 12 | python train.py +model=codereviewer ++input.train_with_history=true ++logger.project=test ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_train_batches=10 ++trainer.limit_val_batches=10 ++dataset.use_cache=true ++dataset.train_dataloader_conf.batch_size=4 ++dataset.val_dataloader_conf.batch_size=4 ++input.encoder_input_type=diff 13 | python eval.py +model=codereviewer ++input.train_with_history=true ++input.context_ratio=0.25 ++input.generate_with_history=true ++logger.project=test ++logger.artifact_config.project="$1/test" ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_test_batches=10 ++dataset.use_cache=true ++input.encoder_input_type=diff 14 | python compute_metrics.py ++logger.project=test ++logger.artifact_config.name=codereviewer_with-history_preds ++logger.artifact_config.project="$1/test" ++logger.artifact_config.version=context-ratio_0.25_with-history 15 | 16 | # run race pipeline with diffs & history on first 10 examples 17 | python train.py +model=race ++input.train_with_history=true ++logger.project=test ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_train_batches=10 ++trainer.limit_val_batches=10 ++dataset.use_cache=false ++dataset.train_dataloader_conf.batch_size=4 ++dataset.val_dataloader_conf.batch_size=4 ++input.encoder_input_type=diff 18 | python eval.py +model=race ++input.train_with_history=true ++input.context_ratio=0.25 ++input.generate_with_history=false ++logger.project=test ++logger.artifact_config.project="$1/test" ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_test_batches=10 ++dataset.use_cache=true ++input.encoder_input_type=diff 19 | python compute_metrics.py ++logger.project=test ++logger.artifact_config.name=race_with-history_preds ++logger.artifact_config.project="$1/test" ++logger.artifact_config.version=context-ratio_0.25_with-history 20 | 21 | # run distilgpt2 pipeline on first 10 examples 22 | python train.py +model=distilgpt2 ++input.train_with_history=true ++input.encoder_input_type=diff ++logger.project=test ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_train_batches=10 ++trainer.limit_val_batches=10 ++dataset.use_cache=false ++dataset.train_dataloader_conf.batch_size=4 ++dataset.val_dataloader_conf.batch_size=4 ++input.encoder_input_type=diff 23 | python eval.py +model=distilgpt2 ++input.train_with_history=true ++input.encoder_input_type=diff ++input.generate_with_history=true ++input.context_ratio=0.25 ++logger.project=test ++logger.artifact_config.project="$1/test" ++trainer.accelerator="$2" ++trainer.devices=1 ++trainer.limit_test_batches=10 ++dataset.use_cache=true ++input.encoder_input_type=diff 24 | python compute_metrics.py ++logger.project=test ++logger.artifact_config.name=distilgpt2_with-history_preds ++logger.artifact_config.project="$1/test" ++logger.artifact_config.version=context-ratio_0.25 -------------------------------------------------------------------------------- /tests/test_mrr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import pytorch_lightning as pl 4 | import torch 5 | from torchmetrics import MetricCollection 6 | from torchmetrics.utilities import check_forward_full_state_property 7 | from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer 8 | 9 | from src.metrics.mrr import MRR 10 | 11 | pl.seed_everything(123) 12 | 13 | 14 | @pytest.fixture 15 | def metrics_collection(): 16 | return MetricCollection( 17 | { 18 | "MRR@1": MRR(top_k=1, ignore_index=-100), 19 | "MRR@2": MRR(top_k=2, ignore_index=-100), 20 | "MRR@3": MRR(top_k=3, ignore_index=-100), 21 | } 22 | ) 23 | 24 | 25 | @pytest.fixture 26 | def metrics_collection_no_shift(): 27 | return MetricCollection( 28 | { 29 | "MRR@1": MRR(top_k=1, shift=False, ignore_index=-100), 30 | "MRR@2": MRR(top_k=2, shift=False, ignore_index=-100), 31 | "MRR@3": MRR(top_k=3, shift=False, ignore_index=-100), 32 | } 33 | ) 34 | 35 | 36 | def test_full_top1_match(metrics_collection): 37 | scores = torch.tensor( 38 | [[0.5, 0, 0.3, 0.2, 0.4], [1.0, -100, -200, -100, 0], [-1, -1, -1, -1, -1]], dtype=torch.float 39 | ) 40 | labels = torch.tensor([-1, 0, 0], dtype=torch.long) 41 | results = metrics_collection(scores, labels) 42 | assert np.allclose([results["MRR@1"].item(), results["MRR@3"].item()], [1.0, 1.0], rtol=1e-05, atol=1e-08) 43 | 44 | 45 | def test_full_top3_match(metrics_collection): 46 | scores = torch.tensor( 47 | [[0.5, 0, 0.3, 0.2, 0.4], [1.0, -100, -200, -101, 0], [-1, -1, -1, -1, -1]], dtype=torch.float 48 | ) 49 | labels = torch.tensor([-1, 2, 1], dtype=torch.long) 50 | results = metrics_collection(scores, labels) 51 | print(results["MRR@1"].item(), results["MRR@3"].item()) 52 | assert np.allclose( 53 | [results["MRR@1"].item(), results["MRR@3"].item()], [0.0, (1 / 3 + 1 / 3) / 2], rtol=1e-05, atol=1e-08 54 | ) 55 | 56 | 57 | def test_different_batch_sizes(metrics_collection): 58 | # batch size 4 59 | scores = torch.tensor( 60 | [ 61 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 62 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 63 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 64 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 65 | ], 66 | dtype=torch.float, 67 | ) 68 | labels = torch.tensor([[-1, 0, 2, -100], [-1, 1, 1, -100], [-1, 2, 1, 3], [-1, 1, 2, -100]], dtype=torch.long) 69 | results = metrics_collection(scores, labels) 70 | assert np.allclose([results["MRR@2"].item()], [(0.75 + 1 / 3) / 4], rtol=1e-05, atol=1e-08) 71 | 72 | # batch size 2 73 | metrics_collection.reset() 74 | first_half = metrics_collection(scores[:2], labels[:2]) 75 | second_half = metrics_collection(scores[2:], labels[2:]) 76 | total = metrics_collection.compute() 77 | assert np.allclose( 78 | [(first_half["MRR@2"].item() + second_half["MRR@2"].item()) / 2], [(0.75 + 1 / 3) / 4], rtol=1e-05, atol=1e-08 79 | ) 80 | assert np.allclose([total["MRR@2"].item()], [(0.75 + 1 / 3) / 4], rtol=1e-05, atol=1e-08) 81 | 82 | # batch size 1 83 | metrics_collection.reset() 84 | results = [] 85 | for i in range(4): 86 | results.append(metrics_collection(scores[i], labels[i])) 87 | total = metrics_collection.compute() 88 | assert np.allclose( 89 | [np.mean([res["MRR@2"].item() for res in results])], [(0.75 + 1 / 3) / 4], rtol=1e-05, atol=1e-08 90 | ) 91 | assert np.allclose([total["MRR@2"].item()], [(0.75 + 1 / 3) / 4], rtol=1e-05, atol=1e-08) 92 | 93 | 94 | def test_gpt2_shift(metrics_collection, metrics_collection_no_shift): 95 | model = AutoModelForCausalLM.from_pretrained("gpt2") 96 | tokenizer = AutoTokenizer.from_pretrained("gpt2") 97 | 98 | input = "My name is John" 99 | tokenized_input = tokenizer(input, add_special_tokens=True, padding=False, return_tensors="pt").input_ids 100 | logits = model(input_ids=tokenized_input, labels=tokenized_input).logits 101 | preds_tokens = tokenizer.convert_ids_to_tokens(torch.topk(logits, 1, dim=-1)[1].squeeze(-1).squeeze(0)) 102 | input_tokens = tokenizer.convert_ids_to_tokens(tokenized_input.squeeze(0)) 103 | 104 | assert input_tokens[1:][-1] == preds_tokens[:-1][-1] 105 | 106 | results = metrics_collection(logits, tokenized_input) 107 | results_no_shift = metrics_collection_no_shift(logits, tokenized_input) 108 | 109 | for key in results: 110 | assert results[key] > results_no_shift[key] 111 | 112 | 113 | def test_t5_shift(metrics_collection, metrics_collection_no_shift): 114 | model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") 115 | tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") 116 | 117 | input = "Q: Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering." 118 | target = "George Washington is a fictional character. Geoffrey Hinton is a fictional character." 119 | tokenized_input = tokenizer(input, add_special_tokens=True, padding=False, return_tensors="pt").input_ids 120 | tokenized_target = tokenizer(target, add_special_tokens=True, padding=False, return_tensors="pt").input_ids 121 | 122 | logits = model(input_ids=tokenized_input, labels=tokenized_target).logits 123 | preds_tokens = tokenizer.convert_ids_to_tokens(torch.topk(logits, 1, dim=-1)[1].squeeze(-1).squeeze(0)) 124 | target_tokens = tokenizer.convert_ids_to_tokens(tokenized_target.squeeze(0)) 125 | 126 | assert preds_tokens[:-1] == target_tokens[:-1] 127 | 128 | results = metrics_collection(logits, tokenized_target) 129 | results_no_shift = metrics_collection_no_shift(logits, tokenized_target) 130 | 131 | for key in results: 132 | assert results[key] < results_no_shift[key] 133 | 134 | 135 | def test_full_state_update(): 136 | scores = torch.tensor( 137 | [ 138 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 139 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 140 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 141 | [[0.5, 0, 0.3], [1.0, -100, -200], [0.5, 0, 0.3], [-1, -1, -1]], 142 | ], 143 | dtype=torch.float, 144 | ) 145 | labels = torch.tensor([[-1, 0, 2, -100], [-1, 1, 1, -100], [-1, 2, 1, 3], [-1, 1, 2, -100]], dtype=torch.long) 146 | 147 | for top_k in list(range(1, 4)): 148 | check_forward_full_state_property( 149 | MRR, 150 | init_args=dict(top_k=top_k, ignore_index=-100), 151 | input_args={"predictions": scores, "references": labels}, 152 | ) 153 | -------------------------------------------------------------------------------- /tests/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import subprocess 5 | 6 | 7 | def prepare_data(tmp_path): 8 | for i, part in enumerate(["train", "val", "test"]): 9 | with open(f"{tmp_path}/{part}.jsonl", "w") as f: 10 | json.dump( 11 | { 12 | "author": i, 13 | "repo": f"sample_{part}_repo", 14 | "hash": "sample hash", 15 | "mods": [ 16 | {"change_type": "MODIFY", "old_path": "fname", "new_path": "fname", "diff": "sample diff"} 17 | ], 18 | "message": "sample commit message", 19 | "license": "MIT License", 20 | "language": "Python", 21 | }, 22 | f, 23 | ) 24 | return tmp_path 25 | 26 | 27 | def test_train_pipeline(tmp_path): 28 | root_dir = prepare_data(tmp_path) 29 | 30 | if "train.py" not in os.listdir(): 31 | os.chdir("..") 32 | 33 | for use_history in ["true", "false"]: 34 | command = ( 35 | "python train.py +model=codet5 " 36 | "++input.encoder_input_type=diff " 37 | f"++input.train_with_history={use_history} " 38 | "++trainer.accelerator=cpu " 39 | "++trainer.devices=1 " 40 | "++trainer.max_epochs=1 " 41 | "++dataset.use_cache=false " 42 | "++dataset.use_eval_downsample=false " 43 | f'++dataset.dataset_root="{root_dir}" ' 44 | "++dataset.train_dataloader_conf.batch_size=1 " 45 | "++dataset.val_dataloader_conf.batch_size=1 " 46 | "++logger.use_wandb=false " 47 | "++optimizer.learning_rate=0.00002 ++optimizer.weight_decay=0.0 ++optimizer.num_warmup_steps=100" 48 | ) 49 | process = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) 50 | stdout, stderr = process.communicate() 51 | output_lines = re.split(r"[\n\r]", stdout.decode("utf-8")) 52 | assert any(line.startswith("Epoch 0: 100%") for line in output_lines) 53 | 54 | 55 | def test_eval_pipeline(tmp_path): 56 | root_dir = prepare_data(tmp_path) 57 | 58 | if "eval.py" not in os.listdir(): 59 | os.chdir("..") 60 | 61 | for use_history in ["true", "false"]: 62 | for context_ratio in [0.0, 0.5]: 63 | command = ( 64 | "python eval.py +model=codet5 " 65 | "++input.encoder_input_type=diff " 66 | f"++input.train_with_history={use_history} " 67 | f"++input.generate_with_history={use_history} " 68 | f"++input.context_ratio={context_ratio} " 69 | "++trainer.accelerator=cpu " 70 | "++trainer.devices=1 " 71 | "++dataset.use_eval_downsample=false " 72 | "++dataset.use_cache=false " 73 | f'++dataset.dataset_root="{root_dir}" ' 74 | "++dataset.test_dataloader_conf.batch_size=1 " 75 | "++logger.use_wandb=false" 76 | ) 77 | process = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) 78 | stdout, stderr = process.communicate() 79 | output_lines = re.split(r"[\n\r]", stdout.decode("utf-8")) 80 | assert any(line.startswith("Testing DataLoader 0: 100%") for line in output_lines) 81 | -------------------------------------------------------------------------------- /tests/test_prefix_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from omegaconf import OmegaConf 4 | from transformers import AutoTokenizer 5 | 6 | from conf import BaseEncoderDecoderConfig 7 | from src.model import CMCModule 8 | from src.utils import BatchTest, PrefixAllowedTokens 9 | 10 | torch.manual_seed(42) 11 | 12 | 13 | @pytest.fixture(scope="module") 14 | def default_setting(): 15 | conf = OmegaConf.structured( 16 | BaseEncoderDecoderConfig( 17 | configuration="encoder_decoder", 18 | encoder_name_or_path="distilbert-base-uncased", 19 | decoder_name_or_path="distilgpt2", 20 | encoder_context_max_len=512, 21 | decoder_context_max_len=256, 22 | diff_tokenizer_name_or_path="distilbert-base-uncased", 23 | msg_tokenizer_name_or_path="distilgpt2", 24 | tie_word_embeddings=False, 25 | tie_encoder_decoder=False, 26 | ) 27 | ) 28 | 29 | model = CMCModule( 30 | model_cfg=conf, 31 | diff_tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"), 32 | msg_tokenizer=AutoTokenizer.from_pretrained("distilgpt2"), 33 | batch_size=1, 34 | ) 35 | return model, {"num_beams": 4, "num_return_sequences": 4} 36 | 37 | 38 | @pytest.mark.parametrize( 39 | "context,prefix,expected", 40 | [ 41 | ("", "Firs", "First"), 42 | ("GPT-2 is a generative", " lan", " language"), 43 | ("My twitter", " userna", " username"), 44 | ], 45 | ) 46 | def test_with_and_without_prefix_fn(default_setting, context, prefix, expected): 47 | model, generation_kwargs = default_setting 48 | 49 | if not context: 50 | context = model._msg_tokenizer.eos_token 51 | 52 | tokenized_context = model._msg_tokenizer(context, return_tensors="pt").input_ids 53 | tokenized_context_w_prefix = model._msg_tokenizer(context + prefix, return_tensors="pt").input_ids 54 | 55 | min_len = 2 56 | max_len = 2 57 | 58 | results_with_prefix_fn = model.generate( 59 | batch=BatchTest( 60 | encoder_input_ids=torch.zeros_like(tokenized_context), 61 | encoder_attention_mask=torch.zeros_like(tokenized_context), 62 | decoder_input_ids=tokenized_context, 63 | decoder_attention_mask=torch.ones_like(tokenized_context), 64 | prefixes=[prefix], 65 | targets=[], 66 | labels=None, 67 | retrieved_diff_input_ids=None, 68 | retrieved_diff_attention_mask=None, 69 | retrieved_msg_input_ids=None, 70 | retrieved_msg_attention_mask=None, 71 | ), 72 | min_length=min_len + tokenized_context.shape[1], 73 | max_length=max_len + tokenized_context.shape[1], 74 | **generation_kwargs 75 | ) 76 | sequences_with_prefix_fn = model._msg_tokenizer.batch_decode( 77 | results_with_prefix_fn[:, tokenized_context.shape[1] :] 78 | ) 79 | 80 | results_without_prefix_fn = model.generate( 81 | batch=BatchTest( 82 | encoder_input_ids=torch.zeros_like(tokenized_context_w_prefix), 83 | encoder_attention_mask=torch.zeros_like(tokenized_context_w_prefix), 84 | decoder_input_ids=tokenized_context_w_prefix, 85 | decoder_attention_mask=torch.ones_like(tokenized_context_w_prefix), 86 | prefixes=[""], 87 | targets=[], 88 | labels=None, 89 | retrieved_diff_input_ids=None, 90 | retrieved_diff_attention_mask=None, 91 | retrieved_msg_input_ids=None, 92 | retrieved_msg_attention_mask=None, 93 | ), 94 | min_length=min_len + tokenized_context_w_prefix.shape[1], 95 | max_length=max_len + tokenized_context_w_prefix.shape[1], 96 | **generation_kwargs 97 | ) 98 | sequences_without_prefix_fn = model._msg_tokenizer.batch_decode( 99 | results_without_prefix_fn[:, tokenized_context.shape[1] :] 100 | ) 101 | 102 | assert any([seq.startswith(expected) for seq in sequences_with_prefix_fn]) 103 | assert not any([seq.startswith(expected) for seq in sequences_without_prefix_fn]) 104 | 105 | 106 | @pytest.mark.parametrize( 107 | "context,prefix,generated", 108 | [ 109 | ("GPT-2 is a generative", " la", ""), 110 | ("My twitter", " userna", ""), 111 | ("Hello", " wor", ""), 112 | ], 113 | ) 114 | def test_generation_empty_prefix(default_setting, context, prefix, generated): 115 | model, generation_kwargs = default_setting 116 | 117 | beam_sentence = context + generated 118 | 119 | tokenized_beam_sentence = model._msg_tokenizer(beam_sentence, return_tensors="pt").input_ids[0] 120 | prefix_fn = PrefixAllowedTokens( 121 | prefix={0: ""}, context_len={0: 0}, tokenizer=model._msg_tokenizer, trie=model.vocab_trie 122 | ) 123 | 124 | allowed_tokens = model._msg_tokenizer.batch_decode(prefix_fn(0, sentence=tokenized_beam_sentence)) 125 | assert len(allowed_tokens) == len(model._msg_tokenizer.get_vocab().keys()) 126 | 127 | 128 | @pytest.mark.parametrize( 129 | "context,prefix,generated", 130 | [ 131 | ("", "Firs", ""), 132 | ("GPT-2 is a generative", " la", ""), 133 | ("My twitter", " userna", ""), 134 | ("Hello", " wor", ""), 135 | ], 136 | ) 137 | def test_generation_start(default_setting, context, prefix, generated): 138 | model, generation_kwargs = default_setting 139 | 140 | beam_sentence = context + generated 141 | 142 | tokenized_context = model._msg_tokenizer(context, return_tensors="pt").input_ids 143 | tokenized_beam_sentence = model._msg_tokenizer(beam_sentence, return_tensors="pt").input_ids[0] 144 | prefix_fn = PrefixAllowedTokens( 145 | prefix={0: prefix}, 146 | context_len={0: tokenized_context.shape[1]}, 147 | tokenizer=model._msg_tokenizer, 148 | trie=model.vocab_trie, 149 | ) 150 | 151 | allowed_tokens = model._msg_tokenizer.batch_decode(prefix_fn(0, sentence=tokenized_beam_sentence)) 152 | for token in allowed_tokens: 153 | assert token.startswith(prefix) or prefix.startswith(token) 154 | 155 | 156 | @pytest.mark.parametrize( 157 | "context,prefix,generated,remaining", 158 | [ 159 | ("update to version", " 3.0", " 3", ".0"), 160 | ("", "GPT-2", "GPT", "-2"), 161 | ("GPT-2 is a generative", " la", " l", "a"), 162 | ("My twitter", " userna", " user", "na"), 163 | ("Hello", " wor", " wo", "r"), 164 | ], 165 | ) 166 | def test_generation_prefix_part(default_setting, context, prefix, generated, remaining): 167 | model, generation_kwargs = default_setting 168 | 169 | beam_sentence = context + generated 170 | 171 | tokenized_context = model._msg_tokenizer(context, return_tensors="pt").input_ids 172 | tokenized_beam_sentence = model._msg_tokenizer(beam_sentence, return_tensors="pt").input_ids[0] 173 | prefix_fn = PrefixAllowedTokens( 174 | prefix={0: prefix}, 175 | context_len={0: tokenized_context.shape[1]}, 176 | tokenizer=model._msg_tokenizer, 177 | trie=model.vocab_trie, 178 | ) 179 | 180 | allowed_tokens = model._msg_tokenizer.batch_decode(prefix_fn(0, sentence=tokenized_beam_sentence)) 181 | assert all(token.startswith(remaining) or remaining.startswith(token) for token in allowed_tokens) 182 | 183 | 184 | @pytest.mark.parametrize( 185 | "context,prefix,generated", 186 | [ 187 | ("GPT-2 is a generative", " la", " language model"), 188 | ("My twitter", " userna", " username"), 189 | ("Hello", " wor", " wor"), 190 | ], 191 | ) 192 | def test_generation_whole_prefix(default_setting, context, prefix, generated): 193 | model, generation_kwargs = default_setting 194 | 195 | beam_sentence = context + generated 196 | 197 | tokenized_context = model._msg_tokenizer(context, return_tensors="pt").input_ids 198 | tokenized_beam_sentence = model._msg_tokenizer(beam_sentence, return_tensors="pt").input_ids[0] 199 | prefix_fn = PrefixAllowedTokens( 200 | prefix={0: prefix}, 201 | context_len={0: tokenized_context.shape[1]}, 202 | tokenizer=model._msg_tokenizer, 203 | trie=model.vocab_trie, 204 | ) 205 | 206 | allowed_tokens = model._msg_tokenizer.batch_decode(prefix_fn(0, sentence=tokenized_beam_sentence)) 207 | assert len(allowed_tokens) == len(model._msg_tokenizer.get_vocab().keys()) 208 | -------------------------------------------------------------------------------- /tests/test_race.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from transformers import AutoTokenizer, T5ForConditionalGeneration 4 | 5 | from src.model.configurations.utils.race import RACE 6 | 7 | 8 | def test_forward_without_retrieval(): 9 | model_name = "t5-small" 10 | race = RACE.from_pretrained(model_name) 11 | t5 = T5ForConditionalGeneration.from_pretrained(model_name) 12 | tok = AutoTokenizer.from_pretrained(model_name) 13 | 14 | input = "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)" 15 | target = "Amanda baked cookies and will bring Jerry some tomorrow." 16 | input_encodings = tok( 17 | "summarize: " + input, truncation=False, padding=False, add_special_tokens=True, return_tensors="pt" 18 | ) 19 | target_encodings = tok(target, truncation=False, padding=False, add_special_tokens=True, return_tensors="pt") 20 | 21 | with torch.no_grad(): 22 | race_outputs = race( 23 | input_ids=input_encodings.input_ids, 24 | attention_mask=input_encodings.attention_mask, 25 | labels=target_encodings.input_ids, 26 | ) 27 | t5_outputs = t5( 28 | input_ids=input_encodings.input_ids, 29 | attention_mask=input_encodings.attention_mask, 30 | labels=target_encodings.input_ids, 31 | ) 32 | 33 | assert race_outputs.keys() == t5_outputs.keys() 34 | assert race_outputs.loss == pytest.approx(t5_outputs.loss) 35 | assert race_outputs.logits.numpy() == pytest.approx(t5_outputs.logits.numpy()) 36 | 37 | 38 | def test_generation_without_retrieval(): 39 | model_name = "t5-small" 40 | race = RACE.from_pretrained(model_name) 41 | t5 = T5ForConditionalGeneration.from_pretrained(model_name) 42 | tok = AutoTokenizer.from_pretrained(model_name) 43 | 44 | input = "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)" 45 | input_encodings = tok( 46 | "summarize: " + input, truncation=False, padding=False, add_special_tokens=True, return_tensors="pt" 47 | ) 48 | with torch.no_grad(): 49 | race_preds = race.generate( 50 | input_ids=input_encodings.input_ids, 51 | attention_mask=input_encodings.attention_mask, 52 | max_new_tokens=20, 53 | num_beams=10, 54 | ) 55 | t5_preds = t5.generate( 56 | input_ids=input_encodings.input_ids, 57 | attention_mask=input_encodings.attention_mask, 58 | max_new_tokens=20, 59 | num_beams=10, 60 | ) 61 | 62 | assert (race_preds.numpy() == t5_preds.numpy()).all() 63 | 64 | race_preds_str = tok.batch_decode(race_preds, skip_special_tokens=True)[0] 65 | t5_preds_str = tok.batch_decode(t5_preds, skip_special_tokens=True)[0] 66 | 67 | assert race_preds_str == t5_preds_str 68 | 69 | 70 | def test_forward_with_retrieval(): 71 | model_name = "t5-small" 72 | race = RACE.from_pretrained(model_name) 73 | t5 = T5ForConditionalGeneration.from_pretrained(model_name) 74 | tok = AutoTokenizer.from_pretrained(model_name) 75 | 76 | input = "Amanda: I baked cookies. Do you want some? Jerry: Sure! Amanda: I'll bring you tomorrow :-)" 77 | similar_input = "Jerry: I baked cookies. Do you want some? Amanda: Sure! Jerry: I'll bring you tomorrow :-)" 78 | similar_target = "Jerry baked cookies and will bring Amanda some tomorrow." 79 | target = "Amanda baked cookies and will bring Jerry some tomorrow." 80 | 81 | input_encodings = tok(input, truncation=False, padding=False, add_special_tokens=True, return_tensors="pt") 82 | similar_input_encodings = tok( 83 | similar_input, truncation=False, padding=False, add_special_tokens=True, return_tensors="pt" 84 | ) 85 | similar_target_encodings = tok( 86 | similar_target, truncation=False, padding=False, add_special_tokens=True, return_tensors="pt" 87 | ) 88 | 89 | with torch.no_grad(): 90 | race_preds = race.generate( 91 | input_ids=input_encodings.input_ids, 92 | attention_mask=input_encodings.attention_mask, 93 | retrieved_diff_input_ids=similar_input_encodings.input_ids, 94 | retrieved_diff_attention_mask=similar_input_encodings.attention_mask, 95 | retrieved_msg_input_ids=similar_target_encodings.input_ids, 96 | retrieved_msg_attention_mask=similar_target_encodings.attention_mask, 97 | max_new_tokens=20, 98 | num_beams=10, 99 | ) 100 | t5_preds = t5.generate( 101 | input_ids=input_encodings.input_ids, 102 | attention_mask=input_encodings.attention_mask, 103 | max_new_tokens=20, 104 | num_beams=10, 105 | ) 106 | 107 | race_preds_str = tok.batch_decode(race_preds, skip_special_tokens=True)[0] 108 | t5_preds_str = tok.batch_decode(t5_preds, skip_special_tokens=True)[0] 109 | 110 | 111 | def test_params(): 112 | model_name = "t5-small" 113 | race = RACE.from_pretrained(model_name) 114 | t5 = T5ForConditionalGeneration.from_pretrained(model_name) 115 | race_params = race.num_parameters() 116 | t5_params = t5.num_parameters() 117 | # we add a linear layer from 2 * hidden_size to 1 118 | # it has 2 * hidden_size parameters for weight and 1 parameter for bias 119 | assert race_params - t5_params == race.config.d_model * 2 + 1 120 | -------------------------------------------------------------------------------- /tests/test_race_preprocessor.py: -------------------------------------------------------------------------------- 1 | from src.data_utils.preprocessors import RACEPreprocessor 2 | 3 | 4 | def test_preprocess_diff(): 5 | preprocessor = RACEPreprocessor(diff_tokenizer=None, msg_tokenizer=None) # tokenizers are not relevant 6 | assert ( 7 | preprocessor._preprocess_diff(header=["fname"], diff="context[NL]-old line[NL]+new line[NL]", line_sep="[NL]") 8 | == " fname context -old +new line " 9 | ) 10 | 11 | assert ( 12 | preprocessor._preprocess_diff( 13 | header=["new file fname"], diff="context[NL]-old line[NL]+new line[NL]", line_sep="[NL]" 14 | ) 15 | == " new file fname context -old +new line " 16 | ) 17 | 18 | assert ( 19 | preprocessor._preprocess_diff( 20 | header=["deleted file fname"], diff="context[NL]-old line[NL]+new line[NL]", line_sep="[NL]" 21 | ) 22 | == " deleted file fname context -old +new line " 23 | ) 24 | 25 | assert ( 26 | preprocessor._preprocess_diff( 27 | header=["rename from fname1", "rename to fname2"], 28 | diff="context[NL]-old line[NL]+new line[NL]", 29 | line_sep="[NL]", 30 | ) 31 | == " rename from fname1 rename to fname2 context -old +new line " 32 | ) 33 | 34 | assert ( 35 | preprocessor._preprocess_diff( 36 | header=["copy from fname1", "copy to fname2"], diff="context[NL]-old line[NL]+new line[NL]", line_sep="[NL]" 37 | ) 38 | == " copy from fname1 copy to fname2 context -old +new line " 39 | ) 40 | -------------------------------------------------------------------------------- /tests/test_transformer_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import pytest 5 | from transformers import AutoTokenizer, PreTrainedTokenizerFast 6 | 7 | from src.retrieval import TransformerEmbedder 8 | from src.retrieval.utils import CommitEmbeddingExample 9 | from src.utils import BatchRetrieval 10 | 11 | 12 | def create_batch(inputs: List[str], tokenizer: PreTrainedTokenizerFast) -> BatchRetrieval: 13 | encoded_inputs = tokenizer(inputs, truncation=True, padding=True, return_tensors="pt") # type: ignore[operator] 14 | return BatchRetrieval( 15 | encoder_input_ids=encoded_inputs.input_ids, 16 | encoder_attention_mask=encoded_inputs.attention_mask, 17 | pos_in_file=[i for i, _ in enumerate(inputs)], 18 | ) 19 | 20 | 21 | def test_bert_embedder(): 22 | embedder = TransformerEmbedder( 23 | name_or_path="bert-base-uncased", device="cpu", precision=123, normalize_embeddings=True 24 | ) 25 | assert embedder.embeddings_dim == embedder.model.config.hidden_size 26 | 27 | inputs = ["example input", "another example input"] 28 | batch = create_batch(inputs, AutoTokenizer.from_pretrained("bert-base-uncased")) 29 | 30 | embeddings: List[CommitEmbeddingExample] = embedder.transform(batch) 31 | for i, embedding in enumerate(embeddings): 32 | assert embedding["diff_embedding"].shape == (embedder.embeddings_dim,) 33 | assert np.linalg.norm(embedding["diff_embedding"]) == pytest.approx(1) 34 | assert embedding["pos_in_file"] == i 35 | 36 | 37 | def test_t5_embedder(): 38 | embedder = TransformerEmbedder(name_or_path="t5-small", device="cpu", precision=123, normalize_embeddings=True) 39 | assert embedder.embeddings_dim == embedder.model.config.d_model 40 | 41 | inputs = ["example input", "another example input"] 42 | batch = create_batch(inputs, AutoTokenizer.from_pretrained("t5-small")) 43 | 44 | embeddings: List[CommitEmbeddingExample] = embedder.transform(batch) 45 | for i, embedding in enumerate(embeddings): 46 | assert embedding["diff_embedding"].shape == (embedder.embeddings_dim,) 47 | assert np.linalg.norm(embedding["diff_embedding"]) == pytest.approx(1) 48 | assert embedding["pos_in_file"] == i 49 | --------------------------------------------------------------------------------