├── .github └── workflows │ ├── publish.yml │ └── ubuntu.yml ├── .gitignore ├── LICENSE ├── README.md ├── code2seq ├── README.md ├── __init__.py ├── code2class_wrapper.py ├── code2seq_wrapper.py ├── data │ ├── __init__.py │ ├── path_context.py │ ├── path_context_data_module.py │ ├── path_context_dataset.py │ ├── typed_path_context_data_module.py │ ├── typed_path_context_dataset.py │ └── vocabulary.py ├── model │ ├── __init__.py │ ├── code2class.py │ ├── code2seq.py │ ├── modules │ │ ├── __init__.py │ │ ├── path_encoder.py │ │ └── typed_path_encoder.py │ └── typed_code2seq.py ├── typed_code2seq_wrapper.py └── utils │ ├── __init__.py │ ├── common.py │ ├── optimization.py │ ├── test.py │ └── train.py ├── config ├── code2class-poj104.yaml ├── code2seq-java-med.yaml ├── code2seq-java-small.yaml ├── code2seq-java-test.yaml ├── typed-code2seq-java-small.yaml └── typed-code2seq-java-test.yaml ├── requirements.txt ├── scripts └── split_dataset.sh ├── setup.py └── tests └── test_tokenization.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | # Publish to Test PyPI in case of pushing into the master 4 | # Publish to PyPI in case of releasing 5 | 6 | on: 7 | push: 8 | 9 | jobs: 10 | build-n-publish: 11 | name: Build and publish Python 🐍 distributions 📦 to PyPI 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.8 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.8 20 | - name: Install pypa/build 21 | run: | 22 | python -m pip install build --user 23 | - name: Build a binary wheel and a source tarball 24 | run: | 25 | python -m build --sdist --wheel --outdir dist/ . 26 | # - name: Publish distribution 📦 to Test PyPI 27 | # if: github.ref == 'refs/heads/master' 28 | # uses: pypa/gh-action-pypi-publish@master 29 | # with: 30 | # password: ${{ secrets.TEST_PYPI_API_TOKEN }} 31 | # repository_url: https://test.pypi.org/legacy/ 32 | - name: Publish distribution 📦 to PyPI 33 | if: startsWith(github.ref, 'refs/tags') 34 | uses: pypa/gh-action-pypi-publish@master 35 | with: 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/ubuntu.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | 6 | jobs: 7 | build: 8 | 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.8 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 21 | - name: Lint with Black 22 | run: | 23 | pip install black 24 | black . --check -l 120 25 | - name: Test with unittest 26 | run: | 27 | python -m unittest discover tests -v 28 | - name: Check types with mypy 29 | run: | 30 | pip install mypy 31 | mypy . --ignore-missing-imports 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | .ipynb_checkpoints/ 3 | __pycache__/ 4 | .DS_Store 5 | .mypy_cache/ 6 | 7 | *.py[cod] 8 | *$py.class 9 | 10 | wandb/ 11 | notebooks/ 12 | outputs/ 13 | 14 | code2seq.egg-info/ 15 | dist/ 16 | build/ 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Egor Spirin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # code2seq 2 | 3 | [![JetBrains Research](https://jb.gg/badges/research.svg)](https://confluence.jetbrains.com/display/ALL/JetBrains+on+GitHub) 4 | [![Github action: build](https://github.com/SpirinEgor/code2seq/workflows/Build/badge.svg)](https://github.com/SpirinEgor/code2seq/actions?query=workflow%3ABuild) 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | 7 | 8 | PyTorch's implementation of code2seq model. 9 | 10 | ## Installation 11 | 12 | You can easily install model through the PIP: 13 | ```shell 14 | pip install code2seq 15 | ``` 16 | 17 | ## Dataset mining 18 | 19 | To prepare your own dataset with a storage format supported by this implementation, use on the following: 20 | 1. Original dataset preprocessing from vanilla repository 21 | 2. [`astminer`](https://github.com/JetBrains-Research/astminer): 22 | the tool for mining path-based representation and more with multiple language support. 23 | 3. [`PSIMiner`](https://github.com/JetBrains-Research/psiminer): 24 | the tool for extracting PSI trees from IntelliJ Platform and creating datasets from them. 25 | ## Available checkpoints 26 | 27 | ### Method name prediction 28 | | Dataset (with link) | Checkpoint | # epochs | F1-score | Precision | Recall | ChrF | 29 | |-------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------|----------|----------|-----------|--------|-------| 30 | | [Java-small](https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/java-paths-methods/java-small.tar.gz) | [link](https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/checkpoints/code2seq_java_small.ckpt) | 11 | 41.49 | 54.26 | 33.59 | 30.21 | 31 | | [Java-med](https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/java-paths-methods/java-med.tar.gz) | [link](https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/checkpoints/code2seq_java_med.ckpt) | 10 | 48.17 | 58.87 | 40.76 | 42.32 | 32 | 33 | ## Configuration 34 | 35 | The model is fully configurable by standalone YAML file. 36 | Navigate to [config](config) directory to see examples of configs. 37 | 38 | ## Examples 39 | 40 | Model training may be done via PyTorch Lightning trainer. 41 | See it [documentation](https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html) for more information. 42 | 43 | ```python 44 | from argparse import ArgumentParser 45 | 46 | from omegaconf import DictConfig, OmegaConf 47 | from pytorch_lightning import Trainer 48 | 49 | from code2seq.data.path_context_data_module import PathContextDataModule 50 | from code2seq.model import Code2Seq 51 | 52 | 53 | def train(config: DictConfig): 54 | # Define data module 55 | data_module = PathContextDataModule(config.data_folder, config.data) 56 | 57 | # Define model 58 | model = Code2Seq( 59 | config.model, 60 | config.optimizer, 61 | data_module.vocabulary, 62 | config.train.teacher_forcing 63 | ) 64 | 65 | # Define hyper parameters 66 | trainer = Trainer(max_epochs=config.train.n_epochs) 67 | 68 | # Train model 69 | trainer.fit(model, datamodule=data_module) 70 | 71 | 72 | if __name__ == "__main__": 73 | __arg_parser = ArgumentParser() 74 | __arg_parser.add_argument("config", help="Path to YAML configuration file", type=str) 75 | __args = __arg_parser.parse_args() 76 | 77 | __config = OmegaConf.load(__args.config) 78 | train(__config) 79 | ``` 80 | -------------------------------------------------------------------------------- /code2seq/README.md: -------------------------------------------------------------------------------- 1 | # code2seq 2 | 3 | PyTorch's implementation of code2seq model. 4 | 5 | ## Configuration 6 | 7 | Use `yaml` files from [config](code2seq/configs) directory to configure all processes. 8 | `model` option is used to define model, for now repository supports: 9 | - code2seq 10 | - typed-code2seq 11 | - code2class 12 | 13 | `data_folder` stands for the path to the folder with dataset. 14 | For checkpoints with predefined config, users can specify data folder by argument in corresponding script. 15 | 16 | ## Data 17 | 18 | Code2seq implementation supports the same data format as the original [model](https://github.com/tech-srl/code2seq). 19 | The only one different is storing vocabulary. To recollect vocabulary use 20 | ```shell 21 | PYTHONPATH='.' python preprocessing/build_vocabulary.py 22 | ``` 23 | 24 | ## Train model 25 | 26 | To train model use `train.py` script 27 | ```shell 28 | python train.py model 29 | ``` 30 | Use [`main.yaml`](code2seq/configs/main.yaml) to set up hyper-parameters. 31 | Use corresponding configuration from [`configs/model`](code2seq/configs/model) to set up dataset. 32 | 33 | To resume training from saved checkpoint use `--resume` argument 34 | ```shell 35 | python train.py model --resume checkpoint.ckpt 36 | ``` 37 | 38 | ## Evaluate model 39 | 40 | To evaluate trained model use `test.py` script 41 | ```shell 42 | python test.py checkpoint.py 43 | ``` 44 | 45 | To specify the folder with data (in case on evaluating on different from training machine) use `--data-folder` argument 46 | ```shell 47 | python test.py checkpoint.py --data-folder path 48 | ``` 49 | -------------------------------------------------------------------------------- /code2seq/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/code2seq/d48c3a28b2f855414c4107dc19692927acb47d9d/code2seq/__init__.py -------------------------------------------------------------------------------- /code2seq/code2class_wrapper.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import cast 3 | 4 | import torch 5 | from commode_utils.common import print_config 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from code2seq.data.path_context_data_module import PathContextDataModule 9 | from code2seq.model import Code2Class 10 | from code2seq.utils.common import filter_warnings 11 | from code2seq.utils.test import test 12 | from code2seq.utils.train import train 13 | 14 | 15 | def configure_arg_parser() -> ArgumentParser: 16 | arg_parser = ArgumentParser() 17 | arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) 18 | arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) 19 | return arg_parser 20 | 21 | 22 | def train_code2class(config: DictConfig): 23 | filter_warnings() 24 | 25 | if config.print_config: 26 | print_config(config, fields=["model", "data", "train", "optimizer"]) 27 | 28 | # Load data module 29 | data_module = PathContextDataModule(config.data_folder, config.data, is_class=True) 30 | 31 | # Load model 32 | code2class = Code2Class(config.model, config.optimizer, data_module.vocabulary) 33 | 34 | train(code2class, data_module, config) 35 | 36 | 37 | def test_code2class(config: DictConfig): 38 | filter_warnings() 39 | 40 | # Load data module 41 | data_module = PathContextDataModule(config.data_folder, config.data) 42 | 43 | # Load model 44 | code2class = Code2Class.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) 45 | 46 | test(code2class, data_module, config.seed) 47 | 48 | 49 | if __name__ == "__main__": 50 | __arg_parser = configure_arg_parser() 51 | __args = __arg_parser.parse_args() 52 | 53 | __config = cast(DictConfig, OmegaConf.load(__args.config)) 54 | if __args.mode == "train": 55 | train_code2class(__config) 56 | else: 57 | test_code2class(__config) 58 | -------------------------------------------------------------------------------- /code2seq/code2seq_wrapper.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import cast 3 | 4 | import torch 5 | from commode_utils.common import print_config 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from code2seq.data.path_context_data_module import PathContextDataModule 9 | from code2seq.model import Code2Seq 10 | from code2seq.utils.common import filter_warnings 11 | from code2seq.utils.test import test 12 | from code2seq.utils.train import train 13 | 14 | 15 | def configure_arg_parser() -> ArgumentParser: 16 | arg_parser = ArgumentParser() 17 | arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) 18 | arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) 19 | return arg_parser 20 | 21 | 22 | def train_code2seq(config: DictConfig): 23 | filter_warnings() 24 | 25 | if config.print_config: 26 | print_config(config, fields=["model", "data", "train", "optimizer"]) 27 | 28 | # Load data module 29 | data_module = PathContextDataModule(config.data_folder, config.data) 30 | 31 | # Load model 32 | code2seq = Code2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) 33 | 34 | train(code2seq, data_module, config) 35 | 36 | 37 | def test_code2seq(config: DictConfig): 38 | filter_warnings() 39 | 40 | # Load data module 41 | data_module = PathContextDataModule(config.data_folder, config.data) 42 | 43 | # Load model 44 | code2seq = Code2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) 45 | 46 | test(code2seq, data_module, config.seed) 47 | 48 | 49 | if __name__ == "__main__": 50 | __arg_parser = configure_arg_parser() 51 | __args = __arg_parser.parse_args() 52 | 53 | __config = cast(DictConfig, OmegaConf.load(__args.config)) 54 | if __args.mode == "train": 55 | train_code2seq(__config) 56 | else: 57 | test_code2seq(__config) 58 | -------------------------------------------------------------------------------- /code2seq/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/code2seq/d48c3a28b2f855414c4107dc19692927acb47d9d/code2seq/data/__init__.py -------------------------------------------------------------------------------- /code2seq/data/path_context.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable, Tuple, Optional, Sequence, List, cast 3 | 4 | import torch 5 | 6 | 7 | @dataclass 8 | class Path: 9 | from_token: List[int] # [max token parts] 10 | path_node: List[int] # [path length] 11 | to_token: List[int] # [max token parts] 12 | 13 | 14 | @dataclass 15 | class LabeledPathContext: 16 | label: List[int] # [max label parts] 17 | path_contexts: Sequence[Path] 18 | 19 | 20 | def transpose(list_of_lists: List[List[int]]) -> List[List[int]]: 21 | return [cast(List[int], it) for it in zip(*list_of_lists)] 22 | 23 | 24 | class BatchedLabeledPathContext: 25 | def __init__(self, all_samples: Sequence[Optional[LabeledPathContext]]): 26 | samples = [s for s in all_samples if s is not None] 27 | 28 | # [max label parts; batch size] 29 | self.labels = torch.tensor(transpose([s.label for s in samples]), dtype=torch.long) 30 | # [batch size] 31 | self.contexts_per_label = torch.tensor([len(s.path_contexts) for s in samples]) 32 | 33 | # [max token parts; n contexts] 34 | self.from_token = torch.tensor( 35 | transpose([path.from_token for s in samples for path in s.path_contexts]), dtype=torch.long 36 | ) 37 | # [path length; n contexts] 38 | self.path_nodes = torch.tensor( 39 | transpose([path.path_node for s in samples for path in s.path_contexts]), dtype=torch.long 40 | ) 41 | # [max token parts; n contexts] 42 | self.to_token = torch.tensor( 43 | transpose([path.to_token for s in samples for path in s.path_contexts]), dtype=torch.long 44 | ) 45 | 46 | def __len__(self) -> int: 47 | return len(self.contexts_per_label) 48 | 49 | def __get_all_tensors(self) -> Iterable[Tuple[str, torch.Tensor]]: 50 | for name, value in vars(self).items(): 51 | if isinstance(value, torch.Tensor): 52 | yield name, value 53 | 54 | def pin_memory(self) -> "BatchedLabeledPathContext": 55 | for name, value in self.__get_all_tensors(): 56 | setattr(self, name, value.pin_memory()) 57 | return self 58 | 59 | def move_to_device(self, device: torch.device): 60 | for name, value in self.__get_all_tensors(): 61 | setattr(self, name, value.to(device)) 62 | 63 | 64 | @dataclass 65 | class TypedPath(Path): 66 | from_type: List[int] # [max type parts] 67 | to_type: List[int] # [max type parts] 68 | 69 | 70 | @dataclass 71 | class LabeledTypedPathContext(LabeledPathContext): 72 | path_contexts: Sequence[TypedPath] 73 | 74 | 75 | class BatchedLabeledTypedPathContext(BatchedLabeledPathContext): 76 | def __init__(self, all_samples: Sequence[Optional[LabeledTypedPathContext]]): 77 | super().__init__(all_samples) 78 | samples = [s for s in all_samples if s is not None] 79 | # [max type parts; n contexts] 80 | self.from_type = torch.tensor( 81 | transpose([path.from_type for s in samples for path in s.path_contexts]), dtype=torch.long 82 | ) 83 | # [max type parts; n contexts] 84 | self.to_type = torch.tensor( 85 | transpose([path.to_type for s in samples for path in s.path_contexts]), dtype=torch.long 86 | ) 87 | -------------------------------------------------------------------------------- /code2seq/data/path_context_data_module.py: -------------------------------------------------------------------------------- 1 | from os.path import exists, join, basename 2 | from typing import List, Optional 3 | 4 | import torch 5 | from commode_utils.common import download_dataset 6 | from commode_utils.vocabulary import build_from_scratch 7 | from omegaconf import DictConfig 8 | from pytorch_lightning import LightningDataModule 9 | from torch.utils.data import DataLoader 10 | 11 | from code2seq.data.path_context import LabeledPathContext, BatchedLabeledPathContext 12 | from code2seq.data.path_context_dataset import PathContextDataset 13 | from code2seq.data.vocabulary import Vocabulary 14 | 15 | 16 | class PathContextDataModule(LightningDataModule): 17 | _train = "train" 18 | _val = "val" 19 | _test = "test" 20 | 21 | def __init__(self, data_dir: str, config: DictConfig, is_class: bool = False): 22 | super().__init__() 23 | self._config = config 24 | self._data_dir = data_dir 25 | self._name = basename(data_dir) 26 | self._is_class = is_class 27 | 28 | self._vocabulary = self.setup_vocabulary() 29 | 30 | @property 31 | def vocabulary(self) -> Vocabulary: 32 | if self._vocabulary is None: 33 | raise RuntimeError(f"Setup data module for initializing vocabulary") 34 | return self._vocabulary 35 | 36 | def prepare_data(self): 37 | if exists(self._data_dir): 38 | print(f"Dataset is already downloaded") 39 | return 40 | if "url" not in self._config: 41 | raise ValueError(f"Config doesn't contain url for, can't download it automatically") 42 | download_dataset(self._config.url, self._data_dir, self._name) 43 | 44 | def setup_vocabulary(self) -> Vocabulary: 45 | vocabulary_path = join(self._data_dir, Vocabulary.vocab_filename) 46 | if not exists(vocabulary_path): 47 | print("Can't find vocabulary, collect it from train holdout") 48 | build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), Vocabulary) 49 | return Vocabulary(vocabulary_path, self._config.labels_count, self._config.tokens_count, self._is_class) 50 | 51 | @staticmethod 52 | def collate_wrapper(batch: List[Optional[LabeledPathContext]]) -> BatchedLabeledPathContext: 53 | return BatchedLabeledPathContext(batch) 54 | 55 | def _create_dataset(self, holdout_file: str, random_context: bool) -> PathContextDataset: 56 | if self._vocabulary is None: 57 | raise RuntimeError(f"Setup vocabulary before creating data loaders") 58 | return PathContextDataset(holdout_file, self._config, self._vocabulary, random_context) 59 | 60 | def _shared_dataloader(self, holdout: str) -> DataLoader: 61 | if self._vocabulary is None: 62 | raise RuntimeError(f"Setup vocabulary before creating data loaders") 63 | 64 | holdout_file = join(self._data_dir, f"{holdout}.c2s") 65 | random_context = self._config.random_context if holdout == self._train else False 66 | dataset = self._create_dataset(holdout_file, random_context) 67 | 68 | batch_size = self._config.batch_size if holdout == self._train else self._config.test_batch_size 69 | shuffle = holdout == self._train 70 | 71 | return DataLoader( 72 | dataset, 73 | batch_size, 74 | shuffle=shuffle, 75 | num_workers=self._config.num_workers, 76 | collate_fn=self.collate_wrapper, 77 | pin_memory=True, 78 | ) 79 | 80 | def train_dataloader(self, *args, **kwargs) -> DataLoader: 81 | return self._shared_dataloader(self._train) 82 | 83 | def val_dataloader(self, *args, **kwargs) -> DataLoader: 84 | return self._shared_dataloader(self._val) 85 | 86 | def test_dataloader(self, *args, **kwargs) -> DataLoader: 87 | return self._shared_dataloader(self._test) 88 | 89 | def predict_dataloader(self, *args, **kwargs) -> DataLoader: 90 | return self.test_dataloader(*args, **kwargs) 91 | 92 | def transfer_batch_to_device( 93 | self, batch: BatchedLabeledPathContext, device: torch.device, dataloader_idx: int 94 | ) -> BatchedLabeledPathContext: 95 | batch.move_to_device(device) 96 | return batch 97 | -------------------------------------------------------------------------------- /code2seq/data/path_context_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import exists 2 | from random import shuffle 3 | from typing import Dict, List, Optional 4 | 5 | from commode_utils.filesystem import get_lines_offsets, get_line_by_offset 6 | from omegaconf import DictConfig 7 | from torch.utils.data import Dataset 8 | 9 | from code2seq.data.path_context import LabeledPathContext, Path 10 | from code2seq.data.vocabulary import Vocabulary 11 | 12 | 13 | class PathContextDataset(Dataset): 14 | _log_file = "bad_samples.log" 15 | _separator = "|" 16 | 17 | def __init__(self, data_file: str, config: DictConfig, vocabulary: Vocabulary, random_context: bool): 18 | if not exists(data_file): 19 | raise ValueError(f"Can't find file with data: {data_file}") 20 | self._data_file = data_file 21 | self._config = config 22 | self._vocab = vocabulary 23 | self._random_context = random_context 24 | 25 | self._line_offsets = get_lines_offsets(data_file) 26 | self._n_samples = len(self._line_offsets) 27 | 28 | open(self._log_file, "w").close() 29 | 30 | def __len__(self): 31 | return self._n_samples 32 | 33 | def __getitem__(self, index) -> Optional[LabeledPathContext]: 34 | raw_sample = get_line_by_offset(self._data_file, self._line_offsets[index]) 35 | try: 36 | raw_label, *raw_path_contexts = raw_sample.split() 37 | except ValueError as e: 38 | with open(self._log_file, "a") as f_out: 39 | f_out.write(f"Error reading sample from line #{index}: {e}") 40 | return None 41 | 42 | # Choose paths for current data sample 43 | n_contexts = min(len(raw_path_contexts), self._config.max_context) 44 | if self._random_context: 45 | shuffle(raw_path_contexts) 46 | raw_path_contexts = raw_path_contexts[:n_contexts] 47 | 48 | # Tokenize label 49 | if self._config.max_label_parts == 1: 50 | label = self.tokenize_class(raw_label, self._vocab.label_to_id) 51 | else: 52 | label = self.tokenize_label(raw_label, self._vocab.label_to_id, self._config.max_label_parts) 53 | 54 | # Tokenize paths 55 | try: 56 | paths = [self._get_path(raw_path.split(",")) for raw_path in raw_path_contexts] 57 | except ValueError as e: 58 | with open(self._log_file, "a") as f_out: 59 | f_out.write(f"Error parsing sample from line #{index}: {e}") 60 | return None 61 | 62 | return LabeledPathContext(label, paths) 63 | 64 | @staticmethod 65 | def tokenize_class(raw_class: str, vocab: Dict[str, int]) -> List[int]: 66 | return [vocab[raw_class]] 67 | 68 | @staticmethod 69 | def tokenize_label(raw_label: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: 70 | sublabels = raw_label.split(PathContextDataset._separator) 71 | max_parts = max_parts or len(sublabels) 72 | label_unk = vocab[Vocabulary.UNK] 73 | 74 | label = [vocab[Vocabulary.SOS]] + [vocab.get(st, label_unk) for st in sublabels[:max_parts]] 75 | if len(sublabels) < max_parts: 76 | label.append(vocab[Vocabulary.EOS]) 77 | label += [vocab[Vocabulary.PAD]] * (max_parts + 1 - len(label)) 78 | return label 79 | 80 | @staticmethod 81 | def tokenize_token(token: str, vocab: Dict[str, int], max_parts: Optional[int]) -> List[int]: 82 | sub_tokens = token.split(PathContextDataset._separator) 83 | max_parts = max_parts or len(sub_tokens) 84 | token_unk = vocab[Vocabulary.UNK] 85 | 86 | result = [vocab.get(st, token_unk) for st in sub_tokens[:max_parts]] 87 | result += [vocab[Vocabulary.PAD]] * (max_parts - len(result)) 88 | return result 89 | 90 | def _get_path(self, raw_path: List[str]) -> Path: 91 | return Path( 92 | from_token=self.tokenize_token(raw_path[0], self._vocab.token_to_id, self._config.max_token_parts), 93 | path_node=self.tokenize_token(raw_path[1], self._vocab.node_to_id, self._config.path_length), 94 | to_token=self.tokenize_token(raw_path[2], self._vocab.token_to_id, self._config.max_token_parts), 95 | ) 96 | -------------------------------------------------------------------------------- /code2seq/data/typed_path_context_data_module.py: -------------------------------------------------------------------------------- 1 | from os.path import exists, join 2 | from typing import List, Optional 3 | 4 | from commode_utils.vocabulary import build_from_scratch 5 | from omegaconf import DictConfig 6 | 7 | from code2seq.data.path_context import LabeledTypedPathContext, BatchedLabeledTypedPathContext 8 | from code2seq.data.path_context_data_module import PathContextDataModule 9 | from code2seq.data.typed_path_context_dataset import TypedPathContextDataset 10 | from code2seq.data.vocabulary import TypedVocabulary 11 | 12 | 13 | class TypedPathContextDataModule(PathContextDataModule): 14 | _vocabulary: TypedVocabulary 15 | 16 | def __init__(self, data_dir: str, config: DictConfig): 17 | super().__init__(data_dir, config) 18 | 19 | @staticmethod 20 | def collate_wrapper( # type: ignore[override] 21 | batch: List[Optional[LabeledTypedPathContext]], 22 | ) -> BatchedLabeledTypedPathContext: 23 | return BatchedLabeledTypedPathContext(batch) 24 | 25 | def _create_dataset(self, holdout_file: str, random_context: bool) -> TypedPathContextDataset: 26 | if self._vocabulary is None: 27 | raise RuntimeError(f"Setup vocabulary before creating data loaders") 28 | return TypedPathContextDataset(holdout_file, self._config, self._vocabulary, random_context) 29 | 30 | def setup_vocabulary(self) -> TypedVocabulary: 31 | if not exists(join(self._data_dir, TypedVocabulary.vocab_filename)): 32 | print("Can't find vocabulary, collect it from train holdout") 33 | build_from_scratch(join(self._data_dir, f"{self._train}.c2s"), TypedVocabulary) 34 | vocabulary_path = join(self._data_dir, TypedVocabulary.vocab_filename) 35 | return TypedVocabulary( 36 | vocabulary_path, self._config.labels_count, self._config.tokens_count, self._config.types_count 37 | ) 38 | 39 | @property 40 | def vocabulary(self) -> TypedVocabulary: 41 | if self._vocabulary is None: 42 | raise RuntimeError(f"Setup data module for initializing vocabulary") 43 | return self._vocabulary 44 | -------------------------------------------------------------------------------- /code2seq/data/typed_path_context_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from omegaconf import DictConfig 4 | 5 | from code2seq.data.path_context import TypedPath 6 | from code2seq.data.path_context_dataset import PathContextDataset 7 | from code2seq.data.vocabulary import TypedVocabulary 8 | 9 | 10 | class TypedPathContextDataset(PathContextDataset): 11 | def __init__(self, data_file: str, config: DictConfig, vocabulary: TypedVocabulary, random_context: bool): 12 | super().__init__(data_file, config, vocabulary, random_context) 13 | self._vocab: TypedVocabulary = vocabulary 14 | 15 | def _get_path(self, raw_path: List[str]) -> TypedPath: 16 | return TypedPath( 17 | from_type=self.tokenize_token(raw_path[0], self._vocab.type_to_id, self._config.max_type_parts), 18 | from_token=self.tokenize_token(raw_path[1], self._vocab.token_to_id, self._config.max_token_parts), 19 | path_node=self.tokenize_token(raw_path[2], self._vocab.node_to_id, self._config.path_length), 20 | to_token=self.tokenize_token(raw_path[3], self._vocab.token_to_id, self._config.max_token_parts), 21 | to_type=self.tokenize_token(raw_path[4], self._vocab.type_to_id, self._config.max_type_parts), 22 | ) 23 | -------------------------------------------------------------------------------- /code2seq/data/vocabulary.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from collections import Counter 3 | from os.path import dirname, join 4 | from pickle import load, dump 5 | from typing import Dict, Counter as CounterType, Optional, List 6 | 7 | from commode_utils.vocabulary import BaseVocabulary, build_from_scratch 8 | 9 | 10 | class Vocabulary(BaseVocabulary): 11 | def __init__( 12 | self, 13 | vocabulary_file: str, 14 | labels_count: Optional[int] = None, 15 | tokens_count: Optional[int] = None, 16 | is_class: bool = False, 17 | ): 18 | super().__init__(vocabulary_file, labels_count, tokens_count) 19 | if is_class: 20 | labels = self._extract_tokens_by_count(self._counters[self.LABEL], labels_count) 21 | self._label_to_id = {token: i for i, token in enumerate(labels)} 22 | 23 | @staticmethod 24 | def _process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]], context_seq: List[str]): 25 | label, *path_contexts = raw_sample.split(" ") 26 | counters[Vocabulary.LABEL].update(label.split(Vocabulary._separator)) 27 | for path_context in path_contexts: 28 | for token, desc in zip(path_context.split(","), context_seq): 29 | counters[desc].update(token.split(Vocabulary._separator)) 30 | 31 | @staticmethod 32 | def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]): 33 | Vocabulary._process_raw_sample( 34 | raw_sample, counters, [BaseVocabulary.TOKEN, BaseVocabulary.NODE, BaseVocabulary.TOKEN] 35 | ) 36 | 37 | 38 | class TypedVocabulary(Vocabulary): 39 | TYPE = "tokenType" 40 | 41 | _path_context_seq = [TYPE, Vocabulary.TOKEN, Vocabulary.NODE, Vocabulary.TOKEN, TYPE] 42 | 43 | def __init__( 44 | self, 45 | vocabulary_file: str, 46 | labels_count: Optional[int] = None, 47 | tokens_count: Optional[int] = None, 48 | types_count: Optional[int] = None, 49 | ): 50 | super().__init__(vocabulary_file, labels_count, tokens_count) 51 | 52 | self._type_to_id = {self.PAD: 0, self.UNK: 1, self.SOS: 2, self.EOS: 3} 53 | types = self._extract_tokens_by_count(self._counters[self.TYPE], types_count) 54 | self._type_to_id.update((token, i + 4) for i, token in enumerate(types)) 55 | 56 | @property 57 | def type_to_id(self) -> Dict[str, int]: 58 | return self._type_to_id 59 | 60 | @staticmethod 61 | def process_raw_sample(raw_sample: str, counters: Dict[str, CounterType[str]]): 62 | if TypedVocabulary.TYPE not in counters: 63 | counters[TypedVocabulary.TYPE] = Counter() 64 | context_seq = [ 65 | TypedVocabulary.TYPE, 66 | BaseVocabulary.TOKEN, 67 | BaseVocabulary.NODE, 68 | BaseVocabulary.TOKEN, 69 | TypedVocabulary.TYPE, 70 | ] 71 | TypedVocabulary._process_raw_sample(raw_sample, counters, context_seq) 72 | 73 | 74 | def convert_from_vanilla(vocabulary_path: str): 75 | counters: Dict[str, CounterType[str]] = {} 76 | with open(vocabulary_path, "rb") as dict_file: 77 | counters[Vocabulary.TOKEN] = Counter(load(dict_file)) 78 | counters[Vocabulary.NODE] = Counter(load(dict_file)) 79 | counters[Vocabulary.LABEL] = Counter(load(dict_file)) 80 | 81 | for feature, counter in counters.items(): 82 | print(f"Count {len(counter)} {feature}, top-5: {counter.most_common(5)}") 83 | 84 | dataset_dir = dirname(vocabulary_path) 85 | vocabulary_file = join(dataset_dir, Vocabulary.vocab_filename) 86 | with open(vocabulary_file, "wb") as f_out: 87 | dump(counters, f_out) 88 | 89 | 90 | if __name__ == "__main__": 91 | __arg_parse = ArgumentParser() 92 | __arg_parse.add_argument("data", type=str, help="Path to file with data") 93 | __arg_parse.add_argument("--typed", action="store_true", help="Use typed vocabulary") 94 | __args = __arg_parse.parse_args() 95 | 96 | if __args.data.endswith(".dict.c2s"): 97 | convert_from_vanilla(__args.data) 98 | else: 99 | __vocab_cls = TypedVocabulary if __args.typed else Vocabulary 100 | build_from_scratch(__args.data, __vocab_cls) 101 | -------------------------------------------------------------------------------- /code2seq/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .code2class import Code2Class 2 | from .code2seq import Code2Seq 3 | from .typed_code2seq import TypedCode2Seq 4 | 5 | __all__ = ["Code2Class", "Code2Seq", "TypedCode2Seq"] 6 | -------------------------------------------------------------------------------- /code2seq/model/code2class.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import torch 4 | from commode_utils.modules import Classifier 5 | from omegaconf import DictConfig 6 | from pytorch_lightning import LightningModule 7 | from pytorch_lightning.utilities.types import EPOCH_OUTPUT 8 | from torch.optim import Optimizer 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | from torchmetrics import Metric, Accuracy, MetricCollection 11 | 12 | from code2seq.data.path_context import BatchedLabeledPathContext 13 | from code2seq.data.vocabulary import Vocabulary 14 | from code2seq.model.modules import PathEncoder 15 | from code2seq.utils.optimization import configure_optimizers_alon 16 | 17 | 18 | class Code2Class(LightningModule): 19 | def __init__(self, model_config: DictConfig, optimizer_config: DictConfig, vocabulary: Vocabulary): 20 | super().__init__() 21 | self.save_hyperparameters() 22 | self._optim_config = optimizer_config 23 | 24 | self._encoder = PathEncoder( 25 | model_config, 26 | len(vocabulary.token_to_id), 27 | vocabulary.token_to_id[Vocabulary.PAD], 28 | len(vocabulary.node_to_id), 29 | vocabulary.node_to_id[Vocabulary.PAD], 30 | ) 31 | 32 | self._classifier = Classifier(model_config, len(vocabulary.label_to_id)) 33 | 34 | metrics: Dict[str, Metric] = { 35 | f"{holdout}_acc": Accuracy(num_classes=len(vocabulary.label_to_id)) for holdout in ["train", "val", "test"] 36 | } 37 | self.__metrics = MetricCollection(metrics) 38 | 39 | def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: 40 | return configure_optimizers_alon(self._optim_config, self.parameters()) 41 | 42 | def forward( # type: ignore 43 | self, 44 | from_token: torch.Tensor, 45 | path_nodes: torch.Tensor, 46 | to_token: torch.Tensor, 47 | contexts_per_label: torch.Tensor, 48 | ) -> torch.Tensor: 49 | encoded_paths = self._encoder(from_token, path_nodes, to_token) 50 | output_logits = self._classifier(encoded_paths, contexts_per_label) 51 | return output_logits 52 | 53 | # ========== MODEL STEP ========== 54 | 55 | def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: 56 | # [batch size; num_classes] 57 | logits = self(batch.from_token, batch.path_nodes, batch.to_token, batch.contexts_per_label) 58 | labels = batch.labels.squeeze(0) 59 | loss = torch.nn.functional.cross_entropy(logits, labels) 60 | 61 | with torch.no_grad(): 62 | predictions = logits.argmax(-1) 63 | accuracy = self.__metrics[f"{step}_acc"](predictions, labels) 64 | 65 | return {f"{step}/loss": loss, f"{step}/accuracy": accuracy} 66 | 67 | def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore 68 | result = self._shared_step(batch, "train") 69 | self.log_dict(result, on_step=True, on_epoch=False) 70 | self.log("acc", result["train/accuracy"], prog_bar=True, logger=False) 71 | return result["train/loss"] 72 | 73 | def validation_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore 74 | return self._shared_step(batch, "val") 75 | 76 | def test_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore 77 | return self._shared_step(batch, "test") 78 | 79 | # ========== ON EPOCH END ========== 80 | 81 | def _shared_epoch_end(self, outputs: EPOCH_OUTPUT, step: str): 82 | assert isinstance(outputs, dict) 83 | with torch.no_grad(): 84 | mean_loss = torch.stack([out[f"{step}/loss"] for out in outputs]).mean() 85 | accuracy = self.__metrics[f"{step}_acc"].compute() 86 | log = {f"{step}/loss": mean_loss, f"{step}/accuracy": accuracy} 87 | self.__metrics[f"{step}_acc"].reset() 88 | self.log_dict(log, on_step=False, on_epoch=True) 89 | 90 | def training_epoch_end(self, outputs: EPOCH_OUTPUT): 91 | self._shared_epoch_end(outputs, "train") 92 | 93 | def validation_epoch_end(self, outputs: EPOCH_OUTPUT): 94 | self._shared_epoch_end(outputs, "val") 95 | 96 | def test_epoch_end(self, outputs: EPOCH_OUTPUT): 97 | self._shared_epoch_end(outputs, "test") 98 | -------------------------------------------------------------------------------- /code2seq/model/code2seq.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Dict, Optional 2 | 3 | import torch 4 | from commode_utils.losses import SequenceCrossEntropyLoss 5 | from commode_utils.metrics import SequentialF1Score, ClassificationMetrics 6 | from commode_utils.metrics.chrF import ChrF 7 | from commode_utils.modules import LSTMDecoderStep, Decoder 8 | from omegaconf import DictConfig 9 | from pytorch_lightning import LightningModule 10 | from pytorch_lightning.utilities.types import EPOCH_OUTPUT 11 | from torch import nn 12 | from torch.optim import Optimizer 13 | from torch.optim.lr_scheduler import _LRScheduler 14 | from torchmetrics import MetricCollection, Metric 15 | 16 | from code2seq.data.path_context import BatchedLabeledPathContext 17 | from code2seq.data.vocabulary import Vocabulary 18 | from code2seq.model.modules import PathEncoder 19 | from code2seq.utils.optimization import configure_optimizers_alon 20 | 21 | 22 | class Code2Seq(LightningModule): 23 | def __init__( 24 | self, 25 | model_config: DictConfig, 26 | optimizer_config: DictConfig, 27 | vocabulary: Vocabulary, 28 | teacher_forcing: float = 0.0, 29 | ): 30 | super().__init__() 31 | self.save_hyperparameters() 32 | self._optim_config = optimizer_config 33 | self._vocabulary = vocabulary 34 | 35 | if vocabulary.SOS not in vocabulary.label_to_id: 36 | raise ValueError(f"Can't find SOS token in label to id vocabulary") 37 | 38 | self.__pad_idx = vocabulary.label_to_id[vocabulary.PAD] 39 | eos_idx = vocabulary.label_to_id[vocabulary.EOS] 40 | ignore_idx = [vocabulary.label_to_id[vocabulary.SOS], vocabulary.label_to_id[vocabulary.UNK]] 41 | metrics: Dict[str, Metric] = { 42 | f"{holdout}_f1": SequentialF1Score(pad_idx=self.__pad_idx, eos_idx=eos_idx, ignore_idx=ignore_idx) 43 | for holdout in ["train", "val", "test"] 44 | } 45 | id2label = {v: k for k, v in vocabulary.label_to_id.items()} 46 | metrics.update( 47 | {f"{holdout}_chrf": ChrF(id2label, ignore_idx + [self.__pad_idx, eos_idx]) for holdout in ["val", "test"]} 48 | ) 49 | self.__metrics = MetricCollection(metrics) 50 | 51 | self._encoder = self._get_encoder(model_config) 52 | decoder_step = LSTMDecoderStep(model_config, len(vocabulary.label_to_id), self.__pad_idx) 53 | self._decoder = Decoder( 54 | decoder_step, len(vocabulary.label_to_id), vocabulary.label_to_id[vocabulary.SOS], teacher_forcing 55 | ) 56 | 57 | self.__loss = SequenceCrossEntropyLoss(self.__pad_idx, reduction="batch-mean") 58 | 59 | @property 60 | def vocabulary(self) -> Vocabulary: 61 | return self._vocabulary 62 | 63 | def _get_encoder(self, config: DictConfig) -> nn.Module: 64 | return PathEncoder( 65 | config, 66 | len(self._vocabulary.token_to_id), 67 | self._vocabulary.token_to_id[Vocabulary.PAD], 68 | len(self._vocabulary.node_to_id), 69 | self._vocabulary.node_to_id[Vocabulary.PAD], 70 | ) 71 | 72 | # ========== Main PyTorch-Lightning hooks ========== 73 | 74 | def configure_optimizers(self) -> Tuple[List[Optimizer], List[_LRScheduler]]: 75 | return configure_optimizers_alon(self._optim_config, self.parameters()) 76 | 77 | def forward( # type: ignore 78 | self, 79 | from_token: torch.Tensor, 80 | path_nodes: torch.Tensor, 81 | to_token: torch.Tensor, 82 | contexts_per_label: torch.Tensor, 83 | output_length: int, 84 | target_sequence: torch.Tensor = None, 85 | ) -> Tuple[torch.Tensor, torch.Tensor]: 86 | encoded_paths = self._encoder(from_token, path_nodes, to_token) 87 | output_logits, attention_weights = self._decoder( 88 | encoded_paths, contexts_per_label, output_length, target_sequence 89 | ) 90 | return output_logits, attention_weights 91 | 92 | # ========== Model step ========== 93 | 94 | def logits_from_batch( 95 | self, batch: BatchedLabeledPathContext, target_sequence: Optional[torch.Tensor] = None 96 | ) -> Tuple[torch.Tensor, torch.Tensor]: 97 | return self( 98 | batch.from_token, 99 | batch.path_nodes, 100 | batch.to_token, 101 | batch.contexts_per_label, 102 | batch.labels.shape[0], 103 | target_sequence, 104 | ) 105 | 106 | def _shared_step(self, batch: BatchedLabeledPathContext, step: str) -> Dict: 107 | target_sequence = batch.labels if step == "train" else None 108 | # [seq length; batch size; vocab size] 109 | logits, _ = self.logits_from_batch(batch, target_sequence) 110 | result = {f"{step}/loss": self.__loss(logits[1:], batch.labels[1:])} 111 | 112 | with torch.no_grad(): 113 | prediction = logits.argmax(-1) 114 | metric: ClassificationMetrics = self.__metrics[f"{step}_f1"](prediction, batch.labels) 115 | result.update( 116 | {f"{step}/f1": metric.f1_score, f"{step}/precision": metric.precision, f"{step}/recall": metric.recall} 117 | ) 118 | if step != "train": 119 | result[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"](prediction, batch.labels) 120 | 121 | return result 122 | 123 | def training_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore 124 | result = self._shared_step(batch, "train") 125 | self.log_dict(result, on_step=True, on_epoch=False) 126 | self.log("f1", result["train/f1"], prog_bar=True, logger=False) 127 | return result["train/loss"] 128 | 129 | def validation_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore 130 | result = self._shared_step(batch, "val") 131 | return result["val/loss"] 132 | 133 | def test_step(self, batch: BatchedLabeledPathContext, batch_idx: int) -> Dict: # type: ignore 134 | result = self._shared_step(batch, "test") 135 | return result["test/loss"] 136 | 137 | # ========== On epoch end ========== 138 | 139 | def _shared_epoch_end(self, step_outputs: EPOCH_OUTPUT, step: str): 140 | with torch.no_grad(): 141 | losses = [so if isinstance(so, torch.Tensor) else so["loss"] for so in step_outputs] 142 | mean_loss = torch.stack(losses).mean() 143 | metric = self.__metrics[f"{step}_f1"].compute() 144 | log = { 145 | f"{step}/loss": mean_loss, 146 | f"{step}/f1": metric.f1_score, 147 | f"{step}/precision": metric.precision, 148 | f"{step}/recall": metric.recall, 149 | } 150 | self.__metrics[f"{step}_f1"].reset() 151 | if step != "train": 152 | log[f"{step}/chrf"] = self.__metrics[f"{step}_chrf"].compute() 153 | self.__metrics[f"{step}_chrf"].reset() 154 | self.log_dict(log, on_step=False, on_epoch=True) 155 | 156 | def training_epoch_end(self, step_outputs: EPOCH_OUTPUT): 157 | self._shared_epoch_end(step_outputs, "train") 158 | 159 | def validation_epoch_end(self, step_outputs: EPOCH_OUTPUT): 160 | self._shared_epoch_end(step_outputs, "val") 161 | 162 | def test_epoch_end(self, step_outputs: EPOCH_OUTPUT): 163 | self._shared_epoch_end(step_outputs, "test") 164 | -------------------------------------------------------------------------------- /code2seq/model/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .path_encoder import PathEncoder 2 | from .typed_path_encoder import TypedPathEncoder 3 | 4 | __all__ = ["PathEncoder", "TypedPathEncoder"] 5 | -------------------------------------------------------------------------------- /code2seq/model/modules/path_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torch import nn 6 | 7 | 8 | class PathEncoder(nn.Module): 9 | def __init__( 10 | self, 11 | config: DictConfig, 12 | n_tokens: int, 13 | token_pad_id: int, 14 | n_nodes: int, 15 | node_pad_id: int, 16 | ): 17 | super().__init__() 18 | self.node_pad_id = node_pad_id 19 | self.num_directions = 2 if config.use_bi_rnn else 1 20 | 21 | self.token_embedding = nn.Embedding(n_tokens, config.embedding_size, padding_idx=token_pad_id) 22 | self.node_embedding = nn.Embedding(n_nodes, config.embedding_size, padding_idx=node_pad_id) 23 | 24 | self.dropout_rnn = nn.Dropout(config.encoder_dropout) 25 | self.path_lstm = nn.LSTM( 26 | config.embedding_size, 27 | config.encoder_rnn_size, 28 | num_layers=config.rnn_num_layers, 29 | bidirectional=config.use_bi_rnn, 30 | dropout=config.encoder_dropout if config.rnn_num_layers > 1 else 0, 31 | ) 32 | 33 | concat_size = self._calculate_concat_size(config.embedding_size, config.encoder_rnn_size, self.num_directions) 34 | self.embedding_dropout = nn.Dropout(config.encoder_dropout) 35 | if "decoder_size" in config: 36 | out_size = config["decoder_size"] 37 | elif "classifier_size" in config: 38 | out_size = config["classifier_size"] 39 | else: 40 | raise ValueError("Specify out size of encoder") 41 | self.linear = nn.Linear(concat_size, out_size, bias=False) 42 | self.norm = nn.LayerNorm(out_size) 43 | 44 | @staticmethod 45 | def _calculate_concat_size(embedding_size: int, rnn_size: int, num_directions: int) -> int: 46 | return embedding_size * 2 + rnn_size * num_directions 47 | 48 | def _token_embedding(self, tokens: torch.Tensor) -> torch.Tensor: 49 | return self.token_embedding(tokens).sum(0) 50 | 51 | def _path_nodes_embedding(self, path_nodes: torch.Tensor) -> torch.Tensor: 52 | # [max path length; n contexts; embedding size] 53 | path_nodes_embeddings = self.node_embedding(path_nodes) 54 | 55 | with torch.no_grad(): 56 | is_contain_pad_id, first_pad_pos = torch.max(path_nodes == self.node_pad_id, dim=0) 57 | first_pad_pos[~is_contain_pad_id] = path_nodes.shape[0] # if no pad token use len+1 position 58 | sorted_path_lengths, sort_indices = torch.sort(first_pad_pos, descending=True) 59 | _, reverse_sort_indices = torch.sort(sort_indices) 60 | sorted_path_lengths = sorted_path_lengths.to(torch.device("cpu")) 61 | path_nodes_embeddings = path_nodes_embeddings[:, sort_indices] 62 | 63 | packed_path_nodes = nn.utils.rnn.pack_padded_sequence(path_nodes_embeddings, sorted_path_lengths) 64 | 65 | # [num layers * num directions; total paths; rnn size] 66 | _, (h_t, _) = self.path_lstm(packed_path_nodes) 67 | # [total_paths; rnn size * num directions] 68 | encoded_paths = h_t[-self.num_directions :].transpose(0, 1).reshape(h_t.shape[1], -1) 69 | encoded_paths = self.dropout_rnn(encoded_paths) 70 | 71 | encoded_paths = encoded_paths[reverse_sort_indices] 72 | return encoded_paths 73 | 74 | def _concat_with_linear(self, encoded_contexts: List[torch.Tensor]) -> torch.Tensor: 75 | # [n contexts; sum across all embeddings] 76 | concat = torch.cat(encoded_contexts, dim=-1) 77 | 78 | # [n contexts; output size] 79 | concat = self.embedding_dropout(concat) 80 | return torch.tanh(self.norm(self.linear(concat))) 81 | 82 | def forward(self, from_token: torch.Tensor, path_nodes: torch.Tensor, to_token: torch.Tensor) -> torch.Tensor: 83 | """Encode each path context into the vector 84 | 85 | :param from_token: [max token parts; n contexts] start tokens 86 | :param path_nodes: [path length; n contexts] path nodes 87 | :param to_token: [max tokens parts; n contexts] end tokens 88 | :return: [n contexts; encoder size] 89 | """ 90 | # [n contexts; embedding size] 91 | encoded_from_tokens = self._token_embedding(from_token) 92 | encoded_to_tokens = self._token_embedding(to_token) 93 | 94 | # [n contexts; rnn size * num directions] 95 | encoded_paths = self._path_nodes_embedding(path_nodes) 96 | 97 | # [n contexts; output size] 98 | output = self._concat_with_linear([encoded_from_tokens, encoded_paths, encoded_to_tokens]) 99 | return output 100 | -------------------------------------------------------------------------------- /code2seq/model/modules/typed_path_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from omegaconf import DictConfig 3 | from torch import nn 4 | 5 | from code2seq.model.modules import PathEncoder 6 | 7 | 8 | class TypedPathEncoder(PathEncoder): 9 | def __init__( 10 | self, 11 | config: DictConfig, 12 | n_tokens: int, 13 | token_pad_id: int, 14 | n_nodes: int, 15 | node_pad_id: int, 16 | n_types: int, 17 | type_pad_id: int, 18 | ): 19 | super().__init__(config, n_tokens, token_pad_id, n_nodes, node_pad_id) 20 | 21 | self.type_embedding = nn.Embedding(n_types, config.embedding_size, padding_idx=type_pad_id) 22 | 23 | @staticmethod 24 | def _calculate_concat_size(embedding_size: int, rnn_size: int, num_directions: int) -> int: 25 | return embedding_size * 4 + rnn_size * num_directions 26 | 27 | def _type_embedding(self, types: torch.Tensor) -> torch.Tensor: 28 | return self.type_embedding(types).sum(0) 29 | 30 | def forward( # type: ignore 31 | self, 32 | from_type: torch.Tensor, 33 | from_token: torch.Tensor, 34 | path_nodes: torch.Tensor, 35 | to_token: torch.Tensor, 36 | to_type: torch.Tensor, 37 | ) -> torch.Tensor: 38 | """Encode each path context into the vector 39 | 40 | :param from_type: [n contexts; max type parts] types of start tokens 41 | :param from_token: [n contexts; max token parts] start tokens 42 | :param path_nodes: [n contexts; path nodes] path nodes 43 | :param to_token: [n contexts; max tokens parts] end tokens 44 | :param to_type: [n contexts; max types parts] types of end tokens 45 | :return: [n contexts; encoder size] 46 | """ 47 | # [total paths; embedding size] 48 | encoded_from_tokens = self._token_embedding(from_token) 49 | encoded_to_tokens = self._token_embedding(to_token) 50 | 51 | # [total paths; embeddings size] 52 | encoded_from_types = self._type_embedding(from_type) 53 | encoded_to_types = self._type_embedding(to_type) 54 | 55 | # [total_paths; rnn size * num directions] 56 | encoded_paths = self._path_nodes_embedding(path_nodes) 57 | 58 | # [total_paths; output size] 59 | output = self._concat_with_linear( 60 | [encoded_from_types, encoded_from_tokens, encoded_paths, encoded_to_tokens, encoded_to_types] 61 | ) 62 | return output 63 | -------------------------------------------------------------------------------- /code2seq/model/typed_code2seq.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | 6 | from code2seq.data.path_context import BatchedLabeledTypedPathContext 7 | from code2seq.data.vocabulary import TypedVocabulary 8 | from code2seq.model import Code2Seq 9 | from code2seq.model.modules import TypedPathEncoder, PathEncoder 10 | 11 | 12 | class TypedCode2Seq(Code2Seq): 13 | def __init__( 14 | self, 15 | model_config: DictConfig, 16 | optimizer_config: DictConfig, 17 | vocabulary: TypedVocabulary, 18 | teacher_forcing: float = 0.0, 19 | ): 20 | super().__init__(model_config, optimizer_config, vocabulary, teacher_forcing) 21 | self._vocabulary: TypedVocabulary = vocabulary 22 | 23 | def _get_encoder(self, config: DictConfig) -> PathEncoder: 24 | return TypedPathEncoder( 25 | config, 26 | len(self._vocabulary.token_to_id), 27 | self._vocabulary.token_to_id[TypedVocabulary.PAD], 28 | len(self._vocabulary.node_to_id), 29 | self._vocabulary.node_to_id[TypedVocabulary.PAD], 30 | len(self._vocabulary.type_to_id), 31 | self._vocabulary.type_to_id[TypedVocabulary.PAD], 32 | ) 33 | 34 | def forward( # type: ignore 35 | self, 36 | from_type: torch.Tensor, 37 | from_token: torch.Tensor, 38 | path_nodes: torch.Tensor, 39 | to_token: torch.Tensor, 40 | to_type: torch.Tensor, 41 | contexts_per_label: torch.Tensor, 42 | output_length: int, 43 | target_sequence: torch.Tensor = None, 44 | ) -> Tuple[torch.Tensor, torch.Tensor]: 45 | encoded_paths = self._encoder(from_type, from_token, path_nodes, to_token, to_type) 46 | output_logits, attention_weights = self._decoder( 47 | encoded_paths, contexts_per_label, output_length, target_sequence 48 | ) 49 | return output_logits, attention_weights 50 | 51 | def logits_from_batch( # type: ignore[override] 52 | self, batch: BatchedLabeledTypedPathContext, target_sequence: Optional[torch.Tensor] 53 | ) -> Tuple[torch.Tensor, torch.Tensor]: 54 | return self( 55 | batch.from_type, 56 | batch.from_token, 57 | batch.path_nodes, 58 | batch.to_token, 59 | batch.to_type, 60 | batch.contexts_per_label, 61 | batch.labels.shape[0], 62 | target_sequence, 63 | ) 64 | -------------------------------------------------------------------------------- /code2seq/typed_code2seq_wrapper.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import cast 3 | 4 | import torch 5 | from commode_utils.common import print_config 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from code2seq.data.typed_path_context_data_module import TypedPathContextDataModule 9 | from code2seq.model import TypedCode2Seq 10 | from code2seq.utils.common import filter_warnings 11 | from code2seq.utils.test import test 12 | from code2seq.utils.train import train 13 | 14 | 15 | def configure_arg_parser() -> ArgumentParser: 16 | arg_parser = ArgumentParser() 17 | arg_parser.add_argument("mode", help="Mode to run script", choices=["train", "test"]) 18 | arg_parser.add_argument("-c", "--config", help="Path to YAML configuration file", type=str) 19 | return arg_parser 20 | 21 | 22 | def train_typed_code2seq(config: DictConfig): 23 | filter_warnings() 24 | 25 | if config.print_config: 26 | print_config(config, fields=["model", "data", "train", "optimizer"]) 27 | 28 | # Load data module 29 | data_module = TypedPathContextDataModule(config.data_folder, config.data) 30 | 31 | # Load model 32 | typed_code2seq = TypedCode2Seq(config.model, config.optimizer, data_module.vocabulary, config.train.teacher_forcing) 33 | 34 | train(typed_code2seq, data_module, config) 35 | 36 | 37 | def test_typed_code2seq(config: DictConfig): 38 | filter_warnings() 39 | 40 | # Load data module 41 | data_module = TypedPathContextDataModule(config.data_folder, config.data) 42 | 43 | # Load model 44 | typed_code2seq = TypedCode2Seq.load_from_checkpoint(config.checkpoint, map_location=torch.device("cpu")) 45 | 46 | test(typed_code2seq, data_module, config.seed) 47 | 48 | 49 | if __name__ == "__main__": 50 | __arg_parser = configure_arg_parser() 51 | __args = __arg_parser.parse_args() 52 | 53 | __config = cast(DictConfig, OmegaConf.load(__args.config)) 54 | if __args.mode == "train": 55 | train_typed_code2seq(__config) 56 | else: 57 | test_typed_code2seq(__config) 58 | -------------------------------------------------------------------------------- /code2seq/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JetBrains-Research/code2seq/d48c3a28b2f855414c4107dc19692927acb47d9d/code2seq/utils/__init__.py -------------------------------------------------------------------------------- /code2seq/utils/common.py: -------------------------------------------------------------------------------- 1 | from warnings import filterwarnings 2 | 3 | 4 | def filter_warnings(): 5 | # "The dataloader does not have many workers which may be a bottleneck." 6 | filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.utilities.distributed", lineno=50) 7 | filterwarnings("ignore", category=UserWarning, module="pytorch_lightning.trainer.data_loading", lineno=110) 8 | # "Please also save or load the state of the optimizer when saving or loading the scheduler." 9 | filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=216) # save 10 | filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler", lineno=234) # load 11 | -------------------------------------------------------------------------------- /code2seq/utils/optimization.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Iterable 2 | 3 | import torch 4 | from omegaconf import DictConfig 5 | from torch.optim import Adam, Optimizer, SGD 6 | from torch.optim.lr_scheduler import _LRScheduler, LambdaLR 7 | 8 | 9 | def configure_optimizers_alon( 10 | optim_config: DictConfig, parameters: Iterable[torch.Tensor] 11 | ) -> Tuple[List[Optimizer], List[_LRScheduler]]: 12 | """Create optimizers like in original Alon work 13 | https://github.com/tech-srl/code2seq/blob/a01076ef649d298e5f90ac2ce1f6a42f4ff49cc2/model.py#L386-L397 14 | :param optim_config: hyper parameters 15 | :param parameters: model parameters for optimization 16 | :return: list of optimizers and schedulers 17 | """ 18 | optimizer: Optimizer 19 | if optim_config.optimizer == "Momentum": 20 | # using the same momentum value as in original realization by Alon 21 | optimizer = SGD( 22 | parameters, 23 | optim_config.lr, 24 | momentum=0.95, 25 | nesterov=optim_config.nesterov, 26 | weight_decay=optim_config.weight_decay, 27 | ) 28 | elif optim_config.optimizer == "Adam": 29 | optimizer = Adam(parameters, optim_config.lr, weight_decay=optim_config.weight_decay) 30 | else: 31 | raise ValueError(f"Unknown optimizer name: {optim_config.optimizer}, try one of: Adam, Momentum") 32 | scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: optim_config.decay_gamma ** epoch) 33 | return [optimizer], [scheduler] 34 | -------------------------------------------------------------------------------- /code2seq/utils/test.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from pytorch_lightning import Trainer, seed_everything, LightningModule, LightningDataModule 5 | 6 | 7 | def test(model: LightningModule, data_module: LightningDataModule, seed: Optional[int] = None): 8 | seed_everything(seed) 9 | gpu = 1 if torch.cuda.is_available() else None 10 | trainer = Trainer(gpus=gpu) 11 | trainer.test(model, datamodule=data_module) 12 | -------------------------------------------------------------------------------- /code2seq/utils/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from commode_utils.callbacks import ModelCheckpointWithUploadCallback, PrintEpochResultCallback 3 | from omegaconf import DictConfig, OmegaConf 4 | from pytorch_lightning import seed_everything, Trainer, LightningModule, LightningDataModule 5 | from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, TQDMProgressBar 6 | from pytorch_lightning.loggers import WandbLogger 7 | 8 | 9 | def train(model: LightningModule, data_module: LightningDataModule, config: DictConfig): 10 | seed_everything(config.seed) 11 | params = config.train 12 | 13 | # define logger 14 | wandb_logger = WandbLogger( 15 | project=config.wandb.project, 16 | group=config.wandb.group, 17 | log_model=False, 18 | offline=config.wandb.offline, 19 | config=OmegaConf.to_container(config), 20 | ) 21 | 22 | # define model checkpoint callback 23 | checkpoint_callback = ModelCheckpointWithUploadCallback( 24 | dirpath=wandb_logger.experiment.dir, 25 | filename="{epoch:02d}-val_loss={val/loss:.4f}", 26 | monitor="val/loss", 27 | every_n_epochs=params.save_every_epoch, 28 | save_top_k=-1, 29 | auto_insert_metric_name=False, 30 | ) 31 | # define early stopping callback 32 | early_stopping_callback = EarlyStopping(patience=params.patience, monitor="val/loss", verbose=True, mode="min") 33 | # define callback for printing intermediate result 34 | print_epoch_result_callback = PrintEpochResultCallback(after_test=False) 35 | # define learning rate logger 36 | lr_logger = LearningRateMonitor("step") 37 | # define progress bar callback 38 | progress_bar = TQDMProgressBar(refresh_rate=config.progress_bar_refresh_rate) 39 | trainer = Trainer( 40 | max_epochs=params.n_epochs, 41 | gradient_clip_val=params.clip_norm, 42 | deterministic=True, 43 | check_val_every_n_epoch=params.val_every_epoch, 44 | log_every_n_steps=params.log_every_n_steps, 45 | logger=wandb_logger, 46 | gpus=params.gpu, 47 | callbacks=[lr_logger, early_stopping_callback, checkpoint_callback, print_epoch_result_callback, progress_bar], 48 | resume_from_checkpoint=config.get("checkpoint", None), 49 | ) 50 | 51 | trainer.fit(model=model, datamodule=data_module) 52 | trainer.test(model=model, datamodule=data_module) 53 | -------------------------------------------------------------------------------- /config/code2class-poj104.yaml: -------------------------------------------------------------------------------- 1 | data_folder: ../data/poj-104/poj-104-code2seq 2 | 3 | checkpoint: null 4 | 5 | seed: 7 6 | # Training in notebooks (e.g. Google Colab) may crash with too small value 7 | progress_bar_refresh_rate: 1 8 | print_config: true 9 | 10 | wandb: 11 | project: Code2Class -- poj-104 12 | group: null 13 | offline: full 14 | 15 | data: 16 | url: https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/poj-104/poj-104-code2seq.tar.gz 17 | num_workers: 0 18 | 19 | max_labels: null 20 | max_label_parts: 1 21 | max_tokens: 190000 22 | max_token_parts: 5 23 | path_length: 9 24 | 25 | max_context: 200 26 | random_context: true 27 | 28 | batch_size: 512 29 | test_batch_size: 768 30 | 31 | model: 32 | # Encoder 33 | embedding_size: 128 34 | encoder_dropout: 0.25 35 | encoder_rnn_size: 128 36 | use_bi_rnn: true 37 | rnn_num_layers: 1 38 | 39 | # Classifier 40 | classifier_layers: 2 41 | classifier_size: 128 42 | activation: relu 43 | 44 | optimizer: 45 | optimizer: "Momentum" 46 | nesterov: true 47 | lr: 0.01 48 | weight_decay: 0 49 | decay_gamma: 0.95 50 | 51 | train: 52 | n_epochs: 10 53 | patience: 10 54 | clip_norm: 5 55 | teacher_forcing: 1.0 56 | val_every_epoch: 1 57 | save_every_epoch: 1 58 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /config/code2seq-java-med.yaml: -------------------------------------------------------------------------------- 1 | data_folder: ../data/code2seq/java-med 2 | 3 | checkpoint: null 4 | 5 | seed: 7 6 | # Training in notebooks (e.g. Google Colab) may crash with too small value 7 | progress_bar_refresh_rate: 1 8 | print_config: true 9 | 10 | wandb: 11 | project: Code2Seq -- java-med 12 | group: null 13 | offline: false 14 | 15 | data: 16 | url: https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/java-paths-methods/java-med.tar.gz 17 | num_workers: 4 18 | 19 | # Each token appears at least 10 times (99.2% coverage) 20 | labels_count: 10 21 | max_label_parts: 7 22 | # Each token appears at least 1000 times (99.5% coverage) 23 | tokens_count: 1000 24 | max_token_parts: 5 25 | path_length: 9 26 | 27 | max_context: 200 28 | random_context: true 29 | 30 | batch_size: 512 31 | test_batch_size: 512 32 | 33 | model: 34 | # Encoder 35 | embedding_size: 128 36 | encoder_dropout: 0.25 37 | encoder_rnn_size: 128 38 | use_bi_rnn: true 39 | rnn_num_layers: 1 40 | 41 | # Decoder 42 | decoder_size: 320 43 | decoder_num_layers: 1 44 | rnn_dropout: 0.5 45 | 46 | optimizer: 47 | optimizer: "Momentum" 48 | nesterov: true 49 | lr: 0.01 50 | weight_decay: 0 51 | decay_gamma: 0.95 52 | 53 | train: 54 | n_epochs: 10 55 | patience: 10 56 | clip_norm: 5 57 | teacher_forcing: 1.0 58 | val_every_epoch: 1 59 | save_every_epoch: 1 60 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /config/code2seq-java-small.yaml: -------------------------------------------------------------------------------- 1 | data_folder: ../data/code2seq/java-small 2 | 3 | checkpoint: null 4 | 5 | seed: 7 6 | # Training in notebooks (e.g. Google Colab) may crash with too small value 7 | progress_bar_refresh_rate: 1 8 | print_config: true 9 | 10 | wandb: 11 | project: Code2Seq -- java-small 12 | group: null 13 | offline: false 14 | 15 | data: 16 | url: https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/java-paths-methods/java-small.tar.gz 17 | num_workers: 4 18 | 19 | # Each token appears at least 10 times (98.5% coverage) 20 | labels_count: 10 21 | max_label_parts: 7 22 | # Each token appears at least 1000 times (99.2% coverage) 23 | tokens_count: 1000 24 | max_token_parts: 5 25 | path_length: 9 26 | 27 | max_context: 200 28 | random_context: true 29 | 30 | batch_size: 512 31 | test_batch_size: 512 32 | 33 | model: 34 | # Encoder 35 | embedding_size: 128 36 | encoder_dropout: 0.25 37 | encoder_rnn_size: 128 38 | use_bi_rnn: true 39 | rnn_num_layers: 1 40 | 41 | # Decoder 42 | decoder_size: 320 43 | decoder_num_layers: 1 44 | rnn_dropout: 0.5 45 | 46 | optimizer: 47 | optimizer: "Momentum" 48 | nesterov: true 49 | lr: 0.01 50 | weight_decay: 0 51 | decay_gamma: 0.95 52 | 53 | train: 54 | n_epochs: 15 55 | patience: 15 56 | clip_norm: 5 57 | teacher_forcing: 1.0 58 | val_every_epoch: 1 59 | save_every_epoch: 1 60 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /config/code2seq-java-test.yaml: -------------------------------------------------------------------------------- 1 | data_folder: ../data/code2seq/java-test 2 | 3 | checkpoint: null 4 | 5 | seed: 7 6 | progress_bar_refresh_rate: 1 7 | print_config: true 8 | 9 | wandb: 10 | project: Test project 11 | group: Test group 12 | offline: true 13 | 14 | data: 15 | num_workers: 0 16 | 17 | labels_count: 10 18 | max_label_parts: 7 19 | tokens_count: 1000 20 | max_token_parts: 5 21 | path_length: 9 22 | 23 | max_context: 200 24 | random_context: true 25 | 26 | batch_size: 5 27 | test_batch_size: 10 28 | 29 | model: 30 | # Encoder 31 | embedding_size: 10 32 | encoder_dropout: 0.25 33 | encoder_rnn_size: 10 34 | use_bi_rnn: true 35 | rnn_num_layers: 1 36 | 37 | # Decoder 38 | decoder_size: 20 39 | decoder_num_layers: 1 40 | rnn_dropout: 0.5 41 | 42 | optimizer: 43 | optimizer: "Momentum" 44 | nesterov: true 45 | lr: 0.01 46 | weight_decay: 0 47 | decay_gamma: 0.95 48 | 49 | train: 50 | n_epochs: 5 51 | patience: 10 52 | clip_norm: 10 53 | teacher_forcing: 1.0 54 | val_every_epoch: 1 55 | save_every_epoch: 1 56 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /config/typed-code2seq-java-small.yaml: -------------------------------------------------------------------------------- 1 | data_folder: ../data/code2seq/java-small 2 | 3 | checkpoint: null 4 | 5 | seed: 7 6 | # Training in notebooks (e.g. Google Colab) may crash with too small value 7 | progress_bar_refresh_rate: 1 8 | print_config: true 9 | 10 | wandb: 11 | project: Code2Seq -- java-small 12 | group: typed 13 | offline: false 14 | 15 | data: 16 | url: https://s3.eu-west-1.amazonaws.com/datasets.ml.labs.aws.intellij.net/java-paths-methods/java-small.tar.gz 17 | num_workers: 4 18 | 19 | # Each token appears at least 10 times (98.5% coverage) 20 | labels_count: 10 21 | max_label_parts: 7 22 | # Each token appears at least 1000 times (99.2% coverage) 23 | tokens_count: 1000 24 | max_token_parts: 5 25 | 26 | types_count: null 27 | max_type_parts: 5 28 | 29 | path_length: 9 30 | 31 | max_context: 200 32 | random_context: true 33 | 34 | batch_size: 512 35 | test_batch_size: 768 36 | 37 | model: 38 | # Encoder 39 | embedding_size: 128 40 | encoder_dropout: 0.25 41 | encoder_rnn_size: 128 42 | use_bi_rnn: true 43 | rnn_num_layers: 1 44 | 45 | # Decoder 46 | decoder_size: 320 47 | decoder_num_layers: 1 48 | rnn_dropout: 0.5 49 | 50 | optimizer: 51 | optimizer: "Momentum" 52 | nesterov: true 53 | lr: 0.01 54 | weight_decay: 0 55 | decay_gamma: 0.95 56 | 57 | train: 58 | n_epochs: 10 59 | patience: 10 60 | clip_norm: 5 61 | teacher_forcing: 1.0 62 | val_every_epoch: 1 63 | save_every_epoch: 1 64 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /config/typed-code2seq-java-test.yaml: -------------------------------------------------------------------------------- 1 | data_folder: ../data/code2seq/java-test-typed 2 | 3 | checkpoint: null 4 | 5 | seed: 7 6 | # Training in notebooks (e.g. Google Colab) may crash with too small value 7 | progress_bar_refresh_rate: 1 8 | print_config: true 9 | 10 | wandb: 11 | project: Test project 12 | group: Test group 13 | offline: true 14 | 15 | data: 16 | num_workers: 0 17 | 18 | # Each token appears at least 10 times (99.2% coverage) 19 | labels_count: 10 20 | max_label_parts: 7 21 | # Each token appears at least 1000 times (99.5% coverage) 22 | tokens_count: 1000 23 | max_token_parts: 5 24 | 25 | types_count: null 26 | max_type_parts: 5 27 | path_length: 9 28 | 29 | max_context: 200 30 | random_context: true 31 | 32 | batch_size: 5 33 | test_batch_size: 10 34 | 35 | model: 36 | # Encoder 37 | embedding_size: 10 38 | encoder_dropout: 0.25 39 | encoder_rnn_size: 10 40 | use_bi_rnn: true 41 | rnn_num_layers: 1 42 | 43 | # Decoder 44 | decoder_size: 20 45 | decoder_num_layers: 1 46 | rnn_dropout: 0.5 47 | 48 | optimizer: 49 | optimizer: "Momentum" 50 | nesterov: true 51 | lr: 0.01 52 | weight_decay: 0 53 | decay_gamma: 0.95 54 | 55 | train: 56 | n_epochs: 5 57 | patience: 10 58 | clip_norm: 10 59 | teacher_forcing: 1.0 60 | val_every_epoch: 1 61 | save_every_epoch: 1 62 | log_every_n_steps: 10 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0 2 | pytorch-lightning==1.5.1 3 | torchmetrics==0.6.0 4 | tqdm==4.62.3 5 | wandb==0.12.6 6 | omegaconf==2.1.1 7 | commode-utils==0.4.1 8 | -------------------------------------------------------------------------------- /scripts/split_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Script - split data between train and test 3 | # Default values 4 | # options: 5 | # -h, --help show brief help 6 | # $1 specify a directory where dataset is located 7 | # $2 specify a directory to store output in 8 | # $3 specify a percentage of dataset used as train set 9 | # $4 specify a percentage of dataset used as test set 10 | # $5 specify a percentage of dataset used as validation set 11 | 12 | SHUFFLE=false 13 | 14 | ORIGINAL_DATASET_PATH=$1 15 | SPLIT_DATASET_PATH=$2 16 | TRAIN_SPLIT_PART=$3 17 | TEST_SPLIT_PART=$4 18 | VAL_SPLIT_PART=$5 19 | 20 | DIR_TRAIN="${SPLIT_DATASET_PATH}/train" 21 | DIR_VAL="${SPLIT_DATASET_PATH}/val" 22 | DIR_TEST="${SPLIT_DATASET_PATH}/test" 23 | 24 | echo "Train $TRAIN_SPLIT_PART % " 25 | echo "Val $VAL_SPLIT_PART %" 26 | echo "Test $TEST_SPLIT_PART %" 27 | echo "Shuffle $SHUFFLE" 28 | echo "Original dataset path: ${ORIGINAL_DATASET_PATH}" 29 | echo "Train dataset path: ${DIR_TRAIN}" 30 | echo "Val dataset path = ${DIR_VAL}" 31 | echo "Test dataset path = ${DIR_TEST}" 32 | 33 | echo "" 34 | echo "Removing all data inside ${SPLIT_DATASET_PATH}" 35 | rm -rf "$SPLIT_DATASET_PATH" 36 | mkdir "$SPLIT_DATASET_PATH" 37 | 38 | mkdir "$DIR_TRAIN" 39 | mkdir "$DIR_VAL" 40 | mkdir "$DIR_TEST" 41 | 42 | cp -r "$ORIGINAL_DATASET_PATH"/* "$DIR_TRAIN"/ 43 | 44 | find "$DIR_TRAIN"/* -type d -exec basename {} \; | while read DIR_CLASS 45 | do 46 | echo "Splitting class - $DIR_CLASS"; 47 | mkdir "$DIR_VAL/$DIR_CLASS" 48 | mkdir "$DIR_TEST/$DIR_CLASS" 49 | num_files=$(find "$DIR_TRAIN/$DIR_CLASS" -type f | wc -l) 50 | train_bound=$(expr $num_files \* $TRAIN_SPLIT_PART / 100) 51 | test_bound=$(expr $train_bound + $num_files \* $TEST_SPLIT_PART / 100) 52 | 53 | counter=$(expr 0) 54 | 55 | files=$(find "$DIR_TRAIN/$DIR_CLASS" -type f -exec basename {} \;) 56 | 57 | for file in $files; 58 | do 59 | counter=$(expr $counter + 1) 60 | if [ $counter -gt $train_bound ] && [ $counter -le $test_bound ]; then 61 | mv "$DIR_TRAIN/$DIR_CLASS/$file" "$DIR_TEST/$DIR_CLASS/$file" 62 | fi 63 | if [ $counter -gt $test_bound ]; then 64 | mv "$DIR_TRAIN/$DIR_CLASS/$file" "$DIR_VAL/$DIR_CLASS/$file" 65 | fi 66 | done 67 | done 68 | 69 | echo "Done" 70 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | VERSION = "1.2.1" 4 | 5 | with open("README.md") as readme_file: 6 | readme = readme_file.read() 7 | 8 | install_requires = [ 9 | "torch>=1.10.0", 10 | "pytorch-lightning~=1.5.0", 11 | "wandb~=0.12.0", 12 | "omegaconf~=2.1.1", 13 | "commode-utils>=0.4.1", 14 | ] 15 | 16 | setup_args = dict( 17 | name="code2seq", 18 | version=VERSION, 19 | description="Set of pytorch modules and utils to train code2seq model", 20 | long_description_content_type="text/markdown", 21 | long_description=readme, 22 | install_requires=install_requires, 23 | license="MIT", 24 | packages=find_packages(), 25 | author="Egor Spirin", 26 | author_email="spirin.egor@gmail.com", 27 | keywords=["code2seq", "pytorch", "pytorch-lightning", "ml4code", "ml4se"], 28 | url="https://github.com/JetBrains-Research/code2seq", 29 | download_url="https://pypi.org/project/code2seq/", 30 | ) 31 | 32 | if __name__ == "__main__": 33 | setup(**setup_args) 34 | -------------------------------------------------------------------------------- /tests/test_tokenization.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from code2seq.data.path_context_dataset import PathContextDataset 6 | from code2seq.data.vocabulary import Vocabulary 7 | 8 | 9 | class TestDatasetTokenization(unittest.TestCase): 10 | vocab = {Vocabulary.PAD: 0, Vocabulary.UNK: 1, Vocabulary.SOS: 2, Vocabulary.EOS: 3, "my": 4, "super": 5} 11 | 12 | def test_tokenize_label(self): 13 | raw_label = "my|super|label" 14 | tokenized = PathContextDataset.tokenize_label(raw_label, self.vocab, 5) 15 | # my super 16 | correct = [2, 4, 5, 1, 3, 0] 17 | 18 | self.assertListEqual(tokenized, correct) 19 | 20 | def test_tokenize_class(self): 21 | raw_class = "super" 22 | tokenized = PathContextDataset.tokenize_class(raw_class, self.vocab) 23 | correct = [5] 24 | 25 | self.assertListEqual(tokenized, correct) 26 | 27 | def test_tokenize_token(self): 28 | raw_token = "my|super|token" 29 | tokenized = PathContextDataset.tokenize_token(raw_token, self.vocab, 5) 30 | correct = [4, 5, 1, 0, 0] 31 | 32 | self.assertListEqual(tokenized, correct) 33 | 34 | 35 | if __name__ == "__main__": 36 | unittest.main() 37 | --------------------------------------------------------------------------------