├── .github └── workflows │ └── black-lint.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── python ├── .gitignore ├── configs │ ├── codeT5.yaml │ └── coditT5.yaml ├── deltr │ ├── Macros.py │ ├── coditT5 │ │ ├── CodeT5.py │ │ ├── prediction.py │ │ ├── save_pretrained.py │ │ └── utils.py │ └── collector │ │ ├── DataCollector.py │ │ ├── DataProcessor.py │ │ ├── RealDiffCollector.py │ │ └── diff_utils.py ├── prepare-conda-env.sh └── requirements.txt └── requirements.txt /.github/workflows/black-lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - name: Check files using the black formatter 11 | uses: rickstaa/action-black@v1 12 | id: action_black 13 | with: 14 | black_args: "." 15 | - name: Check for modified files 16 | id: git-check 17 | run: echo "modified=$(if git diff-index --quiet HEAD --; then echo "false"; else echo "true"; fi)" >> $GITHUB_OUTPUT 18 | - name: Push changes 19 | if: steps.git-check.outputs.modified == 'true' 20 | run: | 21 | git config --global user.name 'jiyang' 22 | git config --global user.email 'jiyang.zhang@utexas.edu' 23 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/${{ github.repository }} 24 | git commit -am "Automated black format fixes" 25 | git push 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temp file 2 | *~ 3 | \#*\# 4 | __pycache__/ 5 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 EngineeringSoftware 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multilingual Code Co-Evolution Using Large Language Models 2 | 3 | This repo hosts the code and data for the following FSE 2023 paper: 4 | 5 | Title: [Multilingual Code Co-Evolution Using Large Language Models](https://arxiv.org/abs/2307.14991) 6 | 7 | Authors: [Jiyang Zhang](https://jiyangzhang.github.io/), [Pengyu Nie](https://pengyunie.github.io/), [Junyi Jessy Li](https://jessyli.com/), [Milos Gligoric](http://users.ece.utexas.edu/~gligoric/) 8 | 9 | ```bibtex 10 | @inproceedings{ZhangETAL23Codeditor, 11 | author = {Zhang, Jiyang and Nie, Pengyu and Li, Junyi Jessy and Gligoric, Milos}, 12 | title = {Multilingual Code Co-Evolution Using Large Language Models}, 13 | booktitle = {Joint European Software Engineering Conference and Symposium on the Foundations of Software Engineering}, 14 | year = {2023}, 15 | } 16 | ``` 17 | 18 | ## News 19 | May 2024 20 | The fine-tuned EditsTranslation model is released on 🤗 ! 🔥[cs2java](https://huggingface.co/EngineeringSoftware/EditsTranslation-cs2java) and [java2cs](https://huggingface.co/EngineeringSoftware/EditsTranlation-java2cs/settings) 21 | 22 | ## How to Use 23 | 24 | [sec-howto]: #how-to-use 25 | 26 | ```python 27 | from transformers import T5ForConditionalGeneration, AutoTokenizer 28 | 29 | checkpoint = "EngineeringSoftware/EditsTranlation-java2cs" 30 | 31 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 32 | model = T5ForConditionalGeneration.from_pretrained(checkpoint) 33 | 34 | code_input = """class HelloWorld { public static void main(String[] args) { System.out.println("Hello, World!")""" 35 | 36 | input_ids = tokenizer(code_input, return_tensors="pt").input_ids 37 | generated_ids = model.generate(input_ids, max_length=200) 38 | print(tokenizer.decode(generated_ids[0], skip_special_tokens=True)) 39 | # output: ; } } ; class HelloWorld { public static void main(String[] args) { System.out.println("Hello, World!") ; } } ; 40 | ``` 41 | 42 | 43 | 44 | 45 | ## Introduction 46 | 47 | This repo contains the code and artifacts for reproducing the experiments in [Multilingual Code Co-Evolution Using Large Language Models](https://arxiv.org/abs/2307.14991). 48 | In this work, we introduce Codeditor for co-evolving software implemented in multiple programming languages. 49 | 50 | The code includes: 51 | 52 | - scripts for processing dataset 53 | - scripts for training and evaluating codeditor models 54 | 55 | The artifacts include: 56 | 57 | - Java to C# raw paired changes 58 | - Java to C# translation dataset processed for codeditor models 59 | 60 | ## Data Downloads 61 | 62 | [sec-downloads]: #data-downloads 63 | 64 | All our data is hosted on UTBox via [a shared folder](https://utexas.box.com/s/iwcvwgx23g9xvowu9joa661rz74k9eea). 65 | 66 | 67 | ## Code for Processing Fine-tuning Data 68 | 69 | [sec-process]: #code-for-processing-fine-tuning-data 70 | 71 | We provide the sample script to process the datasets for edit-translation. Requires the raw data files at `raw_data/`. 72 | 73 | ``` 74 | cd python/ 75 | python -m deltr.collector.DataProcessor edit_translation_data_process --exp cs2java --src_lang cs --tgt_lang java 76 | 77 | ``` 78 | 79 | ## Code for Training and Evaluating Models 80 | 81 | [sec-traineval]: #code-for-training-and-evaluating-models 82 | 83 | ### Train ML models 84 | 85 | ``` 86 | cd python/ 87 | python -m deltr.coditT5.CodeT5 fit --exp_dir {MODELS_DIR}/${model_name}/${dataset} --data.dataset {dataset} --data.model ${model_name} --config configs/coditT5.yaml 88 | 89 | # Example: python -m deltr.coditT5.CodeT5 fit --exp_dir models/edit-translation/java2cs --data.dataset java2cs --data.model edit-translation --config configs/coditT5.yaml 90 | ``` 91 | 92 | Results are generated to `models/${model}/${dataset}/`, where: 93 | 94 | - `model/`: stores the trained model. 95 | 96 | - `logs/`: stores logs during training. 97 | 98 | ### Run ML models to do inference 99 | 100 | Requires the dataset at `data/${model}/${dataset}/`, the trained model at `models/${model}/${dataset}/model/`. 101 | 102 | ``` 103 | cd python/ 104 | python -m deltr.coditT5.CodeT5 predict --exp_dir {MODELS_DIR}/${model_name}/${dataset} --data.dataset {dataset} --data.model ${model_name} --config configs/coditT5.yaml 105 | 106 | ``` 107 | 108 | Results are generated to `models/${model}/${dataset}/`, where: 109 | 110 | - `output.hyp`: the predictions. 111 | -------------------------------------------------------------------------------- /python/.gitignore: -------------------------------------------------------------------------------- 1 | # Temp files 2 | 3 | *~ 4 | \#*\# 5 | .DS_Store 6 | *.class 7 | *.pyc 8 | .pdf 9 | 10 | experiments.log 11 | 12 | target/ 13 | 14 | .idea/ 15 | *.iml 16 | 17 | /raw_data 18 | 19 | # logs 20 | tacc-logs/ 21 | 22 | *.code-workspace 23 | 24 | # downloads dir 25 | _downloads/ 26 | data/ 27 | 28 | repo-data/ 29 | 30 | models/ 31 | .vscode/settings.json 32 | -------------------------------------------------------------------------------- /python/configs/codeT5.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | batch_size: 1 3 | eval_batch_size: 1 4 | 5 | model: 6 | pretrained_model: Salesforce/codet5-base 7 | pretrained_tokenizer: Salesforce/codet5-base 8 | skip_special_token_when_generate: False 9 | beam_size: 20 10 | 11 | trainer: 12 | auto_select_gpus: true 13 | gpus: -1 14 | strategy: ddp 15 | # find_unused_parameters: false 16 | precision: 16 17 | 18 | # max_steps: 50_000 19 | # fast_dev_run: true 20 | max_epochs: 30 21 | accumulate_grad_batches: 4 # effective batch size 1*4(gpu)*4(accumulate) = 32 22 | 23 | callbacks: 24 | - class_path: pytorch_lightning.callbacks.EarlyStopping 25 | init_args: 26 | monitor: bleu/val 27 | mode: max 28 | min_delta: 0 29 | patience: 5 30 | verbose: true 31 | # - class_path: pytorch_lightning.callbacks.StochasticWeightAveraging # Incompatible with EarlyStopping 32 | - class_path: pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor 33 | init_args: 34 | logging_interval: step 35 | 36 | optimizer: 37 | class_path: transformers.optimization.AdamW 38 | init_args: 39 | lr: 0.00005 40 | eps: 1e-8 41 | weight_decay: 0.01 42 | 43 | lr_scheduler: 44 | class_path: torch.optim.lr_scheduler.OneCycleLR 45 | init_args: 46 | max_lr: 0.00005 47 | pct_start: 0.1 48 | div_factor: 1 49 | total_steps: 30 50 | anneal_strategy: linear 51 | 52 | ckpt: 53 | save_top_k: 1 54 | monitor: bleu/val 55 | mode: max 56 | -------------------------------------------------------------------------------- /python/configs/coditT5.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | batch_size: 1 3 | eval_batch_size: 1 4 | 5 | model: 6 | pretrained_model: ../models/pretrain/model/ 7 | pretrained_tokenizer: ../models/codeT5Tokenizer 8 | beam_size: 20 9 | skip_special_token_when_generate: False 10 | 11 | trainer: 12 | auto_select_gpus: true 13 | gpus: -1 14 | strategy: ddp 15 | # find_unused_parameters: false 16 | precision: 16 17 | 18 | # max_steps: 50_000 19 | # fast_dev_run: true 20 | max_epochs: 30 21 | accumulate_grad_batches: 12 # effective batch size 1*4(gpu)*12(accumulate) = 48 22 | 23 | callbacks: 24 | - class_path: pytorch_lightning.callbacks.EarlyStopping 25 | init_args: 26 | monitor: bleu/val 27 | mode: max 28 | min_delta: 0 29 | patience: 5 30 | verbose: true 31 | # - class_path: pytorch_lightning.callbacks.StochasticWeightAveraging # Incompatible with EarlyStopping 32 | - class_path: pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor 33 | init_args: 34 | logging_interval: step 35 | 36 | optimizer: 37 | class_path: transformers.optimization.AdamW 38 | init_args: 39 | lr: 0.00005 40 | eps: 1e-8 41 | weight_decay: 0.01 42 | 43 | lr_scheduler: 44 | class_path: torch.optim.lr_scheduler.OneCycleLR 45 | init_args: 46 | max_lr: 0.00005 47 | pct_start: 0.1 48 | div_factor: 1 49 | total_steps: 50 50 | anneal_strategy: linear 51 | 52 | ckpt: 53 | save_top_k: 1 54 | monitor: bleu/val 55 | mode: max 56 | -------------------------------------------------------------------------------- /python/deltr/Macros.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | class Macros: 6 | this_dir: Path = Path(os.path.dirname(os.path.realpath(__file__))) 7 | python_dir: Path = this_dir.parent 8 | project_dir: Path = python_dir.parent 9 | model_dir: Path = project_dir / "models" 10 | data_dir: Path = project_dir / "data" 11 | raw_data_dir: Path = project_dir / "raw_data" 12 | script_dir: Path = project_dir / "scripts" 13 | tacc_log_dir: Path = project_dir / "tacc-logs" 14 | results_dir: Path = project_dir / "results" 15 | model_results_dir: Path = results_dir / "model-results" 16 | log_file: Path = python_dir / "experiments.log" 17 | paper_dir: Path = project_dir / "papers" / "paper" 18 | config_dir: Path = python_dir / "configs" 19 | doc_dir: Path = project_dir / "docs" 20 | gleu_dir: Path = this_dir / "gleu" 21 | 22 | downloads_dir: Path = project_dir / "_downloads" 23 | repos_downloads_dir: Path = downloads_dir / "repos" 24 | repos_results_dir: Path = downloads_dir / "repos_results" 25 | collector_dir: Path = project_dir / "collector" 26 | collector_version = "1.0-SNAPSHOT" 27 | 28 | train: str = "train" 29 | valid: str = "valid" 30 | test: str = "test" 31 | -------------------------------------------------------------------------------- /python/deltr/coditT5/CodeT5.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import ( 3 | RobertaTokenizer, 4 | T5ForConditionalGeneration, 5 | T5EncoderModel, 6 | ) 7 | from typing import List, Tuple, Dict, Optional, Union, Sequence 8 | from jsonargparse.typing import Path_dc, Path_drw 9 | import os 10 | from pathlib import Path 11 | from seutil import LoggingUtils 12 | import torch 13 | import torch.utils.data 14 | import pytorch_lightning as pl 15 | from pytorch_lightning.utilities.cli import ( 16 | LR_SCHEDULER_REGISTRY, 17 | OPTIMIZER_REGISTRY, 18 | instantiate_class, 19 | SaveConfigCallback, 20 | ) 21 | import collections 22 | import numpy as np 23 | 24 | from .utils import ( 25 | DefaultLightningCLI, 26 | ExampleDataset, 27 | PredictDataset, 28 | Prediction, 29 | ) 30 | from deltr.Macros import Macros 31 | from deltr.eval.evaluate import compute_bleu_scores 32 | from deltr.collector.diff_utils import EDIT_TOKENS 33 | 34 | from deltr.coditT5.prediction import PredictionWriter 35 | 36 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) 37 | 38 | MAX_LENGTH = 512 39 | 40 | 41 | class CodeT5DataModule(pl.LightningDataModule): 42 | def __init__( 43 | self, 44 | dataset: str = "java2cs", 45 | model: str = "CodeT5", 46 | infer_data: str = "test", 47 | batch_size: int = 2, 48 | eval_batch_size: int = 8, 49 | ): 50 | """ 51 | :model_outputs: {model_name: {train: Path, test: Path}} 52 | """ 53 | super().__init__() 54 | 55 | pl.seed_everything(42) 56 | self.data_dir = Macros.data_dir / model / dataset 57 | self.dataset = dataset 58 | self.infer_data = infer_data 59 | self.model = model 60 | self.save_hyperparameters() 61 | logger.info(f"Data Module params: \n{self.hparams}") 62 | 63 | def setup(self, stage: Optional[str] = None): 64 | """Load and encode train/valid/test dataset""" 65 | 66 | self.tokenizer = self.trainer.lightning_module.tokenizer 67 | self.stage = stage 68 | if stage == "fit" or stage is None: 69 | # Process training data 70 | train_source_file = self.data_dir / f"train.{self.dataset}.src" 71 | train_target_file = self.data_dir / f"train.{self.dataset}.tgt" 72 | self.train_dataset = ExampleDataset(train_source_file, train_target_file) 73 | 74 | # Process validatoin data 75 | valid_source_file = self.data_dir / f"valid.{self.dataset}.src" 76 | valid_target_file = self.data_dir / f"valid.{self.dataset}.tgt" 77 | self.valid_dataset = ExampleDataset(valid_source_file, valid_target_file) 78 | 79 | if stage == "predict": 80 | test_source_file = self.data_dir / f"{self.infer_data}.{self.dataset}.src" 81 | test_target_file = self.data_dir / f"{self.infer_data}.{self.dataset}.tgt" 82 | logger.info("Start to process prediction data...") 83 | self.test_dataset = PredictDataset(test_source_file, test_target_file) 84 | 85 | if stage == "validate": 86 | valid_source_file = self.data_dir / f"valid.{self.dataset}.src" 87 | valid_target_file = self.data_dir / f"valid.{self.dataset}.tgt" 88 | self.valid_dataset = ExampleDataset(valid_source_file, valid_target_file) 89 | 90 | def tokenizer_collate_fn( 91 | self, batch_data: List[Tuple[str, str]] 92 | ) -> Sequence[torch.Tensor]: 93 | """Customize collate function""" 94 | source_batch = [self.tokenize_sequence(t[0]) for t in batch_data] 95 | target_batch = [self.tokenize_sequence(t[1]) for t in batch_data] 96 | max_length = MAX_LENGTH 97 | batch_size = len(source_batch) 98 | 99 | batched_input_ids, batched_labels_ids = [], [] 100 | for i in range(batch_size): 101 | batched_input_ids.append( 102 | self.tokenizer.encode( 103 | source_batch[i], 104 | max_length=max_length, 105 | truncation=True, 106 | padding="max_length", 107 | ) 108 | ) 109 | batched_labels_ids.append( 110 | self.tokenizer.encode( 111 | target_batch[i], 112 | max_length=max_length, 113 | truncation=True, 114 | padding="max_length", 115 | ) 116 | ) 117 | 118 | return ( 119 | torch.LongTensor(batched_input_ids), 120 | torch.LongTensor(batched_labels_ids), 121 | ) 122 | 123 | def tokenize_collate_fn_predict(self, batch_data: List[Tuple[str, str, int]]): 124 | source_batch = [self.tokenize_sequence(t[0]) for t in batch_data] 125 | target_batch = [self.tokenize_sequence(t[1]) for t in batch_data] 126 | index_batch = [t[2] for t in batch_data] 127 | max_length = MAX_LENGTH 128 | batch_size = len(source_batch) 129 | 130 | ( 131 | batched_input_ids, 132 | batched_labels_ids, 133 | ) = ( 134 | [], 135 | [], 136 | ) 137 | for i in range(batch_size): 138 | batched_input_ids.append( 139 | self.tokenizer.encode( 140 | source_batch[i], 141 | max_length=max_length, 142 | truncation=True, 143 | padding="longest", 144 | ) 145 | ) 146 | batched_labels_ids.append( 147 | self.tokenizer.encode( 148 | target_batch[i], 149 | max_length=max_length, 150 | truncation=True, 151 | padding="longest", 152 | ) 153 | ) 154 | 155 | return ( 156 | torch.LongTensor(batched_input_ids), 157 | torch.LongTensor(batched_labels_ids), 158 | index_batch, 159 | ) 160 | 161 | def tokenize_sequence(self, seq: str) -> List[str]: 162 | """Given string sequence should be able to be split by space.""" 163 | 164 | space_split_tokens = seq.split() 165 | new_subtokens = [] 166 | for token in space_split_tokens: 167 | new_subtokens += self.tokenizer.tokenize(" " + token) 168 | return new_subtokens 169 | 170 | def train_dataloader(self): 171 | return torch.utils.data.DataLoader( 172 | self.train_dataset, 173 | shuffle=True, 174 | batch_size=self.hparams.batch_size, 175 | num_workers=16, 176 | collate_fn=self.tokenizer_collate_fn, 177 | persistent_workers=True, 178 | ) 179 | 180 | def val_dataloader(self): 181 | return torch.utils.data.DataLoader( 182 | self.valid_dataset, 183 | shuffle=False, 184 | batch_size=self.hparams.batch_size, 185 | num_workers=1, 186 | collate_fn=self.tokenizer_collate_fn, 187 | persistent_workers=True, 188 | ) 189 | 190 | def test_dataloader(self): 191 | return torch.utils.data.DataLoader( 192 | self.test_dataset, 193 | shuffle=False, 194 | batch_size=self.hparams.eval_batch_size, 195 | num_workers=0, 196 | collate_fn=self.tokenizer_collate_fn, 197 | ) 198 | 199 | def predict_dataloader(self): 200 | return torch.utils.data.DataLoader( 201 | self.test_dataset, 202 | shuffle=False, 203 | batch_size=self.hparams.eval_batch_size, 204 | num_workers=0, 205 | collate_fn=self.tokenize_collate_fn_predict, 206 | ) 207 | 208 | 209 | class CodeT5Module(pl.LightningModule): 210 | # Instantiate the model 211 | def __init__( 212 | self, 213 | pretrained_tokenizer: Union[Path_drw, str], 214 | pretrained_model: Union[Path_drw, str], 215 | optimizer_init: dict, 216 | lr_scheduler_init: dict, 217 | output_dir=None, 218 | skip_special_token_when_generate: bool = True, 219 | beam_size=5, 220 | num_return_sequences=1, 221 | ): 222 | super(CodeT5Module, self).__init__() 223 | 224 | pl.seed_everything(42) 225 | if isinstance(pretrained_tokenizer, Path_drw): 226 | pretrained_tokenizer = os.path.relpath( 227 | Path(pretrained_tokenizer.abs_path), Path.cwd() 228 | ) 229 | if isinstance(pretrained_model, Path_drw): 230 | pretrained_model = os.path.relpath( 231 | Path(pretrained_model.abs_path), Path.cwd() 232 | ) 233 | 234 | self.save_hyperparameters() 235 | self.beam_size = beam_size 236 | self.num_return_sequences = num_return_sequences 237 | 238 | self.tokenizer = RobertaTokenizer.from_pretrained( 239 | self.hparams.pretrained_tokenizer 240 | ) 241 | 242 | self.model = T5ForConditionalGeneration.from_pretrained( 243 | self.hparams.pretrained_model 244 | ) 245 | self.skip_special_token_when_generate = skip_special_token_when_generate 246 | self.model.resize_token_embeddings(len(self.tokenizer)) 247 | logger.info(f"Model Module params: \n{self.hparams}") 248 | 249 | def forward(self, *args, **kwargs): 250 | return self.model(*args, **kwargs) 251 | 252 | def configure_optimizers(self): 253 | if "weight_decay" in self.hparams.optimizer_init["init_args"]: 254 | no_decay = ["bias", "LayerNorm.weight"] 255 | parameters = [ 256 | { 257 | "params": [ 258 | p 259 | for n, p in self.named_parameters() 260 | if not any(nd in n for nd in no_decay) 261 | ], 262 | "weight_decay": self.hparams.optimizer_init["init_args"][ 263 | "weight_decay" 264 | ], 265 | }, 266 | { 267 | "params": [ 268 | p 269 | for n, p in self.named_parameters() 270 | if any(nd in n for nd in no_decay) 271 | ], 272 | "weight_decay": 0.0, 273 | }, 274 | ] 275 | else: 276 | parameters = self.parameters() 277 | optimizer = instantiate_class(parameters, self.hparams.optimizer_init) 278 | lr_scheduler = instantiate_class(optimizer, self.hparams.lr_scheduler_init) 279 | return { 280 | "optimizer": optimizer, 281 | "lr_scheduler": lr_scheduler, 282 | } 283 | 284 | def training_step(self, batch: List[torch.Tensor], batch_idx=-1): 285 | inputs, labels = batch 286 | attention_masks = ~(inputs == self.tokenizer.pad_token_id) 287 | outputs = self.model( 288 | inputs, labels=labels, attention_mask=attention_masks, return_dict=True 289 | ) 290 | train_loss = outputs.loss 291 | self.log_dict({"loss/train": train_loss.item()}, on_step=True) 292 | 293 | return train_loss 294 | 295 | def validation_step(self, batch: List[torch.Tensor], batch_idx=-1): 296 | inputs, labels = batch 297 | attention_masks = ~(inputs == self.tokenizer.pad_token_id) 298 | batch_size = inputs.shape[0] 299 | outputs = self.model( 300 | inputs, attention_mask=attention_masks, labels=labels, return_dict=True 301 | ) 302 | val_loss = outputs.loss 303 | output_sequences = self.model.generate( 304 | input_ids=inputs, 305 | attention_mask=attention_masks, 306 | num_beams=5, 307 | num_return_sequences=self.num_return_sequences, 308 | max_length=MAX_LENGTH, 309 | ) 310 | pred_sequences = [] 311 | target_sequences = [] 312 | srcs = [] 313 | for input_ids, output_ids, label in zip(inputs, output_sequences, labels): 314 | pred = self.detokenize(output_ids) 315 | if pred == "": 316 | pred = "" 317 | target = self.detokenize(label) 318 | pred_sequences.append(pred) 319 | target_sequences.append(target) 320 | _, bleu_score_list = compute_bleu_scores(target_sequences, pred_sequences) 321 | if self.trainer.datamodule.stage == "validate": 322 | return pred_sequences 323 | metrics_list = {"bleu/val": bleu_score_list} 324 | metrics_list["loss/val"] = [val_loss.item()] * batch_size 325 | 326 | # log the prediction of model 327 | s = "" 328 | for i in range(batch_size): 329 | s += f"# Example {i}\n\n" 330 | s += f"- gold\n```\n{target_sequences[i]}\n```\n\n" 331 | s += f"- pred\n```\n{pred_sequences[i]}\n```\n\n" 332 | s += f"- metrics\n\n" 333 | for k, v in metrics_list.items(): 334 | s += f"{k}: {v[i]}\n" 335 | s += "\n" 336 | 337 | self.logger.experiment.add_text("examples/val", s, global_step=self.global_step) 338 | # self.logger.log_text( 339 | # key="validation", 340 | # columns=["examples/val"], 341 | # data=[[s]], 342 | # step=self.global_step, 343 | # ) 344 | 345 | return metrics_list 346 | 347 | def predict_step(self, batch: List[torch.Tensor], batch_idx=-1): 348 | inputs, labels, indexs = batch 349 | attention_masks = ~(inputs == self.tokenizer.pad_token_id) 350 | batch_size = inputs.shape[0] 351 | pred_sequences = [] 352 | 353 | output_sequences = self.model.generate( 354 | input_ids=inputs, 355 | attention_mask=attention_masks, 356 | num_beams=self.beam_size, 357 | num_return_sequences=self.num_return_sequences, 358 | max_length=MAX_LENGTH, 359 | ) 360 | 361 | for index, output_ids in zip(indexs, output_sequences): 362 | pred = self.tokenizer.convert_tokens_to_string( 363 | self.post_process_edit_sequences( 364 | self.tokenizer.convert_ids_to_tokens( 365 | output_ids, 366 | skip_special_tokens=self.skip_special_token_when_generate, 367 | ) 368 | ) 369 | ) 370 | pred_sequences.append(Prediction(index, pred)) 371 | 372 | return pred_sequences 373 | 374 | def validation_epoch_end(self, outputs: Union[List[Dict], List[List[str]]]): 375 | dataset_name = self.trainer.datamodule.dataset 376 | if self.trainer.datamodule.stage == "validate": 377 | all_valid_preds = [] 378 | for batch_pred in outputs: 379 | all_valid_preds.extend(batch_pred) 380 | output_file = ( 381 | f"valid.{dataset_name}.hyp" 382 | if self.num_return_sequences == 1 383 | else f"valid.{dataset_name}.{self.num_return_sequences}.hyp" 384 | ) 385 | with open(f"{self.hparams.output_dir}/{output_file}", "w") as f: 386 | for pred in all_valid_preds: 387 | f.write(f"{pred}\n") 388 | return 389 | metrics_list = collections.defaultdict(list) 390 | for o in outputs: 391 | for k in o: 392 | metrics_list[k] += o[k] 393 | metrics = summarize_metrics(metrics_list) 394 | self.log_dict(metrics) 395 | 396 | def detokenize(self, output_ids: torch.Tensor) -> str: 397 | pred = ( 398 | self.tokenizer.convert_tokens_to_string( 399 | self.post_process_edit_sequences( 400 | self.tokenizer.convert_ids_to_tokens( 401 | output_ids, 402 | skip_special_tokens=self.skip_special_token_when_generate, 403 | ) 404 | ) 405 | ) 406 | .replace("", "") 407 | .replace("", "") 408 | .replace("", "") 409 | ) 410 | return pred 411 | 412 | def save_pretrained(self, save_dir: Union[str, Path, Path_drw, Path_dc]): 413 | if isinstance(save_dir, (Path_drw, Path_dc)): 414 | save_dir = Path(save_dir.abs_path) 415 | self.model.save_pretrained(save_dir) 416 | self.tokenizer.save_pretrained(save_dir) 417 | 418 | def post_process_edit_sequences(self, token_list: List[str]) -> List[str]: 419 | """Post process token list with edit keywords, manually add space.""" 420 | token_list_after_process = [] 421 | for tk in token_list: 422 | if tk in self.tokenizer.additional_special_tokens or tk in EDIT_TOKENS: 423 | token_list_after_process.append(f"Ġ{tk}Ġ") 424 | else: 425 | token_list_after_process.append(tk) 426 | return token_list_after_process 427 | 428 | 429 | def summarize_metrics( 430 | metrics: Dict[str, Union[float, List[float]]], 431 | ) -> Dict[str, float]: 432 | metrics_summary = {} 433 | for k, v in metrics.items(): 434 | if isinstance(v, list): 435 | metrics_summary[k] = float(np.mean([float(x) for x in v])) 436 | else: 437 | metrics_summary[k] = float(v) 438 | return metrics_summary 439 | 440 | 441 | if __name__ == "__main__": 442 | LoggingUtils.setup(LoggingUtils.INFO, Macros.log_file) 443 | 444 | OPTIMIZER_REGISTRY.register_classes( 445 | transformers.optimization, torch.optim.Optimizer, override=True 446 | ) 447 | LR_SCHEDULER_REGISTRY.register_classes( 448 | transformers.optimization, torch.optim.lr_scheduler._LRScheduler, override=True 449 | ) 450 | 451 | DefaultLightningCLI( 452 | CodeT5Module, 453 | CodeT5DataModule, 454 | save_config_callback=SaveConfigCallback, 455 | prediction_writer=PredictionWriter, 456 | optimizers=[(None, "optimizer", "model.optimizer_init")], 457 | lr_schedulers=[(None, "lr_scheduler", "model.lr_scheduler_init")], 458 | ) 459 | -------------------------------------------------------------------------------- /python/deltr/coditT5/prediction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from pathlib import Path 4 | from jsonargparse.typing import Path_dc, Path_drw 5 | from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union, Any 6 | 7 | import pytorch_lightning as pl 8 | import seutil as su 9 | from pytorch_lightning.callbacks import BasePredictionWriter 10 | 11 | from deltr.eval.evaluate import run_evaluation 12 | 13 | logger = su.LoggingUtils.get_logger(__name__, su.LoggingUtils.DEBUG) 14 | 15 | 16 | class PredictionWriter(BasePredictionWriter): 17 | def __init__( 18 | self, 19 | output_dir: Union[Path, str], 20 | no_compute_metrics: bool = True, 21 | dataset: str = "", 22 | model: str = "", 23 | infer_data: str = "test", 24 | ): 25 | super().__init__(write_interval="epoch") 26 | self.no_compute_metrics = no_compute_metrics 27 | self.output_dir = Path(output_dir) 28 | su.io.mkdir(self.output_dir) 29 | self.temp_dir = self.output_dir / "temp" 30 | su.io.mkdir(self.temp_dir) 31 | self.dataset = dataset 32 | self.model_name = model 33 | self.infer_data = infer_data 34 | 35 | def write_on_epoch_end( 36 | self, 37 | trainer: pl.Trainer, 38 | pl_module: pl.LightningModule, 39 | results: List[List[List[Any]]], 40 | batch_indices: Optional[Sequence[Sequence[Sequence[int]]]], 41 | ): 42 | # Collect preds, and put into a file according to current global rank 43 | 44 | preds: List[str] = [] 45 | for dl_batch_preds in results: 46 | for batch_preds in dl_batch_preds: 47 | if isinstance(batch_preds, list): 48 | for pred in batch_preds: 49 | preds.append(pred) 50 | else: 51 | preds.append(batch_preds) 52 | 53 | su.io.dump( 54 | self.temp_dir / f"{pl_module.global_rank}.pkl", 55 | preds, 56 | ) 57 | 58 | # Wait all processes to finish prediction 59 | trainer.training_type_plugin.barrier("prediction") 60 | 61 | if pl_module.global_rank == 0: 62 | id2pred = {} 63 | for rank in range(trainer.world_size): 64 | for pred in su.io.load(self.temp_dir / f"{rank}.pkl"): 65 | id2pred[pred.id] = pred.data 66 | if sorted(id2pred.keys()) != list(range(len(id2pred))): 67 | logger.warning(f"Prediction ids are not continuous") 68 | preds = [id2pred[i] for i in sorted(id2pred.keys())] 69 | 70 | # Dump predictions 71 | logger.info("Saving predictions") 72 | with open( 73 | self.output_dir / f"{self.infer_data}.{self.dataset}.hyp", "w+" 74 | ) as f: 75 | for pred in preds: 76 | f.write(f"{pred}\n") 77 | 78 | if not self.no_compute_metrics: 79 | # Compute metrics 80 | logger.info("Computing and saving metrics") 81 | 82 | run_evaluation( 83 | dataset=self.dataset, 84 | model=self.model_name, 85 | ) 86 | 87 | # Delete temp directory 88 | su.io.rmdir(self.temp_dir) 89 | -------------------------------------------------------------------------------- /python/deltr/coditT5/save_pretrained.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, Union 3 | from jsonargparse import CLI 4 | from jsonargparse.typing import Path_dc, Path_drw, Path_fr 5 | from seutil import LoggingUtils 6 | 7 | from deltr.Macros import Macros 8 | 9 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) 10 | 11 | 12 | def locate_ckpt(ckpt_dir: Path) -> Optional[Path]: 13 | ckpt_files = list(ckpt_dir.glob("*.ckpt")) 14 | if len(ckpt_files) == 0: 15 | ckpt_file = None 16 | logger.info(f"No checkpoint found in {ckpt_dir}") 17 | elif len(ckpt_files) == 1: 18 | ckpt_file = ckpt_files[0] 19 | logger.info(f"Found one checkpoint in {ckpt_dir}: {ckpt_file.name}") 20 | else: 21 | ckpt_files = [f for f in ckpt_files if f.name != "last.ckpt"] 22 | ckpt_file = sorted(ckpt_files, key=lambda x: x.stat().st_mtime)[-1] 23 | logger.warning( 24 | f"Multiple checkpoints found in {ckpt_dir}: {[x.name for x in ckpt_files]}; picking the latest modified: {ckpt_file.name}" 25 | ) 26 | return ckpt_file 27 | 28 | 29 | def add_tokens_to_tokenizer(): 30 | from transformers import RobertaTokenizer 31 | from deltr.collector.diff_utils import EDIT_TOKENS 32 | 33 | lowercase_edit_tokens = [tk.lower() for tk in EDIT_TOKENS] 34 | tokenizer = RobertaTokenizer.from_pretrained("Salesforce/codet5-base") 35 | special_tokens_dict = { 36 | "additional_special_tokens": EDIT_TOKENS + lowercase_edit_tokens 37 | } 38 | tokenizer.add_special_tokens(special_tokens_dict) 39 | print(f"Size of codeT5 tokenizer is {len(tokenizer)}") 40 | tokenizer.save_pretrained(f"{Macros.model_dir}/EditModelTokenizer") 41 | 42 | 43 | def save_pretrained( 44 | model_cls: str, 45 | ckpt_dir: Path_drw, 46 | ckpt_name: str = None, 47 | output_dir: Optional[Union[Path_drw, Path_dc]] = None, 48 | ): 49 | ckpt_dir = Path_drw(ckpt_dir) 50 | ckpt_dir = Path(ckpt_dir.abs_path) 51 | if ckpt_name: 52 | ckpt_path = ckpt_dir / ckpt_name 53 | else: 54 | ckpt_path = locate_ckpt(ckpt_dir) 55 | if output_dir is not None: 56 | output_dir = Path(output_dir.abs_path) 57 | else: 58 | output_dir = ckpt_dir 59 | if model_cls == "CodeT5": 60 | from deltr.coditT5.CodeT5 import CodeT5Module 61 | 62 | model = CodeT5Module.load_from_checkpoint(ckpt_path) 63 | model.save_pretrained(output_dir) 64 | elif model_cls == "T5Encoder": 65 | from deltr.coditT5.SeqClassifier import SeqClassifierModule 66 | 67 | model = SeqClassifierModule.load_from_checkpoint(ckpt_path) 68 | model.save_pretrained(output_dir) 69 | else: 70 | raise ValueError(f"Unknown model class: {model_cls}") 71 | 72 | 73 | if __name__ == "__main__": 74 | CLI(add_tokens_to_tokenizer, as_positional=False) 75 | -------------------------------------------------------------------------------- /python/deltr/coditT5/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import time 4 | from pathlib import Path 5 | from typing import ( 6 | Any, 7 | Dict, 8 | List, 9 | Optional, 10 | Tuple, 11 | Type, 12 | Union, 13 | ) 14 | from tqdm import tqdm 15 | import torch 16 | import numpy as np 17 | from jsonargparse.typing import Path_dc, Path_drw, Path_dw, Path_fc, Path_fr 18 | from pytorch_lightning.callbacks.base import Callback 19 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 20 | from pytorch_lightning.utilities.cli import ( 21 | LR_SCHEDULER_REGISTRY, 22 | OPTIMIZER_REGISTRY, 23 | LightningArgumentParser, 24 | LightningCLI, 25 | ) 26 | import pytorch_lightning as pl 27 | from recordclass import RecordClass 28 | 29 | 30 | from seutil.LoggingUtils import LoggingUtils 31 | import seutil as su 32 | 33 | 34 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) 35 | 36 | 37 | class DefaultLightningCLI(LightningCLI): 38 | def __init__( 39 | self, 40 | *args, 41 | optimizers: Optional[ 42 | List[Tuple[Optional[Union[Type, List[Type]]], str, str]] 43 | ] = None, 44 | lr_schedulers: Optional[ 45 | List[Tuple[Optional[Union[Type, List[Type]]], str, str]] 46 | ] = None, 47 | prediction_writer: Optional[Callback] = None, 48 | **kwargs, 49 | ): 50 | self.optimizers = optimizers 51 | self.lr_schedulers = lr_schedulers 52 | self.prediction_writer = prediction_writer 53 | kwargs.setdefault("save_config_overwrite", True) 54 | super().__init__(*args, **kwargs) 55 | 56 | def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: 57 | super().add_arguments_to_parser(parser) 58 | parser.add_argument( 59 | "--exp_dir", 60 | required=True, 61 | help="Path to experiment directory", 62 | type=Union[Path_drw, Path_dc], 63 | ) 64 | 65 | parser.add_argument( 66 | "--resume", 67 | required=False, 68 | help="When training, what to do if a checkpoint already exists: unset (default) = error; True = resume; False = remove (all existing checkpoints)", 69 | type=bool, 70 | ) 71 | 72 | parser.add_argument( 73 | "--ckpt_name", 74 | required=False, 75 | help="The checkpoint file name to load (under regular ckpt directory); if unset, the latest checkpoint will be loaded", 76 | type=str, 77 | ) 78 | 79 | parser.add_argument( 80 | "--no_compute_metrics", 81 | required=False, 82 | help="When predicting, do not compute metrics and only collect predictions", 83 | type=bool, 84 | default=True, 85 | ) 86 | 87 | parser.add_argument( 88 | "--no_ckpt_ok", 89 | required=False, 90 | help="When predicting, what to do if no checkpoint exists: False (default) = error; True = predict from scratch", 91 | type=bool, 92 | default=False, 93 | ) 94 | 95 | parser.add_argument( 96 | "--output_dir", 97 | required=False, 98 | help="Path to the output directory during prediction", 99 | type=Path_dc, 100 | ) 101 | 102 | parser.add_lightning_class_args(ModelCheckpoint, "ckpt") 103 | parser.set_defaults( 104 | { 105 | "ckpt.save_last": True, 106 | "ckpt.verbose": True, 107 | } 108 | ) 109 | 110 | if self.optimizers is not None: 111 | for types, nested_key, link_to in self.optimizers: 112 | if types is None: 113 | types = OPTIMIZER_REGISTRY.classes 114 | parser.add_optimizer_args(types, nested_key, link_to) 115 | 116 | if self.lr_schedulers is not None: 117 | for types, nested_key, link_to in self.lr_schedulers: 118 | if types is None: 119 | types = LR_SCHEDULER_REGISTRY.classes 120 | parser.add_lr_scheduler_args(types, nested_key, link_to) 121 | 122 | def before_instantiate_classes(self) -> None: 123 | super().before_instantiate_classes() 124 | config = self.config[self.config["subcommand"]] 125 | # In ddp mode, default disable find_unused_parameters 126 | if config["trainer"]["strategy"] == "ddp": 127 | config["trainer"]["strategy"] = pl.plugins.DDPPlugin( 128 | find_unused_parameters=False, 129 | ) 130 | 131 | # # Don't save config in non-fit mode 132 | if self.config["subcommand"] != "fit": 133 | self.save_config_callback = None 134 | 135 | # Set up experiment directory and logger 136 | exp_dir = Path(config["exp_dir"].abs_path).resolve() 137 | 138 | config["trainer"]["default_root_dir"] = os.path.relpath(exp_dir, Path.cwd()) 139 | ckpt_dir = exp_dir / "model" 140 | su.io.mkdir(ckpt_dir) 141 | config["ckpt"]["dirpath"] = os.path.relpath(ckpt_dir, Path.cwd()) 142 | 143 | # locate checkpoint file 144 | if config["ckpt_path"] is None: 145 | if config["ckpt_name"] is not None: 146 | ckpt_file = ckpt_dir / config["ckpt_name"] 147 | else: 148 | ckpt_file = self.locate_ckpt(ckpt_dir, self.config["subcommand"]) 149 | else: 150 | ckpt_file = Path(os.path.abspath(config["ckpt_path"])).resolve() 151 | 152 | if self.config["subcommand"] == "fit": 153 | # If a checkpoint path is specified, assume we want to resume from it 154 | if config["ckpt_path"] is not None or config["ckpt_name"] is not None: 155 | config.setdefault("resume", True) 156 | 157 | # If there is a checkpoint, we must decide what to do with it 158 | if ckpt_file is not None: 159 | if config["resume"] is None: 160 | raise RuntimeError( 161 | f"A checkpoint is present at {ckpt_file}, but I'm not sure what to do with it. Either set `--resume True` to use it or `--resume False` to overwrite it." 162 | ) 163 | elif config["resume"] is True: 164 | logger.info(f"Resuming from checkpoint {ckpt_file}") 165 | config["ckpt_path"] = str(ckpt_file.resolve()) 166 | else: 167 | logger.info(f"Removing checkpoints under {ckpt_dir}") 168 | su.io.mkdir(ckpt_dir, fresh=True) 169 | config["ckpt_path"] = None 170 | 171 | if ( 172 | self.config["subcommand"] == "predict" 173 | or self.config["subcommand"] == "validate" 174 | or self.config["subcommand"] == "test" 175 | ): 176 | if ( 177 | self.config["subcommand"] == "test" 178 | or self.config["subcommand"] == "validate" 179 | ): 180 | config["trainer"]["gpus"] = 1 181 | config["model"]["output_dir"] = os.path.relpath(exp_dir, Path.cwd()) 182 | if ckpt_file is not None: 183 | config["ckpt_path"] = str(ckpt_file.resolve()) 184 | print("Checkpoint path", config["ckpt_path"]) 185 | else: 186 | if config["no_ckpt_ok"] is False: 187 | raise RuntimeError( 188 | f"No checkpoint found, cannot predict (unless using `--no_ckpt_ok True` to allow predicting from scratch)" 189 | ) 190 | else: 191 | logger.info("No checkpoint found, predicting from scratch") 192 | 193 | if self.prediction_writer is None: 194 | logger.warning( 195 | "No prediction writer specified. " 196 | "Will not write predictions to disk." 197 | ) 198 | elif config["model"]["output_dir"] is None: 199 | logger.warning( 200 | "No output directory specified." 201 | "Will not write predictions to disk." 202 | ) 203 | elif self.config["subcommand"] == "predict": 204 | config["trainer"]["callbacks"].append( 205 | self.prediction_writer( 206 | config["model"]["output_dir"], 207 | config["no_compute_metrics"], 208 | config["data"]["dataset"], 209 | config["data"]["model"], 210 | config["data"]["infer_data"], 211 | ) 212 | ) 213 | 214 | (exp_dir / "logs").mkdir(parents=True, exist_ok=True) 215 | logger_save_dir = exp_dir / "logs" / self.config["subcommand"] 216 | logger_version = datetime.datetime.now().strftime("%y%m%d-%H%M%S") 217 | while (logger_save_dir / logger_version).exists(): 218 | time.sleep(1) 219 | logger_version = datetime.datetime.now().strftime("%y%m%d-%H%M%S") 220 | su.io.mkdir(logger_save_dir) 221 | config["trainer"]["logger"] = { 222 | "class_path": "pytorch_lightning.loggers.tensorboard.TensorBoardLogger", 223 | "init_args": { 224 | "save_dir": os.path.relpath(logger_save_dir, Path.cwd()), 225 | "name": None, 226 | # "project": "delta-translation", 227 | "version": logger_version, 228 | }, 229 | } 230 | 231 | @classmethod 232 | def locate_ckpt(cls, ckpt_dir: Path, mode: str) -> Optional[Path]: 233 | ckpt_files = list(ckpt_dir.glob("*.ckpt")) 234 | if len(ckpt_files) == 0: 235 | ckpt_file = None 236 | logger.info(f"No checkpoint found in {ckpt_dir}") 237 | elif len(ckpt_files) == 1: 238 | ckpt_file = ckpt_files[0] 239 | logger.info(f"Found one checkpoint in {ckpt_dir}: {ckpt_file.name}") 240 | else: 241 | if (ckpt_dir / "last.ckpt").is_file() and mode == "fit": 242 | ckpt_file = ckpt_dir / "last.ckpt" 243 | logger.info( 244 | f"Found the last checkpoint in {ckpt_dir}: {ckpt_file.name}" 245 | ) 246 | else: 247 | for f in ckpt_files: 248 | if f.name == "last.ckpt": 249 | ckpt_files.remove(f) 250 | ckpt_file = sorted(ckpt_files, key=lambda x: x.stat().st_mtime)[-1] 251 | logger.warning( 252 | f"Multiple checkpoints found in {ckpt_dir}: {[x.name for x in ckpt_files]}; picking the latest modified: {ckpt_file.name}" 253 | ) 254 | return ckpt_file 255 | 256 | 257 | class SequenceLabelingDataset(torch.utils.data.Dataset): 258 | "Characterizes a dataset for PyTorch" 259 | 260 | def __init__( 261 | self, 262 | source_file_path: Path, 263 | context_file_path: Path, 264 | label_file_path: Path, 265 | tokenizer: Any, 266 | ): 267 | """Read data from jsonl files.""" 268 | self.source_code = [ 269 | code.strip() 270 | for code in open(source_file_path, "r", encoding="utf-8").readlines() 271 | ] 272 | self.context = [ 273 | ctx.strip() 274 | for ctx in open(context_file_path, "r", encoding="utf-8").readlines() 275 | ] 276 | self.labels = [ 277 | [int(label) for label in lb.strip().split()] 278 | for lb in open(label_file_path, "r", encoding="utf-8").readlines() 279 | ] 280 | self.tokenized_labels = tokenize_and_align_labels( 281 | self.source_code, self.labels, tokenizer 282 | ) 283 | 284 | def __len__(self): 285 | return len(self.source_code) 286 | 287 | def __getitem__(self, index: int): 288 | return { 289 | "code": self.source_code[index], 290 | "context": self.context[index], 291 | "labels": self.tokenized_labels[index], 292 | } 293 | 294 | 295 | class SequenceLabelingChunkDataset(torch.utils.data.Dataset): 296 | """Dataset for sequence labeling and chunk the data""" 297 | 298 | def __init__( 299 | self, 300 | source_file_path: Path, 301 | context_file_path: Path, 302 | label_file_path: Path, 303 | tokenizer: Any, 304 | ): 305 | """Read data from jsonl files.""" 306 | 307 | self.JAVA_CHUNK_LEN = 240 308 | self.CS_CHUNK_LEN = 255 309 | self.tokenizer = tokenizer 310 | source_code = [ 311 | code.strip() 312 | for code in open(source_file_path, "r", encoding="utf-8").readlines() 313 | ] 314 | context = [ 315 | ctx.strip() 316 | for ctx in open(context_file_path, "r", encoding="utf-8").readlines() 317 | ] 318 | labels = [ 319 | [int(label) for label in lb.strip().split()] 320 | for lb in open(label_file_path, "r", encoding="utf-8").readlines() 321 | ] 322 | tokenized_labels = tokenize_and_align_labels(source_code, labels, tokenizer) 323 | self.__split_data_to_chunks__(source_code, context, tokenized_labels) 324 | 325 | def __len__(self): 326 | return len(self.tokenized_code_input) 327 | 328 | def __split_data_to_chunks__(self, source_code, context, tokenized_labels): 329 | """Split examples into chunks if too long.""" 330 | 331 | self.tokenized_code_input = [] 332 | self.tokenized_context_input = [] 333 | self.data_index = [] 334 | self.labels = [] 335 | too_long_context = 0 336 | 337 | for index in tqdm(range(len(source_code)), total=len(source_code)): 338 | tokenized_code = self.tokenizer.tokenize(source_code[index]) 339 | tokenized_context = self.tokenizer.tokenize(context[index]) 340 | tokenized_label = tokenized_labels[index] 341 | assert len(tokenized_code) == len(tokenized_labels[index]) 342 | if ( 343 | len(tokenized_code) + len(tokenized_context) + 1 344 | > self.tokenizer.model_max_length 345 | ): 346 | # context_length = min(self.MAX_CTX_LEN, len(tokenized_context)) 347 | # if context_length == self.MAX_CTX_LEN: 348 | too_long_context += 1 349 | # start to cut 350 | code_start_id, code_end_id = 0, 0 351 | context_start_id, context_end_id = 0, 0 352 | while code_start_id < len(tokenized_code): 353 | code_end_id = self.CS_CHUNK_LEN + code_start_id 354 | context_end_id = self.JAVA_CHUNK_LEN + context_start_id 355 | self.tokenized_code_input.append( 356 | tokenized_code[code_start_id:code_end_id] 357 | ) 358 | self.tokenized_context_input.append( 359 | tokenized_context[context_start_id:context_end_id] 360 | ) 361 | self.labels.append(tokenized_label[code_start_id:code_end_id]) 362 | self.data_index.append(index) 363 | code_start_id = code_end_id 364 | context_start_id = context_end_id 365 | 366 | else: 367 | self.tokenized_code_input.append(tokenized_code) 368 | self.tokenized_context_input.append(tokenized_context) 369 | self.data_index.append(index) 370 | self.labels.append(tokenized_label) 371 | 372 | return 373 | 374 | def __getitem__(self, index: int): 375 | return { 376 | "code": self.tokenized_code_input[index], 377 | "context": self.tokenized_context_input[index], 378 | "labels": self.labels[index], 379 | "index": self.data_index[index], 380 | } 381 | 382 | 383 | def tokenize_and_align_labels( 384 | source_code: List[str], labels: List[int], tokenizer: Any 385 | ) -> List[List[int]]: 386 | tokenized_labels = [] 387 | 388 | for code, label in zip(source_code, labels): 389 | tokenized_inputs = tokenizer( 390 | code.split(), is_split_into_words=True, add_special_tokens=False 391 | ) 392 | word_ids = tokenized_inputs.word_ids() 393 | previous_word_idx = None 394 | label_ids = [] 395 | for word_idx in word_ids: 396 | # Special tokens have a word id that is None. We set the label to -100 so they are automatically 397 | # ignored in the loss function. 398 | if word_idx is None: 399 | label_ids.append(-100) 400 | # We set the label for the first token of each word. 401 | elif word_idx != previous_word_idx: 402 | label_ids.append(label[word_idx]) 403 | # For the other tokens in a word, we set the label to either the current label or -100, depending on 404 | # the label_all_tokens flag. 405 | else: 406 | label_ids.append(-100) 407 | previous_word_idx = word_idx 408 | 409 | tokenized_labels.append(label_ids) 410 | 411 | return tokenized_labels 412 | 413 | 414 | class ExampleDataset(torch.utils.data.Dataset): 415 | def __init__(self, source_file_path: Path, target_file_path: Path): 416 | self.source_file_path = source_file_path 417 | self.target_file_path = target_file_path 418 | self.source_offset = [] 419 | self.target_offset = [] 420 | self.n_data = 0 421 | 422 | with open(source_file_path, "rb") as fp: 423 | self.source_offset = [0] 424 | while fp.readline(): 425 | self.source_offset.append(fp.tell()) 426 | self.source_offset = self.source_offset[:-1] 427 | 428 | with open(target_file_path, "rb") as fp: 429 | self.target_offset = [0] 430 | while fp.readline(): 431 | self.target_offset.append(fp.tell()) 432 | self.target_offset = self.target_offset[:-1] 433 | 434 | assert len(self.target_offset) == len(self.source_offset) 435 | 436 | self.n_data = len(self.target_offset) 437 | 438 | def __len__(self) -> int: 439 | return self.n_data 440 | 441 | def __getitem__(self, index: int) -> Tuple: 442 | if index < 0: 443 | index = self.n_data + index 444 | 445 | with open(self.source_file_path, "r", errors="replace") as sf, open( 446 | self.target_file_path, "r", errors="replace" 447 | ) as tf: 448 | sf.seek(self.source_offset[index]) 449 | source_line = sf.readline() 450 | tf.seek(self.target_offset[index]) 451 | target_line = tf.readline() 452 | 453 | return (source_line.strip(), target_line.strip()) 454 | 455 | 456 | class Prediction(RecordClass): 457 | """Prediction at one data""" 458 | 459 | id: int = -1 460 | data: str = "" 461 | 462 | 463 | class PredictDataset(torch.utils.data.Dataset): 464 | def __init__(self, source_file_path: Path, target_file_path: Path): 465 | self.source_file_path = source_file_path 466 | self.target_file_path = target_file_path 467 | self.source_offset = [] 468 | self.target_offset = [] 469 | self.n_data = 0 470 | 471 | with open(source_file_path, "rb") as fp: 472 | self.source_offset = [0] 473 | while fp.readline(): 474 | self.source_offset.append(fp.tell()) 475 | self.source_offset = self.source_offset[:-1] 476 | 477 | with open(target_file_path, "rb") as fp: 478 | self.target_offset = [0] 479 | while fp.readline(): 480 | self.target_offset.append(fp.tell()) 481 | self.target_offset = self.target_offset[:-1] 482 | 483 | assert len(self.target_offset) == len(self.source_offset) 484 | 485 | self.n_data = len(self.target_offset) 486 | 487 | def __len__(self) -> int: 488 | return self.n_data 489 | 490 | def __getitem__(self, index: int) -> Tuple: 491 | if index < 0: 492 | index = self.n_data + index 493 | 494 | with open(self.source_file_path, "r", errors="replace") as sf, open( 495 | self.target_file_path, "r", errors="replace" 496 | ) as tf: 497 | sf.seek(self.source_offset[index]) 498 | source_line = sf.readline() 499 | tf.seek(self.target_offset[index]) 500 | target_line = tf.readline() 501 | return (source_line.strip(), target_line.strip(), index) 502 | -------------------------------------------------------------------------------- /python/deltr/collector/DataCollector.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | import re 4 | from typing import * 5 | from pathlib import Path 6 | import sys 7 | from tqdm import tqdm 8 | from jsonargparse import CLI 9 | from seutil import ( 10 | LoggingUtils, 11 | IOUtils, 12 | BashUtils, 13 | TimeUtils, 14 | io, 15 | TimeoutException, 16 | bash, 17 | ) 18 | import difflib 19 | import math 20 | 21 | from deltr.Environment import Environment 22 | from deltr.collector.ProjectData import ProjectData 23 | from deltr.Macros import Macros 24 | 25 | 26 | projects_map = { 27 | "antlr_antlr4": "tunnelvisionlabs_antlr4cs", 28 | "apache_lucene": "apache_lucenenet", 29 | "apache_poi": "nissl-lab_npoi", 30 | "itext_itext7": "itext_itext7-dotnet", 31 | "formicary_fpml-toolkit-java": "formicary_fpml-toolkit-csharp", 32 | "eclipse_jgit": "mono_ngit", 33 | "quartz-scheduler_quartz": "quartznet_quartznet", 34 | "terabyte_jgit": "mono_ngit", 35 | # "locationtech_jts": "NetTopologySuite_NetTopologySuite", 36 | } 37 | 38 | cs_port_date = { 39 | "tunnelvisionlabs_antlr4cs": "Feb 16, 2013", 40 | "apache_lucenenet": "Nov 21, 2005", 41 | "nissl-lab_npoi": "May 8, 2011", 42 | "mono_ngit": "Oct 7, 2010", 43 | "itext_itext7-dotnet": "Apr 8, 2016", 44 | "NetTopologySuite_NetTopologySuite": "Jul 29, 2006", 45 | "formicary_fpml-toolkit-csharp": "July 11, 2006", 46 | "apache_logging-log4net": "Jan 24, 2004", 47 | "nhibernate_nhibernate-core": "Feb 18, 2003", 48 | "nant_nant": "Aug 11, 2001", 49 | "quartznet_quartznet": "Dec 8, 2006", 50 | "nHapiNET_nHapi": "Mar 8, 2014", 51 | } 52 | 53 | project_branch = { 54 | "antlr_antlr4": "master", 55 | "tunnelvisionlabs_antlr4cs": "master", 56 | "apache_lucene": "main", 57 | "apache_lucenenet": "master", 58 | "apache_poi": "trunk", 59 | "nissl-lab_npoi": "master", 60 | "eclipse_jgit": "master", 61 | "mono_ngit": "master", 62 | "itext_itext7": "develop", # develop? 63 | "itext_itext7-dotnet": "develop", 64 | "locationtech_jts": "master", 65 | "NetTopologySuite_NetTopologySuite": "develop", 66 | "formicary_fpml-toolkit-java": "master", 67 | "formicary_fpml-toolkit-csharp": "master", 68 | "terabyte_jgit": "master", 69 | "apache_logging-log4j": "release-2.x", 70 | "apache_logging-log4net": "master", 71 | "nhibernate_nhibernate-core": "master", 72 | "hibernate_hibernate-orm": "main", 73 | "apache_ant": "master", 74 | "nant_nant": "master", 75 | "quartz-scheduler_quartz": "master", 76 | "quartznet_quartznet": "main", 77 | "hapifhir_hapi-hl7v2": "master", 78 | "nHapiNET_nHapi": "master", 79 | } 80 | 81 | 82 | class DataCollector: 83 | logger = LoggingUtils.get_logger( 84 | __name__, LoggingUtils.DEBUG if Environment.is_debug else LoggingUtils.INFO 85 | ) 86 | 87 | def __init__(self): 88 | self.repos_downloads_dir: Path = Macros.repos_downloads_dir 89 | self.repos_results_dir: Path = Macros.repos_results_dir 90 | self.repos_dir: Path = Macros.project_dir / "repos" 91 | self.results_dir = Macros.results_dir / "repo-data" 92 | self.raw_data_dir = Macros.data_dir / "raw" 93 | self.collected_projects_list = [] 94 | self.collected_projects_file = ( 95 | Macros.results_dir / "collected-github-java-repos.json" 96 | ) 97 | if self.collected_projects_file.exists(): 98 | self.collected_projects_list = IOUtils.load(self.collected_projects_file) 99 | else: 100 | self.collected_projects_list = [] 101 | io.mkdir(self.repos_downloads_dir) 102 | io.mkdir(self.repos_results_dir) 103 | self.token_sim_threshold = 0.4 104 | self.overlap_token_sim_threshold = 0.6 105 | self.line_sim_threshold = 0.5 106 | 107 | return 108 | 109 | def collect_java_method_diff_for_data(self): 110 | """Collect java historical method changes for dataset augmentation.""" 111 | java_projects = projects_map.keys() 112 | total_collected_data = [] 113 | 114 | for java_prj in tqdm(java_projects, total=len(java_projects)): 115 | java_diffs = self.collect_java_changed_history(java_prj) 116 | total_collected_data.extend(java_diffs) 117 | 118 | io.dump( 119 | Macros.data_dir / "raw" / "java-only-history-diff.jsonl", 120 | total_collected_data, 121 | io.Fmt.jsonList, 122 | ) 123 | 124 | def collect_java_project_data(self): 125 | """Collect Java project methods data.""" 126 | 127 | java_projects_file = self.repos_dir / "java-projects.json" 128 | java_projects = io.load(java_projects_file) 129 | 130 | for prj_name, prj_url in java_projects.items(): 131 | if prj_name in self.collected_projects_list: 132 | continue 133 | self.logger.info(f"Start to collect methods from project {prj_name}") 134 | self.collect_method_data( 135 | prj_url, prj_name 136 | ) # change the code if SHA is specified 137 | self.collected_projects_list.append(prj_name) 138 | # end for 139 | 140 | io.dump(self.collected_projects_file, self.collected_projects_list) 141 | 142 | def collect_csharp_project_data(self): 143 | """Collect C# project methods data.""" 144 | csharp_projects_file = self.repos_dir / "csharp-projects.json" 145 | csharp_projects = io.load(csharp_projects_file) 146 | 147 | for prj_name, prj_url in csharp_projects.items(): 148 | if prj_name in self.collected_projects_list: 149 | continue 150 | self.logger.info(f"Start to collect methods from project {prj_name}") 151 | self.collect_method_data( 152 | prj_url, prj_name, lang="csharp" 153 | ) # change the code if SHA is specified 154 | self.collected_projects_list.append(prj_name) 155 | # end for 156 | 157 | io.dump(self.collected_projects_file, self.collected_projects_list) 158 | 159 | def build_translation_augment_data(self): 160 | """Aggregate translation data for augmentation.""" 161 | 162 | method_map = io.load( 163 | Macros.results_dir / "stats" / "stats-method-mapping-augment.json" 164 | ) 165 | augment_data = [] 166 | 167 | for j_prj in method_map: 168 | c_prj = projects_map[j_prj] 169 | java_methods = io.load( 170 | self.repos_results_dir / j_prj / "collector" / "java-method-data.json" 171 | ) 172 | cs_methods = io.load( 173 | self.repos_results_dir / c_prj / "collector" / "csharp-method-data.json" 174 | ) 175 | for j_id, c_id in method_map[j_prj]["map"].items(): 176 | augment_data.append( 177 | { 178 | "project": j_prj, 179 | "java-SHA": method_map[j_prj]["java-SHA"], 180 | "java-old": "", 181 | "java-new": java_methods[int(j_id)], 182 | "cs-SHA": method_map[j_prj]["cs-SHA"], 183 | "cs-old": "", 184 | "cs-new": cs_methods[int(c_id)], 185 | } 186 | ) 187 | 188 | io.dump( 189 | Macros.data_dir / "raw" / "translation-augment-data.jsonl", 190 | augment_data, 191 | io.Fmt.jsonList, 192 | ) 193 | 194 | def collect_project_commit_history(self): 195 | """Collect projects commit history with interested files.""" 196 | 197 | for java_project, cs_project in projects_map.items(): 198 | if cs_project == "antlr_antlr4-cs": 199 | cs_project = "antlr_antlr4" 200 | # 1. mine history from java project 201 | self.logger.info(f"Start to collect commit history of {java_project}") 202 | self.collect_project_history(java_project, lang="java") 203 | # 2. mine history from csharp project 204 | self.logger.info(f"Start to collect commit history of {cs_project}") 205 | self.collect_project_history(cs_project, lang="cs") 206 | # end for 207 | 208 | def check_miss_method(self, project_name: str, align_methods: dict): 209 | """Check miss method.""" 210 | 211 | map = io.load(Macros.results_dir / f"{project_name}-method-hash-map.json") 212 | for align_method in align_methods: 213 | align_class_name = ".".join(align_method.split(".")[:-1]) 214 | # print(align_class_name) 215 | if align_class_name not in map: 216 | self.logger.info(f"Miss method: {align_class_name}") 217 | 218 | def augment_translation_data(self): 219 | """Augment training data by collecting pure translation data.""" 220 | 221 | augment_data_stats = {} 222 | for java_project in projects_map: 223 | augment_data_stats[ 224 | java_project 225 | ] = self.build_java_csharp_method_map_for_project(java_project) 226 | io.dump( 227 | Macros.results_dir / "stats" / "stats-method-mapping-augment.json", 228 | augment_data_stats, 229 | io.Fmt.jsonPretty, 230 | ) 231 | 232 | def build_java_csharp_method_map_for_project(self, java_project: str): 233 | """Build java and c# aligned methods map for a particular project.""" 234 | 235 | dataset_test_date = io.load( 236 | Macros.results_dir / "stats" / "stats-data-split-date.json" 237 | ) 238 | j_prj, c_prj = java_project, projects_map[java_project] 239 | 240 | # first find the correct commit 241 | java_test_date = dataset_test_date[java_project]["valid"]["java"] 242 | cs_test_date = dataset_test_date[java_project]["valid"]["cs"] 243 | earlier_date = java_test_date if java_test_date < cs_test_date else cs_test_date 244 | self.logger.info(f"Earlier date is {earlier_date} for {java_project}") 245 | # mine methods in projects 246 | with io.cd(Macros.repos_downloads_dir / java_project): 247 | bash.run("git checkout $(git branch --show-current)") 248 | java_sha = bash.run( 249 | f"git log --until='{earlier_date}' --first-parent --no-merges --pretty=format:'%H'" 250 | ).stdout.split("\n")[1][:8] 251 | self.logger.info(f"Mining java SHA {java_sha}") 252 | self.collect_method_data("", java_project, java_sha, "java") 253 | with io.cd(Macros.repos_downloads_dir / c_prj): 254 | bash.run("git checkout $(git branch --show-current)") 255 | cs_sha = bash.run( 256 | f"git log --until='{earlier_date}' --first-parent --no-merges --pretty=format:'%H'" 257 | ).stdout.split("\n")[1][:8] 258 | self.logger.info(f"Mining c# SHA {cs_sha}") 259 | self.collect_method_data("", c_prj, cs_sha, "cs") 260 | 261 | java_method_file = ( 262 | self.repos_results_dir / j_prj / "collector" / f"java-method-data.json" 263 | ) 264 | csharp_method_file = ( 265 | self.repos_results_dir / c_prj / "collector" / f"csharp-method-data.json" 266 | ) 267 | java_methods = io.load(java_method_file) 268 | csharp_methods = io.load(csharp_method_file) 269 | j2c_map = self.map_java_cs_methods( 270 | java_methods=java_methods, csharp_methods=csharp_methods 271 | ) 272 | 273 | self.logger.info(f"Size of project {j_prj} map is {len(j2c_map)}.") 274 | 275 | results_dict = { 276 | "java-SHA": java_sha, 277 | "cs-SHA": cs_sha, 278 | "map": j2c_map, 279 | } 280 | 281 | return results_dict 282 | 283 | def map_java_cs_methods(self, java_methods: List[dict], csharp_methods: List[dict]): 284 | """Find java and c# method map, return the method id dict.""" 285 | j2c_map = {} 286 | for j_m in tqdm(java_methods, total=(len(java_methods))): 287 | for c_m in csharp_methods: 288 | if ( 289 | c_m["name"].lower() == j_m["name"].lower() 290 | and c_m["path"].split("/")[-1].split(".")[0].lower() 291 | == j_m["path"].split("/")[-1].split(".")[0].lower() 292 | and [p.lower() for p_list in c_m["params"] for p in p_list] 293 | == [p.lower() for p_list in j_m["params"] for p in p_list] 294 | ): 295 | if ( 296 | c_m["class_name"] is not None 297 | and j_m["class_name"] is not None 298 | and c_m["class_name"].lower() != j_m["class_name"].lower() 299 | ): 300 | continue 301 | if str(j_m["id"]) not in j2c_map: 302 | j2c_map[str(j_m["id"])] = str(c_m["id"]) 303 | else: 304 | self.logger.info("Find duplicate mapping, ignore this one.") 305 | break 306 | # end if 307 | return j2c_map 308 | 309 | def build_java_csharp_method_map(self): 310 | """Build java and c# aligned methods map.""" 311 | 312 | method_hash_list = set() 313 | if (Macros.results_dir / "java-csharp-method-map.json").exists(): 314 | j2c_prj_map = io.load(Macros.results_dir / "java-csharp-method-map.json") 315 | else: 316 | j2c_prj_map = {} 317 | # j2c_prj_map.pop("apache_poi-nissl-lab_npoi") 318 | for j_prj, c_prj in [("apache_logging-log4j", "apache_logging-log4net")]: 319 | duplicate_count = 0 320 | # if f"{j_prj}-{c_prj}" in j2c_prj_map: 321 | # continue 322 | j2c_map = {} 323 | java_method_file = ( 324 | self.repos_results_dir / j_prj / "collector" / "java-method-data.json" 325 | ) 326 | csharp_method_file = ( 327 | self.repos_results_dir / c_prj / "collector" / "csharp-method-data.json" 328 | ) 329 | java_methods = io.load(java_method_file) 330 | csharp_methods = io.load(csharp_method_file) 331 | 332 | # rules: 1. same function names 333 | # 2. same parameters (types and names) 334 | # 3. same file name 335 | # 4. same class name 336 | for j_m in tqdm(java_methods, total=(len(java_methods))): 337 | for c_m in csharp_methods: 338 | if ( 339 | c_m["name"].lower() == j_m["name"].lower() 340 | and c_m["path"].split("/")[-1].split(".")[0].lower() 341 | == j_m["path"].split("/")[-1].split(".")[0].lower() 342 | # and [p.lower() for p_list in c_m["params"] for p in p_list] 343 | # == [p.lower() for p_list in j_m["params"] for p in p_list] 344 | ): 345 | # if ( 346 | # c_m["class_name"] is not None 347 | # and j_m["class_name"] is not None 348 | # and c_m["class_name"].lower() != j_m["class_name"].lower() 349 | # ): 350 | # continue 351 | if j_m["id"] not in j2c_map: 352 | j2c_map[j_m["id"]] = str(c_m["id"]) 353 | java_method_hash = ( 354 | j_m["path"].split("/")[-1].replace(".java", "").lower() 355 | + f".{j_m['class_name'].lower()}" 356 | + f".{j_m['name'].lower()}" 357 | ) 358 | cs_method_hash = ( 359 | c_m["path"].split("/")[-1].replace(".cs", "").lower() 360 | + f".{c_m['class_name'].lower()}" 361 | ) 362 | 363 | method_hash_list.add(cs_method_hash) 364 | else: 365 | self.logger.info("Find duplicate mapping, ignore this one.") 366 | duplicate_count += 1 367 | break 368 | # end if 369 | 370 | self.logger.info( 371 | f"Size of project {j_prj} map is {len(j2c_map)}. Duplicate cases are ignored: {duplicate_count}" 372 | ) 373 | j2c_prj_map[f"{j_prj}-{c_prj}"] = j2c_map 374 | io.dump( 375 | Macros.results_dir / f"{j_prj}-method-hash-map.json", method_hash_list 376 | ) 377 | io.dump(Macros.results_dir / "java-csharp-method-map.json", j2c_prj_map) 378 | 379 | def build_java_csharp_file_map(self): 380 | """Build map between Java and C# files.""" 381 | 382 | method_maps = io.load(Macros.results_dir / "java-csharp-method-map.json") 383 | if (Macros.results_dir / "java-csharp-file-map.json").exists(): 384 | prj_file_maps = io.load(Macros.results_dir / "java-csharp-file-map.json") 385 | else: 386 | prj_file_maps = {} 387 | if (Macros.results_dir / "java-csharp-mapped-files.json").exists(): 388 | prj_mapped_files = io.load( 389 | Macros.results_dir / "java-csharp-mapped-files.json" 390 | ) 391 | prj_mapped_files = defaultdict(list, prj_mapped_files) 392 | else: 393 | prj_mapped_files = defaultdict(list) 394 | 395 | for j_prj, cs_prj in projects_map.items(): 396 | if f"{j_prj}-{cs_prj}" in prj_file_maps: 397 | continue 398 | prj_file_maps[f"{j_prj}-{cs_prj}"] = {} 399 | method_id_map = method_maps[f"{j_prj}-{cs_prj}"] 400 | java_methods = io.load( 401 | self.repos_results_dir / j_prj / "collector" / "java-method-data.json" 402 | ) 403 | csharp_methods = io.load( 404 | self.repos_results_dir 405 | / cs_prj 406 | / "collector" 407 | / "csharp-method-data.json" 408 | ) 409 | for j_mid, c_mid in method_id_map.items(): 410 | j_m_file = java_methods[int(j_mid)]["path"] 411 | c_m_file = csharp_methods[int(c_mid)]["path"] 412 | prj_file_maps[f"{j_prj}-{cs_prj}"][j_m_file] = c_m_file 413 | prj_mapped_files[j_prj].append(j_m_file) 414 | prj_mapped_files[cs_prj].append(c_m_file) 415 | # remove duplicate file name 416 | prj_mapped_files[j_prj] = list(set(prj_mapped_files[j_prj])) 417 | prj_mapped_files[cs_prj] = list(set(prj_mapped_files[cs_prj])) 418 | # end for 419 | io.dump(Macros.results_dir / "java-csharp-file-map.json", prj_file_maps) 420 | io.dump(Macros.results_dir / "java-csharp-mapped-files.json", prj_mapped_files) 421 | 422 | def remove_comments_from_csharp(self): 423 | """Remove comments from csharp code.""" 424 | 425 | line_comment_pattern = r"//(.*?)\n" 426 | block_comment_pattern = r"/\*(.*?)\*/" 427 | for prj in projects_map.values(): 428 | csharp_method_file = ( 429 | self.repos_results_dir / prj / "collector" / "csharp-method-data.json" 430 | ) 431 | csharp_methods = io.load(csharp_method_file) 432 | for m in csharp_methods: 433 | if "//" in m["code"]: 434 | m["code"] = re.sub(line_comment_pattern, "", m["code"]) 435 | if "/*" in m["code"]: 436 | m["code"] = re.sub(block_comment_pattern, "", m["code"]) 437 | # end for 438 | io.dump(csharp_method_file, csharp_methods, io.Fmt.jsonPretty) 439 | 440 | # -- Helper functions ----------------------------------------------------- 441 | 442 | def sample_git_history(self): 443 | """Sample git history of methods for manually checking.""" 444 | K = 5 445 | 446 | history_to_check = defaultdict(list) 447 | for j_prj in tqdm(projects_map, total=len(projects_map)): 448 | aligned_method_history = io.load( 449 | self.results_dir / f"{j_prj}-method-aligned-history.json" 450 | ) 451 | total_size = len(aligned_method_history) 452 | sample_ids = random.choices(range(total_size), k=K) 453 | for i, dt in enumerate(aligned_method_history): 454 | if i in sample_ids: 455 | history_to_check[j_prj].append((dt, aligned_method_history[dt])) 456 | io.dump( 457 | Macros.data_dir / "raw" / "manually-check-history.json", 458 | history_to_check, 459 | io.Fmt.jsonPretty, 460 | ) 461 | 462 | def collect_commit_date(self): 463 | """Collect commit date for each data""" 464 | 465 | data_list = io.load(Macros.data_dir / "raw" / "delta-translation-dataset.jsonl") 466 | for dt in tqdm(data_list, total=len(data_list)): 467 | # add java commit date 468 | sha = dt["java-SHA"].split("-")[1] 469 | prj = dt["project"] 470 | branch_name = project_branch[prj] 471 | with io.cd(Macros.repos_downloads_dir / prj): 472 | bash.run(f"git checkout {branch_name} -f") 473 | commit_date = bash.run( 474 | f"git show -s --format=%cd --date=format:'%Y-%m-%d %H:%M:%S' {sha}", 475 | check_returncode=0, 476 | ).stdout.strip() 477 | dt["java-commit-date"] = commit_date 478 | 479 | # add csharp commit date 480 | sha = dt["cs-SHA"].split("-")[1] 481 | prj = projects_map[dt["project"]] 482 | branch_name = project_branch[prj] 483 | with io.cd(Macros.repos_downloads_dir / prj): 484 | bash.run(f"git checkout {branch_name} -f") 485 | commit_date = bash.run( 486 | f"git show -s --format=%cd --date=format:'%Y-%m-%d %H:%M:%S' {sha}", 487 | check_returncode=0, 488 | ).stdout.strip() 489 | dt["cs-commit-date"] = commit_date 490 | # end for 491 | io.dump( 492 | Macros.data_dir / "delta-translation-dataset-w-date.jsonl", 493 | data_list, 494 | io.Fmt.jsonList, 495 | ) 496 | 497 | def compare_overlap_similarity( 498 | self, 499 | java_diff: dict, 500 | cs_diff: dict, 501 | ) -> Tuple[float, float]: 502 | """Algorim to find the aligned pair""" 503 | 504 | add_java_tks, del_java_tks = java_diff["add-tokens"], java_diff["del-tokens"] 505 | add_cs_tks, del_cs_tks = cs_diff["add-tokens"], cs_diff["del-tokens"] 506 | # 1. token level similarity 507 | if len(add_java_tks) == 0 and len(add_cs_tks) == 0: 508 | add_tokens_similarity = self.overlap_token_sim_threshold 509 | else: 510 | if len(add_java_tks) == 0 or len(add_cs_tks) == 0: 511 | add_tokens_similarity = 0.0 512 | else: 513 | add_tokens_similarity = self.compute_diff_similarity( 514 | add_java_tks, add_cs_tks, task="inclusion" 515 | ) 516 | if len(del_java_tks) == 0 and len(del_cs_tks) == 0: 517 | del_tokens_similarity = self.overlap_token_sim_threshold 518 | else: 519 | if len(del_java_tks) == 0 or len(del_cs_tks) == 0: 520 | del_tokens_similarity = 0.0 521 | else: 522 | del_tokens_similarity = self.compute_diff_similarity( 523 | del_java_tks, del_cs_tks, task="inclusion" 524 | ) 525 | tokens_sim = add_tokens_similarity * (0.5) + del_tokens_similarity * (0.5) 526 | 527 | # 2. line level similarity 528 | add_line_similarity = self.compute_diff_similarity( 529 | java_diff["add-code"], cs_diff["add-code"], task="inclusion" 530 | ) 531 | del_line_similarity = self.compute_diff_similarity( 532 | java_diff["del-code"], cs_diff["del-code"], task="inclusion" 533 | ) 534 | line_sim = add_line_similarity * 0.5 + del_line_similarity * 0.5 535 | 536 | return tokens_sim, line_sim 537 | 538 | def compare_diff_similarity( 539 | self, 540 | java_diff: dict, 541 | cs_diff: dict, 542 | ) -> Tuple[float, float]: 543 | """Algorim to find the aligned pair""" 544 | 545 | add_java_tks, del_java_tks = java_diff["add-tokens"], java_diff["del-tokens"] 546 | add_cs_tks, del_cs_tks = cs_diff["add-tokens"], cs_diff["del-tokens"] 547 | # 1. token level similarity 548 | if len(add_java_tks) == 0 and len(add_cs_tks) == 0: 549 | add_tokens_similarity = 0.4 550 | else: 551 | add_tokens_similarity = self.compute_diff_similarity( 552 | add_java_tks, 553 | add_cs_tks, 554 | ) 555 | if len(del_java_tks) == 0 and len(del_cs_tks) == 0: 556 | del_tokens_similarity = 0.4 557 | else: 558 | del_tokens_similarity = self.compute_diff_similarity( 559 | del_java_tks, del_cs_tks 560 | ) 561 | tokens_sim = add_tokens_similarity * (0.5) + del_tokens_similarity * (0.5) 562 | # 2. line level similarity 563 | add_line_similarity = self.compute_diff_similarity( 564 | java_diff["add-code"], cs_diff["add-code"] 565 | ) 566 | del_line_similarity = self.compute_diff_similarity( 567 | java_diff["del-code"], cs_diff["del-code"] 568 | ) 569 | line_sim = add_line_similarity * 0.5 + del_line_similarity * 0.5 570 | 571 | return tokens_sim, line_sim 572 | 573 | def collect_project_history(self, project_name: str, lang: str): 574 | """Collect git history of projct and extract the changed files that have aligned methods. 575 | 576 | Args: 577 | project_name (str): project name 578 | lang (str): java or cs 579 | """ 580 | 581 | downloads_dir = self.repos_downloads_dir / project_name 582 | if not downloads_dir.exists(): 583 | self.logger.warning(f"Project {project_name} not found.") 584 | if project_name == "antlr_antlr4" and lang == "cs": 585 | project_name = "antlr_antlr4-cs" 586 | prj_mapped_files = io.load( 587 | Macros.results_dir / "java-csharp-mapped-files.json" 588 | )[project_name] 589 | 590 | with IOUtils.cd(downloads_dir): 591 | BashUtils.run("git checkout master") 592 | shalist = BashUtils.run( 593 | f"git log --color=never --since='May 15 2011' --first-parent --no-merges --pretty=format:'%H'" 594 | ).stdout.split("\n") 595 | shalist = [sha[:8] for sha in shalist] 596 | shalist = shalist[::-1] 597 | # 2. Check the changed files in each commit 598 | 599 | target_commits = OrderedDict() 600 | time_order_sha = [] 601 | with IOUtils.cd(downloads_dir): 602 | for i in tqdm(range(len(shalist) - 1)): 603 | cur_sha, pre_sha = shalist[i + 1], shalist[i] 604 | changed_files = BashUtils.run( 605 | f"git diff {pre_sha} {cur_sha} --name-only" 606 | ).stdout.split("\n") 607 | changed_files = [f for f in changed_files if f.split(".")[-1] == lang] 608 | if ( 609 | len(changed_files) 610 | > 0 611 | # and len(set(changed_files).intersection(set(prj_mapped_files))) > 0 612 | ): 613 | target_commits[f"{pre_sha}-{cur_sha}"] = changed_files 614 | 615 | self.logger.info( 616 | f"Collect {len(target_commits)} commits with {lang} files changed." 617 | ) 618 | io.dump( 619 | self.repos_results_dir / project_name / "git-history.json", 620 | target_commits, 621 | io.Fmt.jsonNoSort, 622 | ) 623 | 624 | def check_bad_examples(self): 625 | K = 50 626 | examples_to_check = io.load(Macros.data_dir / f"manual-check-{K}-examples.json") 627 | bad_examples = [] 628 | for dt in examples_to_check: 629 | added_java_tks, del_java_tks = dt["add-java-tks"], dt["del-java-tks"] 630 | added_cs_tks, del_cs_tks = dt["add-cs-tks"], dt["del-cs-tks"] 631 | if len(added_java_tks) == 0 and len(added_cs_tks) == 0: 632 | add_tokens_similarity = 0.4 633 | else: 634 | add_tokens_similarity = self.compute_diff_similarity( 635 | added_java_tks, added_cs_tks 636 | ) 637 | if len(del_java_tks) == 0 and len(del_cs_tks) == 0: 638 | del_tokens_similarity = 0.4 639 | else: 640 | del_tokens_similarity = self.compute_diff_similarity( 641 | del_java_tks, del_cs_tks 642 | ) 643 | tokens_sim = add_tokens_similarity * (0.5) + del_tokens_similarity * (0.5) 644 | dt["tokens-sim"] = tokens_sim 645 | 646 | if tokens_sim < 0.4: 647 | print(dt["id"]) 648 | # self.logger.info(f"Number of bad examples are {len(bad_examples)}") 649 | # io.dump( 650 | # Macros.data_dir / "bad-manual-check-examples.json", 651 | # bad_examples, 652 | # io.Fmt.jsonNoSort, 653 | # ) 654 | 655 | def code_tokenizer(self, raw_code: str, lang: str) -> Tuple[str, str]: 656 | """Tokenize both Java and C# code in the dataset""" 657 | from deltr.exe.CodeTokenizer import CodeTokenizer 658 | import atexit 659 | 660 | self.tokenizer = CodeTokenizer(main_class="org.csevo.Tokenizer") 661 | self.tokenizer.setup() 662 | atexit.register(self.tokenizer.teardown) 663 | 664 | # c# comment pattern 665 | line_comment_pattern = "\s//(.*?)\n" 666 | block_comment_pattern = r"/\*(.*?)\*/" 667 | 668 | if "//" in raw_code or "/*" in raw_code: 669 | code_no_comment = re.sub(line_comment_pattern, "\n", raw_code) 670 | code_no_comment = re.sub(block_comment_pattern, "", code_no_comment) 671 | else: 672 | code_no_comment = raw_code 673 | 674 | code_tok = self.tokenizer.tokenize(code_no_comment, lang).strip() 675 | return code_no_comment, code_tok 676 | 677 | 678 | def compute_minimal_code_diffs(old_tokens: List[str], new_tokens: List[str]): 679 | added_tokens = [] 680 | del_tokens = [] 681 | 682 | for edit_type, o_start, o_end, n_start, n_end in difflib.SequenceMatcher( 683 | None, old_tokens, new_tokens 684 | ).get_opcodes(): 685 | if edit_type == "equal": 686 | continue 687 | elif edit_type == "replace": 688 | added_tokens.extend(new_tokens[n_start:n_end]) 689 | del_tokens.extend(old_tokens[o_start:o_end]) 690 | 691 | elif edit_type == "insert": 692 | added_tokens.extend(new_tokens[n_start:n_end]) 693 | else: 694 | del_tokens.extend(old_tokens[o_start:o_end]) 695 | 696 | return added_tokens, del_tokens 697 | 698 | 699 | if __name__ == "__main__": 700 | LoggingUtils.setup(LoggingUtils.INFO, Macros.log_file) 701 | CLI(DataCollector, as_positional=False) 702 | -------------------------------------------------------------------------------- /python/deltr/collector/DataProcessor.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from jsonargparse import CLI 3 | from seutil import LoggingUtils 4 | import seutil as su 5 | 6 | from deltr.Macros import Macros 7 | from deltr.collector.diff_utils import ( 8 | compute_code_diffs, 9 | remove_keep_span, 10 | compute_unique_edits, 11 | compute_minimal_comment_diffs, 12 | compute_minimal_replace_diffs, 13 | format_minimal_diff_spans, 14 | EDIT_TOKENS, 15 | EDIT_START_TOKENS, 16 | ) 17 | 18 | logger = su.log.get_logger(__name__, su.LoggingUtils.INFO) 19 | 20 | 21 | class DataProcessor: 22 | SPLITS = ["train", "valid", "test"] 23 | 24 | def meta_edit_data_process( 25 | self, 26 | exp: str, 27 | src_lang: str = "java", 28 | tgt_lang: str = "cs", 29 | model_name: str = "metaEdits", 30 | setting: str = "time-segmented", 31 | ): 32 | """ 33 | Process dataset for meta edit model. 34 | """ 35 | 36 | model_data_dir = Macros.data_dir / model_name / setting / exp 37 | su.io.mkdir(model_data_dir) 38 | 39 | raw_data_dir = Macros.data_dir / "raw" / setting 40 | 41 | for split in self.SPLITS: 42 | data_list = su.io.load(raw_data_dir / f"delta-translation-{split}.jsonl") 43 | 44 | model_input_file = model_data_dir / f"{split}.{exp}.src" 45 | model_output_file = model_data_dir / f"{split}.{exp}.tgt" 46 | model_ground_truth_file = model_data_dir / f"{split}.{exp}.seq" 47 | 48 | model_inputs, model_outputs, golds = [], [], [] 49 | 50 | for i, dt in enumerate(data_list): 51 | # src lang 52 | src_lang_edits_tokens = compute_minimal_replace_diffs( 53 | dt[f"{src_lang}-old"]["tokenized_code"].split() + [""], 54 | dt[f"{src_lang}-new"]["tokenized_code"].split() + [""], 55 | )[0] 56 | assert ( 57 | format_minimal_diff_spans( 58 | dt[f"{src_lang}-old"]["tokenized_code"].split() + [""], 59 | src_lang_edits_tokens, 60 | ) 61 | == dt[f"{src_lang}-new"]["tokenized_code"] + " " 62 | ), f"index {i} {dt[f'{src_lang}-old']['tokenized_code']} \n {dt[f'{src_lang}-new']['tokenized_code']}" 63 | 64 | src_lang_edits = " ".join(src_lang_edits_tokens) 65 | tgt_lang_edits_tokens = compute_minimal_replace_diffs( 66 | dt[f"{tgt_lang}-old"]["tokenized_code"].split() + [""], 67 | dt[f"{tgt_lang}-new"]["tokenized_code"].split() + [""], 68 | )[0] 69 | assert ( 70 | format_minimal_diff_spans( 71 | dt[f"{tgt_lang}-old"]["tokenized_code"].split() + [""], 72 | tgt_lang_edits_tokens, 73 | ) 74 | == dt[f"{tgt_lang}-new"]["tokenized_code"] + " " 75 | ), f"index {i} {dt[f'{tgt_lang}-old']['tokenized_code']} \n {dt[f'{tgt_lang}-new']['tokenized_code']}" 76 | 77 | tgt_lang_edits = " ".join(tgt_lang_edits_tokens) 78 | meta_edits, _, _ = compute_code_diffs( 79 | src_lang_edits_tokens, tgt_lang_edits_tokens 80 | ) 81 | meta_edits_plan = remove_keep_span(meta_edits) 82 | src_lang_target = " ".join( 83 | dt[f"{src_lang}-new"]["tokenized_code"].split() 84 | ) 85 | tgt_lang_source = " ".join( 86 | dt[f"{tgt_lang}-old"]["tokenized_code"].split() 87 | ) 88 | tgt_lang_target = " ".join( 89 | dt[f"{tgt_lang}-new"]["tokenized_code"].split() 90 | ) 91 | golds.append(tgt_lang_target) 92 | model_inputs.append( 93 | f"{src_lang_edits} {tgt_lang_source} {src_lang_target}" 94 | ) 95 | model_outputs.append(f"{meta_edits_plan} {tgt_lang_edits}") 96 | # endfor 97 | su.io.dump(model_input_file, model_inputs, su.io.Fmt.txtList) 98 | su.io.dump(model_output_file, model_outputs, su.io.Fmt.txtList) 99 | su.io.dump(model_ground_truth_file, golds, su.io.Fmt.txtList) 100 | 101 | def edit_translation_data_process( 102 | self, 103 | exp: str, 104 | model_name: str = "edit-translation", 105 | src_lang: str = "java", 106 | tgt_lang: str = "cs", 107 | setting: str = "time-segmented", 108 | ): 109 | """ 110 | Process dataset for edit translation model. 111 | """ 112 | 113 | model_data_dir = Macros.data_dir / model_name / setting / exp 114 | su.io.mkdir(model_data_dir) 115 | raw_data_dir = Macros.data_dir / "raw" / setting 116 | 117 | for split in self.SPLITS: 118 | data_list = su.io.load(raw_data_dir / f"delta-translation-{split}.jsonl") 119 | 120 | model_input_file = model_data_dir / f"{split}.{exp}.src" 121 | model_output_file = model_data_dir / f"{split}.{exp}.tgt" 122 | model_ground_truth_file = model_data_dir / f"{split}.{exp}.seq" 123 | 124 | model_inputs, model_outputs, golds = [], [], [] 125 | 126 | for dt in data_list: 127 | # src lang 128 | src_lang_edits_tokens = compute_minimal_replace_diffs( 129 | dt[f"{src_lang}-old"]["tokenized_code"].split() + [""], 130 | dt[f"{src_lang}-new"]["tokenized_code"].split() + [""], 131 | )[0] 132 | assert ( 133 | format_minimal_diff_spans( 134 | dt[f"{src_lang}-old"]["tokenized_code"].split() + [""], 135 | src_lang_edits_tokens, 136 | ) 137 | == dt[f"{src_lang}-new"]["tokenized_code"] + " " 138 | ), f"{dt[f'{src_lang}-old']['tokenized_code']} \n {dt[f'{src_lang}-new']['tokenized_code']}" 139 | src_lang_edits = " ".join(src_lang_edits_tokens) 140 | 141 | tgt_lang_edits_tokens = compute_minimal_replace_diffs( 142 | dt[f"{tgt_lang}-old"]["tokenized_code"].split() + [""], 143 | dt[f"{tgt_lang}-new"]["tokenized_code"].split() + [""], 144 | )[0] 145 | assert ( 146 | format_minimal_diff_spans( 147 | dt[f"{tgt_lang}-old"]["tokenized_code"].split() + [""], 148 | tgt_lang_edits_tokens, 149 | ) 150 | == dt[f"{tgt_lang}-new"]["tokenized_code"] + " " 151 | ), f"{dt[f'{tgt_lang}-old']['tokenized_code']} \n {dt[f'{tgt_lang}-new']['tokenized_code']}" 152 | tgt_lang_edits = " ".join(tgt_lang_edits_tokens) 153 | 154 | src_lang_target = " ".join( 155 | dt[f"{src_lang}-new"]["tokenized_code"].split() 156 | ) 157 | tgt_lang_source = " ".join( 158 | dt[f"{tgt_lang}-old"]["tokenized_code"].split() 159 | ) 160 | tgt_lang_target = " ".join( 161 | dt[f"{tgt_lang}-new"]["tokenized_code"].split() 162 | ) 163 | golds.append(tgt_lang_target) 164 | assert tgt_lang_source != tgt_lang_target 165 | 166 | model_inputs.append( 167 | f"{src_lang_edits} {tgt_lang_source} {src_lang_target}" 168 | ) 169 | model_outputs.append(f"{tgt_lang_edits}") 170 | # endfor 171 | su.io.dump(model_input_file, model_inputs, su.io.Fmt.txtList) 172 | su.io.dump(model_output_file, model_outputs, su.io.Fmt.txtList) 173 | su.io.dump(model_ground_truth_file, golds, su.io.Fmt.txtList) 174 | 175 | 176 | if __name__ == "__main__": 177 | LoggingUtils.setup(LoggingUtils.INFO, Macros.log_file) 178 | CLI(DataProcessor, as_positional=False) 179 | -------------------------------------------------------------------------------- /python/deltr/collector/RealDiffCollector.py: -------------------------------------------------------------------------------- 1 | import re 2 | import copy 3 | import os 4 | from collections import defaultdict 5 | from difflib import SequenceMatcher 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import * 9 | import sys 10 | from tqdm import tqdm 11 | from jsonargparse import CLI 12 | from seutil import ( 13 | LoggingUtils, 14 | IOUtils, 15 | BashUtils, 16 | TimeUtils, 17 | io, 18 | TimeoutException, 19 | bash, 20 | ) 21 | import difflib 22 | import random 23 | 24 | from deltr.collector.DataCollector import ( 25 | projects_map, 26 | cs_port_date, 27 | compute_minimal_code_diffs, 28 | ) 29 | from deltr.Macros import Macros 30 | from deltr.collector.utils import ( 31 | include_jaccard, 32 | jaccard, 33 | get_commit_date, 34 | tokenize_code, 35 | ) 36 | from deltr.Environment import Environment 37 | from deltr.collector.ProjectData import ProjectData 38 | 39 | line_comment_pattern = r"//(.*?)\n" 40 | block_comment_pattern = r"/\*(.*?)\*/" 41 | 42 | cs_java_type_map = { 43 | "bool": "boolean", 44 | "bool?": "Boolean", 45 | "sbtye": "byte", 46 | "sbtye?": "Byte", 47 | "ushort": "short", 48 | "uint": "int", 49 | "uint?": "Integer", 50 | "int?": "Integer", 51 | "ulong": "long", 52 | "char?": "Character", 53 | "list": "ArrayList", 54 | } 55 | 56 | """ 57 | 1. Collect java and C# aligned diff: collect_aligned_data() # collect aligned raw diffs from projects 58 | 1.1 build_sha_changed_file_map() 59 | Build dict: SHA -> [changed file] # java and c# 60 | 1.2 mine_changed_methods() 61 | Build dict: SHA -> [changed methods] # java and c# 62 | 1.3 aggregate_methods_histories() 63 | Build dict: method_hash -> [method diff] # java and c# 64 | NOTE: method_hash = {m_path}.{class_name}.{method_name}-{params} 65 | 1.4 collect_aligned_method_history 66 | 1.4.0 build_unique_method_id() Build dict: method_id -> path.class_name.method_name-params # for java and c# 67 | NOTE: method_id is min_path.file_name.class_name.method_name-params # should be unique 68 | 1) first use the file name.class_name.method_name-params 69 | 2) if the id is not unique, add the prior dir name 70 | 1.4.1 align_project_java_csharp_method() Match java_method_id and cs_method_id 71 | if java_method_id == cs_method_id --> match 72 | else # match java params and csharp params by rules / manually 73 | 1.4.2 collect_method_diff_history() Build dict: method_id -> {java: [method diff], c#: [method diff]}= 74 | 1.4.3 filter_aligned_diff() 75 | Build list: [java diff, c# diff] 76 | NOTE: 1) use commit date to find paired diffs. i.e. for each java diff, find the closest c# diff in time and distance should be < 90 days 77 | 2) use jaccard sim between add/del tokens and add/del lines 78 | 2. Find the exact method from collected diffs 79 | build_delta_translation_dataset() 80 | 3. Split by time (target language time split c#, java) 81 | problem: only consider commit time 82 | """ 83 | 84 | 85 | class RealDiffCollector: 86 | logger = LoggingUtils.get_logger(__name__, LoggingUtils.INFO) 87 | 88 | # Constants 89 | commit_date_threshold = 90 # days 90 | token_sim_threshold = 0.4 91 | overlap_token_sim_threshold = 0.6 92 | line_sim_threshold = 0.5 93 | 94 | def __init__(self) -> None: 95 | self.results_dir = Macros.results_dir / "repo-data" 96 | self.repos_downloads_dir = Macros.repos_downloads_dir 97 | self.repos_results_dir = Macros.repos_results_dir 98 | 99 | # data split params 100 | self.train_ratio = 0.7 101 | self.val_ratio = 0.1 102 | self.test_ratio = 0.2 103 | assert self.train_ratio + self.val_ratio + self.test_ratio == 1 104 | 105 | def collect_real_diff_for_projects(self): 106 | for java_proj in projects_map: 107 | self.collect_real_aligned_diff_for_project(java_project_name=java_proj) 108 | 109 | def collect_real_aligned_diff_for_project(self, java_project_name: str): 110 | """script for collecting data from real git history for a project""" 111 | 112 | for lang in ["java", "cs"]: # hard-code for java and c# only 113 | if lang == "cs": 114 | target_project_name = projects_map[java_project_name] 115 | else: 116 | target_project_name = java_project_name 117 | self.build_sha_changed_file_map(project_name=target_project_name, lang=lang) 118 | self.mine_changed_methods(project_name=target_project_name, lang=lang) 119 | self.aggregate_methods_histories( 120 | project_name=target_project_name, lang=lang 121 | ) 122 | 123 | self.align_java_csharp_method(java_project_name=java_project_name) 124 | _ = input( 125 | "Human should help aligning method id in two projects. Finished? (y/n)" 126 | ) 127 | self.collect_project_diff_history(java_project_name=java_project_name) 128 | self.filter_project_aligned_diff(java_project_name=java_project_name) 129 | 130 | # 1.1 131 | def build_sha_changed_file_map(self, project_name: str, lang: str): 132 | """ 133 | Build a dictionary where the key is the SHA and value is the list of changed files. 134 | Do not constrain the files we collect. 135 | """ 136 | downloads_dir = self.repos_downloads_dir / project_name 137 | if lang == "java": 138 | cs_project_name = projects_map[project_name] 139 | start_date = cs_port_date[cs_project_name] 140 | else: 141 | start_date = cs_port_date[project_name] 142 | 143 | # 0. download repos 144 | project_url = f"git@github.com:{project_name.split('_')[0]}/{project_name.split('_')[1]}.git" 145 | if not downloads_dir.exists(): 146 | self.logger.info(f"Cloning repo {project_name} ... ") 147 | with IOUtils.cd(self.repos_downloads_dir): 148 | try: 149 | with TimeUtils.time_limit(300): 150 | BashUtils.run( 151 | f"git clone {project_url} {project_name}", 152 | expected_return_code=0, 153 | ) 154 | # end with 155 | except TimeoutException: 156 | self.logger.info( 157 | f"{project_name} exceeds time limit, ignore this one." 158 | ) 159 | return 160 | except: 161 | self.logger.warning( 162 | f"Project {project_name} failed: {sys.exc_info()}" 163 | ) 164 | return 165 | # end with 166 | 167 | # 1. get all shas 168 | with IOUtils.cd(downloads_dir): 169 | branch_name = bash.run( 170 | "git rev-parse --abbrev-ref HEAD", check_returncode=0 171 | ).stdout.strip() 172 | BashUtils.run(f"git checkout {branch_name}", expected_return_code=0) 173 | shalist = BashUtils.run( 174 | f"git log --color=never --since='{start_date}' --first-parent --no-merges --pretty=format:'%H'" 175 | ).stdout.split("\n") 176 | shalist = [sha[:8] for sha in shalist] 177 | self.logger.info(f"{len(shalist)} commits found") 178 | shalist = shalist[::-1] # in chronological order (old SHA before new SHA) 179 | 180 | # 2. Check the changed files in each commit 181 | commits_to_files = defaultdict(list) 182 | with IOUtils.cd(downloads_dir): 183 | for i in tqdm(range(len(shalist) - 1)): 184 | cur_sha, pre_sha = shalist[i + 1], shalist[i] 185 | changed_files = BashUtils.run( 186 | f"git diff {pre_sha} {cur_sha} --name-only" 187 | ).stdout.split("\n") 188 | changed_files = [f for f in changed_files if f.split(".")[-1] == lang] 189 | # find intersection files 190 | for cf in changed_files: 191 | commits_to_files[f"{pre_sha}-{cur_sha}"].append(cf) 192 | self.logger.info(f"{len(commits_to_files)} commits remained") 193 | io.dump( 194 | self.repos_results_dir / project_name / f"{lang}-commits-to-files.json", 195 | commits_to_files, 196 | io.Fmt.jsonNoSort, 197 | ) 198 | 199 | # 1.2 200 | def mine_changed_methods_for_projects(self): 201 | for java_proj in projects_map: 202 | self.logger.info(f"Mining project {java_proj}") 203 | self.mine_changed_methods(java_proj, "java") 204 | # self.mine_changed_methods(projects_map[java_proj], "cs") 205 | 206 | def mine_changed_methods(self, project_name: str, lang: str): 207 | """Mine the changed methods in changed files between two consecutive commits.""" 208 | 209 | git_history = io.load( 210 | self.repos_results_dir / project_name / f"{lang}-commits-to-files.json" 211 | ) 212 | sha_to_files = defaultdict(set) 213 | for sha in git_history: 214 | changed_files = git_history[sha] 215 | prev_sha, cur_sha = sha.split("-")[0], sha.split("-")[1] 216 | sha_to_files[cur_sha] = sha_to_files[cur_sha].union(set(changed_files)) 217 | sha_to_files[prev_sha] = sha_to_files[prev_sha].union(set(changed_files)) 218 | # endfor 219 | repos_results_dir = self.repos_results_dir / project_name 220 | project_git_changed_methods_history = OrderedDict() 221 | for sha in tqdm(git_history, total=len(git_history)): 222 | prev_sha, cur_sha = sha.split("-")[0], sha.split("-")[1] 223 | prev_changed_files = sha_to_files[prev_sha] 224 | cur_changed_files = sha_to_files[cur_sha] 225 | assert len(prev_changed_files) > 0 226 | assert len(cur_changed_files) > 0 227 | # 1. collect methods from two shas 228 | prev_sha, cur_sha = sha.split("-")[0], sha.split("-")[1] 229 | if not ( 230 | repos_results_dir / "collector" / f"{lang}-method-data-{prev_sha}.json" 231 | ).exists(): 232 | self.collect_methods_for_commit( 233 | project_name, lang, prev_sha, list(prev_changed_files) 234 | ) 235 | if not ( 236 | repos_results_dir / "collector" / f"{lang}-method-data-{cur_sha}.json" 237 | ).exists(): 238 | self.collect_methods_for_commit( 239 | project_name, lang, cur_sha, list(cur_changed_files) 240 | ) 241 | # 2. extract changed methods 242 | changed_methods_list = self.collect_changed_methods_for_commit( 243 | project_name, lang, prev_sha, cur_sha 244 | ) 245 | if len(changed_methods_list) > 0: 246 | project_git_changed_methods_history[ 247 | f"{prev_sha}-{cur_sha}" 248 | ] = changed_methods_list 249 | # end for 250 | self.logger.info(f"Collect {len(project_git_changed_methods_history)} history.") 251 | io.dump( 252 | self.repos_results_dir 253 | / project_name 254 | / f"{lang}-changed-methods-in-git-history.json", 255 | project_git_changed_methods_history, 256 | io.Fmt.jsonNoSort, 257 | ) 258 | 259 | # 1.3 260 | def aggregate_methods_histories(self, project_name: str, lang: str): 261 | """ 262 | Aggregate each methods' change histories 263 | A dict where key is method name and value is list of dict which represents change. 264 | """ 265 | 266 | repos_results_dir = self.repos_results_dir / project_name 267 | sha_2_methods = io.load( 268 | repos_results_dir / f"{lang}-changed-methods-in-git-history.json" 269 | ) 270 | project_git_changed_methods_history = defaultdict(list) 271 | for sha in tqdm(sha_2_methods, total=len(sha_2_methods)): 272 | changed_methods = sha_2_methods[sha] 273 | prev_sha, cur_sha = sha.split("-")[0], sha.split("-")[1] 274 | for m in changed_methods: 275 | old_m = m[0] 276 | new_m = m[1] 277 | m_path = old_m["path"] 278 | class_name = old_m["class_name"] 279 | method_name = old_m["name"] 280 | params = str(old_m["params"]) 281 | method_hash = f"{m_path}.{class_name}.{method_name}-{params}" 282 | new_method_hash = f"{new_m['path']}.{class_name}.{method_name}-{params}" 283 | 284 | # get diff 285 | d = difflib.Differ() 286 | if "".join(new_m["code"].split()) == "".join(old_m["code"].split()): 287 | continue 288 | sha1_diff = tokenize_code(" ".join(old_m["code"].split())) 289 | sha2_diff = tokenize_code(" ".join(new_m["code"].split())) 290 | added_tokens, deled_tokens = compute_minimal_code_diffs( 291 | sha1_diff, sha2_diff 292 | ) 293 | 294 | old_code_lines: List[str] = [ 295 | " ".join(cl.split()) + "\n" 296 | for cl in old_m["code"].splitlines(keepends=True) 297 | ] 298 | new_code_lines: List[str] = [ 299 | " ".join(cl.split()) + "\n" 300 | for cl in new_m["code"].splitlines(keepends=True) 301 | ] 302 | code_diff = list(d.compare(old_code_lines, new_code_lines)) 303 | del_code = [code for code in code_diff if code[0] == "-"] 304 | add_code = [code for code in code_diff if code[0] == "+"] 305 | project_git_changed_methods_history[method_hash].append( 306 | { 307 | "old-sha": prev_sha, 308 | "old-method-hash": method_hash, 309 | "new-method-hash": new_method_hash, 310 | "new-sha": cur_sha, 311 | "add-tokens": added_tokens, 312 | "del-tokens": deled_tokens, 313 | "add-code": add_code, 314 | "del-code": del_code, 315 | } 316 | ) 317 | self.logger.info( 318 | f"{len(project_git_changed_methods_history)} {lang} methods found in {project_name}'s git history." 319 | ) 320 | io.dump( 321 | self.repos_results_dir 322 | / project_name 323 | / f"{lang}-method-change-history.json", 324 | project_git_changed_methods_history, 325 | io.Fmt.jsonNoSort, 326 | ) 327 | 328 | def build_unique_method_id( 329 | self, java_changed_methods_file: Path, cs_changed_methods_file: Path 330 | ): 331 | """Build a map for java and c# hash to unique id""" 332 | 333 | java_change_history = io.load(java_changed_methods_file) 334 | cs_change_history = io.load(cs_changed_methods_file) 335 | # 0. collect unique java method id 336 | java_id_map = defaultdict(list) 337 | for java_hash in java_change_history: 338 | java_method_id = java_hash.split("/")[-1].replace(".java", "").lower() 339 | java_id_map[java_method_id].append(java_hash) 340 | # endfor 341 | new_java_hash_map = {} 342 | for java_method_id, java_hash_list in java_id_map.items(): 343 | if len(java_hash_list) > 1: 344 | smallest_index = -1 345 | for index in range(1, 20): 346 | dir_index = -1 * index 347 | java_id_set = set() 348 | for java_hash in java_hash_list: 349 | new_hash = ( 350 | ".".join(java_hash.split("/")[dir_index:]) 351 | .replace(".java", "") 352 | .lower() 353 | ) 354 | java_id_set.add(new_hash) 355 | # endfor 356 | if len(java_id_set) == len(java_hash_list): 357 | smallest_index = dir_index 358 | break 359 | for java_hash in java_hash_list: 360 | new_hash = ( 361 | ".".join(java_hash.split("/")[smallest_index:]) 362 | .replace(".java", "") 363 | .lower() 364 | ) 365 | new_java_hash_map[new_hash] = java_hash 366 | # endfor 367 | else: 368 | new_java_hash_map[java_method_id] = java_hash_list[0] 369 | # endfor 370 | 371 | # 1. collect unique cs method id 372 | cs_id_map = defaultdict(list) 373 | for cs_hash in cs_change_history: 374 | cs_method_id = cs_hash.split("/")[-1].replace(".cs", "").lower() 375 | cs_id_map[cs_method_id].append(cs_hash) 376 | # endfor 377 | new_cs_hash_map = {} 378 | for cs_method_id, cs_hash_list in cs_id_map.items(): 379 | if len(cs_hash_list) > 1: 380 | smallest_index = -1 381 | for index in range(1, 20): 382 | dir_index = -1 * index 383 | cs_id_set = set() 384 | for cs_hash in cs_hash_list: 385 | new_hash = ( 386 | ".".join(cs_hash.split("/")[dir_index:]) 387 | .replace(".cs", "") 388 | .lower() 389 | ) 390 | cs_id_set.add(new_hash) 391 | # endfor 392 | if len(cs_id_set) == len(cs_hash_list): 393 | smallest_index = dir_index 394 | break 395 | for cs_hash in cs_hash_list: 396 | new_hash = ( 397 | ".".join(cs_hash.split("/")[smallest_index:]) 398 | .replace(".cs", "") 399 | .lower() 400 | ) 401 | new_cs_hash_map[new_hash] = cs_hash 402 | # endfor 403 | else: 404 | new_cs_hash_map[cs_method_id] = cs_hash_list[0] 405 | # endfor 406 | return new_java_hash_map, new_cs_hash_map 407 | 408 | # 1.4.1 409 | def align_project_java_csharp_method(self): 410 | """Align the java and csharp method for all the projects.""" 411 | for java_project in projects_map: 412 | self.align_java_csharp_method(java_project) 413 | 414 | # 1.4.2 Build dict: method_hash -> {java: [method diffs], c#: [method diffs]}= 415 | def collect_method_diff_history(self): 416 | """Collect aligned method diff history for all the projects""" 417 | for java_project in projects_map: 418 | self.collect_project_diff_history(java_project) 419 | 420 | # 1.4.3 filter the aligned diff 421 | def filter_aligned_diff(self, java_first: bool = True): 422 | """Filter the aligned diff for all the projects""" 423 | for java_project in projects_map: 424 | self.filter_project_aligned_diff(java_project, java_first=java_first) 425 | 426 | # 2 427 | def build_delta_translation_dataset(self, filter_type: str = "time+sim"): 428 | """Build the dataset: project, java-SHA, java-new, java-old, cs-SHA, cs-new, cs-old.""" 429 | 430 | projects = projects_map.keys() 431 | dataset = [] 432 | for project in projects: 433 | cs_project = projects_map[project] 434 | self.logger.info(f"Building dataset for {project}") 435 | no_diff_count = 0 436 | file_not_found_error = 0 437 | missing_cs_method_count = 0 438 | 439 | aligned_diff_data = io.load( 440 | self.results_dir / f"{filter_type}-{project}-aligned-method-diff.json" 441 | ) 442 | for method_info in tqdm(aligned_diff_data, total=len(aligned_diff_data)): 443 | java_method_info = method_info["java"] 444 | # 1. extract Java code 445 | java_old_method, java_new_method = None, None 446 | old_sha, new_sha = ( 447 | java_method_info["old-sha"], 448 | java_method_info["new-sha"], 449 | ) 450 | 451 | # 1.1 get old_method 452 | try: 453 | old_sha_methods = io.load( 454 | self.repos_results_dir 455 | / project 456 | / "collector" 457 | / f"java-method-data-{old_sha}.json" 458 | ) 459 | except FileNotFoundError: 460 | self.logger.error( 461 | f"Cannot find old SHA {old_sha} for java project {project}" 462 | ) 463 | file_not_found_error += 1 464 | continue 465 | 466 | java_del_code = [code[2:] for code in java_method_info["del-code"]] 467 | java_add_code = [code[2:] for code in java_method_info["add-code"]] 468 | 469 | for j_m in old_sha_methods: 470 | bad_code = False 471 | if ( 472 | f"{j_m['path']}.{j_m['class_name']}.{j_m['name']}-{j_m['params']}" 473 | == java_method_info["old-method-hash"] 474 | ): 475 | # to make sure the data is correct 476 | for del_code in java_del_code: 477 | if " ".join(del_code.split()) not in " ".join( 478 | j_m["code"].split() 479 | ): 480 | bad_code = True 481 | break 482 | if bad_code: 483 | continue 484 | java_old_method = j_m 485 | break 486 | # sanity check 487 | if java_old_method is None: 488 | raise RuntimeError( 489 | f"Cannot find old version of Java method: {java_method_info['new-method-hash']} on SHA {old_sha}." 490 | ) 491 | # 1.2 get new_method 492 | try: 493 | new_sha_methods = io.load( 494 | self.repos_results_dir 495 | / project 496 | / "collector" 497 | / f"java-method-data-{new_sha}.json" 498 | ) 499 | except FileNotFoundError: 500 | self.logger.error( 501 | f"Cannot find new SHA {new_sha} for java project {project}" 502 | ) 503 | file_not_found_error += 1 504 | continue 505 | 506 | for j_m in new_sha_methods: 507 | bad_code = False 508 | if ( 509 | f"{j_m['path']}.{j_m['class_name']}.{j_m['name']}-{j_m['params']}" 510 | == java_method_info["new-method-hash"] 511 | ): 512 | for add_code in java_add_code: 513 | if " ".join(add_code.split()) not in " ".join( 514 | j_m["code"].split() 515 | ): 516 | bad_code = True 517 | break 518 | if bad_code: 519 | continue 520 | java_new_method = j_m 521 | break 522 | # sanity check 523 | if java_new_method is None: 524 | raise RuntimeError( 525 | f"Cannot find new version of Java method : {java_method_info['new-method-hash']} on SHA {new_sha}." 526 | ) 527 | assert java_old_method["code"] != java_new_method["code"] 528 | 529 | # 2. extract C# code 530 | cs_old_method, cs_new_method = None, None 531 | cs_method_info = method_info["cs"] 532 | old_sha, new_sha = ( 533 | cs_method_info["old-sha"], 534 | cs_method_info["new-sha"], 535 | ) 536 | 537 | # 2.1 get old_method 538 | try: 539 | old_sha_methods = io.load( 540 | self.repos_results_dir 541 | / cs_project 542 | / "collector" 543 | / f"cs-method-data-{old_sha}.json" 544 | ) 545 | except FileNotFoundError: 546 | self.logger.error( 547 | f"Cannot find old SHA {old_sha} for C# project {cs_project}" 548 | ) 549 | file_not_found_error + 1 550 | continue 551 | 552 | cs_del_code = [ 553 | code[2:] 554 | for code in cs_method_info["del-code"] 555 | if not code[2:].startswith("//") 556 | ] 557 | cs_add_code = [ 558 | code[2:] 559 | for code in cs_method_info["add-code"] 560 | if not code[2:].startswith("//") 561 | ] 562 | 563 | for c_m in old_sha_methods: 564 | bad_code = False 565 | if ( 566 | f"{c_m['path']}.{c_m['class_name']}.{c_m['name']}-{c_m['params']}" 567 | == cs_method_info["old-method-hash"] 568 | ): 569 | # to make sure the data is correct 570 | for del_code in cs_del_code: 571 | if "//" in c_m["code"]: 572 | c_m["code"] = re.sub( 573 | line_comment_pattern, "", c_m["code"] 574 | ) 575 | if "/*" in c_m["code"]: 576 | c_m["code"] = re.sub( 577 | block_comment_pattern, "", c_m["code"] 578 | ) 579 | if " ".join(del_code.split()) not in " ".join( 580 | c_m["code"].split() 581 | ): 582 | bad_code = True 583 | break 584 | if bad_code: 585 | continue 586 | cs_old_method = c_m 587 | break 588 | if cs_old_method is None: 589 | # raise RuntimeError( 590 | # f"Cannot find old version of C# method: {cs_method_info['old-method-hash']} on SHA {old_sha}." 591 | # ) 592 | self.logger.warning( 593 | f"Cannot find old version of C# method: {cs_method_info['old-method-hash']} on SHA {old_sha}." 594 | ) 595 | # 2.2 get new_method 596 | try: 597 | new_sha_methods = io.load( 598 | self.repos_results_dir 599 | / cs_project 600 | / "collector" 601 | / f"cs-method-data-{new_sha}.json" 602 | ) 603 | except FileNotFoundError: 604 | self.logger.error( 605 | f"Cannot find new SHA {new_sha} for C# project {cs_project}" 606 | ) 607 | file_not_found_error += 1 608 | continue 609 | for c_m in new_sha_methods: 610 | bad_code = False 611 | if ( 612 | f"{c_m['path']}.{c_m['class_name']}.{c_m['name']}-{c_m['params']}" 613 | == cs_method_info["new-method-hash"] 614 | ): 615 | for add_code in cs_add_code: 616 | if "//" in c_m["code"]: 617 | c_m["code"] = re.sub( 618 | line_comment_pattern, "", c_m["code"] 619 | ) 620 | if "/*" in c_m["code"]: 621 | c_m["code"] = re.sub( 622 | block_comment_pattern, "", c_m["code"] 623 | ) 624 | if " ".join(add_code.split()) not in " ".join( 625 | c_m["code"].split() 626 | ): 627 | bad_code = True 628 | break 629 | if bad_code: 630 | continue 631 | cs_new_method = c_m 632 | break 633 | if cs_new_method is None: 634 | self.logger.warning( 635 | f"Cannot find new version of c# method: {cs_method_info['new-method-hash']} on SHA {new_sha}." 636 | ) 637 | if cs_new_method is None or cs_old_method is None: 638 | missing_cs_method_count += 1 639 | continue 640 | assert cs_old_method["code"] != cs_new_method["code"] 641 | 642 | if ( 643 | java_old_method["code"].split() == java_new_method["code"].split() 644 | or cs_old_method["code"].split() == cs_new_method["code"].split() 645 | ): 646 | no_diff_count += 1 647 | continue 648 | dataset.append( 649 | { 650 | "project": project, 651 | "java-SHA": java_method_info["new-sha"], 652 | "java-old": java_old_method, 653 | "java-new": java_new_method, 654 | "java-commit-date": java_method_info["commit-date"], 655 | "cs-SHA": cs_method_info["new-sha"], 656 | "cs-old": cs_old_method, 657 | "cs-new": cs_new_method, 658 | "cs-commit-date": cs_method_info["commit-date"], 659 | } 660 | ) 661 | self.logger.info(f"Total no diff data is {no_diff_count} pairs.") 662 | self.logger.info( 663 | f"Total file not found error is {file_not_found_error} pairs, and missing method is {missing_cs_method_count} for project {project}." 664 | ) 665 | 666 | self.logger.info(f"In total collect {len(dataset)} pairs.") 667 | io.dump( 668 | Macros.raw_data_dir / "delta-translation-dataset-cs2java.jsonl", 669 | dataset, 670 | io.Fmt.jsonList, 671 | ) 672 | 673 | # 3. split based on commit date 674 | def time_sort_dataset(self, data_file: str, target_lang: str = "cs"): 675 | """Sort the dataset by the date of the commit.""" 676 | 677 | data_list = io.load(Macros.raw_data_dir / data_file) 678 | time_sorted_data = [] 679 | 680 | current_project = "" 681 | project_data = [] 682 | for dt in tqdm(data_list, total=len(data_list)): 683 | prj = dt["project"] 684 | 685 | if prj != current_project and current_project != "": 686 | # process all data in the prior project 687 | sorted_project_data = sorted( 688 | project_data, key=lambda x: x[f"{target_lang}-commit-date"] 689 | ) # sorted from old to new 690 | self.logger.info( 691 | f"project {current_project} has {len(sorted_project_data)} data points." 692 | ) 693 | time_sorted_data.extend(sorted_project_data) 694 | project_data = [dt] 695 | current_project = prj 696 | elif current_project == "": 697 | project_data = [dt] 698 | current_project = prj 699 | else: 700 | project_data.append(dt) 701 | sorted_project_data = sorted( 702 | project_data, key=lambda x: x[f"{target_lang}-commit-date"] 703 | ) # sorted from old to new 704 | self.logger.info( 705 | f"project {current_project} has {len(sorted_project_data)} data points." 706 | ) 707 | time_sorted_data.extend(sorted_project_data) 708 | io.dump( 709 | Macros.raw_data_dir / "delta-translation-dataset-time-sorted.jsonl", 710 | time_sorted_data, 711 | io.Fmt.jsonList, 712 | ) 713 | 714 | # 4. time segement dataset 715 | def time_segment_dataset(self): 716 | """Split the raw data into by time based on Java project.""" 717 | 718 | data_list = io.load( 719 | Macros.raw_data_dir / "delta-translation-dataset-time-sorted.jsonl" 720 | ) 721 | 722 | split_date = {} # the date to split train, valid, test set 723 | train_set = [] 724 | valid_set = [] 725 | test_set = [] 726 | 727 | project_data = defaultdict(list) 728 | 729 | for dt in data_list: 730 | project_data[dt["project"]].append(dt) 731 | 732 | for prj, prj_data in project_data.items(): 733 | total_size = len(prj_data) 734 | self.logger.info(f"project {prj} has {total_size} data points.") 735 | if len(prj_data) == 2: 736 | assert ( 737 | prj_data[0]["java-commit-date"] <= prj_data[1]["java-commit-date"] 738 | ) 739 | if prj_data[0]["java-commit-date"] < prj_data[1]["java-commit-date"]: 740 | train_set = 1 741 | test_set = 1 742 | else: 743 | test_set = 2 744 | else: 745 | train_size = int(total_size * (self.val_ratio + self.train_ratio)) 746 | test_size = total_size - train_size 747 | if train_size == total_size: 748 | test_size = 1 749 | while ( 750 | prj_data[train_size - 1]["cs-commit-date"] 751 | >= prj_data[train_size]["cs-commit-date"] 752 | and train_size > 0 753 | ): 754 | train_size -= 1 755 | test_size += 1 756 | # endif 757 | val_size = int( 758 | train_size * (self.val_ratio / (self.val_ratio + self.train_ratio)) 759 | ) 760 | if val_size == 0 and train_size > 1: 761 | val_size = 1 762 | train_size = train_size - val_size 763 | train_set.extend(prj_data[:train_size]) 764 | valid_set.extend(prj_data[train_size : train_size + val_size]) 765 | test_set.extend(prj_data[train_size + val_size :]) 766 | 767 | # write down date 768 | split_date[prj] = { 769 | "train": { 770 | "java": ( 771 | prj_data[0]["java-commit-date"], 772 | prj_data[train_size - 1]["java-commit-date"], 773 | ), 774 | "cs": ( 775 | prj_data[0]["cs-commit-date"], 776 | prj_data[train_size - 1]["cs-commit-date"], 777 | ), 778 | }, 779 | "valid": { 780 | "java": ( 781 | prj_data[train_size]["java-commit-date"], 782 | prj_data[train_size + val_size - 1]["java-commit-date"], 783 | ), 784 | "cs": ( 785 | prj_data[train_size]["cs-commit-date"], 786 | prj_data[train_size + val_size - 1]["cs-commit-date"], 787 | ), 788 | }, 789 | "test": { 790 | "java": ( 791 | prj_data[train_size + val_size]["java-commit-date"], 792 | prj_data[-1]["java-commit-date"], 793 | ), 794 | "cs": ( 795 | prj_data[train_size + val_size]["cs-commit-date"], 796 | prj_data[-1]["cs-commit-date"], 797 | ), 798 | }, 799 | } 800 | assert split_date[prj]["valid"]["cs"][1] < split_date[prj]["test"]["cs"][0] 801 | 802 | self.logger.info( 803 | f"{len(train_set)} training data, {len(valid_set)} validation data and {len(test_set)} test data." 804 | ) 805 | io.dump( 806 | Macros.data_dir / "raw" / "delta-translation-train.jsonl", 807 | train_set, 808 | io.Fmt.jsonList, 809 | ) 810 | io.dump( 811 | Macros.data_dir / "raw" / "delta-translation-valid.jsonl", 812 | valid_set, 813 | io.Fmt.jsonList, 814 | ) 815 | io.dump( 816 | Macros.data_dir / "raw" / "delta-translation-test.jsonl", 817 | test_set, 818 | io.Fmt.jsonList, 819 | ) 820 | 821 | io.dump( 822 | Macros.results_dir / "stats" / "stats-data-split-date.json", 823 | split_date, 824 | io.Fmt.jsonPretty, 825 | ) 826 | 827 | # 4.1 projects segement dataset 828 | def projects_segment_dataset(self): 829 | """Split the raw data by projects based on Java project.""" 830 | 831 | # setup 832 | raw_data_dir = Macros.data_dir / "raw" 833 | output_dir = raw_data_dir / "cross-project" 834 | io.mkdir(output_dir) 835 | data_list = io.load( 836 | raw_data_dir / "delta-translation-dataset-time-sorted.jsonl" 837 | ) 838 | 839 | # because the projects data amount is not even, manually specify the split 840 | train_projects = ["itext_itext7"] 841 | valid_projects = ["terabyte_jgit"] 842 | test_projects = set() 843 | 844 | train_set = [] 845 | valid_set = [] 846 | test_set = [] 847 | 848 | for dt in data_list: 849 | if dt["project"] in train_projects: 850 | train_set.append(dt) 851 | elif dt["project"] in valid_projects: 852 | valid_set.append(dt) 853 | else: 854 | test_set.append(dt) 855 | test_projects.add(dt["project"]) 856 | 857 | self.logger.info( 858 | f"{len(train_set)} training data, {len(valid_set)} validation data and {len(test_set)} test data." 859 | ) 860 | io.dump( 861 | output_dir / "delta-translation-train.jsonl", 862 | train_set, 863 | io.Fmt.jsonList, 864 | ) 865 | io.dump( 866 | output_dir / "delta-translation-valid.jsonl", 867 | valid_set, 868 | io.Fmt.jsonList, 869 | ) 870 | io.dump( 871 | output_dir / "delta-translation-test.jsonl", 872 | test_set, 873 | io.Fmt.jsonList, 874 | ) 875 | io.dump( 876 | output_dir / "projects-split.json", 877 | { 878 | "training-projects": train_projects, 879 | "valida-projects": valid_projects, 880 | "test-projects": list(test_projects), 881 | }, 882 | ) 883 | 884 | def tokenize_collected_data(self, file_path: str): 885 | """Tokenize the collected data.""" 886 | from deltr.exe.CodeTokenizer import CodeTokenizer 887 | import atexit 888 | 889 | tokenizer = CodeTokenizer(main_class="org.csevo.Tokenizer") 890 | tokenizer.setup() 891 | atexit.register(tokenizer.teardown) 892 | 893 | data_list = io.load( 894 | file_path, 895 | io.Fmt.jsonList, 896 | ) 897 | new_data = [] 898 | for dt in tqdm(data_list, total=len(data_list)): 899 | # java 900 | dt["java-new"]["tokenized_code"] = tokenizer.tokenize( 901 | dt["java-new"]["code"], "java" 902 | ).strip() 903 | dt["java-old"]["tokenized_code"] = tokenizer.tokenize( 904 | dt["java-old"]["code"], "java" 905 | ).strip() 906 | # c# 907 | dt["cs-new"]["tokenized_code"] = tokenizer.tokenize( 908 | dt["cs-new"]["code"], "cs" 909 | ).strip() 910 | dt["cs-old"]["tokenized_code"] = tokenizer.tokenize( 911 | dt["cs-old"]["code"], "cs" 912 | ).strip() 913 | 914 | # sanity check 915 | assert dt["cs-new"]["tokenized_code"] != dt["cs-old"]["tokenized_code"] 916 | new_data.append(dt) 917 | 918 | io.dump( 919 | file_path, 920 | new_data, 921 | io.Fmt.jsonList, 922 | ) 923 | 924 | # Helper functions 925 | def collect_changed_methods_for_commit( 926 | self, 927 | project_name: str, 928 | lang: str, 929 | prev_sha: str, 930 | cur_sha: str, 931 | ): 932 | """Collect changed methods from changed files between two consecutive commits.""" 933 | 934 | import json 935 | 936 | repo_results_dir = self.repos_results_dir / project_name / "collector" 937 | try: 938 | prev_sha_methods = io.load( 939 | repo_results_dir / f"{lang}-method-data-{prev_sha}.json" 940 | ) 941 | cur_sha_methods = io.load( 942 | repo_results_dir / f"{lang}-method-data-{cur_sha}.json" 943 | ) 944 | except json.decoder.JSONDecodeError: 945 | return [] 946 | except: 947 | self.logger.warning( 948 | f"{cur_sha} and {prev_sha} can not be parsed for project {project_name}." 949 | ) 950 | return [] 951 | changed_methods_list = [] 952 | 953 | io.dump( 954 | repo_results_dir / f"{lang}-method-data-{cur_sha}.json", 955 | cur_sha_methods, 956 | io.Fmt.jsonNoSort, 957 | ) 958 | io.dump( 959 | repo_results_dir / f"{lang}-method-data-{prev_sha}.json", 960 | prev_sha_methods, 961 | io.Fmt.jsonNoSort, 962 | ) 963 | for p_m in prev_sha_methods: 964 | for c_m in cur_sha_methods: 965 | new_path = c_m["path"] 966 | old_path = p_m["path"] # we consider the exact match of two paths 967 | if new_path != old_path: 968 | continue 969 | if ( 970 | p_m["name"] == c_m["name"] 971 | and p_m["class_name"] == c_m["class_name"] 972 | and p_m["params"] == c_m["params"] 973 | and p_m["code"].split() != c_m["code"].split() 974 | ): 975 | changed_methods_list.append((p_m, c_m)) 976 | break 977 | # end for 978 | return changed_methods_list 979 | 980 | def collect_methods_for_commit( 981 | self, project_name: str, lang: str, sha: str, changed_files: List[str] 982 | ): 983 | """Collect methods for a given SHA and given project.""" 984 | 985 | downloads_dir = Macros.repos_downloads_dir / project_name 986 | # download repo 987 | project_url = f"git@github.com:{project_name.split('_')[0]}/{project_name.split('_')[1]}.git" 988 | if not downloads_dir.exists(): 989 | self.logger.info(f"Cloning repo {project_name} ... ") 990 | with IOUtils.cd(self.repos_downloads_dir): 991 | try: 992 | with TimeUtils.time_limit(300): 993 | BashUtils.run( 994 | f"git clone {project_url} {project_name}", 995 | expected_return_code=0, 996 | ) 997 | # end with 998 | except TimeoutException: 999 | self.logger.info( 1000 | f"{project_name} exceeds time limit, ignore this one." 1001 | ) 1002 | return 1003 | except: 1004 | self.logger.warning( 1005 | f"Project {project_name} failed: {sys.exc_info()}" 1006 | ) 1007 | return 1008 | # end with 1009 | with io.cd(downloads_dir): 1010 | self.logger.info(f"Checkout {sha}") 1011 | bash.run(f"git checkout {sha} -f", check_returncode=0) 1012 | self.collect_method_data( 1013 | project_url="", 1014 | project_name=project_name, 1015 | project_sha=sha, 1016 | lang=lang, 1017 | changed_files=changed_files, 1018 | ) 1019 | 1020 | def collect_method_data( 1021 | self, 1022 | project_url: str, 1023 | project_name: str, 1024 | project_sha: str = None, 1025 | lang: str = "java", 1026 | changed_files: List[str] = None, 1027 | ): 1028 | """Collect methods in the project. If changed_files are given, only collect methods in the files.""" 1029 | 1030 | Environment.require_collector() 1031 | 1032 | # 0. Download repo 1033 | downloads_dir = self.repos_downloads_dir / project_name 1034 | results_dir = self.repos_results_dir / project_name 1035 | 1036 | IOUtils.mk_dir(results_dir) 1037 | assert downloads_dir.exists() 1038 | 1039 | # 2. Use parser to parse project 1040 | project_data = ProjectData.create() 1041 | project_data.name = project_name 1042 | project_data.url = project_url 1043 | 1044 | # Get revision (SHA) 1045 | with IOUtils.cd(downloads_dir): 1046 | try: 1047 | if project_sha: 1048 | BashUtils.run( 1049 | f"git checkout {project_sha} -f", expected_return_code=0 1050 | ) 1051 | else: 1052 | project_sha = bash.run(f"git rev-parse HEAD").stdout.strip() 1053 | except: 1054 | self.logger.warning(f"Project {project_name} failed: {sys.exc_info()}") 1055 | return 1056 | project_data.revision = project_sha 1057 | 1058 | project_data_file = results_dir / "project.json" 1059 | IOUtils.dump( 1060 | project_data_file, IOUtils.jsonfy(project_data), IOUtils.Format.jsonPretty 1061 | ) 1062 | 1063 | # Prepare config 1064 | log_file = results_dir / "collector-log.txt" 1065 | output_dir = results_dir / "collector" 1066 | 1067 | config = { 1068 | "collect": True, 1069 | "projectDir": str(downloads_dir), 1070 | "projectDataFile": str(project_data_file), 1071 | "logFile": str(log_file), 1072 | "outputDir": str(output_dir), 1073 | "lang": lang, 1074 | "revision": str(project_data.revision), 1075 | } 1076 | if changed_files: 1077 | config["fileNames"] = changed_files 1078 | # self.logger.info(f"Project parsing config file: \n {config}") 1079 | config_file = results_dir / "collector-config.json" 1080 | IOUtils.dump(config_file, config, IOUtils.Format.jsonPretty) 1081 | 1082 | self.logger.info( 1083 | f"Starting the collector. Check log at {log_file} and outputs at {output_dir}" 1084 | ) 1085 | rr = BashUtils.run( 1086 | f"java -jar {Environment.collector_jar} {config_file}", 1087 | ) 1088 | if rr.stderr: 1089 | self.logger.warning(f"Stderr of collector:\n{rr.stderr}") 1090 | # end if 1091 | 1092 | return 1093 | 1094 | def collect_pymethod_data( 1095 | self, 1096 | project_url: str, 1097 | project_name: str, 1098 | project_sha: str = None, 1099 | lang: str = "python", 1100 | changed_files: List[str] = None, 1101 | ): 1102 | """Collect python methods in the project. If changed_files are given, only collect methods in the files.""" 1103 | 1104 | # Environment.require_collector() 1105 | 1106 | # 0. Download repo 1107 | downloads_dir = self.repos_downloads_dir / project_name 1108 | results_dir = self.repos_results_dir / project_name 1109 | 1110 | IOUtils.mk_dir(results_dir) 1111 | 1112 | # Clone the repo if not exists 1113 | if not downloads_dir.exists(): 1114 | self.logger.info(f"Cloning repo {project_name} ... ") 1115 | with IOUtils.cd(self.repos_downloads_dir): 1116 | try: 1117 | with TimeUtils.time_limit(300): 1118 | BashUtils.run( 1119 | f"git clone {project_url} {project_name}", 1120 | expected_return_code=0, 1121 | ) 1122 | # end with 1123 | except TimeoutException: 1124 | self.logger.info( 1125 | f"{project_name} exceeds time limit, ignore this one." 1126 | ) 1127 | return 1128 | except: 1129 | self.logger.warning( 1130 | f"Project {project_name} failed: {sys.exc_info()}" 1131 | ) 1132 | return 1133 | # end with 1134 | # end if 1135 | 1136 | # 2. Use parser to parse project 1137 | project_data = ProjectData.create() 1138 | project_data.name = project_name 1139 | project_data.url = project_url 1140 | 1141 | # Get revision (SHA) 1142 | with IOUtils.cd(downloads_dir): 1143 | try: 1144 | if project_sha: 1145 | BashUtils.run( 1146 | f"git checkout {project_sha} -f", expected_return_code=0 1147 | ) 1148 | else: 1149 | project_sha = bash.run(f"git rev-parse HEAD").stdout.strip() 1150 | except: 1151 | self.logger.warning(f"Project {project_name} failed: {sys.exc_info()}") 1152 | return 1153 | project_data.revision = project_sha 1154 | 1155 | project_data_file = results_dir / "project.json" 1156 | IOUtils.dump( 1157 | project_data_file, IOUtils.jsonfy(project_data), IOUtils.Format.jsonPretty 1158 | ) 1159 | 1160 | # Run python parser to parse python files 1161 | py_func_dict = self.parse_python_methods( 1162 | project_dir=downloads_dir / "runtime" / "Python3" 1163 | ) 1164 | print(f"In total collected {len(py_func_dict)} functions") 1165 | io.dump( 1166 | Macros.results_dir / "temp-test-py-funcs.json", 1167 | py_func_dict, 1168 | io.Fmt.jsonPretty, 1169 | ) 1170 | 1171 | return 1172 | 1173 | def parse_python_methods(self, project_dir: str): 1174 | """Parse all python methods in a project given directory.""" 1175 | 1176 | from deltr.collector.PythonParser import PythonParser 1177 | import dataclasses 1178 | 1179 | parser = PythonParser() 1180 | pyfiles = [] 1181 | for root, dirs, files in os.walk(project_dir): 1182 | for file in files: 1183 | if str(file).endswith(".py"): 1184 | pyfiles.append(os.path.join(root, file)) 1185 | 1186 | assert len(pyfiles) > 0, "No python files found." 1187 | function_dict = {} 1188 | for pyfile in pyfiles: 1189 | functions = parser.collect_functions(project_dir / pyfile) 1190 | if functions is None: 1191 | continue 1192 | for m in functions: 1193 | function_dict[m.path + m.className + m.name] = dataclasses.asdict(m) 1194 | 1195 | return function_dict 1196 | 1197 | def filter_project_aligned_diff( 1198 | self, 1199 | java_project_name: str, 1200 | filter_type: str = "time+sim", 1201 | java_first: bool = True, 1202 | ): 1203 | """Find the aligned diff from the aligned histories and filter based on jaccard similarity.""" 1204 | 1205 | aligned_method_history = io.load( 1206 | self.results_dir / f"{java_project_name}-method-diff-history.json" 1207 | ) 1208 | aligned_diff_data = [] 1209 | for method_history in tqdm( 1210 | aligned_method_history, total=len(aligned_method_history) 1211 | ): 1212 | if filter_type == "time": 1213 | matched_diff_list = self.match_diff_based_on_time( 1214 | method_history, java_project_name, java_first=java_first 1215 | ) 1216 | aligned_diff_data.append(matched_diff_list) 1217 | elif filter_type == "sim": 1218 | matched_diff_list = self.match_diff_based_on_jaccard(method_history) 1219 | aligned_diff_data.extend(matched_diff_list) 1220 | else: 1221 | matched_diff_list = self.match_diff_based_on_time( 1222 | method_history, java_project_name, java_first=java_first 1223 | ) # [{}] 1224 | matched_diff_list = self.match_diff_based_on_jaccard(matched_diff_list) 1225 | aligned_diff_data.extend(matched_diff_list) 1226 | 1227 | self.logger.info( 1228 | f"Collect {len(aligned_diff_data)} aligned method diff for project {java_project_name}." 1229 | ) 1230 | io.dump( 1231 | self.results_dir 1232 | / f"{filter_type}-{java_project_name}-aligned-method-diff.json", 1233 | aligned_diff_data, 1234 | io.Fmt.jsonNoSort, 1235 | ) 1236 | 1237 | def match_diff_based_on_jaccard(self, method_diff_history: dict): 1238 | """Find the matched diff based on smilarity.""" 1239 | method_diff_list = [] 1240 | if isinstance(method_diff_history, dict): 1241 | java_diffs = method_diff_history["java"] 1242 | cs_diffs = method_diff_history["cs"] 1243 | elif isinstance(method_diff_history, list): 1244 | java_diffs = [dt["java"] for dt in method_diff_history] 1245 | cs_diffs = [dt["cs"] for dt in method_diff_history] 1246 | 1247 | cs_best_match_sim = defaultdict(float) 1248 | for i, java_diff in enumerate(java_diffs): 1249 | best_sim = 0 1250 | diff_pair = None 1251 | cs_diff_list = cs_diffs[i] 1252 | for cs_diff in cs_diff_list: 1253 | if cs_diff["add-tokens"] == [] and cs_diff["del-tokens"] == []: 1254 | continue 1255 | token_sim, line_sim = self.compare_diff_similarity(java_diff, cs_diff) 1256 | if ( 1257 | token_sim < self.token_sim_threshold 1258 | or line_sim < self.line_sim_threshold 1259 | ): 1260 | continue 1261 | 1262 | if token_sim + line_sim >= best_sim: 1263 | best_sim = token_sim + line_sim 1264 | diff_pair = { 1265 | "java": java_diff, 1266 | "cs": cs_diff, 1267 | } 1268 | if diff_pair: 1269 | cs_diff = diff_pair["cs"] 1270 | exist_best_sim = cs_best_match_sim[ 1271 | cs_diff["new-sha"] 1272 | + cs_diff["old-sha"] 1273 | + str(cs_diff["add-code"]) 1274 | + str(cs_diff["del-code"]) 1275 | + cs_diff["new-method-hash"] 1276 | + cs_diff["commit-date"] 1277 | ] 1278 | if exist_best_sim > 0 and exist_best_sim < best_sim: 1279 | # de-duplicate 1280 | duplicate = True 1281 | for method_diff_pair in method_diff_list: 1282 | if ( 1283 | method_diff_pair["cs"]["new-sha"] == cs_diff["new-sha"] 1284 | and method_diff_pair["cs"]["old-sha"] == cs_diff["old-sha"] 1285 | and str(method_diff_pair["cs"]["add-code"]) 1286 | == str(cs_diff["add-code"]) 1287 | and str(method_diff_pair["cs"]["del-code"]) 1288 | == str(cs_diff["del-code"]) 1289 | and method_diff_pair["cs"]["new-method-hash"] 1290 | == cs_diff["new-method-hash"] 1291 | and method_diff_pair["cs"]["commit-date"] 1292 | == cs_diff["commit-date"] 1293 | ): 1294 | duplicate = False 1295 | method_diff_list.remove(method_diff_pair) 1296 | break 1297 | assert duplicate == False 1298 | method_diff_list.append(diff_pair) 1299 | cs_best_match_sim[ 1300 | cs_diff["new-sha"] 1301 | + cs_diff["old-sha"] 1302 | + str(cs_diff["add-code"]) 1303 | + str(cs_diff["del-code"]) 1304 | + cs_diff["new-method-hash"] 1305 | + cs_diff["commit-date"] 1306 | ] = best_sim 1307 | elif exist_best_sim == 0: 1308 | cs_best_match_sim[ 1309 | cs_diff["new-sha"] 1310 | + cs_diff["old-sha"] 1311 | + str(cs_diff["add-code"]) 1312 | + str(cs_diff["del-code"]) 1313 | + cs_diff["new-method-hash"] 1314 | + cs_diff["commit-date"] 1315 | ] = best_sim 1316 | method_diff_list.append(diff_pair) 1317 | 1318 | return method_diff_list 1319 | 1320 | def match_diff_based_on_time( 1321 | self, method_diff_history: dict, java_project_name: str, java_first: bool = True 1322 | ): 1323 | """Find the matched diff based on commit date.""" 1324 | 1325 | matched_diff_list = [] 1326 | which_older = 1 if java_first else -1 1327 | for java_diff in method_diff_history["java"]: 1328 | potential_match_cs_diffs = [] 1329 | if java_diff["add-tokens"] == [] and java_diff["del-tokens"] == []: 1330 | continue 1331 | java_commit_date = datetime.strptime( 1332 | get_commit_date(java_diff["new-sha"], java_project_name), 1333 | "%Y-%m-%d %H:%M:%S", 1334 | ) 1335 | java_diff["commit-date"] = str(java_commit_date) 1336 | # min_time = float("inf") 1337 | cs_project_name = projects_map[java_project_name] 1338 | for cs_diff in method_diff_history["cs"]: 1339 | if cs_diff["add-tokens"] == [] and cs_diff["del-tokens"] == []: 1340 | continue 1341 | cs_commit_date = datetime.strptime( 1342 | get_commit_date(cs_diff["new-sha"], cs_project_name), 1343 | "%Y-%m-%d %H:%M:%S", 1344 | ) 1345 | cs_diff["commit-date"] = str(cs_commit_date) 1346 | delta_days = (java_commit_date - cs_commit_date).days 1347 | if ( 1348 | abs(delta_days) < self.commit_date_threshold 1349 | and which_older * delta_days < 0 1350 | ): 1351 | potential_match_cs_diffs.append(cs_diff) 1352 | if len(potential_match_cs_diffs) > 0: 1353 | matched_diff_list.append( 1354 | {"java": java_diff, "cs": potential_match_cs_diffs} 1355 | ) 1356 | return matched_diff_list 1357 | 1358 | def collect_project_diff_history(self, java_project_name: str): 1359 | """Align the history of methods in c# and java.""" 1360 | 1361 | method_align_history = [] 1362 | cs_project_name = projects_map[java_project_name] 1363 | java_method_history = io.load( 1364 | Macros.repos_results_dir 1365 | / java_project_name 1366 | / f"java-method-change-history.json" 1367 | ) 1368 | cs_method_history = io.load( 1369 | Macros.repos_results_dir 1370 | / cs_project_name 1371 | / f"cs-method-change-history.json" 1372 | ) 1373 | java_cs_method_map = io.load( 1374 | Macros.results_dir 1375 | / "repo-data" 1376 | / f"{java_project_name}-java-csharp-method-map.json" 1377 | ) 1378 | for j_method_hash, c_method_hash in tqdm(java_cs_method_map.items()): 1379 | if isinstance(c_method_hash, list): 1380 | c_method_hash = c_method_hash[0] 1381 | # endif 1382 | method_align_history.append( 1383 | { 1384 | "java": java_method_history[j_method_hash], 1385 | "cs": cs_method_history[c_method_hash], 1386 | } 1387 | ) 1388 | # endfor 1389 | self.logger.info( 1390 | f"Found {len(method_align_history)} methods with diff history in project {java_project_name}." 1391 | ) 1392 | io.dump( 1393 | self.results_dir / f"{java_project_name}-method-diff-history.json", 1394 | method_align_history, 1395 | io.Fmt.jsonNoSort, 1396 | ) 1397 | 1398 | def align_java_csharp_method(self, java_project_name: str): 1399 | """ 1400 | 1.4.1 1401 | Build java and csharp method mapping. 1402 | 1403 | NOTE: it involves manually checking the mapping. 1404 | """ 1405 | 1406 | cs_project_name = projects_map[java_project_name] 1407 | java_id_map, cs_id_map = self.build_unique_method_id( 1408 | java_changed_methods_file=Macros.repos_results_dir 1409 | / java_project_name 1410 | / "java-method-change-history.json", 1411 | cs_changed_methods_file=Macros.repos_results_dir 1412 | / cs_project_name 1413 | / "cs-method-change-history.json", 1414 | ) 1415 | java_2_cs_map = self.map_java_to_cs(java_id_map, cs_id_map) 1416 | self.logger.info( 1417 | f"Size of methods mapping for project {java_project_name} is {len(java_2_cs_map)}" 1418 | ) 1419 | io.dump( 1420 | Macros.results_dir 1421 | / "repo-data" 1422 | / f"{java_project_name}-java-csharp-method-map.json", 1423 | java_2_cs_map, 1424 | io.Fmt.jsonPretty, 1425 | ) 1426 | 1427 | def map_java_to_cs(self, java_id_map: dict, cs_id_map: dict): 1428 | """Map java unique id with c# unique id.""" 1429 | 1430 | java_2_cs_methods = {} 1431 | uncertain_match_count = 0 1432 | for java_id, java_hash in tqdm(java_id_map.items()): 1433 | if java_id in cs_id_map: 1434 | java_2_cs_methods[java_hash] = cs_id_map[java_id] 1435 | else: 1436 | # try manually inspect 1437 | possible_match_cs_methods = [] 1438 | for cs_id, cs_hash in cs_id_map.items(): 1439 | if java_id.split("-")[0] == cs_id.split("-")[0]: 1440 | # first check the number of parameters 1441 | if str(java_id.count("[")) != str(cs_id.count("[")): 1442 | continue 1443 | temp_cs_id = copy.deepcopy(cs_id) 1444 | for cs_type, java_type in cs_java_type_map.items(): 1445 | temp_cs_id = temp_cs_id.replace(cs_type, java_type) 1446 | if java_id == temp_cs_id: 1447 | java_2_cs_methods[java_hash] = cs_hash 1448 | break 1449 | else: 1450 | possible_match_cs_methods.append(cs_id) 1451 | # endfor 1452 | 1453 | # if len(possible_match_cs_methods) > 0: 1454 | # java_2_cs_methods[java_hash] = [ 1455 | # cs_id_map[cs_id] for cs_id in possible_match_cs_methods 1456 | # ] 1457 | # uncertain_match_count += 1 1458 | 1459 | if len(possible_match_cs_methods) > 0: 1460 | max_sim = 0 1461 | match_cs_id = "" 1462 | for cs_id in possible_match_cs_methods: 1463 | param_sim = similar(java_id.split("-")[1], cs_id.split("-")[1]) 1464 | if param_sim > max_sim and param_sim > 0.5: 1465 | max_sim = param_sim 1466 | match_cs_id = cs_id 1467 | if match_cs_id != "": 1468 | if max_sim < 0.8: 1469 | uncertain_match_count += 1 1470 | java_2_cs_methods[java_hash] = [cs_id_map[match_cs_id]] 1471 | else: 1472 | java_2_cs_methods[java_hash] = cs_id_map[match_cs_id] 1473 | self.logger.info(f"Uncertain match count: {uncertain_match_count}") 1474 | 1475 | return java_2_cs_methods 1476 | 1477 | def compare_diff_similarity( 1478 | self, 1479 | java_diff: dict, 1480 | cs_diff: dict, 1481 | ) -> Tuple[float, float]: 1482 | """Algorithm to find the aligned pair""" 1483 | 1484 | add_java_tks, del_java_tks = java_diff["add-tokens"], java_diff["del-tokens"] 1485 | add_cs_tks, del_cs_tks = cs_diff["add-tokens"], cs_diff["del-tokens"] 1486 | # 1. token level similarity 1487 | if len(add_java_tks) == 0 and len(add_cs_tks) == 0: 1488 | add_tokens_similarity = self.token_sim_threshold 1489 | else: 1490 | add_tokens_similarity = self.compute_diff_similarity( 1491 | add_java_tks, 1492 | add_cs_tks, 1493 | ) 1494 | if len(del_java_tks) == 0 and len(del_cs_tks) == 0: 1495 | del_tokens_similarity = self.token_sim_threshold 1496 | else: 1497 | del_tokens_similarity = self.compute_diff_similarity( 1498 | del_java_tks, del_cs_tks 1499 | ) 1500 | tokens_sim = add_tokens_similarity * (0.5) + del_tokens_similarity * (0.5) 1501 | # 2. line level similarity 1502 | add_line_similarity = self.compute_diff_similarity( 1503 | java_diff["add-code"], cs_diff["add-code"] 1504 | ) 1505 | del_line_similarity = self.compute_diff_similarity( 1506 | java_diff["del-code"], cs_diff["del-code"] 1507 | ) 1508 | line_sim = add_line_similarity * 0.5 + del_line_similarity * 0.5 1509 | 1510 | return tokens_sim, line_sim 1511 | 1512 | def compute_diff_similarity( 1513 | self, sha1_lines: List[str], sha2_lines: List[str], task: str = None 1514 | ): 1515 | """Compute similarity between the changes (added lines & deleted lines)""" 1516 | 1517 | sha1_diff = tokenize_code(" ".join(sha1_lines).strip().lower()) 1518 | sha2_diff = tokenize_code(" ".join(sha2_lines).strip().lower()) 1519 | if task == "inclusion": 1520 | jaccard_sim = include_jaccard(sha1_diff, sha2_diff) 1521 | else: 1522 | jaccard_sim = jaccard(sha1_diff, sha2_diff) 1523 | 1524 | return jaccard_sim 1525 | 1526 | def remove_csharp_comment(self, cs_code: str): 1527 | line_comment_pattern = r"//(.*?)\n" 1528 | block_comment_pattern = r"/\*(.*?)\*/" 1529 | cs_code = re.sub(line_comment_pattern, "", cs_code) 1530 | cs_code = re.sub(block_comment_pattern, "", cs_code) 1531 | 1532 | return cs_code 1533 | 1534 | # temp script 1535 | def sample_filtered_diff(self, filter_type: str = "time+sim"): 1536 | """Sample the filtered method diff for inspection.""" 1537 | 1538 | total_number = 0 1539 | sample_diff_method_list = [] 1540 | for java_project in projects_map: 1541 | filtered_aligned_diff = io.load( 1542 | self.results_dir 1543 | / f"{filter_type}-{java_project}-aligned-method-diff.json" 1544 | ) 1545 | total_number += len(filtered_aligned_diff) 1546 | K = 20 1547 | sample_ids = random.choices(range(len(filtered_aligned_diff)), k=K) 1548 | for i, dt in enumerate(filtered_aligned_diff): 1549 | if i in sample_ids: 1550 | del dt["cs"]["add-tokens"] 1551 | del dt["cs"]["del-tokens"] 1552 | del dt["java"]["add-tokens"] 1553 | del dt["java"]["del-tokens"] 1554 | sample_diff_method_list.append(dt) 1555 | self.logger.info(f"In total collect {total_number} pairs of changes.") 1556 | io.dump( 1557 | Macros.raw_data_dir / f"{filter_type}-filtered-sampled-aligned-diff.json", 1558 | sample_diff_method_list, 1559 | io.Fmt.jsonPretty, 1560 | ) 1561 | 1562 | def analyze_project_aligned_diff(self, java_project_name: str): 1563 | """ 1564 | This method is designed for finding the aligned diff from the aligned histories and 1565 | filter based on jaccard similarity and inspect the examples. 1566 | """ 1567 | 1568 | aligned_method_history = io.load( 1569 | self.results_dir / f"{java_project_name}-method-diff-history.json" 1570 | ) 1571 | within_time_not_sim = [] 1572 | no_diff_within_time = [] 1573 | for method_history in tqdm( 1574 | aligned_method_history, total=len(aligned_method_history) 1575 | ): 1576 | time_matched_diff_list = self.match_diff_based_on_time( 1577 | method_history, java_project_name 1578 | ) 1579 | if len(time_matched_diff_list) == 0: 1580 | for mh in method_history["java"]: 1581 | del mh["add-tokens"] 1582 | del mh["del-tokens"] 1583 | for mh in method_history["cs"]: 1584 | del mh["add-tokens"] 1585 | del mh["del-tokens"] 1586 | no_diff_within_time.append(method_history) 1587 | sim_matched_diff_list = self.match_diff_based_on_jaccard( 1588 | time_matched_diff_list 1589 | ) 1590 | for dt in time_matched_diff_list: 1591 | if dt not in sim_matched_diff_list: 1592 | try: 1593 | del dt["java"]["add-tokens"] 1594 | del dt["java"]["del-tokens"] 1595 | del dt["cs"]["add-tokens"] 1596 | del dt["cs"]["del-tokens"] 1597 | except: 1598 | pass 1599 | within_time_not_sim.append(dt) 1600 | 1601 | io.dump( 1602 | self.results_dir 1603 | / f"time-not-sim-{java_project_name}-aligned-method-diff.json", 1604 | within_time_not_sim, 1605 | io.Fmt.jsonNoSort, 1606 | ) 1607 | io.dump( 1608 | self.results_dir 1609 | / f"not-time-not-sim-{java_project_name}-aligned-method-diff.json", 1610 | no_diff_within_time, 1611 | io.Fmt.jsonNoSort, 1612 | ) 1613 | 1614 | 1615 | def similar(a, b): 1616 | return SequenceMatcher(None, a, b).ratio() 1617 | 1618 | 1619 | if __name__ == "__main__": 1620 | LoggingUtils.setup(LoggingUtils.INFO, Macros.log_file) 1621 | CLI(RealDiffCollector, as_positional=False) 1622 | -------------------------------------------------------------------------------- /python/prepare-conda-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This script prepares a conda environment for running/developing 4 | # Codeditor, with GPU support if an Nvidia GPU is detected. 5 | # 6 | # Requires conda to be installed and available in PATH. 7 | # 8 | # Usage: 9 | # ./prepare-conda-env.sh 10 | # # after the script finishes, activate the environment: 11 | # conda activate deltr 12 | # 13 | # Usage with options: 14 | # ./prepare-conda-env.sh [cuda_version] [env_name] [conda_path] 15 | # # cuda_version: {cpu,10.2,11.3,11.6,system} the CUDA toolkit version for PyTorch (default: "11.6" if Nvidia GPU is available detected by nvidia-smi, "cpu" otherwise) 16 | # # env_name: name of the conda environment to create (default: deltr) 17 | # # conda_path: path to conda.sh (default: automatically detected) 18 | 19 | 20 | _DIR=$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd ) 21 | 22 | function get_conda_path() { 23 | local conda_exe=$(which conda) 24 | if [[ -z ${conda_exe} ]]; then 25 | echo "Fail to detect conda! Have you installed Anaconda/Miniconda?" 1>&2 26 | exit 1 27 | fi 28 | 29 | echo "$(dirname ${conda_exe})/../etc/profile.d/conda.sh" 30 | } 31 | 32 | function get_gpu_avail() { 33 | if [[ -z $(which nvidia-smi) ]]; then 34 | echo "cpu" 35 | else 36 | echo "gpu" 37 | fi 38 | } 39 | 40 | function get_cuda_version() { 41 | if [[ -z $(which nvcc) ]]; then 42 | echo "cpu" 43 | else 44 | echo "$(nvcc -V | grep "release" | sed -E "s/.*release ([^,]+),.*/\1/")" 45 | fi 46 | } 47 | 48 | readonly PYTORCH_V=1.12.1 49 | readonly TORCHTEXT_V=0.13.1 50 | 51 | function prepare_conda_env() { 52 | local cuda_version=$1; shift # cpu|system|10.2|11.3|11.6 53 | local env_name=${1:-deltr}; shift 54 | local conda_path=$1; shift 55 | 56 | set -e 57 | if [[ -z ${cuda_version} ]]; then 58 | if [[ $(get_gpu_avail) == "gpu" ]]; then 59 | # by default, use newer cuda version for better compatibility with newer GPUs 60 | cuda_version="11.6" 61 | else 62 | cuda_version="cpu" 63 | fi 64 | fi 65 | if [[ -z ${conda_path} ]]; then 66 | conda_path=$(get_conda_path) 67 | fi 68 | echo ">>> Preparing conda environment \"${env_name}\"; cuda_version: ${cuda_version}; conda path: ${conda_path}" 69 | 70 | # Preparation 71 | source ${conda_path} 72 | conda env remove --name $env_name 73 | conda create --name $env_name python=3.9 pip -y 74 | conda activate $env_name 75 | 76 | # Install Pytorch 77 | case ${cuda_version} in 78 | cpu) 79 | conda install -y pytorch=${PYTORCH_V} torchtext=${TORCHTEXT_V} cpuonly -c pytorch 80 | ;; 81 | 10.2) 82 | conda install -y pytorch=${PYTORCH_V} torchtext=${TORCHTEXT_V} cudatoolkit=10.2 -c pytorch 83 | ;; 84 | 11.3) 85 | conda install -y pytorch=${PYTORCH_V} torchtext=${TORCHTEXT_V} cudatoolkit=11.3 -c pytorch 86 | ;; 87 | 11.6) 88 | conda install -y pytorch=${PYTORCH_V} torchtext=${TORCHTEXT_V} cudatoolkit=11.6 -c pytorch -c conda-forge 89 | ;; 90 | system) 91 | local sys_cuda_version=$(get_cuda_version) 92 | case ${sys_cuda_version} in 93 | 10.2) 94 | echo ">>> Found system cuda ${sys_cuda_version}, attemping to install pytorch with pip..." 95 | pip install torch==${PYTORCH_V}+cu102 torchtext==${TORCHTEXT_V} --extra-index-url https://download.pytorch.org/whl/cu102 96 | ;; 97 | 11.3) 98 | echo ">>> Found system cuda ${sys_cuda_version}, attemping to install pytorch with pip..." 99 | pip install torch==${PYTORCH_V}+cu113 torchtext==${TORCHTEXT_V} --extra-index-url https://download.pytorch.org/whl/cu113 100 | ;; 101 | 11.6) 102 | echo ">>> Found system cuda ${sys_cuda_version}, attemping to install pytorch with pip..." 103 | pip install torch==${PYTORCH_V}+cu116 torchtext==${TORCHTEXT_V} --extra-index-url https://download.pytorch.org/whl/cu116 104 | ;; 105 | *) 106 | echo ">>> [ERROR] Not found compatible system cuda (detected ${sys_cuda_version})!" 107 | return 108 | ;; 109 | esac 110 | ;; 111 | *) 112 | echo ">>> [ERROR] cuda_version=${cuda_version} is not supported!" 113 | return 114 | ;; 115 | esac 116 | 117 | } 118 | 119 | 120 | prepare_conda_env "$@" -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.6.5 2 | numpy==1.20.3 3 | pytorch_lightning==1.5.8 4 | recordclass==0.16.2 5 | sacrebleu==1.4.14 6 | sentencepiece==0.1.96 7 | setuptools==58.0.4 8 | seutil==0.8.4 9 | six==1.16.0 10 | tqdm==4.62.3 11 | transformers==4.16.2 12 | tree_sitter==0.20.0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nltk==3.6.5 2 | numpy==1.20.3 3 | pytorch_lightning==1.5.8 4 | recordclass==0.16.2 5 | sacrebleu==1.4.14 6 | sentencepiece==0.1.96 7 | setuptools==58.0.4 8 | seutil==0.8.4 9 | six==1.16.0 10 | tqdm==4.62.3 11 | transformers==4.16.2 12 | tree_sitter==0.20.0 13 | antlr4-python3-runtime==4.9.3 14 | jsonargparse==4.24.1 --------------------------------------------------------------------------------