├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── asdl ├── README.md ├── __init__.py ├── asdl.py ├── asdl_ast.py ├── hypothesis.py ├── lang │ ├── __init__.py │ └── csharp │ │ ├── Syntax.xml │ │ ├── __init__.py │ │ ├── csharp_grammar.py │ │ ├── csharp_hypothesis.py │ │ ├── csharp_transition.py │ │ ├── demo.py │ │ ├── demo_edits.py │ │ └── grammar.json ├── transition_system.py └── utils.py ├── common ├── __init__.py ├── registerable.py ├── savable.py └── utils.py ├── datasets ├── __init__.py ├── githubedits │ └── common │ │ ├── __init__.py │ │ └── config.py └── utils.py ├── edit_components ├── __init__.py ├── change_entry.py ├── change_graph.py ├── dataset.py ├── diff_utils.py ├── evaluate.py ├── utils │ ├── __init__.py │ ├── decode.py │ ├── relevance.py │ ├── sub_token.py │ ├── unary_closure.py │ ├── utils.py │ └── wikidata.py └── vocab.py ├── edit_model ├── __init__.py ├── data_model.py ├── edit_encoder │ ├── __init__.py │ ├── bag_of_edits_change_encoder.py │ ├── edit_encoder.py │ ├── graph_change_encoder.py │ ├── hybrid_change_encoder.py │ ├── sequential_change_encoder.py │ └── tree_diff_encoder.py ├── editor.py ├── embedder.py ├── encdec │ ├── __init__.py │ ├── decoder.py │ ├── edit_decoder.py │ ├── encoder.py │ ├── graph_encoder.py │ ├── sequential_decoder.py │ ├── sequential_encoder.py │ └── transition_decoder.py ├── gnn.py ├── nn_utils.py ├── pointer_net.py └── utils.py ├── exp_githubedits.py ├── scripts └── githubedits │ ├── test.sh │ └── train.sh ├── source_data └── githubedits.tar.gz ├── structural_edits.yml └── trees ├── __init__.py ├── edits.py ├── hypothesis.py ├── substitution_system.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # PyCharm files 104 | .idea/ 105 | 106 | # VIM stuff 107 | *.swp 108 | 109 | # data downloaded automatically using pull_data.sh 110 | data/ 111 | 112 | # saved model files 113 | saved_models/ 114 | 115 | # decodes output files 116 | decodes/ 117 | 118 | # private files 119 | *private/ 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 NeuLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Structural Edits via Incremental Tree Transformations 2 | 3 | Code for ["Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)](https://openreview.net/pdf?id=v9hAX77--cZ) 4 | 5 | If you use our code and data, please cite our paper: 6 | ``` 7 | @inproceedings{yao2021learning, 8 | title={Learning Structural Edits via Incremental Tree Transformations}, 9 | author={Ziyu Yao and Frank F. Xu and Pengcheng Yin and Huan Sun and Graham Neubig}, 10 | booktitle={International Conference on Learning Representations}, 11 | year={2021}, 12 | url={https://openreview.net/forum?id=v9hAX77--cZ} 13 | } 14 | ``` 15 | 16 | Our implementation is adapted from [TranX](https://github.com/pcyin/tranx) and [Graph2Tree](https://github.com/microsoft/iclr2019-learning-to-represent-edits). 17 | We are grateful to the two work! 18 | ``` 19 | @inproceedings{yin18emnlpdemo, 20 | title = {{TRANX}: A Transition-based Neural Abstract Syntax Parser for Semantic Parsing and Code Generation}, 21 | author = {Pengcheng Yin and Graham Neubig}, 22 | booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP) Demo Track}, 23 | year = {2018} 24 | } 25 | @inproceedings{yin2018learning, 26 | title={Learning to Represent Edits}, 27 | author={Pengcheng Yin and Graham Neubig and Miltiadis Allamanis and Marc Brockschmidt and Alexander L. Gaunt}, 28 | booktitle={International Conference on Learning Representations}, 29 | year={2019}, 30 | url={https://openreview.net/forum?id=BJl6AjC5F7}, 31 | } 32 | ``` 33 | 34 | 35 | ## 1. Prepare Environment 36 | We recommend using `conda` to manage the environment: 37 | ``` 38 | conda env create -n "structural_edits" -f structural_edits.yml 39 | conda activate structural_edits 40 | ``` 41 | 42 | Install the punkt tokenizer: 43 | ``` 44 | python 45 | >>> import nltk 46 | >>> nltk.download('punkt') 47 | >>> 48 | ``` 49 | 50 | ## 2. Data 51 | Please extract the datasets and vocabulary files by: 52 | ``` 53 | cd source_data 54 | tar -xzvf githubedits.tar.gz 55 | ``` 56 | 57 | All necessary source data has been included as the following: 58 | ``` 59 | | --source_data 60 | | |-- githubedits 61 | | |-- githubedits.{train|train_20p|dev|test}.jsonl 62 | | |-- csharp_fixers.jsonl 63 | | |-- vocab.from_repo.{080910.freq10|edit}.json 64 | | |-- Syntax.xml 65 | | |-- configs 66 | | |-- ...(model config json files) 67 | ``` 68 | A sample file containing 20% of the GitHubEdits training data is included as `source_data/githubedits/githubedits.train_20p.jsonl` for running small experiments. 69 | 70 | We have generated and included the vocabulary files as well. To create your own vocabulary, see `edit_components/vocab.py`. 71 | 72 | Copyright: The original data were downloaded from [Yin et al., (2019)](http://www.cs.cmu.edu/~pengchey/githubedits.zip). 73 | 74 | 75 | ## 3. Experiments 76 | See training and test scripts in `scripts/githubedits/`. Please configure the `PYTHONPATH` environment variable in line 6. 77 | 78 | ### 3.1 Training 79 | For training, uncomment the desired setting in `scripts/githubedits/train.sh` and run: 80 | ``` 81 | bash scripts/githubedits/train.sh source_data/githubedits/configs/CONFIGURATION_FILE 82 | ``` 83 | where `CONFIGURATION_FILE` is the json file of your setting. 84 | Please check out the `TODO`'s in [`scripts/githubedits/train.sh`](scripts/githubedits/train.sh). 85 | 86 | 87 | #### 3.1.1 Supervised Learning 88 | For example, if you want to train Graph2Edit + Sequence Edit Encoder on GitHubEdits's 20\% sample data, 89 | please uncomment only line 22-26 in `scripts/githubedits/train.sh` and run: 90 | ``` 91 | bash scripts/githubedits/train.sh source_data/githubedits/configs/graph2iteredit.seq_edit_encoder.20p.json 92 | ``` 93 | **Note**: 94 | - When you run the experiment for the first time, you might need to wait for ~15 minutes for data preprocessing. 95 | - By default, the data preprocessing includes generating and saving the target edit sequences for instances in the training data. 96 | However, this may cause a `out of (cpu) memory` issue. **A simple way to solve this problem is to set `--small_memory` in the `train.sh` script.** 97 | We explained the details in [Section 4.2 Out of Memory Issue](#42-out-of-memory-issue). 98 | 99 | 100 | 101 | #### 3.1.2 Imitation Learning 102 | To further train the model with PostRefine imitation learning, 103 | please replace `FOLDER_OF_SUPERVISED_PRETRAINED_MODEL` with your model dir in `source_data/githubedits/configs/graph2iteredit.seq_edit_encoder.20p.postrefine.imitation.json`. 104 | Uncomment only line 27-31 in `scripts/githubedits/train.sh` and run: 105 | ``` 106 | bash scripts/githubedits/train.sh source_data/githubedits/configs/graph2iteredit.seq_edit_encoder.20p.postrefine.imitation.json 107 | ``` 108 | Note that `--small_memory` cannot be used in this setting. 109 | 110 | ### 3.2 Test 111 | To test a trained model, first uncomment only the desired setting in `scripts/githubedits/test.sh` and replace `work_dir` with your model directory, 112 | and then run: 113 | ``` 114 | bash scripts/githubedits/test.sh 115 | ``` 116 | Please check out the `TODO`'s in [`scripts/githubedits/test.sh`](scripts/githubedits/test.sh). 117 | 118 | ## 4. FAQ 119 | 120 | ### 4.1 Applications to Other Languages 121 | 122 | 123 | In principle, our framework can work with various programming languages. 124 | To this end, several changes are needed: 125 | 1. Implementing a language-specific `ASDLGrammar` class for the new language. 126 | - This class could inherit the [`asdl.asdl.ASDLGrammar`](asdl/asdl.py) class. 127 | - Basic functions should include 128 | - Defining the `primitive` and `composite` types, 129 | - Implementing the class constructor (e.g., converting from the `.xml` or `.txt` syntax descriptions), 130 | - Converting the source AST data into an `asdl.asdl_ast.AbstractSyntaxTree` object. 131 | - Example: see the [`asdl.lang.csharp.CSharpASDLGrammar`](asdl/lang/csharp/csharp_grammar.py) class. 132 | - **Sanity check**: It is very helpful to implement a `demo_edits.py` file like [this one for csharp](asdl/lang/csharp/demo_edits.py) and 133 | make sure you have checked out the generated ASTs and target edit sequences. 134 | - **Useful resource**: The [TranX](https://github.com/pcyin/tranx) library contains ASDLGrammar classes for some other languages. 135 | _Note that we have revised the `asdl.asdl.ASDLGrammar` class so directly using the TranX implementation may not work._ 136 | However, this resource is still a good starting point; you may consider modify it based on the sanity check outputs. 137 | 138 | 2. Implementing a language-specific `TransitionSystem` class. 139 | - The target edit sequences (of the training data) are calculated by `trees.substitution_system.SubstitutionSystem`, 140 | which depends on a `asdl.transition_system.TransitionSystem` object (or its inheritor) (see [reference](trees/substitution_system.py#L16)). 141 | - In our current implementation of CSharp, we have reused the `CSharpTransitionSystem` class implemented in the [Graph2Tree library](https://github.com/microsoft/iclr2019-learning-to-represent-edits). 142 | However, only the `get_primitive_field_actions` function of the `TransitionSystem` class is actually used by the `SubstitutionSystem` ([example](trees/substitution_system.py#L131)). 143 | Therefore, for simplicity, one can only implement only this function. 144 | Basically, this `get_primitive_field_actions` function defines how the leaf string should be generated 145 | (e.g., multiple `GenTokenAction` actions should be taken for generating a multi-word leaf string), which we will discuss next. 146 | 147 | 3. Customizing the leaf string generation. 148 | - Following the last item, one may also need to customize the `GenTokenAction` action especially about whether and how the [stop signal](asdl/transition_system.py#L29) will be used. 149 | For CSharp, we do not use detect any stop signal as in our datasets the leaf string is typically one single-word token. 150 | However, it will be needed when the leaf string contains multiple words. 151 | - Accordingly, one may customize the [`Add`](/trees/edits.py#L61) edit action and the [`SubstitutionSystem`](trees/substitution_system.py#L170) 152 | regarding how the leaf string should be added to the current tree. 153 | 154 | 155 | ### 4.2 Out of Memory Issue 156 | **The issue:** 157 | By default, the data preprocessing step will 158 | (1) run a dynamic programming algorithm to calculate the shortest edit sequence `(a_1, a_2, ..., a_T)` 159 | as the target edit sequence for each code pair `(C-, C+)`, and 160 | (2) save every intermediate tree graph `(g_1, g_2, ..., g_T)`, where `g_{t+1}` is the transformation result of 161 | applying edit action `a_t` to tree `g_t` at time step `t`, as the input to the tree encoder (see [3.1.2 in our paper](https://openreview.net/pdf?id=v9hAX77--cZ)). 162 | Therefore, a completely preprocessed training set has a very large size and will take up a lot of CPU memory 163 | every time you load the data for model training. 164 | 165 | **The solution:** 166 | A simple solution is to avoid saving any intermediate tree graph, 167 | i.e., we will only save the shortest edit sequence results from (1) 168 | while leaving the generation of intermediate tree graphs in (2) to during the model training. 169 | This can be done by set `--small_memory` in the [train.sh](scripts/githubedits/train.sh) script. 170 | _Currently this option can only be used for regular supervised learning; for imitation learning, this has to be off._ 171 | 172 | Note that there will be a trade-off between the CPU memory and the GPU utility/training speed, 173 | since the generation of the intermediate tree graphs is done at the CPU level. 174 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | -------------------------------------------------------------------------------- /asdl/README.md: -------------------------------------------------------------------------------- 1 | ## ASDL Transition System 2 | 3 | This package contains a standalone transition system based on the ASDL formalism, 4 | and its instantiations in different languages (lambda calculus, prolog, Python, SQL). 5 | A transition system defines the set of tree-constructing actions to generate an ASDL AST. 6 | This package can be used as a standalone component independent of `tranX`. 7 | 8 | ### File Structure 9 | 10 | * `asdl.py` contains classes that implement basic concepts in ASDL (grammar, constructor, production, field, type, etc.) 11 | * `asdl_ast.py` contains the `AbstractSyntaxNode` class that define an abstract syntax tree 12 | * `transition_system.py` contains the abstract class of a transition system, 13 | instantiated by the transition systems in each language folder. 14 | A transition system defines the set of tree-constructing actions used to generate an AST. 15 | * `hypothesis.py` contains the `Hypothesis` class, which records the state of a partially generated AST constructed 16 | by a series of actions. 17 | 18 | ### Example 19 | 20 | Below is an example usage of the `PythonTransitionSystem`, which defines the actions 21 | to generate Python code snippets. 22 | 23 | ```python 24 | # coding=utf-8 25 | 26 | import ast 27 | from asdl.asdl import ASDLGrammar 28 | from asdl.lang.py.py_asdl_helper import * 29 | from asdl.lang.py.py_transition_system import * 30 | from asdl.hypothesis import * 31 | import astor 32 | 33 | # read in the grammar specification of Python 2.7, defined in ASDL 34 | asdl_text = open('py_asdl.txt').read() 35 | grammar = ASDLGrammar.from_text(asdl_text) 36 | 37 | py_code = """pandas.read('file.csv', nrows=100)""" 38 | 39 | # get the (domain-specific) python AST of the example Python code snippet 40 | py_ast = ast.parse(py_code) 41 | 42 | # convert the python AST into general-purpose ASDL AST used by tranX 43 | asdl_ast = python_ast_to_asdl_ast(py_ast.body[0], grammar) 44 | print('String representation of the ASDL AST: \n%s' % asdl_ast.to_string()) 45 | print('Size of the AST: %d' % asdl_ast.size) 46 | 47 | # we can also convert the ASDL AST back into Python AST 48 | py_ast_reconstructed = asdl_ast_to_python_ast(asdl_ast, grammar) 49 | 50 | # initialize the Python transition parser 51 | parser = PythonTransitionSystem(grammar) 52 | 53 | # get the sequence of gold-standard actions to construct the ASDL AST 54 | actions = parser.get_actions(asdl_ast) 55 | 56 | # a hypothesis is an (partial) ASDL AST generated using a sequence of tree-construction actions 57 | hypothesis = Hypothesis() 58 | for t, action in enumerate(actions, 1): 59 | # the type of the action should belong to one of the valid continuing types 60 | # of the transition system 61 | assert action.__class__ in parser.get_valid_continuation_types(hypothesis) 62 | 63 | # if it's an ApplyRule action, the production rule should belong to the 64 | # set of rules with the same LHS type as the current rule 65 | if isinstance(action, ApplyRuleAction) and hypothesis.frontier_node: 66 | assert action.production in grammar[hypothesis.frontier_field.type] 67 | 68 | print('t=%d, Action=%s' % (t, action)) 69 | hypothesis.apply_action(action) 70 | 71 | # get the surface code snippets from the original Python AST, 72 | # the reconstructed AST and the AST generated using actions 73 | # they should be the same 74 | src1 = astor.to_source(py_ast).strip() 75 | src2 = astor.to_source(py_ast_reconstructed).strip() 76 | src3 = astor.to_source(asdl_ast_to_python_ast(hypothesis.tree, grammar)).strip() 77 | 78 | assert src1 == src2 == src3 == "pandas.read('file.csv', nrows=100)" 79 | 80 | ``` -------------------------------------------------------------------------------- /asdl/__init__.py: -------------------------------------------------------------------------------- 1 | from .lang.csharp.csharp_transition import CSharpTransitionSystem 2 | 3 | -------------------------------------------------------------------------------- /asdl/hypothesis.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from .asdl import * 4 | from .asdl_ast import AbstractSyntaxNode, SyntaxToken 5 | from .transition_system import * 6 | 7 | 8 | class Hypothesis(object): 9 | def __init__(self): 10 | self.tree = None 11 | self.actions = [] 12 | self.score = 0. 13 | self.frontier_node = None 14 | self.frontier_field = None 15 | self._value_buffer = [] 16 | 17 | # record the current time step 18 | self.t = 0 19 | 20 | def apply_action(self, action): 21 | if self.tree is None: 22 | assert isinstance(action, ApplyRuleAction), 'Invalid action [%s], only ApplyRule action is valid ' \ 23 | 'at the beginning of decoding' 24 | 25 | self.tree = AbstractSyntaxNode(action.production) 26 | self.update_frontier_info() 27 | elif self.frontier_node: 28 | if isinstance(self.frontier_field.type, ASDLCompositeType): 29 | if isinstance(action, ApplyRuleAction): 30 | field_value = AbstractSyntaxNode(action.production) 31 | field_value.created_time = self.t 32 | self.frontier_field.add_value(field_value) 33 | self.update_frontier_info() 34 | elif isinstance(action, ReduceAction): 35 | assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 36 | 'applied on field with multiple ' \ 37 | 'cardinality' 38 | self.frontier_field.set_finish() 39 | self.update_frontier_info() 40 | else: 41 | raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field)) 42 | else: # fill in a primitive field 43 | if isinstance(action, GenTokenAction): 44 | # only field of type string requires termination signal 45 | end_primitive = False 46 | if self.frontier_field.type.name == 'string': 47 | if action.is_stop_signal(): 48 | self.frontier_field.add_value( 49 | SyntaxToken(self.frontier_field.type, 50 | ' '.join(self._value_buffer))) 51 | self._value_buffer = [] 52 | 53 | end_primitive = True 54 | else: 55 | self._value_buffer.append(action.token) 56 | else: 57 | self.frontier_field.add_value( 58 | SyntaxToken(self.frontier_field.type, 59 | action.token)) 60 | end_primitive = True 61 | 62 | if end_primitive and self.frontier_field.cardinality in ('single', 'optional'): 63 | self.frontier_field.set_finish() 64 | self.update_frontier_info() 65 | 66 | elif isinstance(action, ReduceAction): 67 | assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 68 | 'applied on field with multiple ' \ 69 | 'cardinality' 70 | self.frontier_field.set_finish() 71 | self.update_frontier_info() 72 | else: 73 | raise ValueError('Can only invoke GenToken or Reduce actions on primitive fields') 74 | 75 | self.t += 1 76 | self.actions.append(action) 77 | 78 | def update_frontier_info(self): 79 | def _find_frontier_node_and_field(tree_node): 80 | if tree_node: 81 | for field in tree_node.fields: 82 | # if it's an intermediate node, check its children 83 | if isinstance(field.type, ASDLCompositeType) and field.value: 84 | if field.cardinality in ('single', 'optional'): iter_values = [field.value] 85 | else: iter_values = field.value 86 | 87 | for child_node in iter_values: 88 | result = _find_frontier_node_and_field(child_node) 89 | if result: return result 90 | 91 | # now all its possible children are checked 92 | if not field.finished: 93 | return tree_node, field 94 | 95 | return None 96 | else: return None 97 | 98 | frontier_info = _find_frontier_node_and_field(self.tree) 99 | if frontier_info: 100 | self.frontier_node, self.frontier_field = frontier_info 101 | else: 102 | self.frontier_node, self.frontier_field = None, None 103 | 104 | def clone_and_apply_action(self, action): 105 | new_hyp = self.copy() 106 | new_hyp.apply_action(action) 107 | 108 | return new_hyp 109 | 110 | def copy(self): 111 | new_hyp = Hypothesis() 112 | if self.tree: 113 | new_hyp.tree = self.tree.copy() 114 | 115 | new_hyp.actions = list(self.actions) 116 | new_hyp.score = self.score 117 | new_hyp._value_buffer = list(self._value_buffer) 118 | new_hyp.t = self.t 119 | 120 | new_hyp.update_frontier_info() 121 | 122 | return new_hyp 123 | 124 | @property 125 | def completed(self): 126 | return self.tree and self.frontier_field is None 127 | -------------------------------------------------------------------------------- /asdl/lang/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/asdl/lang/__init__.py -------------------------------------------------------------------------------- /asdl/lang/csharp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/asdl/lang/csharp/__init__.py -------------------------------------------------------------------------------- /asdl/lang/csharp/csharp_grammar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import re 3 | from collections import OrderedDict 4 | from itertools import chain 5 | from typing import Dict, Union 6 | 7 | from bs4 import BeautifulSoup 8 | import json 9 | 10 | from asdl.asdl import ASDLGrammar, ASDLType, ASDLProduction, ASDLConstructor, Field 11 | 12 | 13 | class CSharpASDLGrammar(ASDLGrammar): 14 | """ 15 | Collection of types, constructors and productions 16 | """ 17 | def __init__(self, productions, root_type): 18 | super().__init__(productions, root_type, language='csharp') 19 | 20 | self._primitive_types = [type for type in self.types if type not in self.type2productions and type.is_leaf] 21 | for type in self._primitive_types: 22 | type.is_composite = False 23 | 24 | self._composite_types = [type for type in self.types if type not in self.primitive_types] 25 | for type in self._composite_types: 26 | type.is_composite = True 27 | 28 | @property 29 | def primitive_types(self): 30 | return self._primitive_types 31 | 32 | @property 33 | def composite_types(self): 34 | return self._composite_types 35 | 36 | def to_json(self): 37 | grammar_rules = [] 38 | for prod in self.productions: 39 | entry = dict(constructor=prod.constructor.name, 40 | fields=[dict(name=f.name, type=f.type.name) for f in prod.constructor.fields]) 41 | grammar_rules.append(entry) 42 | 43 | return json.dumps(grammar_rules, indent=2) 44 | 45 | @classmethod 46 | def from_roslyn_xml(cls, xml_text, pruning=False): 47 | bs = BeautifulSoup(xml_text, 'xml') 48 | token_kinds_to_keep = {'NumericLiteralToken', 'StringLiteralToken', 'CharacterLiteralToken'} 49 | 50 | from bs4 import Tag 51 | 52 | all_types = dict() 53 | productions = [] 54 | generic_list_productions = set() 55 | 56 | # add base type 57 | grammar_root_type = ASDLType('SyntaxNode') 58 | all_types[grammar_root_type.name] = grammar_root_type 59 | 60 | for node in bs.Tree.find_all(lambda x: isinstance(x, Tag), recursive=False): 61 | # process type information 62 | base_type_name = node['Base'] 63 | if base_type_name not in all_types: 64 | all_types[base_type_name] = ASDLType(base_type_name) 65 | base_type = all_types[base_type_name] 66 | 67 | node_name = node['Name'] 68 | if node_name in all_types: 69 | node_type = all_types[node_name] 70 | if node_type not in base_type.child_types: 71 | base_type.add_child(node_type) 72 | else: 73 | node_type = ASDLType(node_name, parent_type=base_type) 74 | all_types[node_type.name] = node_type 75 | 76 | if node.name == 'Node': 77 | fields = [] 78 | for field_node in node.find_all('Field', recursive=False): 79 | field_name = field_node['Name'] 80 | field_type_str = field_node['Type'] 81 | 82 | field_kinds = set(kind['Name'] for kind in field_node.find_all('Kind')) 83 | 84 | if pruning: 85 | # if field_type_str == 'SyntaxToken' and (field_name not in {'Identifier', 'OperatorToken'} and #!= 'Identifier' 86 | # len(field_kinds.intersection(token_kinds_to_keep)) == 0): 87 | # continue 88 | 89 | if field_type_str == 'SyntaxToken' and \ 90 | field_name != 'Identifier' and \ 91 | not (field_name == 'OperatorToken' and node_name in {'BinaryExpressionSyntax', 'AssignmentExpressionSyntax', 92 | 'PostfixUnaryExpressionSyntax', 'PrefixUnaryExpressionSyntax'}) and \ 93 | not (field_name == 'Keyword' and node_name == 'PredefinedTypeSyntax') and \ 94 | len(field_kinds.intersection(token_kinds_to_keep)) == 0: 95 | continue 96 | 97 | if field_type_str not in all_types: 98 | all_types[field_type_str] = ASDLType(field_type_str) 99 | field_type = all_types[field_type_str] 100 | 101 | if 'SyntaxList' in field_type_str: 102 | base_type_name = re.match('\w+<(.*?)>', field_type_str).group(1) 103 | if base_type_name not in all_types: 104 | all_types[base_type_name] = ASDLType(base_type_name) 105 | base_type = all_types[base_type_name] 106 | 107 | production = ASDLProduction(field_type, 108 | ASDLConstructor(field_type.name, fields=[ 109 | Field('Element', base_type, 'multiple')])) 110 | generic_list_productions.add(production) 111 | 112 | field_cardinality = 'optional' if field_node.get('Optional', None) == 'true' else 'single' 113 | field = Field(field_name, field_type, field_cardinality) 114 | fields.append(field) 115 | 116 | constructor = ASDLConstructor(node['Name'], fields) 117 | production = ASDLProduction(node_type, constructor) 118 | productions.append(production) 119 | 120 | productions.extend(generic_list_productions) 121 | grammar = CSharpASDLGrammar(productions, root_type=all_types['CSharpSyntaxNode']) 122 | 123 | return grammar 124 | 125 | def get_ast_from_json_str(self, json_str): 126 | json_obj = json.loads(json_str) 127 | 128 | return self.get_ast_from_json_obj(json_obj) 129 | 130 | def convert_ast_into_json_obj(self, ast_node): 131 | from asdl.asdl_ast import AbstractSyntaxNode, RealizedField, SyntaxToken, AbstractSyntaxTree 132 | 133 | if isinstance(ast_node, SyntaxToken): 134 | entry = OrderedDict(Constructor='SyntaxToken', 135 | Value=ast_node.value, 136 | Position=-1) 137 | else: 138 | entry_fields = dict() 139 | for realized_field in ast_node.fields: 140 | field = realized_field.field 141 | 142 | if 'SyntaxList' in field.type.name: 143 | child_entry = [] 144 | # SyntaxList -> (T* Element) 145 | field_elements = realized_field.value.fields[0].as_value_list 146 | 147 | for field_element_ast in field_elements: 148 | element_ast = self.convert_ast_into_json_obj(field_element_ast) 149 | child_entry.append(element_ast) 150 | elif realized_field.value is not None: 151 | child_entry = self.convert_ast_into_json_obj(realized_field.value) 152 | else: 153 | child_entry = None 154 | 155 | entry_fields[field.name] = child_entry 156 | 157 | constructor_name = ast_node.production.constructor.name 158 | entry = OrderedDict(Constructor=constructor_name, 159 | Fields=entry_fields) 160 | 161 | return entry 162 | 163 | def get_ast_from_json_obj(self, json_obj: Dict): 164 | """read an AST from serialized JSON string""" 165 | # FIXME: cyclic import 166 | from asdl.asdl_ast import AbstractSyntaxNode, RealizedField, SyntaxToken, AbstractSyntaxTree 167 | 168 | def get_subtree(entry, parent_field, next_available_id): 169 | if entry is None: 170 | return None, next_available_id 171 | 172 | constructor_name = entry['Constructor'] 173 | 174 | # terminal case 175 | if constructor_name == 'SyntaxToken': 176 | if entry['Value'] is None: 177 | return None, next_available_id # return None for optional field whose value is null 178 | 179 | token = SyntaxToken(parent_field.type, entry['Value'], position=entry['Position'], id=next_available_id) 180 | next_available_id += 1 181 | 182 | return token, next_available_id 183 | 184 | field_entries = entry['Fields'] 185 | node_id = next_available_id 186 | next_available_id += 1 187 | prod = self.get_prod_by_ctr_name(constructor_name) 188 | realized_fields = [] 189 | for field in prod.constructor.fields: 190 | field_value = field_entries[field.name] 191 | 192 | if isinstance(field_value, list): 193 | assert 'SyntaxList' in field.type.name 194 | 195 | sub_ast_id = next_available_id 196 | next_available_id += 1 197 | 198 | sub_ast_prod = self.get_prod_by_ctr_name(field.type.name) 199 | sub_ast_constr_field = sub_ast_prod.constructor.fields[0] 200 | sub_ast_field_values = [] 201 | for field_child_entry in field_value: 202 | child_sub_ast, next_available_id = get_subtree(field_child_entry, sub_ast_constr_field, next_available_id=next_available_id) 203 | sub_ast_field_values.append(child_sub_ast) 204 | 205 | sub_ast = AbstractSyntaxNode(sub_ast_prod, 206 | [RealizedField(sub_ast_constr_field, 207 | sub_ast_field_values)], 208 | id=sub_ast_id) 209 | 210 | # FIXME: have a global mark_finished method! 211 | for sub_ast_field in sub_ast.fields: 212 | if sub_ast_field.cardinality in ('multiple', 'optional'): 213 | sub_ast_field._not_single_cardinality_finished = True 214 | 215 | realized_field = RealizedField(field, sub_ast) 216 | else: 217 | # if the child is an AST or terminal SyntaxNode 218 | sub_ast, next_available_id = get_subtree(field_value, field, next_available_id) 219 | realized_field = RealizedField(field, sub_ast) 220 | 221 | realized_fields.append(realized_field) 222 | 223 | ast_node = AbstractSyntaxNode(prod, realized_fields, id=node_id) 224 | for field in ast_node.fields: 225 | if field.cardinality in ('multiple', 'optional'): 226 | field._not_single_cardinality_finished = True 227 | 228 | return ast_node, next_available_id 229 | 230 | ast_root, _ = get_subtree(json_obj, parent_field=None, next_available_id=0) 231 | ast = AbstractSyntaxTree(ast_root) 232 | 233 | return ast 234 | -------------------------------------------------------------------------------- /asdl/lang/csharp/csharp_hypothesis.py: -------------------------------------------------------------------------------- 1 | from asdl.asdl_ast import AbstractSyntaxNode 2 | from asdl.hypothesis import Hypothesis 3 | from asdl.lang.csharp.csharp_transition import CSharpTransitionSystem 4 | from asdl.transition_system import ApplySubTreeAction, ApplyRuleAction, ReduceAction, GenTokenAction 5 | 6 | 7 | class CSharpHypothesis(Hypothesis): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def apply_action(self, action): 12 | if self.tree is None: 13 | assert isinstance(action, (ApplyRuleAction, ApplySubTreeAction)), \ 14 | 'Invalid action [%s], only ApplyRule and ApplySutTree actions are valid at the beginning of decoding' 15 | 16 | if isinstance(action, ApplyRuleAction): 17 | self.tree = AbstractSyntaxNode(action.production) 18 | elif isinstance(action, ApplySubTreeAction): 19 | self.tree = action.tree.copy() 20 | 21 | self.tree.created_time = self.t 22 | self.update_frontier_info() 23 | elif self.frontier_node: 24 | if self.frontier_field.type.is_composite: 25 | if isinstance(action, ApplyRuleAction): 26 | field_value = AbstractSyntaxNode(action.production) 27 | field_value.created_time = self.t 28 | self.frontier_field.add_value(field_value) 29 | self.update_frontier_info() 30 | elif isinstance(action, ApplySubTreeAction): 31 | field_value = action.tree.copy() 32 | self.frontier_field.add_value(field_value) 33 | self.update_frontier_info() 34 | elif isinstance(action, ReduceAction): 35 | assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 36 | 'applied on field with multiple ' \ 37 | 'cardinality' 38 | self.frontier_field.set_finish() 39 | self.update_frontier_info() 40 | else: 41 | raise ValueError('Invalid action [%s] on field [%s]' % (action, self.frontier_field)) 42 | else: # fill in a primitive field 43 | if isinstance(action, GenTokenAction): 44 | # only field of type string requires termination signal 45 | if action.token.value == CSharpTransitionSystem.END_OF_SYNTAX_TOKEN_LIST_SYMBOL: 46 | assert self.frontier_field.cardinality in ('optional', 'multiple'), 'Reduce action can only be ' \ 47 | 'applied on field with multiple ' \ 48 | 'cardinality' 49 | self.frontier_field.set_finish() 50 | self.update_frontier_info() 51 | else: 52 | self.frontier_field.add_value(action.token) 53 | 54 | if self.frontier_field.cardinality in ('single', 'optional'): 55 | self.frontier_field.set_finish() 56 | self.update_frontier_info() 57 | else: 58 | raise ValueError('Can only invoke GenToken actions on primitive fields') 59 | 60 | self.t += 1 61 | self.actions.append(action) 62 | 63 | def update_frontier_info(self): 64 | def _find_frontier_node_and_field(tree_node): 65 | if tree_node: 66 | for field in tree_node.fields: 67 | # if it's an intermediate node, check its children 68 | if field.type.is_composite and field.value: 69 | if field.cardinality in ('single', 'optional'): iter_values = [field.value] 70 | else: iter_values = field.value 71 | 72 | for child_node in iter_values: 73 | result = _find_frontier_node_and_field(child_node) 74 | if result: return result 75 | 76 | # now all its possible children are checked 77 | if not field.finished: 78 | return tree_node, field 79 | 80 | return None 81 | else: return None 82 | 83 | frontier_info = _find_frontier_node_and_field(self.tree) 84 | if frontier_info: 85 | self.frontier_node, self.frontier_field = frontier_info 86 | else: 87 | self.frontier_node, self.frontier_field = None, None 88 | 89 | def copy(self): 90 | new_hyp = CSharpHypothesis() 91 | if self.tree: 92 | new_hyp.tree = self.tree.copy() 93 | 94 | new_hyp.actions = list(self.actions) 95 | new_hyp.score = self.score 96 | new_hyp._value_buffer = list(self._value_buffer) 97 | new_hyp.t = self.t 98 | 99 | new_hyp.update_frontier_info() 100 | 101 | return new_hyp 102 | -------------------------------------------------------------------------------- /asdl/lang/csharp/demo.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from asdl.hypothesis import ApplyRuleAction, GenTokenAction 4 | from asdl.lang.csharp.csharp_hypothesis import CSharpHypothesis 5 | from asdl.lang.csharp.csharp_transition import CSharpTransitionSystem 6 | from asdl.lang.csharp.csharp_grammar import CSharpASDLGrammar 7 | 8 | 9 | if __name__ == '__main__': 10 | # csharp_grammar_text = request.urlopen('https://raw.githubusercontent.com/dotnet/roslyn/master/src/Compilers' 11 | # '/CSharp/Portable/Syntax/Syntax.xml').read() 12 | csharp_grammar_text = open('Syntax.xml').read() 13 | # fields_to_ignore = ['SemicolonToken', 'OpenBraceToken', 'CloseBraceToken', 'CommaToken', 'ColonToken', 14 | # 'StartQuoteToken', 'EndQuoteToken', 'OpenBracketToken', 'CloseBracketToken', 'NewKeyword'] 15 | 16 | grammar = CSharpASDLGrammar.from_roslyn_xml(csharp_grammar_text, pruning=True) 17 | 18 | open('grammar.json', 'w').write(grammar.to_json()) 19 | 20 | ast_json_list = open('../../../source_data/githubedits/githubedits.train_20p.jsonl').readlines() 21 | ast_json = ast_json_list[0] 22 | 23 | ast_json_obj = json.loads(ast_json)['PrevCodeAST'] 24 | syntax_tree = grammar.get_ast_from_json_obj(ast_json_obj) 25 | ast_root = syntax_tree.root_node 26 | print(ast_root.to_string()) 27 | print(ast_root.size) 28 | 29 | ast_json_obj_updt = json.loads(ast_json)['UpdatedCodeAST'] 30 | syntax_tree_updt = grammar.get_ast_from_json_obj(ast_json_obj_updt) 31 | ast_root_updt = syntax_tree_updt.root_node 32 | print(ast_root_updt.to_string()) 33 | print(ast_root_updt.size) 34 | 35 | transition = CSharpTransitionSystem(grammar) 36 | actions = transition.get_actions(ast_root) 37 | decode_actions = transition.get_decoding_actions(syntax_tree) 38 | 39 | print('Len actions:', len(decode_actions)) 40 | 41 | with open('actions.txt', 'w') as f: 42 | for action in actions: 43 | f.write(str(action) + '\n') 44 | 45 | hyp = CSharpHypothesis() 46 | for action, decode_action in zip(actions, decode_actions): 47 | assert action.__class__ in transition.get_valid_continuation_types(hyp) 48 | if isinstance(action, ApplyRuleAction): 49 | assert action.production in transition.get_valid_continuating_productions(hyp) 50 | assert action.production == decode_action.production 51 | elif isinstance(action, GenTokenAction): 52 | assert action.token == decode_action.token 53 | 54 | if hyp.frontier_node: 55 | assert hyp.frontier_field == decode_action.frontier_field 56 | assert hyp.frontier_node.production == decode_action.frontier_prod 57 | 58 | hyp.apply_action(action) 59 | print(hyp.tree.to_string() == ast_root.to_string()) 60 | 61 | 62 | -------------------------------------------------------------------------------- /asdl/lang/csharp/demo_edits.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | import pickle 4 | 5 | from asdl.lang.csharp.csharp_transition import CSharpTransitionSystem 6 | from asdl.lang.csharp.csharp_grammar import CSharpASDLGrammar 7 | from trees.substitution_system import SubstitutionSystem 8 | 9 | 10 | def _encode(word_list): 11 | return [w.replace('\n', '-NEWLINE-') for w in word_list] 12 | 13 | 14 | if __name__ == '__main__': 15 | # csharp_grammar_text = request.urlopen('https://raw.githubusercontent.com/dotnet/roslyn/master/src/Compilers' 16 | # '/CSharp/Portable/Syntax/Syntax.xml').read() 17 | csharp_grammar_text = open('Syntax.xml').read() 18 | # fields_to_ignore = ['SemicolonToken', 'OpenBraceToken', 'CloseBraceToken', 'CommaToken', 'ColonToken', 19 | # 'StartQuoteToken', 'EndQuoteToken', 'OpenBracketToken', 'CloseBracketToken', 'NewKeyword'] 20 | 21 | grammar = CSharpASDLGrammar.from_roslyn_xml(csharp_grammar_text, pruning=True) 22 | 23 | open('grammar.json', 'w').write(grammar.to_json()) 24 | # tgt_edits_outputs = [] 25 | 26 | ast_json_lines = open('../../../source_data/githubedits/githubedits.train_20p.jsonl').readlines() 27 | for ast_json_idx in tqdm(range(len(ast_json_lines))): 28 | ast_json = ast_json_lines[ast_json_idx] 29 | loaded_ast_json = json.loads(ast_json) 30 | 31 | ast_json_obj_prev = loaded_ast_json['PrevCodeAST'] 32 | syntax_tree_prev = grammar.get_ast_from_json_obj(ast_json_obj_prev) 33 | # ast_root_prev = syntax_tree_prev.root_node 34 | # print(loaded_ast_json['PrevCodeChunk']) 35 | # print(ast_root_prev.to_string(), "\n") 36 | # print(ast_root_prev.size) 37 | 38 | ast_json_obj_updated = loaded_ast_json['UpdatedCodeAST'] 39 | syntax_tree_updated = grammar.get_ast_from_json_obj(ast_json_obj_updated) 40 | # ast_root_updated = syntax_tree_updated.root_node 41 | # print(loaded_ast_json['UpdatedCodeChunk']) 42 | # print(ast_root_updated.to_string(), "\n") 43 | # print(ast_root_updated.size) 44 | 45 | transition = CSharpTransitionSystem(grammar) 46 | # actions = transition.get_actions(ast_root_updated) 47 | # decode_actions = transition.get_decoding_actions( 48 | # target_ast=syntax_tree_updated, 49 | # prev_ast=syntax_tree_prev, 50 | # copy_identifier=True) 51 | # # print('Len actions:', len(decode_actions)) 52 | # decode_action_lens.append(len(decode_actions)) 53 | # non_reduce_decode_action_lens.append(len(list(filter( 54 | # lambda x: not isinstance(x, ReduceAction), decode_actions)))) 55 | 56 | syntax_tree_prev.reindex_w_dummy_reduce() 57 | syntax_tree_updated.reindex_w_dummy_reduce() 58 | previous_code_chunk = _encode(loaded_ast_json['PrevCodeChunkTokens']) 59 | substitution_system = SubstitutionSystem(transition) 60 | tgt_edits = substitution_system.get_decoding_edits_fast(syntax_tree_prev, syntax_tree_updated, 61 | bool_copy_subtree=True, 62 | init_code_tokens=previous_code_chunk, 63 | bool_debug=True) 64 | 65 | # tgt_edits_outputs.append(tgt_edits) 66 | 67 | # pickle.dump(tgt_edits_outputs, open('./tgt_edits_outputs.pkl', 'wb')) 68 | -------------------------------------------------------------------------------- /asdl/transition_system.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | class Action(object): 5 | pass 6 | 7 | 8 | class ApplyRuleAction(Action): 9 | def __init__(self, production): 10 | self.production = production 11 | 12 | def __hash__(self): 13 | return hash(self.production) 14 | 15 | def __eq__(self, other): 16 | return isinstance(other, ApplyRuleAction) and self.production == other.production 17 | 18 | def __ne__(self, other): 19 | return not self.__eq__(other) 20 | 21 | def __repr__(self): 22 | return 'ApplyRule[%s]' % self.production.__repr__() 23 | 24 | 25 | class GenTokenAction(Action): 26 | def __init__(self, token): 27 | self.token = token 28 | 29 | def is_stop_signal(self): 30 | return self.token == '' 31 | 32 | def __repr__(self): 33 | return 'GenToken[%s]' % self.token 34 | 35 | def __eq__(self, other): 36 | return isinstance(other, GenTokenAction) and self.token == other.token 37 | 38 | 39 | class ReduceAction(Action): 40 | def __repr__(self): 41 | return 'Reduce' 42 | 43 | def __eq__(self, other): 44 | return isinstance(other, ReduceAction) 45 | 46 | 47 | class ApplySubTreeAction(Action): 48 | def __init__(self, tree, tree_node_ids=-1): 49 | self.tree = tree 50 | self.tree_node_ids = tree_node_ids 51 | 52 | def __repr__(self): 53 | return 'ApplySubTree[%s], Node[%d]' % (repr(self.tree), self.tree.id) 54 | 55 | 56 | class DecodingAction: 57 | def __init__(self, t, parent_t, frontier_prod, frontier_field, preceding_syntax_token_index): 58 | self.t = t 59 | self.parent_t = parent_t 60 | self.frontier_prod = frontier_prod 61 | self.frontier_field = frontier_field 62 | self.preceding_syntax_token_index = preceding_syntax_token_index 63 | 64 | 65 | class ApplyRuleDecodingAction(ApplyRuleAction, DecodingAction): 66 | def __init__(self, t, parent_t, frontier_prod, frontier_field, production, preceding_syntax_token_index=None): 67 | ApplyRuleAction.__init__(self, production) 68 | DecodingAction.__init__(self, t, parent_t, frontier_prod, frontier_field, preceding_syntax_token_index) 69 | 70 | 71 | class ApplySubTreeDecodingAction(ApplySubTreeAction, DecodingAction): 72 | def __init__(self, t, parent_t, frontier_prod, frontier_field, tree, tree_node_ids, 73 | preceding_syntax_token_index=None): 74 | ApplySubTreeAction.__init__(self, tree, tree_node_ids) 75 | DecodingAction.__init__(self, t, parent_t, frontier_prod, frontier_field, preceding_syntax_token_index) 76 | 77 | 78 | class ReduceDecodingAction(ReduceAction, DecodingAction): 79 | def __init__(self, t, parent_t, frontier_prod, frontier_field, preceding_syntax_token_index=None): 80 | DecodingAction.__init__(self, t, parent_t, frontier_prod, frontier_field, preceding_syntax_token_index) 81 | 82 | 83 | class GenTokenDecodingAction(GenTokenAction, DecodingAction): 84 | def __init__(self, t, parent_t, frontier_prod, frontier_field, token, preceding_syntax_token_index=None): 85 | GenTokenAction.__init__(self, token) 86 | DecodingAction.__init__(self, t, parent_t, frontier_prod, frontier_field, preceding_syntax_token_index) 87 | 88 | 89 | class TransitionSystem(object): 90 | def __init__(self, grammar): 91 | self.grammar = grammar 92 | 93 | def get_actions(self, asdl_ast): 94 | """ 95 | generate action sequence given the ASDL Syntax Tree 96 | """ 97 | 98 | actions = [] 99 | 100 | parent_action = ApplyRuleAction(asdl_ast.production) 101 | actions.append(parent_action) 102 | 103 | for field in asdl_ast.fields: 104 | # is a composite field 105 | if self.grammar.is_composite_type(field.type): 106 | if field.cardinality == 'single': 107 | field_actions = self.get_actions(field.value) 108 | else: 109 | field_actions = [] 110 | 111 | if field.value is not None: 112 | if field.cardinality == 'multiple': 113 | for val in field.value: 114 | cur_child_actions = self.get_actions(val) 115 | field_actions.extend(cur_child_actions) 116 | elif field.cardinality == 'optional': 117 | field_actions = self.get_actions(field.value) 118 | 119 | # if an optional field is filled, then do not need Reduce action 120 | if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions: 121 | field_actions.append(ReduceAction()) 122 | else: # is a primitive field 123 | field_actions = self.get_primitive_field_actions(field) 124 | 125 | # if an optional field is filled, then do not need Reduce action 126 | if field.cardinality == 'multiple' or field.cardinality == 'optional' and not field_actions: 127 | # reduce action 128 | field_actions.append(ReduceAction()) 129 | 130 | actions.extend(field_actions) 131 | 132 | return actions 133 | 134 | def tokenize_code(self, code, mode): 135 | raise NotImplementedError 136 | 137 | def compare_ast(self, hyp_ast, ref_ast): 138 | raise NotImplementedError 139 | 140 | def ast_to_surface_code(self, asdl_ast): 141 | raise NotImplementedError 142 | 143 | def surface_code_to_ast(self, code): 144 | raise NotImplementedError 145 | 146 | def get_primitive_field_actions(self, realized_field): 147 | raise NotImplementedError 148 | 149 | def get_valid_continuation_types(self, hyp): 150 | if hyp.tree: 151 | if self.grammar.is_composite_type(hyp.frontier_field.type): 152 | if hyp.frontier_field.cardinality == 'single': 153 | return ApplyRuleAction, 154 | else: # optional, multiple 155 | return ApplyRuleAction, ReduceAction 156 | else: 157 | if hyp.frontier_field.cardinality == 'single': 158 | return GenTokenAction, 159 | elif hyp.frontier_field.cardinality == 'optional': 160 | if hyp._value_buffer: 161 | return GenTokenAction, 162 | else: 163 | return GenTokenAction, ReduceAction 164 | else: 165 | return GenTokenAction, ReduceAction 166 | else: 167 | return ApplyRuleAction, 168 | 169 | def get_valid_continuating_productions(self, hyp): 170 | if hyp.tree: 171 | if self.grammar.is_composite_type(hyp.frontier_field.type): 172 | return self.grammar[hyp.frontier_field.type] 173 | else: 174 | raise ValueError 175 | else: 176 | return self.grammar[self.grammar.root_type] 177 | 178 | @staticmethod 179 | def get_class_by_lang(lang): 180 | if lang == 'python': 181 | from .lang.py.py_transition_system import PythonTransitionSystem 182 | return PythonTransitionSystem 183 | elif lang == 'python3': 184 | from .lang.py3.py3_transition_system import Python3TransitionSystem 185 | return Python3TransitionSystem 186 | elif lang == 'lambda_dcs': 187 | from .lang.lambda_dcs.lambda_dcs_transition_system import LambdaCalculusTransitionSystem 188 | return LambdaCalculusTransitionSystem 189 | elif lang == 'prolog': 190 | from .lang.prolog.prolog_transition_system import PrologTransitionSystem 191 | return PrologTransitionSystem 192 | 193 | raise ValueError('unknown language %s' % lang) 194 | -------------------------------------------------------------------------------- /asdl/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import re 3 | 4 | 5 | def remove_comment(text): 6 | text = re.sub(re.compile("#.*"), "", text) 7 | text = '\n'.join(filter(lambda x: x, text.split('\n'))) 8 | 9 | return text 10 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/common/__init__.py -------------------------------------------------------------------------------- /common/registerable.py: -------------------------------------------------------------------------------- 1 | class Registrable(object): 2 | """ 3 | A class that collects all registered components, 4 | adapted from `common.registrable.Registrable` from AllenNLP 5 | """ 6 | registered_components = dict() 7 | 8 | @staticmethod 9 | def register(name): 10 | def register_class(cls): 11 | if name in Registrable.registered_components: 12 | raise RuntimeError('class %s already registered' % name) 13 | 14 | Registrable.registered_components[name] = cls 15 | return cls 16 | 17 | return register_class 18 | 19 | @staticmethod 20 | def by_name(name): 21 | return Registrable.registered_components[name] 22 | -------------------------------------------------------------------------------- /common/savable.py: -------------------------------------------------------------------------------- 1 | class Savable(object): 2 | @staticmethod 3 | def load(model_path, *args, **kwargs): 4 | pass 5 | 6 | @staticmethod 7 | def save(model_path, *args, **kwargs): 8 | pass 9 | -------------------------------------------------------------------------------- /common/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | 4 | 5 | class cached_property(object): 6 | """ A property that is only computed once per instance and then replaces 7 | itself with an ordinary attribute. Deleting the attribute resets the 8 | property. 9 | 10 | Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 11 | """ 12 | 13 | def __init__(self, func): 14 | self.__doc__ = getattr(func, '__doc__') 15 | self.func = func 16 | 17 | def __get__(self, obj, cls): 18 | if obj is None: 19 | return self 20 | value = obj.__dict__[self.func.__name__] = self.func(obj) 21 | return value 22 | 23 | 24 | def init_arg_parser(): 25 | arg_parser = argparse.ArgumentParser() 26 | 27 | #### General configuration #### 28 | arg_parser.add_argument('--seed', default=0, type=int, help='Random seed') 29 | arg_parser.add_argument('--cuda', action='store_true', default=False, help='Use gpu') 30 | arg_parser.add_argument('--lang', choices=['python', 'lambda_dcs', 'prolog', 'python3'], default='python', 31 | help='[Deprecated] language to parse. Deprecated, use --transition_system and --parser instead') 32 | arg_parser.add_argument('--asdl_file', type=str, help='Path to ASDL grammar specification') 33 | arg_parser.add_argument('--mode', choices=['train', 'test', 'interactive', 34 | 'train_paraphrase_identifier', 'train_reconstructor','rerank'], required=True, help='Run mode') 35 | 36 | #### Modularized configuration #### 37 | arg_parser.add_argument('--parser', type=str, default='default_parser', required=False, help='name of parser class to load') 38 | arg_parser.add_argument('--transition_system', type=str, default='python2', required=False, help='name of transition system to use') 39 | arg_parser.add_argument('--evaluator', type=str, default='default_evaluator', required=False, help='name of evaluator class to use') 40 | 41 | #### Model configuration #### 42 | arg_parser.add_argument('--lstm', choices=['lstm'], default='lstm', help='Type of LSTM used, currently only standard LSTM cell is supported') 43 | 44 | # Embedding sizes 45 | arg_parser.add_argument('--embed_size', default=128, type=int, help='Size of word embeddings') 46 | arg_parser.add_argument('--action_embed_size', default=128, type=int, help='Size of ApplyRule/GenToken action embeddings') 47 | arg_parser.add_argument('--field_embed_size', default=64, type=int, help='Embedding size of ASDL fields') 48 | arg_parser.add_argument('--type_embed_size', default=64, type=int, help='Embeddings ASDL types') 49 | 50 | # Hidden sizes 51 | arg_parser.add_argument('--hidden_size', default=256, type=int, help='Size of LSTM hidden states') 52 | arg_parser.add_argument('--ptrnet_hidden_dim', default=32, type=int, help='Hidden dimension used in pointer network') 53 | arg_parser.add_argument('--att_vec_size', default=256, type=int, help='size of attentional vector') 54 | 55 | # readout layer 56 | arg_parser.add_argument('--no_query_vec_to_action_map', default=False, action='store_true', 57 | help='Do not use additional linear layer to transform the attentional vector for computing action probabilities') 58 | arg_parser.add_argument('--readout', default='identity', choices=['identity', 'non_linear'], 59 | help='Type of activation if using additional linear layer') 60 | arg_parser.add_argument('--query_vec_to_action_diff_map', default=False, action='store_true', 61 | help='Use different linear mapping ') 62 | 63 | # supervised attention 64 | arg_parser.add_argument('--sup_attention', default=False, action='store_true', help='Use supervised attention') 65 | 66 | # parent information switch for decoder LSTM 67 | arg_parser.add_argument('--no_parent_production_embed', default=False, action='store_true', 68 | help='Do not use embedding of parent ASDL production to update decoder LSTM state') 69 | arg_parser.add_argument('--no_parent_field_embed', default=False, action='store_true', 70 | help='Do not use embedding of parent field to update decoder LSTM state') 71 | arg_parser.add_argument('--no_parent_field_type_embed', default=False, action='store_true', 72 | help='Do not use embedding of the ASDL type of parent field to update decoder LSTM state') 73 | arg_parser.add_argument('--no_parent_state', default=False, action='store_true', 74 | help='Do not use the parent hidden state to update decoder LSTM state') 75 | 76 | arg_parser.add_argument('--no_input_feed', default=False, action='store_true', help='Do not use input feeding in decoder LSTM') 77 | arg_parser.add_argument('--no_copy', default=False, action='store_true', help='Do not use copy mechanism') 78 | 79 | #### Training #### 80 | arg_parser.add_argument('--vocab', type=str, help='Path of the serialized vocabulary') 81 | arg_parser.add_argument('--glove_embed_path', default=None, type=str, help='Path to pretrained Glove mebedding') 82 | 83 | arg_parser.add_argument('--train_file', type=str, help='path to the training target file') 84 | arg_parser.add_argument('--dev_file', type=str, help='path to the dev source file') 85 | arg_parser.add_argument('--pretrain', type=str, help='path to the pretrained model file') 86 | 87 | arg_parser.add_argument('--batch_size', default=10, type=int, help='Batch size') 88 | arg_parser.add_argument('--dropout', default=0., type=float, help='Dropout rate') 89 | arg_parser.add_argument('--word_dropout', default=0., type=float, help='Word dropout rate') 90 | arg_parser.add_argument('--decoder_word_dropout', default=0.3, type=float, help='Word dropout rate on decoder') 91 | arg_parser.add_argument('--primitive_token_label_smoothing', default=0.0, type=float, 92 | help='Apply label smoothing when predicting primitive tokens') 93 | arg_parser.add_argument('--src_token_label_smoothing', default=0.0, type=float, 94 | help='Apply label smoothing in reconstruction model when predicting source tokens') 95 | 96 | arg_parser.add_argument('--negative_sample_type', default='best', type=str, choices=['best', 'sample', 'all']) 97 | 98 | # training schedule details 99 | arg_parser.add_argument('--valid_metric', default='acc', choices=['acc'], 100 | help='Metric used for validation') 101 | arg_parser.add_argument('--valid_every_epoch', default=1, type=int, help='Perform validation every x epoch') 102 | arg_parser.add_argument('--log_every', default=10, type=int, help='Log training statistics every n iterations') 103 | 104 | arg_parser.add_argument('--save_to', default='model', type=str, help='Save trained model to') 105 | arg_parser.add_argument('--save_all_models', default=False, action='store_true', help='Save all intermediate checkpoints') 106 | arg_parser.add_argument('--patience', default=5, type=int, help='Training patience') 107 | arg_parser.add_argument('--max_num_trial', default=10, type=int, help='Stop training after x number of trials') 108 | arg_parser.add_argument('--uniform_init', default=None, type=float, 109 | help='If specified, use uniform initialization for all parameters') 110 | arg_parser.add_argument('--glorot_init', default=False, action='store_true', help='Use glorot initialization') 111 | arg_parser.add_argument('--clip_grad', default=5., type=float, help='Clip gradients') 112 | arg_parser.add_argument('--max_epoch', default=-1, type=int, help='Maximum number of training epoches') 113 | arg_parser.add_argument('--optimizer', default='Adam', type=str, help='optimizer') 114 | arg_parser.add_argument('--lr', default=0.001, type=float, help='Learning rate') 115 | arg_parser.add_argument('--lr_decay', default=0.5, type=float, 116 | help='decay learning rate if the validation performance drops') 117 | arg_parser.add_argument('--lr_decay_after_epoch', default=0, type=int, help='Decay learning rate after x epoch') 118 | arg_parser.add_argument('--decay_lr_every_epoch', action='store_true', default=False, help='force to decay learning rate after each epoch') 119 | arg_parser.add_argument('--reset_optimizer', action='store_true', default=False, help='Whether to reset optimizer when loading the best checkpoint') 120 | arg_parser.add_argument('--verbose', action='store_true', default=False, help='Verbose mode') 121 | arg_parser.add_argument('--eval_top_pred_only', action='store_true', default=False, 122 | help='Only evaluate the top prediction in validation') 123 | 124 | #### decoding/validation/testing #### 125 | arg_parser.add_argument('--load_model', default=None, type=str, help='Load a pre-trained model') 126 | arg_parser.add_argument('--beam_size', default=5, type=int, help='Beam size for beam search') 127 | arg_parser.add_argument('--decode_max_time_step', default=100, type=int, help='Maximum number of time steps used ' 128 | 'in decoding and sampling') 129 | arg_parser.add_argument('--sample_size', default=5, type=int, help='Sample size') 130 | arg_parser.add_argument('--test_file', type=str, help='Path to the test file') 131 | arg_parser.add_argument('--save_decode_to', default=None, type=str, help='Save decoding results to file') 132 | 133 | #### reranking #### 134 | arg_parser.add_argument('--features', nargs='+') 135 | arg_parser.add_argument('--load_reconstruction_model', type=str, help='Load reconstruction model') 136 | arg_parser.add_argument('--load_paraphrase_model', type=str, help='Load paraphrase model') 137 | arg_parser.add_argument('--load_reranker', type=str, help='Load reranking model') 138 | arg_parser.add_argument('--tie_embed', action='store_true', help='tie source and target embedding in training paraphrasing model') 139 | arg_parser.add_argument('--train_decode_file', default=None, type=str, help='Decoding results on training set') 140 | arg_parser.add_argument('--test_decode_file', default=None, type=str, help='Decoding results on test set') 141 | arg_parser.add_argument('--dev_decode_file', default=None, type=str, help='Decoding results on dev set') 142 | arg_parser.add_argument('--metric', default='accuracy', choices=['bleu', 'accuracy']) 143 | arg_parser.add_argument('--num_workers', default=1, type=int, help='number of multiprocess workers') 144 | 145 | #### self-training #### 146 | arg_parser.add_argument('--load_decode_results', default=None, type=str) 147 | arg_parser.add_argument('--unsup_loss_weight', default=1., type=float, help='loss of unsupervised learning weight') 148 | arg_parser.add_argument('--unlabeled_file', type=str, help='Path to the training source file used in semi-supervised self-training') 149 | 150 | #### interactive mode #### 151 | arg_parser.add_argument('--example_preprocessor', default=None, type=str, help='name of the class that is used to pre-process raw input examples') 152 | 153 | return arg_parser 154 | 155 | 156 | def update_args(args, arg_parser): 157 | for action in arg_parser._actions: 158 | if isinstance(action, argparse._StoreAction) or isinstance(action, argparse._StoreTrueAction) \ 159 | or isinstance(action, argparse._StoreFalseAction): 160 | if not hasattr(args, action.dest): 161 | setattr(args, action.dest, action.default) 162 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/githubedits/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/datasets/githubedits/common/__init__.py -------------------------------------------------------------------------------- /datasets/githubedits/common/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import OrderedDict 3 | 4 | 5 | def isfloat(x): 6 | try: 7 | a = float(x) 8 | except ValueError: 9 | return False 10 | else: 11 | return True 12 | 13 | 14 | def isint(x): 15 | try: 16 | if '.' in x: 17 | return False 18 | 19 | a = float(x) 20 | b = int(a) 21 | except: 22 | return False 23 | else: 24 | return a == b 25 | 26 | 27 | class Arguments(OrderedDict): 28 | EVAL_ARGS = {'--gnn_layer_timesteps', '--gnn_residual_connections', '--gnn_connections'} 29 | 30 | def __init__(self, *args, **kwargs): 31 | super(Arguments, self).__init__(*args, **kwargs) 32 | 33 | @staticmethod 34 | def from_file(file_path, cmd_args=None): 35 | config = json.load(open(file_path, 'r')) 36 | 37 | args = Arguments(config) 38 | args['cmd_args'] = cmd_args 39 | 40 | return args 41 | 42 | def to_string(self): 43 | return json.dumps(self, indent=2) 44 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | 4 | class ExampleProcessor(object): 5 | """ 6 | Process a raw input utterance using domain-specific procedures (e.g., stemming), 7 | and post-process a generated hypothesis to the final form 8 | """ 9 | def pre_process_utterance(self, utterance): 10 | raise NotImplementedError 11 | 12 | def post_process_hypothesis(self, hyp, meta_info, **kwargs): 13 | raise NotImplementedError 14 | 15 | 16 | def get_example_processor_cls(dataset): 17 | if dataset == 'conala': 18 | from datasets.conala.example_processor import ConalaExampleProcessor 19 | return ConalaExampleProcessor 20 | else: 21 | raise RuntimeError() 22 | -------------------------------------------------------------------------------- /edit_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/edit_components/__init__.py -------------------------------------------------------------------------------- /edit_components/change_entry.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import difflib 3 | 4 | from edit_components.diff_utils import TokenLevelDiffer 5 | 6 | 7 | class ChangeExample: 8 | def __init__(self, prev_data: List[str], updated_data: List[str], context: List[str], 9 | raw_prev_data: str=None, raw_updated_data: str=None, 10 | id: str='default_id', **kwargs): 11 | self.id = id 12 | 13 | self.prev_data = prev_data 14 | self.updated_data = updated_data 15 | 16 | self.raw_prev_data = raw_prev_data 17 | self.raw_updated_data = raw_updated_data 18 | 19 | self.context = context 20 | 21 | diff_hunk = '\n'.join(list(x.strip('\n') if x.startswith('@') else x 22 | for x in difflib.unified_diff(a=prev_data, b=updated_data, 23 | n=len(self.prev_data) + len(self.updated_data), 24 | lineterm=''))[2:]) 25 | self.diff_hunk = diff_hunk 26 | 27 | self._init_change_seq() 28 | 29 | self.__dict__.update(kwargs) 30 | 31 | def _init_change_seq(self): 32 | differ = TokenLevelDiffer() 33 | diff_result = differ.unified_format(dict(diff=self.diff_hunk)) 34 | change_seq = [] 35 | 36 | prev_token_ptr = updated_token_ptr = 0 37 | for i, (added, removed, same) in enumerate(zip(diff_result.added, diff_result.removed, diff_result.same)): 38 | if same is not None: 39 | tag = 'SAME' 40 | token = same 41 | 42 | assert self.prev_data[prev_token_ptr] == self.updated_data[updated_token_ptr] == token 43 | 44 | prev_token_ptr += 1 45 | updated_token_ptr += 1 46 | elif added is not None and removed is not None: 47 | tag = 'REPLACE' 48 | token = (removed, added) 49 | 50 | assert self.prev_data[prev_token_ptr] == removed 51 | assert self.updated_data[updated_token_ptr] == added 52 | 53 | prev_token_ptr += 1 54 | updated_token_ptr += 1 55 | elif added is not None and removed is None: 56 | tag = 'ADD' 57 | token = added 58 | 59 | assert self.updated_data[updated_token_ptr] == added 60 | 61 | updated_token_ptr += 1 62 | elif added is None and removed is not None: 63 | tag = 'DEL' 64 | token = removed 65 | 66 | assert self.prev_data[prev_token_ptr] == removed 67 | 68 | prev_token_ptr += 1 69 | else: 70 | raise ValueError('unknown change entry') 71 | 72 | change_seq.append((tag, token)) 73 | 74 | setattr(self, 'change_seq', change_seq) 75 | -------------------------------------------------------------------------------- /edit_components/change_graph.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from urllib import request 3 | 4 | from asdl.asdl_ast import AbstractSyntaxTree, AbstractSyntaxNode 5 | from edit_components.dataset import DataSet 6 | 7 | 8 | class ChangeGraph(object): 9 | def __init__(self): 10 | pass 11 | 12 | @staticmethod 13 | def build_change_graph(old_ast: AbstractSyntaxTree, new_ast: AbstractSyntaxTree): 14 | equality_links = [] 15 | 16 | def _modify_id(node, prefix=''): 17 | node.id = f'{prefix}-{node.id}' 18 | if isinstance(node, AbstractSyntaxNode): 19 | for field in node.fields: 20 | for field_val in field.as_value_list: 21 | _modify_id(field_val, prefix) 22 | 23 | old_ast_root_copy = old_ast.root_node.copy() 24 | _modify_id(old_ast_root_copy, 'old') 25 | 26 | new_ast_root_copy = new_ast.root_node.copy() 27 | _modify_id(new_ast_root_copy, 'new') 28 | 29 | old_ast = AbstractSyntaxTree(old_ast_root_copy) 30 | new_ast = AbstractSyntaxTree(new_ast_root_copy) 31 | 32 | def _search_common_sub_tree(tgt_ast_node): 33 | node_query_result = old_ast.find_node(tgt_ast_node) 34 | if node_query_result: 35 | src_node_id, src_node = node_query_result 36 | tgt_ast_node.parent_field.replace(tgt_ast_node, src_node) 37 | 38 | # register this link 39 | equality_links.append((tgt_ast_node.id, src_node.id)) 40 | else: 41 | for field in tgt_ast_node.fields: 42 | if field.type.is_composite: 43 | for field_val in field.as_value_list: 44 | _search_common_sub_tree(field_val) 45 | 46 | _search_common_sub_tree(new_ast.root_node) 47 | 48 | visited = set() 49 | adjacency_list = [] 50 | 51 | def _visit(node, parent_node): 52 | if parent_node: 53 | adjacency_list.append((parent_node.id, node.id)) 54 | 55 | if node.id in visited: 56 | return 57 | 58 | if isinstance(node, AbstractSyntaxNode): 59 | for field in node.fields: 60 | for field_val in field.as_value_list: 61 | _visit(field_val, node) 62 | 63 | visited.add(node.id) 64 | 65 | _visit(old_ast.root_node, None) 66 | _visit(new_ast.root_node, None) 67 | pass 68 | 69 | 70 | if __name__ == '__main__': 71 | dataset_path = 'data/commit_files.from_repo.processed.071009.jsonl.dev.top100' 72 | grammar_text = request.urlopen('https://raw.githubusercontent.com/dotnet/roslyn/master/src/Compilers' 73 | '/CSharp/Portable/Syntax/Syntax.xml').read() 74 | 75 | from asdl.lang.csharp.csharp_grammar import CSharpASDLGrammar 76 | from asdl.lang.csharp.csharp_transition import CSharpTransitionSystem 77 | 78 | grammar = CSharpASDLGrammar.from_roslyn_xml(grammar_text, pruning=True) 79 | transition_system = CSharpTransitionSystem(grammar) 80 | 81 | print('Loading datasets...', file=sys.stderr) 82 | dataset = DataSet.load_from_jsonl(dataset_path, type='tree2tree_subtree_copy', 83 | transition_system=transition_system, 84 | parallel=False, 85 | debug=True) 86 | for example in dataset: 87 | list1 = list(example.prev_code_ast.id2node.values()) 88 | list2 = list(example.prev_code_ast.root_node.descendant_nodes_and_tokens) 89 | 90 | assert list1 == list2 91 | 92 | ChangeGraph.build_change_graph(example.prev_code_ast, example.updated_code_ast) 93 | -------------------------------------------------------------------------------- /edit_components/dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | from torch.utils.data import dataloader 4 | from torch.multiprocessing import reductions 5 | from multiprocessing.reduction import ForkingPickler 6 | 7 | default_collate_func = dataloader.default_collate 8 | 9 | 10 | def default_collate_override(batch): 11 | dataloader._use_shared_memory = False 12 | return default_collate_func(batch) 13 | 14 | setattr(dataloader, 'default_collate', default_collate_override) 15 | 16 | for t in torch._storage_classes: 17 | if sys.version_info[0] == 2: 18 | if t in ForkingPickler.dispatch: 19 | del ForkingPickler.dispatch[t] 20 | else: 21 | if t in ForkingPickler._extra_reducers: 22 | del ForkingPickler._extra_reducers[t] 23 | 24 | 25 | import json 26 | import sys 27 | from collections import OrderedDict 28 | from functools import partial 29 | from multiprocessing import Pool 30 | 31 | import numpy as np 32 | 33 | from asdl.lang.csharp.csharp_hypothesis import CSharpHypothesis 34 | from asdl.lang.csharp.csharp_transition import ApplyRuleAction, ApplySubTreeAction 35 | from edit_components.change_entry import ChangeExample 36 | from edit_model.edit_encoder import SequentialChangeEncoder, GraphChangeEncoder, BagOfEditsChangeEncoder 37 | from edit_model.encdec import SequentialDecoder 38 | 39 | 40 | def _encode(word_list): 41 | return [w.replace('\n', '-NEWLINE-') for w in word_list] 42 | 43 | 44 | def load_one_change_entry_csharp(json_str, editor_type='seq2seq', edit_encoder_type='seq', tensorization=True, 45 | transition_system=None, substitution_system=None, vocab=None, args=None, save_edits=True): 46 | entry = json.loads(json_str) 47 | previous_code_chunk = _encode(entry['PrevCodeChunkTokens']) 48 | updated_code_chunk = _encode(entry['UpdatedCodeChunkTokens']) 49 | context = _encode(entry['PrecedingContext'] + ['|||'] + entry['SucceedingContext']) 50 | 51 | if editor_type == 'seq2seq': 52 | prev_code_ast_json = entry['PrevCodeAST'] 53 | prev_code_ast = transition_system.grammar.get_ast_from_json_obj(prev_code_ast_json) 54 | 55 | updated_code_ast_json = entry['UpdatedCodeAST'] 56 | updated_code_ast = transition_system.grammar.get_ast_from_json_obj(updated_code_ast_json) 57 | 58 | example = ChangeExample(id=entry['Id'], 59 | prev_data=previous_code_chunk, 60 | updated_data=updated_code_chunk, 61 | raw_prev_data=entry['PrevCodeChunk'], 62 | raw_updated_data=entry['UpdatedCodeChunk'], 63 | context=context, 64 | prev_code_ast=prev_code_ast, 65 | updated_code_ast=updated_code_ast) 66 | 67 | # preform tensorization 68 | if tensorization: 69 | if edit_encoder_type == 'sequential': 70 | SequentialChangeEncoder.populate_aligned_token_index_and_mask(example) 71 | elif edit_encoder_type == 'graph': 72 | example.change_edges = GraphChangeEncoder.compute_change_edges(example) 73 | 74 | # SequentialChangeEncoder.populate_aligned_token_index_and_mask(example) 75 | SequentialDecoder.populate_gen_and_copy_index_and_mask(example, vocab, copy_token=args['decoder']['copy_token']) 76 | 77 | elif editor_type in ('graph2tree', 'graph2iteredit'): 78 | prev_code_ast_json = entry['PrevCodeAST'] 79 | prev_code_ast = transition_system.grammar.get_ast_from_json_obj(prev_code_ast_json) 80 | 81 | updated_code_ast_json = entry['UpdatedCodeAST'] 82 | updated_code_ast = transition_system.grammar.get_ast_from_json_obj(updated_code_ast_json) 83 | 84 | if editor_type == 'graph2tree': 85 | tgt_actions = transition_system.get_decoding_actions(target_ast=updated_code_ast, 86 | prev_ast=prev_code_ast, 87 | copy_identifier=args['decoder']['copy_identifier_node']) 88 | else: 89 | prev_code_ast.reindex_w_dummy_reduce() 90 | updated_code_ast.reindex_w_dummy_reduce() 91 | # Note: when args['small_memory'] is True, `tgt_actions` is the `edit_mappings`, not the actual tgt edits 92 | tgt_actions = substitution_system.get_decoding_edits_fast(prev_code_ast, updated_code_ast, 93 | bool_copy_subtree=args['decoder']['copy_subtree'], 94 | init_code_tokens=previous_code_chunk, 95 | return_edits=save_edits, 96 | bool_debug=args['debug']) 97 | 98 | example = ChangeExample(id=entry['Id'], 99 | prev_data=previous_code_chunk, 100 | updated_data=updated_code_chunk, 101 | raw_prev_data=entry['PrevCodeChunk'], 102 | raw_updated_data=entry['UpdatedCodeChunk'], 103 | context=context, 104 | prev_code_ast=prev_code_ast, 105 | updated_code_ast=updated_code_ast, 106 | tgt_actions=tgt_actions) 107 | 108 | # preform tensorization 109 | if tensorization: 110 | if edit_encoder_type == 'sequential': 111 | SequentialChangeEncoder.populate_aligned_token_index_and_mask(example) 112 | elif edit_encoder_type == 'graph': 113 | example.change_edges = GraphChangeEncoder.compute_change_edges(example) 114 | else: 115 | # raise ValueError('unknown dataset type') 116 | example = ChangeExample(id=entry['Id'], 117 | prev_data=previous_code_chunk, 118 | updated_data=updated_code_chunk, 119 | raw_prev_data=entry['PrevCodeChunk'], 120 | raw_updated_data=entry['UpdatedCodeChunk'], 121 | context=context) 122 | 123 | return example 124 | 125 | 126 | class DataSet: 127 | def __init__(self, examples): 128 | self.examples = examples 129 | self.example_id_to_index = OrderedDict([(e.id, idx) for idx, e in enumerate(self.examples)]) 130 | 131 | def batch_iter(self, batch_size, shuffle=False): 132 | index_arr = np.arange(len(self.examples)) 133 | if shuffle: 134 | np.random.shuffle(index_arr) 135 | 136 | batch_num = int(np.ceil(len(self.examples) / float(batch_size))) 137 | for batch_id in range(batch_num): 138 | batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)] 139 | batch_examples = [self.examples[i] for i in batch_ids] 140 | # sort by the length of the change sequence in descending order 141 | batch_examples.sort(key=lambda e: -len(e.change_seq)) 142 | 143 | yield batch_examples 144 | 145 | def __len__(self): 146 | return len(self.examples) 147 | 148 | def __iter__(self): 149 | return iter(self.examples) 150 | 151 | def get_example_by_id(self, eid): 152 | idx = self.example_id_to_index[eid] 153 | return self.examples[idx] 154 | 155 | @staticmethod 156 | def load_from_jsonl(file_path, language='csharp', editor=None, 157 | editor_type=None, edit_encoder_type=None, args=None, vocab=None, transition_system=None, 158 | substitution_system=None, tensorization=True, from_ipython=False, max_workers=1, save_edits=True): 159 | 160 | from edit_model.editor import Seq2SeqEditor, Graph2TreeEditor, Graph2IterEditEditor 161 | 162 | if editor: 163 | if isinstance(editor, Seq2SeqEditor): 164 | editor_type = 'seq2seq' 165 | elif isinstance(editor, Graph2TreeEditor): 166 | editor_type = 'graph2tree' 167 | elif isinstance(editor, Graph2IterEditEditor): 168 | editor_type = 'graph2iteredit' 169 | 170 | if isinstance(editor.edit_encoder, SequentialChangeEncoder): 171 | edit_encoder_type = 'sequential' 172 | elif isinstance(editor.edit_encoder, GraphChangeEncoder): 173 | edit_encoder_type = 'graph' 174 | elif isinstance(editor.edit_encoder, BagOfEditsChangeEncoder): 175 | edit_encoder_type = 'bag' 176 | 177 | if hasattr(editor, 'transition_system'): 178 | transition_system = editor.transition_system 179 | 180 | if hasattr(editor, 'substitution_system'): 181 | substitution_system = editor.substitution_system 182 | 183 | vocab = editor.vocab 184 | args = editor.args 185 | 186 | if editor_type is None: 187 | print("WARNING: unknown dataset type") 188 | 189 | if language == 'csharp': 190 | load_one_change_entry = load_one_change_entry_csharp 191 | else: 192 | raise Exception(f"unavailable language={language}") 193 | 194 | examples = [] 195 | with open(file_path) as f: 196 | print('reading all lines from the dataset', file=sys.stderr) 197 | all_lines = [l for l in f] 198 | print('%d lines. Done' % len(all_lines), file=sys.stderr) 199 | 200 | if from_ipython: 201 | from tqdm import tqdm_notebook 202 | iter_log_func = partial(tqdm_notebook, total=len(all_lines), desc='loading dataset') 203 | else: 204 | from tqdm import tqdm 205 | iter_log_func = partial(tqdm, total=len(all_lines), desc='loading dataset', file=sys.stdout) 206 | 207 | if max_workers > 1: 208 | print('Parallel data loading...', file=sys.stderr) 209 | with Pool(max_workers) as pool: 210 | processed_examples = pool.map(partial(load_one_change_entry, 211 | editor_type=editor_type, 212 | edit_encoder_type=edit_encoder_type, 213 | tensorization=tensorization, 214 | transition_system=transition_system, 215 | substitution_system=substitution_system, 216 | vocab=vocab, args=args, save_edits=save_edits), 217 | iterable=all_lines) # chunksize=min(1000, int(len(all_lines)/max_workers)) 218 | for example in iter_log_func(processed_examples): 219 | examples.append(example) 220 | else: 221 | for line in iter_log_func(all_lines): 222 | example = load_one_change_entry(line, 223 | editor_type=editor_type, 224 | edit_encoder_type=edit_encoder_type, 225 | tensorization=tensorization, 226 | transition_system=transition_system, 227 | substitution_system=substitution_system, 228 | vocab=vocab, args=args, save_edits=save_edits) 229 | examples.append(example) 230 | 231 | data_set = DataSet([e for e in examples if e]) 232 | 233 | return data_set 234 | -------------------------------------------------------------------------------- /edit_components/diff_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import gzip 3 | import glob 4 | import re 5 | import io 6 | import os 7 | import codecs 8 | from collections import defaultdict, OrderedDict, Counter, namedtuple 9 | from itertools import chain 10 | from typing import Any 11 | 12 | import difflib 13 | 14 | 15 | class TokenLevelDiffer: 16 | ''' 17 | Re-render diffs as word-level diffs with the appropriate alignment. 18 | This class is NOT thread-safe. 19 | ''' 20 | 21 | def __init__(self): 22 | self.__matcher = difflib.SequenceMatcher() 23 | 24 | tokenization_regex = re.compile('((?!\W)\w+)|(\w+(?=\W))') 25 | 26 | TokenDiff = namedtuple('TokenDiff', ['added', 'removed', 'same', 'has_comment', 'diff_lines']) 27 | 28 | def __tokenize(self, code: str): 29 | '''Tokenize code in a naive way: split on spaces and at the intersection of alphanumeric characters with other 30 | characters. This is necessarily noisy, but is universal.''' 31 | for token in TokenLevelDiffer.tokenization_regex.split(code): 32 | if token is not None and len(token) > 0: 33 | yield token 34 | 35 | def __assert_correct_size(self): 36 | assert len(self.add_list) == len(self.remove_list) == len(self.same_list) == len(self.diff_line_numbers), \ 37 | (len(self.add_list), len(self.remove_list), len(self.same_list), len(self.diff_line_numbers)) 38 | 39 | def __consolidate_changed_region_buffers(self): 40 | self.__matcher.set_seqs(self.removed_buffer, self.added_buffer) 41 | for tag, i1, i2, j1, j2 in self.__matcher.get_opcodes(): 42 | if tag == 'equal': 43 | assert i2 - i1 == j2 - j1 44 | padding = [None] * (i2 - i1) 45 | self.add_list.extend(padding) 46 | self.remove_list.extend(padding) 47 | self.same_list.extend(self.added_buffer[j1:j2]) 48 | self.has_comment.extend(a_has_comment or b_had_comment for a_has_comment, b_had_comment in 49 | zip(self.removed_has_comment_buffer[i1: i2], 50 | self.added_has_comment_buffer[j1: j2])) 51 | self.diff_line_numbers.extend((k1, k2) for k1, k2 in zip(self.removed_diff_line_number_buffer[i1:i2], 52 | self.added_diff_line_number_buffer[j1:j2])) 53 | self.__assert_correct_size() 54 | else: 55 | self.__assert_correct_size() 56 | max_change_size = max(i2 - i1, j2 - j1) 57 | self.same_list.extend([None] * max_change_size) 58 | self.add_list.extend(self.added_buffer[j1:j2]) 59 | self.add_list.extend([None] * (max_change_size - (j2 - j1))) 60 | self.remove_list.extend(self.removed_buffer[i1:i2]) 61 | self.remove_list.extend([None] * (max_change_size - (i2 - i1))) 62 | comment_data = [False] * max_change_size 63 | line_data = [list() for k in range(max_change_size)] 64 | for k, (com_data, line_num) in enumerate( 65 | zip(self.added_has_comment_buffer[j1:j2], self.added_diff_line_number_buffer[j1:j2])): 66 | comment_data[k] |= com_data 67 | line_data[k].append(line_num) 68 | for k, (com_data, line_num) in enumerate( 69 | zip(self.removed_has_comment_buffer[i1:i2], self.removed_diff_line_number_buffer[i1:i2])): 70 | comment_data[k] |= com_data 71 | line_data[k].append(line_num) 72 | 73 | self.has_comment.extend(comment_data) 74 | self.diff_line_numbers.extend(tuple(k) for k in line_data) 75 | self.__assert_correct_size() 76 | 77 | # Clean-up buffers 78 | self.added_buffer = [] 79 | self.added_diff_line_number_buffer = [] 80 | self.removed_buffer = [] 81 | self.has_comment_buffer = [] 82 | self.removed_diff_line_number_buffer = [] 83 | 84 | def unified_format(self, diff_file): 85 | self.add_list = [] 86 | self.remove_list = [] 87 | self.same_list = [] 88 | self.has_comment = [] 89 | self.diff_line_numbers = [] # type: List[Tuple] 90 | 91 | self.added_buffer = [] 92 | self.added_has_comment_buffer = [] 93 | self.added_diff_line_number_buffer = [] 94 | self.removed_buffer = [] 95 | self.removed_has_comment_buffer = [] 96 | self.removed_diff_line_number_buffer = [] 97 | 98 | for i, line in enumerate(diff_file['diff'].split('\n')): 99 | self.__assert_correct_size() 100 | line_has_comment = False # str(i) in diff_file['comments'] 101 | 102 | if line.startswith('@'): # Ignore diff header 103 | continue 104 | elif line.startswith('+'): 105 | change_type = 0 106 | line = line[1:] 107 | elif line.startswith('-'): 108 | change_type = 1 109 | line = line[1:] 110 | else: 111 | assert line[0] == ' ' 112 | line = line[1:] 113 | change_type = 2 114 | 115 | # tokenized_line = list(self.__tokenize(line)) 116 | # tokenized_line = ([line] if len(line.strip()) > 0 else []) + ['\n'] 117 | tokenized_line = [line] # FIXME: revise this hack! 118 | if change_type == 2 and len(self.added_buffer) + len(self.removed_buffer) > 0: 119 | self.__consolidate_changed_region_buffers() 120 | 121 | # Now add current change! 122 | if change_type == 0: 123 | self.added_buffer.extend(tokenized_line) 124 | self.added_has_comment_buffer.extend([line_has_comment] * len(tokenized_line)) 125 | self.added_diff_line_number_buffer.extend([i] * len(tokenized_line)) 126 | elif change_type == 1: 127 | self.removed_buffer.extend(tokenized_line) 128 | self.removed_has_comment_buffer.extend([line_has_comment] * len(tokenized_line)) 129 | self.removed_diff_line_number_buffer.extend([i] * len(tokenized_line)) 130 | else: 131 | self.__assert_correct_size() 132 | self.same_list.extend(tokenized_line) 133 | padding = [None] * len(tokenized_line) 134 | self.remove_list.extend(padding) 135 | self.add_list.extend(padding) 136 | self.has_comment.extend([line_has_comment] * len(tokenized_line)) 137 | self.diff_line_numbers.extend([(i,)] * len(tokenized_line)) 138 | self.__assert_correct_size() 139 | 140 | if len(self.added_buffer) + len(self.removed_buffer) > 0: 141 | self.__consolidate_changed_region_buffers() 142 | 143 | self.__assert_correct_size() 144 | return self.TokenDiff(self.add_list, self.remove_list, self.same_list, self.has_comment, self.diff_line_numbers) 145 | 146 | 147 | if __name__ == '__main__': 148 | differ = TokenLevelDiffer() 149 | diff_dict = {"diff": """@@ -1,4 +1,3 @@ 150 | 01 151 | -02 152 | +05 153 | -03 154 | 04 155 | """} 156 | 157 | diff_result = differ.unified_format(diff_dict) 158 | pass -------------------------------------------------------------------------------- /edit_components/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from edit_model.editor import Graph2IterEditEditor, Graph2TreeEditor 4 | 5 | 6 | def evaluate_nll(model, test_set, batch_size=32, change_vectors=None, return_nll_list=False, 7 | return_details=False): 8 | was_training = model.training 9 | model.eval() 10 | 11 | cum_nll = 0. 12 | cum_ppl = 0. 13 | cum_examples = 0. 14 | 15 | cum_other_nll_per_cat = dict() 16 | cum_examples_per_cat = dict() 17 | cum_nll_per_cat_by_step = dict() 18 | cum_examples_per_cat_by_step = dict() 19 | 20 | nll_dict = dict() 21 | 22 | if change_vectors is not None: 23 | assert isinstance(model, (Graph2IterEditEditor, Graph2TreeEditor)) 24 | 25 | with torch.no_grad(): 26 | for batch_examples in test_set.batch_iter(batch_size, shuffle=False): 27 | batch_change_vectors = None 28 | if change_vectors is not None: 29 | batch_change_vectors = torch.stack([change_vectors[e.change_vec_idx] for e in batch_examples]) 30 | 31 | # neg_log_probs = -model(batch_examples)['log_probs'] 32 | if isinstance(model, (Graph2IterEditEditor, Graph2TreeEditor)): 33 | results = model(batch_examples, change_vectors=batch_change_vectors) 34 | else: 35 | results = model(batch_examples) 36 | 37 | neg_log_probs = - results['log_probs'] 38 | batch_code_tokens_num = torch.tensor([len(e.updated_data) for e in batch_examples], 39 | dtype=torch.float, 40 | device=neg_log_probs.device) 41 | 42 | batch_nlls = neg_log_probs.cpu().numpy() 43 | batch_ppls = (neg_log_probs / batch_code_tokens_num).cpu().numpy() 44 | for batch_id in range(len(batch_examples)): 45 | nll_dict[batch_examples[batch_id].id] = batch_nlls[batch_id] 46 | 47 | cum_ppl += batch_ppls.sum() 48 | cum_nll += batch_nlls.sum() 49 | cum_examples += len(batch_examples) 50 | 51 | del neg_log_probs 52 | 53 | if isinstance(model, Graph2IterEditEditor) and return_details: 54 | log_probs = results['ungated_log_probs'] 55 | batch_edit_mask = results['batch_edit_mask'] 56 | 57 | tgt_op_log_probs = (results['tgt_op_log_probs'] * batch_edit_mask).sum(dim=0) 58 | tgt_op_mask = batch_edit_mask.sum(dim=0) 59 | tgt_op_log_probs_by_step = torch.unbind((results['tgt_op_log_probs'] * batch_edit_mask).sum(dim=1), dim=0) 60 | tgt_op_mask_by_step = torch.unbind(batch_edit_mask.sum(dim=1), dim=0) 61 | results.update({'tgt_op_log_probs': tgt_op_log_probs, 62 | 'tgt_op_mask': tgt_op_mask, 63 | 'tgt_op_log_probs_by_step': tgt_op_log_probs_by_step, 64 | 'tgt_op_mask_by_step': tgt_op_mask_by_step}) 65 | 66 | if 'tgt_node_log_probs' in results: 67 | tgt_node_log_probs = results['tgt_node_log_probs'] 68 | node_selection_mask = results['node_selection_mask'] 69 | tgt_node_log_probs_by_step = torch.unbind((tgt_node_log_probs * node_selection_mask).sum(dim=1), dim=0) 70 | node_selection_mask_by_step = torch.unbind(node_selection_mask.sum(dim=1), dim=0) 71 | results.update({'tgt_node_log_probs': (tgt_node_log_probs * node_selection_mask).sum(dim=0), 72 | 'node_selection_mask': node_selection_mask.sum(dim=0), 73 | 'tgt_node_log_probs_by_step': tgt_node_log_probs_by_step, 74 | 'node_selection_mask_by_step': node_selection_mask_by_step}) 75 | 76 | if 'tgt_add_log_probs' in results: 77 | tgt_add_log_probs = results['tgt_add_log_probs'] 78 | tgt_add_operator_mask = results['tgt_add_operator_mask'] 79 | tgt_add_log_probs_by_step = torch.unbind((tgt_add_log_probs * tgt_add_operator_mask).sum(dim=1), dim=0) 80 | tgt_add_operator_mask_by_step = torch.unbind(tgt_add_operator_mask.sum(dim=1), dim=0) 81 | results.update({'tgt_add_log_probs': (tgt_add_log_probs * tgt_add_operator_mask).sum(dim=0), 82 | 'tgt_add_operator_mask': tgt_add_operator_mask.sum(dim=0), 83 | 'tgt_add_log_probs_by_step': tgt_add_log_probs_by_step, 84 | 'tgt_add_operator_mask_by_step': tgt_add_operator_mask_by_step}) 85 | 86 | if 'tgt_add_subtree_log_probs' in results: 87 | tgt_add_subtree_log_probs = results['tgt_add_subtree_log_probs'] 88 | tgt_add_subtree_operator_mask = results['tgt_add_subtree_operator_mask'] 89 | tgt_add_subtree_log_probs_by_step = torch.unbind((tgt_add_subtree_log_probs * tgt_add_subtree_operator_mask).sum(dim=1), dim=0) 90 | tgt_add_subtree_operator_mask_by_step = torch.unbind(tgt_add_subtree_operator_mask.sum(dim=1), dim=0) 91 | results.update({'tgt_add_subtree_log_probs': (tgt_add_subtree_log_probs * tgt_add_subtree_operator_mask).sum(dim=0), 92 | 'tgt_add_subtree_operator_mask': tgt_add_subtree_operator_mask.sum(dim=0), 93 | 'tgt_add_subtree_log_probs_by_step': tgt_add_subtree_log_probs_by_step, 94 | 'tgt_add_subtree_operator_mask_by_step': tgt_add_subtree_operator_mask_by_step}) 95 | 96 | keys = ['tgt_op_log_probs', 'tgt_node_log_probs', 'tgt_add_log_probs', 'tgt_add_subtree_log_probs'] 97 | key_masks = ['tgt_op_mask', 'node_selection_mask', 'tgt_add_operator_mask', 'tgt_add_subtree_operator_mask'] 98 | for key_idx, key in enumerate(keys): 99 | if key in results: 100 | _neg_log_probs = - results[key].cpu().numpy().sum() 101 | cum_other_nll_per_cat[key] = cum_other_nll_per_cat.get(key, 0.) + _neg_log_probs 102 | 103 | _count_examples = results[key_masks[key_idx]].cpu().numpy().sum() 104 | cum_examples_per_cat[key] = cum_examples_per_cat.get(key, 0.) + _count_examples 105 | 106 | key_by_step = key + "_by_step" 107 | key_mask_by_step = key_masks[key_idx] + "_by_step" 108 | _log_probs_by_step = results[key_by_step] # a list of summed log_prob 109 | _count_examples_by_step = results[key_mask_by_step] # a list of counts 110 | 111 | if key_by_step not in cum_nll_per_cat_by_step: 112 | cum_nll_per_cat_by_step[key_by_step] = [] 113 | cum_examples_per_cat_by_step[key_by_step] = [] 114 | for step, _log_prob in enumerate(_log_probs_by_step): 115 | if _count_examples_by_step[step] == 0: 116 | continue 117 | if len(cum_nll_per_cat_by_step[key_by_step]) - 1 < step: 118 | cum_nll_per_cat_by_step[key_by_step].append(- _log_prob.item()) 119 | cum_examples_per_cat_by_step[key_by_step].append(_count_examples_by_step[step].item()) 120 | else: 121 | cum_nll_per_cat_by_step[key_by_step][step] -= _log_prob.item() 122 | cum_examples_per_cat_by_step[key_by_step][step] += _count_examples_by_step[step].item() 123 | 124 | gated_log_probs_by_step = torch.unbind(torch.sum(log_probs * batch_edit_mask, dim=1), dim=0) 125 | batch_edit_mask_by_step = torch.unbind(batch_edit_mask.sum(dim=1), dim=0) 126 | key_by_step = "log_probs_by_step" 127 | if key_by_step not in cum_nll_per_cat_by_step: 128 | cum_nll_per_cat_by_step[key_by_step] = [] 129 | cum_examples_per_cat_by_step[key_by_step] = [] 130 | for step, _log_prob in enumerate(gated_log_probs_by_step): 131 | if batch_edit_mask_by_step[step] == 0: 132 | continue 133 | if len(cum_nll_per_cat_by_step[key_by_step]) - 1 < step: 134 | cum_nll_per_cat_by_step[key_by_step].append(- _log_prob.item()) 135 | cum_examples_per_cat_by_step[key_by_step].append(batch_edit_mask_by_step[step].item()) 136 | else: 137 | cum_nll_per_cat_by_step[key_by_step][step] -= _log_prob.item() 138 | cum_examples_per_cat_by_step[key_by_step][step] += batch_edit_mask_by_step[step].item() 139 | 140 | del results 141 | 142 | avg_ppl = np.exp(cum_ppl / cum_examples) 143 | avg_nll = cum_nll / cum_examples 144 | 145 | if was_training: 146 | model.train(was_training) 147 | 148 | if isinstance(model, Graph2IterEditEditor) and return_details: 149 | avg_other_nll_per_cat = {key: cum_other_nll_per_cat[key] / cum_examples_per_cat[key] 150 | for key in cum_other_nll_per_cat.keys()} 151 | 152 | avg_nll_per_cat_per_step = {} 153 | for key in cum_nll_per_cat_by_step.keys(): 154 | _avg_scores = np.array(cum_nll_per_cat_by_step[key]) / np.array(cum_examples_per_cat_by_step[key]) 155 | _print = ["%.3f (count=%d)" % (_score, cum_examples_per_cat_by_step[key][_idx]) 156 | for _idx, _score in enumerate(_avg_scores[:10])] # only the first 10 steps 157 | avg_nll_per_cat_per_step[key] = _print 158 | if return_nll_list: 159 | return avg_nll, avg_ppl, nll_dict, avg_other_nll_per_cat, avg_nll_per_cat_per_step 160 | else: 161 | return avg_nll, avg_ppl, avg_other_nll_per_cat, avg_nll_per_cat_per_step 162 | 163 | else: 164 | if return_nll_list: 165 | return avg_nll, avg_ppl, nll_dict 166 | else: 167 | return avg_nll, avg_ppl 168 | -------------------------------------------------------------------------------- /edit_components/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/edit_components/utils/__init__.py -------------------------------------------------------------------------------- /edit_components/utils/decode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Neural Representations of code revisions 3 | 4 | Usage: 5 | decode.py rerank --mode= --dataset= --relevance_db= [options] MODEL_PATH 6 | decode.py decode_change_vec --mode= --dataset= [options] MODEL_PATH 7 | 8 | Options: 9 | -h --help Show this screen. 10 | --mode= Mode: seq2seq|seq2tree|tree2tree_subtree_copy 11 | --dataset= Dataset 12 | --relevance_db= Relevance DB path 13 | --save_to= Save decode results to [default: None] 14 | --cuda Use gpu 15 | """ 16 | 17 | from edit_model.editor import SequentialAutoEncoder, TreeBasedAutoEncoderWithGraphEncoder 18 | from edit_model.edit_encoder import GraphChangeEncoder 19 | from edit_components.dataset import DataSet 20 | from asdl.lang.csharp.csharp_transition import CSharpTransitionSystem 21 | from asdl.lang.csharp.csharp_grammar import CSharpASDLGrammar 22 | from asdl.lang.csharp.csharp_hypothesis import CSharpHypothesis 23 | from edit_components.utils.utils import * 24 | from edit_components.utils.relevance import generate_reranked_list 25 | import urllib.request as request 26 | from tqdm import tqdm_notebook 27 | import json 28 | from docopt import docopt 29 | import pickle, datetime, time 30 | 31 | 32 | def dump_rerank_file(args): 33 | dataset_file = args['--dataset'] 34 | print(f'load dataset {dataset_file}') 35 | dataset = DataSet.load_from_jsonl(dataset_file, type='sequential', max_workers=10) 36 | 37 | model_file = args['MODEL_PATH'] 38 | print(f'load model from {model_file}') 39 | 40 | if args['--mode'] == 'seq2seq': 41 | model_cls = SequentialAutoEncoder 42 | elif args['--mode'].startswith('tree2tree'): 43 | model_cls = TreeBasedAutoEncoderWithGraphEncoder 44 | else: 45 | model_cls = TreeBasedAutoEncoder 46 | 47 | model = model_cls.load(model_file) 48 | model.eval() 49 | # print(model.training) 50 | print(model.args) 51 | 52 | # feature_vecs = model.code_change_encoder.encode_code_changes(dataset.examples, code_encoder=model.sequential_code_encoder, batch_size=256) 53 | # print(f'decoded {feature_vecs.shape[0]} entries') 54 | 55 | print(f"load relevance db from {args['--relevance_db']}") 56 | relevance_db = pickle.load(open(args['--relevance_db'], 'rb')) 57 | 58 | algo_label = [x for x in model_file.split('/') if 'branch' in x][0] 59 | model_label = [x for x in model_file.split('/') if x.endswith('.bin')][0] 60 | output_folder = f"reranking/{algo_label}/{model_label}-{datetime.datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d-%H-%M-%S')}" 61 | 62 | generate_reranked_list(model, relevance_db, dataset, output_folder) 63 | 64 | 65 | def decode_change_vec(args): 66 | dataset_file = args['--dataset'] 67 | model_file = args['MODEL_PATH'] 68 | print(f'load model from {model_file}') 69 | 70 | if args['--mode'] == 'seq2seq': 71 | model_cls = SequentialAutoEncoder 72 | elif args['--mode'].startswith('tree2tree'): 73 | model_cls = TreeBasedAutoEncoderWithGraphEncoder 74 | else: 75 | model_cls = TreeBasedAutoEncoder 76 | 77 | model = model_cls.load(model_file, use_cuda=args['--cuda']) 78 | model.eval() 79 | print(model.args) 80 | 81 | dataset_file = args['--dataset'] 82 | print(f'load dataset {dataset_file}') 83 | 84 | is_graph_change_encoder = isinstance(model_cls.code_change_encoder, GraphChangeEncoder) 85 | 86 | dataset = DataSet.load_from_jsonl(type='tree2tree_subtree_copy' if is_graph_change_encoder else 'sequential', 87 | transition_system=CSharpTransitionSystem(model.grammar) if '2tree' in args['--mode'] else None, 88 | max_workers=1, 89 | parallel=False, 90 | annotate_tree_change=is_graph_change_encoder) 91 | 92 | change_vecs = model.code_change_encoder.encode_code_changes(dataset.examples, code_encoder=model.sequential_code_encoder, batch_size=256) 93 | print(f'decoded {change_vecs.shape[0]} entries') 94 | 95 | save_to = args['--save_to'] 96 | if save_to == 'None': 97 | save_to = model_file + '.change_vec.pkl' 98 | 99 | pickle.dump(change_vecs, open(save_to, 'wb')) 100 | print(f'saved decoding results to {save_to}') 101 | 102 | return change_vecs 103 | 104 | 105 | if __name__ == '__main__': 106 | args = docopt(__doc__) 107 | 108 | if args['rerank']: 109 | dump_rerank_file(args) 110 | elif args['decode_change_vec']: 111 | decode_change_vec(args) 112 | -------------------------------------------------------------------------------- /edit_components/utils/sub_token.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class SubTokenHelper: 5 | @staticmethod 6 | def get_sub_tokens(toen: str) -> List[str]: 7 | return [] 8 | 9 | -------------------------------------------------------------------------------- /edit_components/utils/unary_closure.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from urllib import request 3 | 4 | from asdl.asdl_ast import AbstractSyntaxNode 5 | from edit_components.dataset import DataSet 6 | from asdl.lang.csharp.csharp_transition import CSharpTransitionSystem # FIXME: move this .py to datasets/githubedits/? 7 | from asdl.lang.csharp.csharp_grammar import CSharpASDLGrammar 8 | 9 | 10 | def extract_unary_closure(data_file): 11 | csharp_grammar_text = request.urlopen('https://raw.githubusercontent.com/dotnet/roslyn/master/src/Compilers' 12 | '/CSharp/Portable/Syntax/Syntax.xml').read() 13 | grammar = CSharpASDLGrammar.from_roslyn_xml(csharp_grammar_text, pruning=True) 14 | transition_system = CSharpTransitionSystem(grammar) 15 | 16 | DataSet.load_from_jsonl(data_file, type='tree', transition_system=transition_system) 17 | 18 | 19 | def get_unary_closure_syntax_sub_tree(ast_root: AbstractSyntaxNode, 20 | unary_closure_root: AbstractSyntaxNode, 21 | unary_closure_last_node: AbstractSyntaxNode) -> List[AbstractSyntaxNode]: 22 | 23 | if ast_root.is_pre_terminal and len(ast_root.fields) <= 1: 24 | if unary_closure_root and not unary_closure_root.is_pre_terminal: # has at least one intermediate production 25 | return [unary_closure_root] 26 | else: 27 | return [] 28 | else: 29 | # case 1: only one child field has instantiated values 30 | instantiated_fields = [(i, field) for i, field in enumerate(ast_root.fields) if field.value_count > 1] 31 | if len(instantiated_fields) == 1: 32 | idx, instantiated_field = instantiated_fields[0] 33 | tgt_field = unary_closure_last_node.fields[idx] 34 | if tgt_field.type.is_composite and instantiated_field.value_count == 1: 35 | child_node = instantiated_field.as_value_list[0] 36 | cloned_node = AbstractSyntaxNode(child_node.production) 37 | tgt_field.add_value(cloned_node) 38 | unary_closure_last_node = cloned_node 39 | 40 | results = get_unary_closure_syntax_sub_tree(child_node, unary_closure_root, unary_closure_last_node) 41 | return results 42 | 43 | # other cases 44 | results = [] 45 | if unary_closure_root and not unary_closure_root.is_pre_terminal: 46 | results.append(unary_closure_root) 47 | 48 | for field_id, instantiated_field in instantiated_fields: 49 | tgt_field = unary_closure_last_node.fields[field_id] 50 | if instantiated_field.field.is_composite and instantiated_field.value_count == 1: 51 | pass 52 | 53 | if len(ast_root.fields) > 1: 54 | unary_closures = [] 55 | if unary_closure_root and not unary_closure_root.is_pre_terminal: 56 | unary_closures.append(unary_closure_root) 57 | 58 | for field in ast_root.fields: 59 | if field.type.is_composite: 60 | pass 61 | else: 62 | pass 63 | 64 | -------------------------------------------------------------------------------- /edit_components/utils/utils.py: -------------------------------------------------------------------------------- 1 | import scipy.spatial.distance as distance 2 | 3 | 4 | def run_from_ipython(): 5 | try: 6 | __IPYTHON__ 7 | return True 8 | except NameError: 9 | return False 10 | 11 | 12 | def get_entry_str(change_entry, dist=None, change_seq=False, score=None): 13 | entry_str = '' 14 | entry_str += 'Id: %s\n' % (change_entry.id) 15 | entry_str += 'Prev:\n%s\nAfter:\n%s\n' % (change_entry.untokenized_previous_code_chunk, 16 | change_entry.untokenized_updated_code_chunk) 17 | if dist is not None: 18 | entry_str += 'Distance: %f\n' % dist 19 | 20 | if change_seq: 21 | entry_str += 'Change Sequence: %s\n' % change_entry.change_seq 22 | 23 | entry_str += f"Score: {score if score is not None else '#'}\n" 24 | 25 | entry_str += '*' * 5 26 | 27 | return entry_str 28 | 29 | -------------------------------------------------------------------------------- /edit_components/utils/wikidata.py: -------------------------------------------------------------------------------- 1 | from edit_components.dataset import OrderedDict, json 2 | 3 | import sys 4 | 5 | assert len(sys.argv) == 4 6 | 7 | wiki_insert_file = sys.argv[1] 8 | wiki_del_file = sys.argv[2] 9 | output_file = sys.argv[3] 10 | 11 | with open(output_file, 'w') as f: 12 | for i, line in enumerate(open(wiki_insert_file, errors='ignore')): 13 | data = line.strip().split('\t') 14 | data = [x.lower() for x in data] 15 | 16 | entry = OrderedDict(Id=f'Ins_{i}', 17 | PrevCodeChunkTokens=data[0].split(' '), 18 | UpdatedCodeChunkTokens=data[2].split(' '), 19 | PrevCodeChunk=data[0] + ' ||| ' + data[1], 20 | UpdatedCodeChunk=data[2], 21 | PrecedingContext=[], 22 | SucceedingContext=[]) 23 | 24 | f.write(json.dumps(entry) + '\n') 25 | 26 | for i, line in enumerate(open(wiki_del_file, errors='ignore')): 27 | data = line.strip().split('\t') 28 | data = [x.lower() for x in data] 29 | 30 | entry = OrderedDict(Id=f'Del_{i}', 31 | PrevCodeChunkTokens=data[0].split(' '), 32 | UpdatedCodeChunkTokens=data[2].split(' '), 33 | PrevCodeChunk=data[0] + ' ||| ' + data[1], 34 | UpdatedCodeChunk=data[2], 35 | PrecedingContext=[], 36 | SucceedingContext=[]) 37 | 38 | f.write(json.dumps(entry) + '\n') 39 | -------------------------------------------------------------------------------- /edit_components/vocab.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Usage: 4 | vocab.py [options] TRAIN_FILE VOCAB_FILE 5 | 6 | Options: 7 | -h --help Show this screen. 8 | --size= vocab size [default: 10000] 9 | --freq_cutoff= frequency cutoff [default: 2] 10 | """ 11 | 12 | from collections import Counter 13 | from itertools import chain 14 | from docopt import docopt 15 | import json 16 | 17 | 18 | class VocabEntry: 19 | def __init__(self): 20 | self.word2id = dict() 21 | self.unk_id = 3 22 | self.word2id[''] = 0 23 | self.word2id[''] = 1 24 | self.word2id[''] = 2 25 | self.word2id[''] = 3 26 | 27 | self.id2word = {v: k for k, v in self.word2id.items()} 28 | 29 | # # insert 100 indexed unks 30 | # for i in range(100): 31 | # self.add('UNK_%d' % i) 32 | 33 | self.add('[DUMMY-REDUCE]') 34 | 35 | def __getitem__(self, word): 36 | return self.word2id.get(word, self.unk_id) 37 | 38 | def is_unk(self, word): 39 | return word not in self.word2id 40 | 41 | def __contains__(self, word): 42 | return word in self.word2id 43 | 44 | def __setitem__(self, key, value): 45 | raise ValueError('vocabulary is readonly') 46 | 47 | def __len__(self): 48 | return len(self.word2id) 49 | 50 | def __repr__(self): 51 | return 'Vocabulary[size=%d]' % len(self) 52 | 53 | def id2word(self, wid): 54 | return self.id2word[wid] 55 | 56 | def add(self, word): 57 | if word not in self: 58 | wid = self.word2id[word] = len(self) 59 | self.id2word[wid] = word 60 | return wid 61 | else: 62 | return self[word] 63 | 64 | def save(self, path): 65 | params = dict(unk_id=self.unk_id, word2id=self.word2id, word_freq=self.word_freq) 66 | json.dump(params, open(path, 'w'), indent=2) 67 | 68 | @staticmethod 69 | def load(path): 70 | entry = VocabEntry() 71 | params = json.load(open(path, 'r')) 72 | 73 | setattr(entry, 'unk_id', params['unk_id']) 74 | setattr(entry, 'word2id', params['word2id']) 75 | setattr(entry, 'word_freq', params['word_freq']) 76 | setattr(entry, 'id2word', {v: k for k, v in params['word2id'].items()}) 77 | 78 | return entry 79 | 80 | @staticmethod 81 | def from_corpus(corpus, size, freq_cutoff=0): 82 | vocab_entry = VocabEntry() 83 | 84 | word_freq = Counter(chain(*corpus)) 85 | freq_words = [w for w in word_freq if word_freq[w] >= freq_cutoff] 86 | print('number of word types: %d, number of word types w/ frequency >= %d: %d' % (len(word_freq), freq_cutoff, 87 | len(freq_words))) 88 | 89 | top_k_words = sorted(word_freq, key=lambda x: (-word_freq[x], x))[:size] 90 | print('top 10 words: %s' % ', '.join(top_k_words[:10])) 91 | 92 | for word in top_k_words: 93 | if len(vocab_entry) < size: 94 | if word_freq[word] >= freq_cutoff: 95 | vocab_entry.add(word) 96 | 97 | # store the work frequency table in the 98 | setattr(vocab_entry, 'word_freq', word_freq) 99 | 100 | return vocab_entry 101 | 102 | 103 | class Vocab(object): 104 | def __init__(self, **kwargs): 105 | self.entries = [] 106 | for key, item in kwargs.items(): 107 | assert isinstance(item, VocabEntry) 108 | self.__setattr__(key, item) 109 | 110 | self.entries.append(key) 111 | 112 | def __repr__(self): 113 | return 'Vocab(%s)' % (', '.join('%s %swords' % (entry, getattr(self, entry)) for entry in self.entries)) 114 | 115 | 116 | if __name__ == '__main__': 117 | from edit_components.dataset import DataSet 118 | 119 | args = docopt(__doc__) 120 | train_set = DataSet.load_from_jsonl(args['TRAIN_FILE'], max_workers=2) 121 | corpus = [change.prev_data + change.updated_data + change.context for change in train_set] 122 | 123 | vocab_entry = VocabEntry.from_corpus(corpus, size=int(args['--size']), freq_cutoff=int(args['--freq_cutoff'])) 124 | print('built vocabulary %s' % vocab_entry) 125 | 126 | vocab_entry.save(args['VOCAB_FILE']) 127 | 128 | # torch.save(vocab_entry, open(args['VOCAB_FILE'], 'wb')) 129 | -------------------------------------------------------------------------------- /edit_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/edit_model/__init__.py -------------------------------------------------------------------------------- /edit_model/data_model.py: -------------------------------------------------------------------------------- 1 | from edit_model import nn_utils 2 | from edit_model.utils import cached_property 3 | 4 | 5 | class BatchedCodeChunk: 6 | def __init__(self, code_list, vocab, device=None, append_boundary_sym=False): 7 | self.code_list = code_list 8 | 9 | self.max_len = max(len(code) for code in code_list) 10 | self.batch_size = len(code_list) 11 | 12 | self._vocab = vocab 13 | self._append_boundary_sym = append_boundary_sym 14 | self.device=device 15 | 16 | # to be set by encoders 17 | self.encoding = None 18 | self.last_state = None 19 | self.last_cell = None 20 | 21 | @cached_property 22 | def index_var(self): 23 | return nn_utils.to_input_variable(self.code_list, 24 | vocab=self._vocab, 25 | device=self.device, 26 | append_boundary_sym=self._append_boundary_sym) 27 | 28 | @cached_property 29 | def mask(self): 30 | """mask for attention, null entries are masked as **1** """ 31 | 32 | assert self._append_boundary_sym is False 33 | 34 | return nn_utils.length_array_to_mask_tensor([len(x) for x in self.code_list], device=self.device) 35 | -------------------------------------------------------------------------------- /edit_model/edit_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .bag_of_edits_change_encoder import BagOfEditsChangeEncoder 2 | from .graph_change_encoder import GraphChangeEncoder 3 | from .sequential_change_encoder import SequentialChangeEncoder 4 | from .hybrid_change_encoder import HybridChangeEncoder 5 | from .tree_diff_encoder import TreeDiffEncoder 6 | -------------------------------------------------------------------------------- /edit_model/edit_encoder/bag_of_edits_change_encoder.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | from torch.autograd import Variable 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | from tqdm import tqdm 9 | import sys 10 | 11 | from edit_components.change_entry import ChangeExample 12 | from edit_model import nn_utils 13 | from edit_model.embedder import EmbeddingTable 14 | 15 | 16 | class BagOfEditsChangeEncoder(nn.Module): 17 | """project a CodeChange instance into distributed vectors""" 18 | 19 | def __init__(self, token_embedder, vocab, **kwargs): 20 | super(BagOfEditsChangeEncoder, self).__init__() 21 | 22 | self.token_embedder = token_embedder 23 | self.token_embedding_size = self.token_embedder.weight.size(1) 24 | self.vocab = vocab 25 | self.change_vector_size = self.token_embedding_size * 2 26 | 27 | @property 28 | def device(self): 29 | return self.token_embedder.device 30 | 31 | def forward(self, code_changes, *args, **kwargs): 32 | """ 33 | given the token encodings of the previous and updated code, 34 | and the diff information (alignment between the tokens between the 35 | previous and updated code), generate the diff representation 36 | """ 37 | 38 | added_tokens = [] 39 | added_token_batch_ids = [] 40 | deled_tokens = [] 41 | deled_token_batch_ids = [] 42 | for e_id, example in enumerate(code_changes): 43 | for entry in example.change_seq: 44 | tag, token = entry 45 | if tag == 'ADD': 46 | token_id = self.vocab[token] 47 | added_tokens.append(token_id) 48 | added_token_batch_ids.append(e_id) 49 | elif tag == 'DEL': 50 | token_id = self.vocab[token] 51 | deled_tokens.append(token_id) 52 | deled_token_batch_ids.append(e_id) 53 | elif tag == 'REPLACE': 54 | added_token_id = self.vocab[token[1]] 55 | deled_token_id = self.vocab[token[0]] 56 | 57 | added_tokens.append(added_token_id) 58 | deled_tokens.append(deled_token_id) 59 | 60 | added_token_batch_ids.append(e_id) 61 | deled_token_batch_ids.append(e_id) 62 | 63 | changed_token_ids = added_tokens + deled_tokens 64 | changed_token_ids = torch.tensor(changed_token_ids, dtype=torch.long, device=self.device) 65 | # (token_num, embed_size) 66 | changed_token_embeds = self.token_embedder.weight[changed_token_ids] 67 | 68 | added_token_embeds = changed_token_embeds[:len(added_tokens)] 69 | deled_token_embeds = changed_token_embeds[len(added_tokens):] 70 | 71 | added_change_embeds = torch.zeros(len(code_changes), self.token_embedding_size, dtype=torch.float, 72 | device=self.device) 73 | if added_token_batch_ids: 74 | added_change_embeds = added_change_embeds.scatter_add_(0, 75 | torch.tensor(added_token_batch_ids, device=self.device).unsqueeze(-1).expand_as(added_token_embeds), 76 | added_token_embeds) 77 | 78 | deled_change_embeds = torch.zeros(len(code_changes), self.token_embedding_size, dtype=torch.float, 79 | device=self.device) 80 | if deled_token_batch_ids: 81 | deled_change_embeds = deled_change_embeds.scatter_add_(0, 82 | torch.tensor(deled_token_batch_ids, device=self.device).unsqueeze(-1).expand_as(deled_token_embeds), 83 | deled_token_embeds) 84 | 85 | change_vectors = torch.cat([added_change_embeds, deled_change_embeds], dim=-1) 86 | 87 | return change_vectors 88 | 89 | def encode_code_change(self, prev_code_tokens, updated_code_tokens, code_encoder): 90 | example = ChangeExample(prev_code_tokens, updated_code_tokens, context=None) 91 | 92 | change_vec = self.forward([example]).data.cpu().numpy()[0] 93 | 94 | return change_vec 95 | 96 | def encode_code_changes(self, examples, code_encoder, batch_size=32): 97 | """encode each change in the list `code_changes`, 98 | return a 2D numpy array of shape (len(code_changes), code_change_embed_dim)""" 99 | 100 | change_vecs = [] 101 | 102 | for batch_examples in tqdm(nn_utils.batch_iter(examples, batch_size), file=sys.stdout, total=len(examples)): 103 | batch_change_vecs = self.forward(batch_examples).data.cpu().numpy() 104 | change_vecs.append(batch_change_vecs) 105 | 106 | change_vecs = np.concatenate(change_vecs, axis=0) 107 | 108 | return change_vecs 109 | -------------------------------------------------------------------------------- /edit_model/edit_encoder/edit_encoder.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | from torch.autograd import Variable 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | from tqdm import tqdm 9 | import sys 10 | 11 | from edit_components.change_entry import ChangeExample 12 | from edit_model import nn_utils 13 | from edit_model.edit_encoder.bag_of_edits_change_encoder import BagOfEditsChangeEncoder 14 | from edit_model.edit_encoder.graph_change_encoder import GraphChangeEncoder 15 | from edit_model.edit_encoder.hybrid_change_encoder import HybridChangeEncoder 16 | from edit_model.edit_encoder.sequential_change_encoder import SequentialChangeEncoder 17 | from edit_model.edit_encoder.tree_diff_encoder import TreeDiffEncoder 18 | 19 | 20 | class EditEncoder(nn.Module): 21 | def __init__(self): 22 | super(EditEncoder, self).__init__() 23 | 24 | @staticmethod 25 | def build(args, vocab, embedder=None, **kwargs): 26 | if args['edit_encoder']['type'] == 'sequential': 27 | edit_encoder = SequentialChangeEncoder(args['encoder']['token_encoding_size'], 28 | args['edit_encoder']['edit_encoding_size'], 29 | args['edit_encoder']['change_tag_embed_size'], 30 | vocab, 31 | no_unchanged_token_encoding_in_diff_seq=args['edit_encoder']['no_unchanged_token_encoding_in_diff_seq']) 32 | elif args['edit_encoder']['type'] == 'graph': 33 | edit_encoder = GraphChangeEncoder(args['edit_encoder']['edit_encoding_size'], 34 | syntax_tree_embedder=embedder, 35 | layer_time_steps=args['edit_encoder']['layer_timesteps'], 36 | dropout=args['edit_encoder']['dropout'], 37 | gnn_use_bias_for_message_linear=args['edit_encoder']['use_bias_for_message_linear'], 38 | master_node_option=args['edit_encoder']['master_node_option'], 39 | connections=args['edit_encoder']['connections']) 40 | elif args['edit_encoder']['type'] == 'hybrid': 41 | edit_encoder = HybridChangeEncoder(token_encoding_size=args['encoder']['token_encoding_size'], 42 | change_vector_dim=args['edit_encoder']['edit_encoding_size'], 43 | syntax_tree_embedder=embedder, 44 | layer_timesteps=args['edit_encoder']['layer_timesteps'], 45 | dropout=args['edit_encoder']['dropout'], 46 | vocab=vocab, 47 | gnn_use_bias_for_message_linear=args['edit_encoder']['no_unchanged_token_encoding_in_diff_seq']) 48 | elif args['edit_encoder']['type'] == 'bag': 49 | edit_encoder = BagOfEditsChangeEncoder(embedder, vocab) 50 | elif args['edit_encoder']['type'] == 'treediff': 51 | edit_encoder = TreeDiffEncoder(graph_encoding_size=args['encoder']['token_encoding_size'], 52 | input_size=args['edit_encoder']['input_size'], 53 | change_vector_size=args['edit_encoder']['edit_encoding_size'], 54 | operators=kwargs['operators'], 55 | operator_embedding=kwargs['operator_embedding'], 56 | production_embedding=kwargs['production_embedding'], 57 | field_embedding=kwargs['field_embedding'], 58 | token_embedding=kwargs['token_embedding'], 59 | copy_syntax_token=args['edit_encoder']['copy_token']) 60 | else: 61 | raise ValueError('unknown code change encoder type [%s]' % args['edit_encoder']['type']) 62 | 63 | return edit_encoder 64 | -------------------------------------------------------------------------------- /edit_model/edit_encoder/hybrid_change_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import tqdm 4 | import sys 5 | from itertools import chain 6 | import numpy as np 7 | 8 | from edit_model.edit_encoder.sequential_change_encoder import SequentialChangeEncoder 9 | from edit_model.edit_encoder.graph_change_encoder import GraphChangeEncoder 10 | from edit_model import nn_utils 11 | from edit_model.embedder import EmbeddingTable 12 | 13 | 14 | class HybridChangeEncoder(nn.Module): 15 | def __init__(self, change_vector_dim, token_encoding_size, syntax_tree_embedder, 16 | layer_timesteps, dropout, vocab, 17 | tag_embed_size=32, gnn_use_bias_for_message_linear=True): 18 | super(HybridChangeEncoder, self).__init__() 19 | 20 | self.seq_change_encoder = SequentialChangeEncoder(token_encoding_size, change_vector_dim, tag_embed_size, vocab=vocab) 21 | self.graph_change_encoder = GraphChangeEncoder(change_vector_dim, layer_timesteps, dropout, syntax_tree_embedder, 22 | tag_embed_size=tag_embed_size, 23 | gnn_use_bias_for_message_linear=gnn_use_bias_for_message_linear) 24 | 25 | self.change_vector_dim = change_vector_dim 26 | self.combo_linear = nn.Linear(change_vector_dim * 2, change_vector_dim) 27 | 28 | @property 29 | def device(self): 30 | return self.seq_change_encoder.device 31 | 32 | def forward(self, examples, prev_code_token_encoding, updated_code_token_encoding): 33 | seq_change_vec = self.seq_change_encoder(examples, prev_code_token_encoding, updated_code_token_encoding) 34 | graph_change_vec = self.graph_change_encoder(examples, prev_code_token_encoding, updated_code_token_encoding) 35 | change_vec = self.combo_linear(torch.cat([seq_change_vec, graph_change_vec], dim=-1)) 36 | 37 | return change_vec 38 | 39 | def encode_code_changes(self, examples, code_encoder, batch_size=32): 40 | change_vecs = [] 41 | 42 | for batch_examples, sorted_example_ids, example_old2new_pos in tqdm(nn_utils.batch_iter(examples, batch_size, sort_func=lambda e: -len(e.change_seq), return_sort_map=True), 43 | total=len(examples) // batch_size, file=sys.stdout): 44 | previous_code_chunk_list = [e.previous_code_chunk for e in batch_examples] 45 | updated_code_chunk_list = [e.updated_code_chunk for e in batch_examples] 46 | context_list = [e.context for e in batch_examples] 47 | 48 | embedding_cache = EmbeddingTable( 49 | chain.from_iterable(previous_code_chunk_list + updated_code_chunk_list + context_list)) 50 | code_encoder.code_token_embedder.populate_embedding_table(embedding_cache) 51 | 52 | batched_prev_code = code_encoder.encode(previous_code_chunk_list, embedding_cache=embedding_cache) 53 | batched_updated_code = code_encoder.encode(updated_code_chunk_list, embedding_cache=embedding_cache) 54 | 55 | batch_change_vecs = self.forward(batch_examples, batched_prev_code, batched_updated_code).data.cpu().numpy() 56 | batch_change_vecs = batch_change_vecs[example_old2new_pos] 57 | change_vecs.append(batch_change_vecs) 58 | 59 | change_vecs = np.concatenate(change_vecs, axis=0) 60 | 61 | return change_vecs 62 | -------------------------------------------------------------------------------- /edit_model/edit_encoder/tree_diff_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | 5 | 6 | class TreeDiffEncoder(nn.Module): 7 | def __init__(self, graph_encoding_size, input_size, change_vector_size, operators, 8 | operator_embedding, production_embedding, field_embedding, token_embedding, 9 | **kwargs): 10 | super(TreeDiffEncoder, self).__init__() 11 | 12 | self.input_size = input_size 13 | self.change_vector_size = change_vector_size 14 | self.operators = operators 15 | self.copy_syntax_token = kwargs['copy_syntax_token'] 16 | 17 | self.operator_embedding = operator_embedding 18 | self.production_embedding = production_embedding 19 | self.field_embedding = field_embedding 20 | self.token_embedding = token_embedding 21 | 22 | operator_embed_size = self.operator_embedding.embedding_dim 23 | action_embed_size = self.production_embedding.embedding_dim 24 | field_embed_size = self.field_embedding.embedding_dim 25 | 26 | self.change_seq_encoder_lstm = nn.LSTM(self.input_size, self.change_vector_size // 2, bidirectional=True) 27 | 28 | # project [op; field; node] for Delete 29 | self.delete_projection = nn.Linear(operator_embed_size + field_embed_size + graph_encoding_size, 30 | input_size, bias=True) 31 | # project [op; field; node; action] for Add 32 | self.add_projection = nn.Linear(operator_embed_size + field_embed_size + graph_encoding_size + 33 | action_embed_size, input_size, bias=True) 34 | if 'add_subtree' in self.operators: 35 | # project [op; field; node; subtree] for AddSubtree 36 | self.add_subtree_projection = nn.Linear(operator_embed_size + field_embed_size + graph_encoding_size * 2, 37 | input_size, bias=True) 38 | # project [op] for Stop 39 | self.stop_projection = nn.Linear(operator_embed_size, input_size, bias=True) 40 | 41 | @property 42 | def device(self): 43 | return self.production_embedding.weight.device 44 | 45 | def forward(self, batch_edits_list, batch_actual_edits_length, masks_cache, 46 | context_encodings, init_input_encodings, cur_input_encodings_list, 47 | batch_memory_encodings=None): 48 | 49 | assert len(batch_edits_list) == len(cur_input_encodings_list) 50 | max_iteration_step = len(batch_edits_list) 51 | batch_size = len(batch_edits_list[0]) 52 | batch_max_node_num_over_time = max(cur_input_encodings.size(1) 53 | for cur_input_encodings in cur_input_encodings_list) 54 | 55 | # (max_iteration_step, batch_size, batch_max_node_num_over_time, source_element_encoding_size) 56 | cur_input_encodings_encoding_over_time = torch.zeros( 57 | max_iteration_step, batch_size, batch_max_node_num_over_time, 58 | cur_input_encodings_list[0].size(2)).to(self.device) 59 | for t, cur_input_encodings in enumerate(cur_input_encodings_list): 60 | cur_input_encodings_encoding_over_time[t, :, :cur_input_encodings.size(1)] = cur_input_encodings 61 | 62 | operator_selection_idx, \ 63 | node_selection_idx, node_selection_mask, node_cand_mask, parent_field_idx, \ 64 | tgt_apply_rule_idx, tgt_apply_rule_mask, apply_rule_cand_mask, \ 65 | tgt_apply_subtree_idx, tgt_apply_subtree_idx_mask, tgt_apply_subtree_mask, apply_subtree_cand_mask, \ 66 | tgt_gen_token_idx, tgt_gen_token_mask, tgt_copy_ctx_token_idx_mask, tgt_copy_ctx_token_mask, \ 67 | tgt_copy_init_token_idx_mask, tgt_copy_init_token_mask = masks_cache 68 | 69 | # (max_iteration_step, batch_size, operator_emb_size) 70 | tgt_operator_embeddings = self.operator_embedding(operator_selection_idx) 71 | 72 | # (max_iteration_step, batch_size, field_emb_size) 73 | tgt_field_embeddings = self.field_embedding(parent_field_idx) 74 | 75 | self_mask = torch.zeros(max_iteration_step * batch_size, batch_max_node_num_over_time, 76 | dtype=torch.long).to(self.device) 77 | self_mask[torch.arange(0, max_iteration_step * batch_size, dtype=torch.long).to(self.device), 78 | node_selection_idx.view(-1)] = 1 79 | self_mask = self_mask.reshape(max_iteration_step, batch_size, batch_max_node_num_over_time) 80 | # (max_iteration_step, batch_size, source_element_encoding_size) 81 | tgt_node_encodings = torch.sum(cur_input_encodings_encoding_over_time * self_mask.unsqueeze(-1), dim=2) 82 | 83 | tgt_production_embeddings = self.production_embedding(tgt_apply_rule_idx) 84 | tgt_gen_token_embeddings = self.token_embedding(tgt_gen_token_idx) 85 | 86 | if self.copy_syntax_token: 87 | tgt_copy_ctx_token_idx_mask[tgt_copy_ctx_token_idx_mask.sum(-1).eq(0), :] = 1 88 | tgt_copy_ctx_token_embeddings = torch.sum( 89 | context_encodings.expand(max_iteration_step, -1, -1, -1) * tgt_copy_ctx_token_idx_mask.unsqueeze(-1), 90 | dim=2) / tgt_copy_ctx_token_idx_mask.sum(dim=-1, keepdim=True) 91 | 92 | tgt_copy_init_token_idx_mask[tgt_copy_init_token_idx_mask.sum(-1).eq(0), :] = 1 93 | tgt_copy_init_token_embeddings = torch.sum( 94 | init_input_encodings.expand(max_iteration_step, -1, -1, -1) * tgt_copy_init_token_idx_mask.unsqueeze(-1), 95 | dim=2) / tgt_copy_init_token_idx_mask.sum(dim=-1, keepdim=True) 96 | 97 | # prepare inputs 98 | _cand_inputs = [] # a list of (max_iteration_step, batch_size, input_size) 99 | for operator in self.operators: # ['delete', 'add', 'add_subtree', 'stop'] 100 | if operator == 'stop': 101 | stop_inputs = self.stop_projection(tgt_operator_embeddings) 102 | _cand_inputs.append(stop_inputs) 103 | elif operator == 'delete': 104 | delete_inputs = self.delete_projection( 105 | torch.cat([tgt_operator_embeddings, tgt_field_embeddings, tgt_node_encodings], dim=-1) 106 | ) 107 | _cand_inputs.append(delete_inputs) 108 | elif operator == 'add': 109 | if self.copy_syntax_token: 110 | tgt_token_gates = tgt_gen_token_mask + tgt_copy_ctx_token_mask + tgt_copy_init_token_mask 111 | tgt_token_gates[tgt_token_gates.eq(0)] = 1 # safeguard 112 | tgt_token_embeddings = tgt_gen_token_embeddings * tgt_gen_token_mask.unsqueeze(-1) + \ 113 | tgt_copy_ctx_token_embeddings * tgt_copy_ctx_token_mask.unsqueeze(-1) + \ 114 | tgt_copy_init_token_embeddings * tgt_copy_init_token_mask.unsqueeze(-1) 115 | # (max_iteration_step, batch_size, action_emb_size) 116 | tgt_token_embeddings = tgt_token_embeddings / tgt_token_gates.unsqueeze(-1) 117 | else: 118 | tgt_token_embeddings = tgt_gen_token_embeddings 119 | 120 | tgt_action_embeddings = tgt_production_embeddings * tgt_apply_rule_mask.unsqueeze(-1) + \ 121 | tgt_token_embeddings * (1 - tgt_apply_rule_mask.unsqueeze(-1)) 122 | 123 | add_inputs = self.add_projection( 124 | torch.cat([tgt_operator_embeddings, tgt_field_embeddings, tgt_node_encodings, 125 | tgt_action_embeddings], dim=-1) 126 | ) 127 | _cand_inputs.append(add_inputs) 128 | elif operator == 'add_subtree': # and tgt_apply_subtree_mask.eq(1).any() 129 | if batch_memory_encodings is None: 130 | assert tgt_apply_subtree_mask.eq(0).all() 131 | add_subtree_inputs = torch.zeros(max_iteration_step, batch_size, self.input_size).to(self.device) 132 | _cand_inputs.append(add_subtree_inputs) 133 | continue 134 | 135 | # a list of (max_iteration_step, batch_size), length=max_num_copied_nodes 136 | _tgt_apply_subtree_idx_unbind = torch.unbind(tgt_apply_subtree_idx, dim=-1) 137 | # reshape to (max_iteration_step*batch_size, max_num_nodes, encoding_size) 138 | _expanded_batch_memory_encodings = batch_memory_encodings.expand(max_iteration_step, -1, -1, -1).\ 139 | reshape(max_iteration_step * batch_size, batch_memory_encodings.size(1), batch_memory_encodings.size(2)) 140 | 141 | _gathered_tgt_apply_subtree_encodings = [] 142 | _count = torch.arange(max_iteration_step * batch_size).to(self.device) 143 | for _tgt_apply_subtree_idx_col in _tgt_apply_subtree_idx_unbind: 144 | _gathered_tgt_apply_subtree_encodings.append( 145 | _expanded_batch_memory_encodings[_count, _tgt_apply_subtree_idx_col.reshape(-1)]) 146 | gathered_tgt_apply_subtree_encodings = torch.stack(_gathered_tgt_apply_subtree_encodings, dim=1).\ 147 | reshape(max_iteration_step, batch_size, -1, batch_memory_encodings.size(2)) 148 | 149 | # safeguard 150 | _reshaped_tgt_apply_subtree_idx_mask = tgt_apply_subtree_idx_mask.reshape( 151 | max_iteration_step * batch_size, -1) 152 | _reshaped_tgt_apply_subtree_idx_mask[_reshaped_tgt_apply_subtree_idx_mask.sum(-1).eq(0), 0] = 1 153 | tgt_apply_subtree_idx_mask = _reshaped_tgt_apply_subtree_idx_mask.reshape(max_iteration_step, batch_size, -1) 154 | 155 | # (max_iteration_step, batch_size, source_element_encoding_size) 156 | tgt_apply_subtree_encodings = torch.sum( 157 | gathered_tgt_apply_subtree_encodings * tgt_apply_subtree_idx_mask.unsqueeze(-1), dim=2) / \ 158 | tgt_apply_subtree_idx_mask.sum(dim=-1, keepdim=True) 159 | 160 | add_subtree_inputs = self.add_subtree_projection( 161 | torch.cat([tgt_operator_embeddings, tgt_field_embeddings, tgt_node_encodings, 162 | tgt_apply_subtree_encodings], dim=-1) 163 | ) 164 | _cand_inputs.append(add_subtree_inputs) 165 | 166 | # (max_iteration_step*batch_size, num of operators, input_size) 167 | _cand_inputs = torch.stack(_cand_inputs, dim=2).reshape(max_iteration_step * batch_size, -1, self.input_size) 168 | _count = torch.arange(max_iteration_step * batch_size).to(self.device) 169 | # (max_iteration_step, batch_size, input_size) 170 | tgt_inputs = _cand_inputs[_count, operator_selection_idx.view(-1)].reshape(max_iteration_step, batch_size, self.input_size) 171 | 172 | padded_inputs_in_steps = pack_padded_sequence(tgt_inputs, batch_actual_edits_length, 173 | batch_first=False, enforce_sorted=False) 174 | 175 | change_seq_encodings, (last_state, last_cell) = self.change_seq_encoder_lstm(padded_inputs_in_steps) 176 | # change_seq_encodings, _ = pad_packed_sequence(change_seq_encodings) 177 | 178 | # (batch_size, hidden_size * 2) 179 | last_state = torch.cat([last_state[0], last_state[1]], 1) 180 | return last_state 181 | -------------------------------------------------------------------------------- /edit_model/embedder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from itertools import chain 3 | from typing import List, Dict 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.utils 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 11 | 12 | from asdl.asdl_ast import AbstractSyntaxNode, SyntaxToken, AbstractSyntaxTree 13 | from asdl.asdl import ASDLGrammar 14 | from edit_model import nn_utils 15 | from edit_model.utils import cached_property 16 | from edit_components.vocab import VocabEntry 17 | from edit_model.pointer_net import PointerNet 18 | 19 | import numpy as np 20 | 21 | ALPHABET = "abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}" 22 | ALPHABET_DICT = {char: idx + 1 for (idx, char) in enumerate(ALPHABET)} # "0" is PAD 23 | ALPHABET_DICT = {char: idx + 2 for (idx, char) in enumerate(ALPHABET)} # "0" is PAD, "1" is UNK 24 | ALPHABET_DICT["PAD"] = 0 25 | ALPHABET_DICT["UNK"] = 1 26 | 27 | 28 | class Embedder: 29 | def __init__(self, vocab=None): 30 | self.vocab = vocab 31 | 32 | @property 33 | def device(self): 34 | return self.weight.device 35 | 36 | def populate_embedding_table(self, embedding_table): 37 | raise NotImplementedError() 38 | 39 | def to_input_variable(self, code_list, return_mask=False): 40 | word_ids = nn_utils.word2id(code_list, self.vocab) 41 | sents_t, masks = nn_utils.input_transpose(word_ids, pad_token=0) 42 | sents_var = torch.tensor(sents_t, dtype=torch.long, device=self.device) 43 | 44 | if return_mask: 45 | mask_var = torch.tensor(masks, dtype=torch.long, device=self.device) 46 | return sents_var, mask_var 47 | 48 | return sents_var 49 | 50 | 51 | class CodeTokenEmbedder(Embedder, nn.Embedding): 52 | def __init__(self, embedding_size, vocab: VocabEntry): 53 | nn.Embedding.__init__(self, len(vocab), embedding_size) 54 | Embedder.__init__(self, vocab) 55 | 56 | nn.init.xavier_normal_(self.weight.data) 57 | 58 | def forward(self, code_list): 59 | if isinstance(code_list, list): 60 | index = self.to_input_variable(code_list) 61 | else: 62 | index = code_list 63 | embedding = super(CodeTokenEmbedder, self).forward(index) 64 | 65 | return embedding 66 | 67 | def populate_embedding_table(self, embedding_table): 68 | tokens = list(embedding_table.tokens.keys()) 69 | indices = [self.vocab[token] for token in tokens] 70 | token_embedding = super(CodeTokenEmbedder, self).forward(torch.tensor(indices, dtype=torch.long, device=self.device)) 71 | 72 | embedding_table.init_with_embeddings(token_embedding) 73 | 74 | def get_embed_for_token_sequences(self, sequences): 75 | return self.forward(sequences) 76 | 77 | 78 | class ConvolutionalCharacterEmbedder(nn.Module, Embedder): 79 | def __init__(self, embed_size: int, max_character_size): 80 | super(ConvolutionalCharacterEmbedder, self).__init__() 81 | 82 | self.max_character_size = max_character_size 83 | self.embed_size = embed_size 84 | 85 | self.conv11_layer = nn.Conv1d(in_channels=len(ALPHABET_DICT), out_channels=20, kernel_size=5) 86 | self.maxpool_layer = nn.MaxPool1d(kernel_size=5, stride=1) 87 | self.conv12_layer = nn.Conv1d(in_channels=20, out_channels=embed_size, kernel_size=12) 88 | 89 | @property 90 | def device(self): 91 | return self.conv11_layer.weight.device 92 | 93 | def populate_embedding_table(self, embedding_table): 94 | # tensorization 95 | # (word_num, max_character_size, encode_char_num) 96 | x = embedding_table.character_input_tensor(max_character_num=self.max_character_size).to(self.device) 97 | 98 | conv1 = F.leaky_relu(self.conv11_layer(x.permute(0, 2, 1))) 99 | maxpool = self.maxpool_layer(conv1) 100 | conv2 = self.conv12_layer(maxpool).squeeze(-1) 101 | 102 | embedding_table.init_with_embeddings(conv2) 103 | 104 | 105 | class EmbeddingTable: 106 | def __init__(self, tokens): 107 | self.tokens = OrderedDict() 108 | for token in tokens: 109 | self.add_token(token) 110 | 111 | def add_token(self, token): 112 | if token not in self.tokens: 113 | self.tokens[token] = len(self.tokens) 114 | 115 | def add_tokens(self, tokens): 116 | for token in tokens: 117 | self.add_token(token) 118 | 119 | def character_input_tensor(self, max_character_num=20, lowercase=True): 120 | # return: (word_num, max_character_num, char_num) 121 | idx_tensor = torch.zeros(len(self.tokens), max_character_num, len(ALPHABET_DICT), dtype=torch.float) 122 | for token, token_id in self.tokens.items(): 123 | if lowercase: 124 | token = token.lower() 125 | 126 | token_trimmed = token[:max_character_num] 127 | token_char_seq = [t for t in token_trimmed] + ['PAD'] * (max_character_num - len(token_trimmed)) 128 | token_char_seq_ids = [ALPHABET_DICT[char] if char in ALPHABET_DICT else ALPHABET_DICT['UNK'] for char in token_char_seq] 129 | idx_tensor[token_id, list(range(max_character_num)), token_char_seq_ids] = 1.0 130 | 131 | return idx_tensor 132 | 133 | def __getitem__(self, token): 134 | token_id = self.tokens[token] 135 | embed = self.embedding[token_id] 136 | 137 | return embed 138 | 139 | def init_with_embeddings(self, embedding_tensor): 140 | # input: (word_num, token_embedding_size) 141 | self.embedding = embedding_tensor 142 | 143 | def to_input_variable(self, sequences, return_mask=False): 144 | """ 145 | given a list of sequences, 146 | return a tensor of shape (max_sent_len, batch_size) 147 | """ 148 | word_ids = nn_utils.word2id(sequences, self.tokens) 149 | sents_t, masks = nn_utils.input_transpose(word_ids, pad_token=0) 150 | sents_var = torch.tensor(sents_t, dtype=torch.long, device=self.embedding.device) 151 | 152 | if return_mask: 153 | mask_var = torch.tensor(masks, dtype=torch.long, device=self.embedding.device) 154 | return sents_var, mask_var 155 | 156 | return sents_var 157 | 158 | def get_embed_for_token_sequences(self, sequences): 159 | # (max_sent_len, batch_size) 160 | input_var = self.to_input_variable(sequences) 161 | 162 | # (max_sent_len, batch_size, embed_size) 163 | seq_embed = self.embedding[input_var] 164 | 165 | return seq_embed 166 | 167 | 168 | class SyntaxTreeEmbedder(nn.Embedding, Embedder): 169 | def __init__(self, embedding_size, vocab: VocabEntry, grammar: ASDLGrammar, node_embed_method: str='type'): 170 | if node_embed_method == 'type': 171 | node_embed_num = len(grammar.types) 172 | elif node_embed_method == 'type_and_field': 173 | node_embed_num = len(grammar.types) + len(grammar.prod_field2id) + 1 # 1 for root field 174 | else: 175 | raise ValueError 176 | 177 | nn.Embedding.__init__(self, len(vocab) + node_embed_num, embedding_size) 178 | Embedder.__init__(self, vocab) 179 | 180 | if node_embed_method == 'type_and_field': 181 | self.combine_type_and_field_embed = nn.Linear(embedding_size * 2, embedding_size) 182 | self.field_embed_offset = len(vocab) + len(grammar.types) 183 | 184 | self.node_embed_method = node_embed_method 185 | 186 | nn.init.xavier_normal_(self.weight.data) 187 | 188 | self.grammar = grammar 189 | 190 | @property 191 | def device(self): 192 | return self.weight.device 193 | 194 | def embed_syntax_tree(self, 195 | batch_syntax_trees: List[AbstractSyntaxTree], 196 | batch_graph_node2example_node: Dict, 197 | prev_code_token_encoding: torch.FloatTensor, 198 | bool_use_position=True): 199 | indices = [] 200 | field_indices = [] 201 | token_indices = [] 202 | batch_ids_for_token = [] 203 | token_pos = [] 204 | for i, (batch_node_id, (e_id, e_node_id)) in enumerate(batch_graph_node2example_node.items()): 205 | node = batch_syntax_trees[e_id].id2node[e_node_id] 206 | if isinstance(node, AbstractSyntaxNode): 207 | idx = self.grammar.type2id[node.production.type] + len(self.vocab) 208 | if self.node_embed_method == 'type_and_field': 209 | field_idx = self.field_embed_offset if node.parent_field is None \ 210 | else self.field_embed_offset + 1 + self.grammar.prod_field2id[(node.parent_field.parent_node.production, node.parent_field.field)] 211 | elif isinstance(node, SyntaxToken): 212 | if bool_use_position and node.position >= 0: 213 | batch_ids_for_token.append(e_id) 214 | token_pos.append(node.position) 215 | token_indices.append(i) 216 | 217 | idx = 0 218 | else: 219 | idx = self.vocab[node.value] 220 | 221 | field_idx = 0 222 | else: 223 | raise ValueError('unknown node type %s' % node) 224 | 225 | indices.append(idx) 226 | if self.node_embed_method == 'type_and_field': 227 | field_indices.append(field_idx) 228 | 229 | # (all_node_num, embedding_dim) 230 | node_embedding = super(SyntaxTreeEmbedder, self).forward(torch.tensor(indices, dtype=torch.long, device=self.device)) 231 | 232 | if self.node_embed_method == 'type_and_field': 233 | node_field_embedding = super(SyntaxTreeEmbedder, self).forward(torch.tensor(field_indices, dtype=torch.long, device=self.device)) 234 | node_embedding = self.combine_type_and_field_embed(torch.cat([node_embedding, node_field_embedding], dim=-1)) 235 | 236 | syntax_token_encodings = prev_code_token_encoding[batch_ids_for_token, token_pos] 237 | node_embedding[token_indices] = syntax_token_encodings 238 | 239 | return node_embedding 240 | 241 | def populate_embedding_table(self, embedding_table): 242 | tokens = list(embedding_table.tokens.keys()) 243 | indices = [self.vocab[token] for token in tokens] 244 | token_embedding = super(SyntaxTreeEmbedder, self).forward(torch.tensor(indices, dtype=torch.long, device=self.device)) 245 | 246 | embedding_table.init_with_embeddings(token_embedding) 247 | 248 | def get_embed_for_token_sequences(self, sequences): 249 | # (max_sent_len, batch_size) 250 | input_var = self.to_input_variable(sequences) 251 | 252 | # (max_sent_len, batch_size, embed_size) 253 | seq_embed = super(SyntaxTreeEmbedder, self).forward(input_var) 254 | 255 | return seq_embed 256 | -------------------------------------------------------------------------------- /edit_model/encdec/__init__.py: -------------------------------------------------------------------------------- 1 | from .graph_encoder import SyntaxTreeEncoder 2 | from .sequential_decoder import SequentialDecoder 3 | from .transition_decoder import TransitionDecoder 4 | -------------------------------------------------------------------------------- /edit_model/encdec/decoder.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class Decoder(nn.Module): 7 | VALID_IDENTIFIER_RE = re.compile(r'^[A-Za-z0-9_-]*$') 8 | 9 | def __init__(self): 10 | super(Decoder, self).__init__() 11 | 12 | @staticmethod 13 | def _can_only_generate_this_token(token): 14 | # remove the BPE delimiter 15 | if token.startswith('\u2581'): 16 | token = token[1:] 17 | 18 | return not Decoder.VALID_IDENTIFIER_RE.match(token) 19 | -------------------------------------------------------------------------------- /edit_model/encdec/encoder.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | EncodingResult = namedtuple('EncodingResult', ['data', 'encoding', 'last_state', 'last_cell', 'mask']) 4 | -------------------------------------------------------------------------------- /edit_model/encdec/graph_encoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict, namedtuple 2 | from typing import List 3 | import sys 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from asdl.asdl_ast import AbstractSyntaxTree 9 | from edit_model.gnn import AdjacencyList, GatedGraphNeuralNetwork 10 | from edit_model.utils import get_method_args_dict 11 | 12 | TreeEncodingResult = namedtuple('TreeEncodingResult', ['data', 'encoding', 'mask', 'syntax_token_mask']) 13 | 14 | 15 | class SyntaxTreeEncoder(nn.Module): 16 | def __init__(self, hidden_size, syntax_tree_embedder, connections, layer_timesteps, residual_connections, dropout, 17 | vocab, grammar, **kwargs): 18 | super(SyntaxTreeEncoder, self).__init__() 19 | 20 | self.grammar = grammar 21 | self.vocab = vocab 22 | 23 | self.connections = connections 24 | self.token_bidirectional_connection = 'bi_token' in connections 25 | self.top_down_connection = 'top_down' in connections 26 | self.bottom_up_connection = 'bottom_up' in connections 27 | self.next_sibling_connection = 'next_sibling' in connections 28 | self.prev_sibling_connection = 'prev_sibling' in connections 29 | self.gnn_use_bias_for_message_linear = kwargs.pop('gnn_use_bias_for_message_linear', True) 30 | 31 | self.num_edge_types = 0 32 | if self.token_bidirectional_connection: 33 | self.num_edge_types += 2 34 | if self.top_down_connection: 35 | self.num_edge_types += 1 36 | if self.bottom_up_connection: 37 | self.num_edge_types += 1 38 | if self.next_sibling_connection: 39 | self.num_edge_types += 1 40 | if self.prev_sibling_connection: 41 | self.num_edge_types += 1 42 | 43 | assert self.num_edge_types > 0 44 | 45 | self.syntax_tree_embedder = syntax_tree_embedder 46 | self.gnn = GatedGraphNeuralNetwork(hidden_size=hidden_size, 47 | num_edge_types=self.num_edge_types, 48 | layer_timesteps=layer_timesteps, 49 | residual_connections=residual_connections, 50 | state_to_message_dropout=dropout, 51 | rnn_dropout=dropout, 52 | use_bias_for_message_linear=self.gnn_use_bias_for_message_linear) 53 | 54 | @property 55 | def device(self): 56 | return self.syntax_tree_embedder.device 57 | 58 | def forward(self, batch_syntax_trees, prev_code_token_encoding=None): 59 | # combine ASTs into a huge graph 60 | batch_adj_lists, example_node2batch_node_map, batch_node2_example_node_map = self.get_batch_adjacency_lists(batch_syntax_trees) 61 | 62 | # get initial embeddings for every nodes 63 | # (V, D) 64 | if prev_code_token_encoding is not None: 65 | init_node_embeddings = self.syntax_tree_embedder.embed_syntax_tree( 66 | batch_syntax_trees, batch_node2_example_node_map, prev_code_token_encoding) 67 | else: 68 | init_node_embeddings = torch.zeros(len(batch_node2_example_node_map), self.syntax_tree_embedder.embedding_dim, 69 | dtype=torch.float).to(self.device) 70 | 71 | # perform encoding 72 | # (V, D) 73 | flattened_node_encodings = self.gnn(init_node_embeddings, batch_adj_lists) 74 | 75 | # split node encodings from the huge graph, List[Variable[batch_node_num]] 76 | encoding_result = self.to_batch_encoding(flattened_node_encodings, batch_syntax_trees, example_node2batch_node_map) 77 | 78 | return encoding_result 79 | 80 | def get_batch_adjacency_lists(self, batch_syntax_trees: List[AbstractSyntaxTree]): 81 | example_node2batch_graph_node = OrderedDict() 82 | 83 | # add parent -> child node on ASTs 84 | ast_adj_list = [] 85 | reversed_ast_adj_list = [] 86 | terminal_tokens_adj_list = [] 87 | reversed_terminal_tokens_adj_list = [] 88 | next_sibling_adj_list = [] 89 | prev_sibling_adj_list = [] 90 | for e_id, syntax_tree in enumerate(batch_syntax_trees): 91 | for node_s_id, node_t_id in syntax_tree.adjacency_list: 92 | # source -> target 93 | node_s_batch_id = example_node2batch_graph_node.setdefault((e_id, node_s_id), 94 | len(example_node2batch_graph_node)) 95 | node_t_batch_id = example_node2batch_graph_node.setdefault((e_id, node_t_id), 96 | len(example_node2batch_graph_node)) 97 | 98 | if self.top_down_connection: 99 | ast_adj_list.append((node_s_batch_id, node_t_batch_id)) 100 | if self.bottom_up_connection: 101 | reversed_ast_adj_list.append((node_t_batch_id, node_s_batch_id)) 102 | 103 | # add bi-directional connection between adjacent terminal nodes 104 | if self.token_bidirectional_connection: 105 | for i in range(len(syntax_tree.syntax_tokens_and_ids) - 1): 106 | cur_token_id, cur_token = syntax_tree.syntax_tokens_and_ids[i] 107 | next_token_id, next_token = syntax_tree.syntax_tokens_and_ids[i + 1] 108 | 109 | cur_token_batch_id = example_node2batch_graph_node[(e_id, cur_token_id)] 110 | next_token_batch_id = example_node2batch_graph_node[(e_id, next_token_id)] 111 | 112 | terminal_tokens_adj_list.append((cur_token_batch_id, next_token_batch_id)) 113 | reversed_terminal_tokens_adj_list.append((next_token_batch_id, cur_token_batch_id)) 114 | 115 | if self.prev_sibling_connection or self.next_sibling_connection: 116 | for left_node_id, right_node_id in syntax_tree.next_siblings_adjacency_list: 117 | left_node_batch_id = example_node2batch_graph_node[(e_id, left_node_id)] 118 | right_node_batch_id = example_node2batch_graph_node[(e_id, right_node_id)] 119 | if self.next_sibling_connection: 120 | next_sibling_adj_list.append((left_node_batch_id, right_node_batch_id)) 121 | if self.prev_sibling_connection: 122 | prev_sibling_adj_list.append((right_node_batch_id, left_node_batch_id)) 123 | 124 | batch_graph_node2example_node = OrderedDict([(v, k) for k, v in example_node2batch_graph_node.items()]) 125 | 126 | all_nodes_num = len(example_node2batch_graph_node) 127 | adj_lists = [] 128 | if self.top_down_connection: 129 | ast_adj_list = AdjacencyList(node_num=all_nodes_num, adj_list=ast_adj_list, device=self.device) 130 | adj_lists.append(ast_adj_list) 131 | if self.bottom_up_connection: 132 | reversed_ast_adj_list = AdjacencyList(node_num=all_nodes_num, adj_list=reversed_ast_adj_list, device=self.device) 133 | adj_lists.append(reversed_ast_adj_list) 134 | 135 | if self.token_bidirectional_connection and terminal_tokens_adj_list: 136 | terminal_tokens_adj_list = AdjacencyList(node_num=all_nodes_num, adj_list=terminal_tokens_adj_list, device=self.device) 137 | reversed_terminal_tokens_adj_list = AdjacencyList(node_num=all_nodes_num, adj_list=reversed_terminal_tokens_adj_list, device=self.device) 138 | 139 | adj_lists.append(terminal_tokens_adj_list) 140 | adj_lists.append(reversed_terminal_tokens_adj_list) 141 | 142 | if self.prev_sibling_connection and prev_sibling_adj_list: 143 | prev_sibling_adj_list = AdjacencyList(node_num=all_nodes_num, adj_list=prev_sibling_adj_list, device=self.device) 144 | adj_lists.append(prev_sibling_adj_list) 145 | 146 | if self.next_sibling_connection and next_sibling_adj_list: 147 | next_sibling_adj_list = AdjacencyList(node_num=all_nodes_num, adj_list=next_sibling_adj_list, 148 | device=self.device) 149 | adj_lists.append(next_sibling_adj_list) 150 | 151 | # print(f'batch size: {len(batch_syntax_trees)}, ' 152 | # f'total_edges: {sum(adj_list.edge_num for adj_list in adj_lists)}, ' 153 | # f'total nodes: {len(example_node2batch_graph_node)}, ' 154 | # f'max edges: {max(len(tree.adjacency_list) for tree in batch_syntax_trees)}, ' 155 | # f'max nodes: {max(tree.node_num for tree in batch_syntax_trees)}', file=sys.stderr) 156 | 157 | return adj_lists, \ 158 | example_node2batch_graph_node, \ 159 | batch_graph_node2example_node 160 | 161 | def to_batch_encoding(self, flattened_node_encodings, batch_syntax_trees, example_node2batch_node_map): 162 | max_node_num = max(tree.node_num for tree in batch_syntax_trees) 163 | index_list = [] 164 | mask_list = [] 165 | syntax_token_mask_list = [] 166 | 167 | for e_id, syntax_tree in enumerate(batch_syntax_trees): 168 | example_nodes_with_batch_id = [(example_node_id, batch_node_id) 169 | for (_e_id, example_node_id), batch_node_id 170 | in example_node2batch_node_map.items() 171 | if _e_id == e_id] 172 | # example_nodes_batch_id = list(map(lambda x: x[1], sorted(example_nodes_with_batch_id, key=lambda t: t[0]))) 173 | sorted_example_nodes_with_batch_id = sorted(example_nodes_with_batch_id, key=lambda t: t[0]) 174 | example_nodes_batch_id = [t[1] for t in sorted_example_nodes_with_batch_id] 175 | 176 | example_idx_list = example_nodes_batch_id + [0] * (max_node_num - syntax_tree.node_num) 177 | example_node_masks = [0] * len(example_nodes_batch_id) + [1] * (max_node_num - syntax_tree.node_num) 178 | syntax_token_masks = [0 if syntax_tree.is_syntax_token(node_id) else 1 for node_id, batch_node_id in sorted_example_nodes_with_batch_id] + [1] * (max_node_num - syntax_tree.node_num) 179 | 180 | index_list.append(example_idx_list) 181 | mask_list.append(example_node_masks) 182 | syntax_token_mask_list.append(syntax_token_masks) 183 | 184 | # (batch_size, max_node_num, node_encoding_size) 185 | batch_node_encoding = flattened_node_encodings[index_list, :] 186 | batch_node_encoding_mask = torch.tensor(mask_list, dtype=torch.bool, device=self.device) # uint8 -> bool, pytorch version upgrade 187 | batch_node_syntax_token_mask = torch.tensor(syntax_token_mask_list, dtype=torch.bool, device=self.device) # FIXME: [Ziyu] syntax_token_mask_list? 188 | 189 | batch_node_encoding.data.masked_fill_(batch_node_encoding_mask.unsqueeze(-1), 0.) 190 | 191 | return TreeEncodingResult(batch_syntax_trees, batch_node_encoding, batch_node_encoding_mask, batch_node_syntax_token_mask) 192 | -------------------------------------------------------------------------------- /edit_model/encdec/sequential_encoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from collections import namedtuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.utils 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 8 | 9 | from edit_model.data_model import BatchedCodeChunk 10 | from edit_model import nn_utils 11 | 12 | 13 | from .encoder import EncodingResult 14 | 15 | 16 | class SequentialEncoder(nn.Module): 17 | """encode the input data""" 18 | 19 | def __init__(self, token_embed_size, token_encoding_size, token_embedder, vocab): 20 | super(SequentialEncoder, self).__init__() 21 | 22 | self.vocab = vocab 23 | 24 | self.token_embedder = token_embedder 25 | self.encoder_lstm = nn.LSTM(token_embed_size, token_encoding_size // 2, bidirectional=True) 26 | 27 | @property 28 | def device(self): 29 | return self.token_embedder.device 30 | 31 | def forward(self, prev_data, is_sorted=False, embedding_cache=None): 32 | batched_code_lens = [len(code) for code in prev_data] 33 | 34 | if is_sorted is False: 35 | original_prev_data = prev_data 36 | sorted_example_ids, example_old2new_pos = nn_utils.get_sort_map(batched_code_lens) 37 | prev_data = [prev_data[i] for i in sorted_example_ids] 38 | 39 | if embedding_cache: 40 | # (code_seq_len, batch_size, token_embed_size) 41 | token_embed = embedding_cache.get_embed_for_token_sequences(prev_data) 42 | else: 43 | # (code_seq_len, batch_size, token_embed_size) 44 | token_embed = self.token_embedder.get_embed_for_token_sequences(prev_data) 45 | 46 | packed_token_embed = pack_padded_sequence(token_embed, [len(code) for code in prev_data]) 47 | 48 | # token_encodings: (tgt_query_len, batch_size, hidden_size) 49 | token_encodings, (last_state, last_cell) = self.encoder_lstm(packed_token_embed) 50 | token_encodings, _ = pad_packed_sequence(token_encodings) 51 | 52 | # (batch_size, hidden_size * 2) 53 | last_state = torch.cat([last_state[0], last_state[1]], 1) 54 | last_cell = torch.cat([last_cell[0], last_cell[1]], 1) 55 | 56 | # (batch_size, tgt_query_len, hidden_size) 57 | token_encodings = token_encodings.permute(1, 0, 2) 58 | if is_sorted is False: 59 | token_encodings = token_encodings[example_old2new_pos] 60 | last_state = last_state[example_old2new_pos] 61 | last_cell = last_cell[example_old2new_pos] 62 | prev_data = original_prev_data 63 | 64 | return EncodingResult(prev_data, token_encodings, last_state, last_cell, 65 | nn_utils.length_array_to_mask_tensor(batched_code_lens, device=self.device)) 66 | 67 | 68 | class ContextEncoder(SequentialEncoder): 69 | def __init__(self, **kwargs): 70 | super(ContextEncoder, self).__init__(**kwargs) 71 | 72 | def forward(self, context_list, is_sorted=False): 73 | return super(ContextEncoder, self).forward(context_list, is_sorted=is_sorted) 74 | -------------------------------------------------------------------------------- /edit_model/gnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.utils 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence 8 | 9 | from typing import List, Tuple, Dict, Sequence, Any 10 | 11 | 12 | class AdjacencyList: 13 | """represent the topology of a graph""" 14 | def __init__(self, node_num: int, adj_list: List, device: torch.device): 15 | self.node_num = node_num 16 | self.data = torch.tensor(adj_list, dtype=torch.long, device=device) 17 | self.edge_num = len(adj_list) 18 | 19 | @property 20 | def device(self): 21 | return self.data.device 22 | 23 | def __getitem__(self, item): 24 | return self.data[item] 25 | 26 | 27 | class GatedGraphNeuralNetwork(nn.Module): 28 | def __init__(self, hidden_size, num_edge_types, layer_timesteps, 29 | residual_connections, 30 | state_to_message_dropout=0.3, 31 | rnn_dropout=0.3, 32 | use_bias_for_message_linear=True): 33 | 34 | super(GatedGraphNeuralNetwork, self).__init__() 35 | 36 | self.hidden_size = hidden_size 37 | self.num_edge_types = num_edge_types 38 | self.layer_timesteps = layer_timesteps 39 | self.residual_connections = {int(k): v for k, v in residual_connections.items()} 40 | self.state_to_message_dropout = state_to_message_dropout 41 | self.rnn_dropout = rnn_dropout 42 | self.use_bias_for_message_linear = use_bias_for_message_linear 43 | 44 | # Prepare linear transformations from node states to messages, for each layer and each edge type 45 | # Prepare rnn cells for each layer 46 | self.state_to_message_linears = [] 47 | self.rnn_cells = [] 48 | for layer_idx in range(len(self.layer_timesteps)): 49 | state_to_msg_linears_cur_layer = [] 50 | # Initiate a linear transformation for each edge type 51 | for edge_type_j in range(self.num_edge_types): 52 | # TODO: glorot_init? 53 | state_to_msg_linear_layer_i_type_j = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias_for_message_linear) 54 | setattr(self, 55 | 'state_to_message_linear_layer%d_type%d' % (layer_idx, edge_type_j), 56 | state_to_msg_linear_layer_i_type_j) 57 | 58 | state_to_msg_linears_cur_layer.append(state_to_msg_linear_layer_i_type_j) 59 | self.state_to_message_linears.append(state_to_msg_linears_cur_layer) 60 | 61 | layer_residual_connections = self.residual_connections.get(layer_idx, []) 62 | rnn_cell_layer_i = nn.GRUCell(self.hidden_size * (1 + len(layer_residual_connections)), self.hidden_size) 63 | setattr(self, 'rnn_cell_layer%d' % layer_idx, rnn_cell_layer_i) 64 | self.rnn_cells.append(rnn_cell_layer_i) 65 | 66 | self.state_to_message_dropout_layer = nn.Dropout(self.state_to_message_dropout) 67 | self.rnn_dropout_layer = nn.Dropout(self.rnn_dropout) 68 | 69 | @property 70 | def device(self): 71 | return self.rnn_cells[0].weight_hh.device 72 | 73 | def forward(self, 74 | initial_node_representation: Variable, 75 | adjacency_lists: List[AdjacencyList], 76 | return_all_states=False) -> Variable: 77 | return self.compute_node_representations(initial_node_representation, adjacency_lists, 78 | return_all_states=return_all_states) 79 | 80 | def compute_node_representations(self, 81 | initial_node_representation: Variable, 82 | adjacency_lists: List[AdjacencyList], 83 | return_all_states=False) -> Variable: 84 | # If the dimension of initial node embedding is smaller, then perform padding first 85 | # one entry per layer (final state of that layer), shape: number of nodes in batch v x D 86 | init_node_repr_size = initial_node_representation.size(1) 87 | device = adjacency_lists[0].data.device 88 | if init_node_repr_size < self.hidden_size: 89 | pad_size = self.hidden_size - init_node_repr_size 90 | zero_pads = torch.zeros(initial_node_representation.size(0), pad_size, dtype=torch.float, device=device) 91 | initial_node_representation = torch.cat([initial_node_representation, zero_pads], dim=-1) 92 | node_states_per_layer = [initial_node_representation] 93 | 94 | node_num = initial_node_representation.size(0) 95 | 96 | message_targets = [] # list of tensors of message targets of shape [E] 97 | for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists): 98 | if adjacency_list_for_edge_type.edge_num > 0: 99 | edge_targets = adjacency_list_for_edge_type[:, 1] 100 | message_targets.append(edge_targets) 101 | message_targets = torch.cat(message_targets, dim=0) # Shape [M] 102 | 103 | # sparse matrix of shape [V, M] 104 | # incoming_msg_sparse_matrix = self.get_incoming_message_sparse_matrix(adjacency_lists).to(device) 105 | for layer_idx, num_timesteps in enumerate(self.layer_timesteps): 106 | # Used shape abbreviations: 107 | # V ~ number of nodes 108 | # D ~ state dimension 109 | # E ~ number of edges of current type 110 | # M ~ number of messages (sum of all E) 111 | 112 | # Extract residual messages, if any: 113 | layer_residual_connections = self.residual_connections.get(layer_idx, []) 114 | # List[(V, D)] 115 | layer_residual_states: List[torch.FloatTensor] = [node_states_per_layer[residual_layer_idx] 116 | for residual_layer_idx in layer_residual_connections] 117 | 118 | # Record new states for this layer. Initialised to last state, but will be updated below: 119 | node_states_for_this_layer = node_states_per_layer[-1] 120 | # For each message propagation step 121 | for t in range(num_timesteps): 122 | messages: List[torch.FloatTensor] = [] # list of tensors of messages of shape [E, D] 123 | message_source_states: List[torch.FloatTensor] = [] # list of tensors of edge source states of shape [E, D] 124 | 125 | # Collect incoming messages per edge type 126 | for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists): 127 | if adjacency_list_for_edge_type.edge_num > 0: 128 | # shape [E] 129 | edge_sources = adjacency_list_for_edge_type[:, 0] 130 | # shape [E, D] 131 | edge_source_states = node_states_for_this_layer[edge_sources] 132 | 133 | f_state_to_message = self.state_to_message_linears[layer_idx][edge_type_idx] 134 | # Shape [E, D] 135 | all_messages_for_edge_type = self.state_to_message_dropout_layer(f_state_to_message(edge_source_states)) 136 | 137 | messages.append(all_messages_for_edge_type) 138 | message_source_states.append(edge_source_states) 139 | 140 | # shape [M, D] 141 | messages: torch.FloatTensor = torch.cat(messages, dim=0) 142 | 143 | # Sum up messages that go to the same target node 144 | # shape [V, D] 145 | incoming_messages = torch.zeros(node_num, messages.size(1), device=device) 146 | incoming_messages = incoming_messages.scatter_add_(0, 147 | message_targets.unsqueeze(-1).expand_as(messages), 148 | messages) 149 | 150 | # shape [V, D * (1 + num of residual connections)] 151 | incoming_information = torch.cat(layer_residual_states + [incoming_messages], dim=-1) 152 | 153 | # pass updated vertex features into RNN cell 154 | # Shape [V, D] 155 | updated_node_states = self.rnn_cells[layer_idx](incoming_information, node_states_for_this_layer) 156 | updated_node_states = self.rnn_dropout_layer(updated_node_states) 157 | node_states_for_this_layer = updated_node_states 158 | 159 | node_states_per_layer.append(node_states_for_this_layer) 160 | 161 | if return_all_states: 162 | return node_states_per_layer[1:] 163 | else: 164 | node_states_for_last_layer = node_states_per_layer[-1] 165 | return node_states_for_last_layer 166 | 167 | 168 | def main(): 169 | gnn = GatedGraphNeuralNetwork(hidden_size=64, num_edge_types=2, 170 | layer_timesteps=[3, 5, 7, 2], residual_connections={2: [0], 3: [0, 1]}) 171 | 172 | adj_list_type1 = AdjacencyList(node_num=4, adj_list=[(0, 2), (2, 1), (1, 3)], device=gnn.device) 173 | adj_list_type2 = AdjacencyList(node_num=4, adj_list=[(0, 0), (0, 1)], device=gnn.device) 174 | 175 | node_representations = gnn.compute_node_representations(initial_node_representation=torch.randn(4, 64), 176 | adjacency_lists=[adj_list_type1, adj_list_type2]) 177 | 178 | print(node_representations) 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /edit_model/nn_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | import numpy as np 10 | 11 | 12 | def dot_prod_attention(h_t, src_encoding, src_encoding_att_linear, mask=None, return_log_att_weight=False): 13 | """ 14 | :param h_t: (batch_size, hidden_size) 15 | :param src_encoding: (batch_size, src_sent_len, hidden_size * 2) 16 | :param src_encoding_att_linear: (batch_size, src_sent_len, hidden_size) 17 | :param mask: (batch_size, src_sent_len) 18 | """ 19 | # (batch_size, src_sent_len) 20 | att_weight = torch.bmm(src_encoding_att_linear, h_t.unsqueeze(2)).squeeze(2) 21 | if mask is not None: 22 | att_weight.data.masked_fill_(mask.bool(), -float('inf')) # byte -> bool, pytorch version upgrade 23 | softmaxed_att_weight = F.softmax(att_weight, dim=-1) 24 | if return_log_att_weight: 25 | log_softmaxed_att_weight = F.log_softmax(att_weight, dim=-1) 26 | 27 | att_view = (att_weight.size(0), 1, att_weight.size(1)) 28 | # (batch_size, hidden_size) 29 | ctx_vec = torch.bmm(softmaxed_att_weight.view(*att_view), src_encoding).squeeze(1) 30 | 31 | if return_log_att_weight: 32 | return ctx_vec, softmaxed_att_weight, log_softmaxed_att_weight 33 | else: 34 | return ctx_vec, softmaxed_att_weight 35 | 36 | 37 | def log_sum_exp(inputs, keepdim=False, mask=None): 38 | """Numerically stable logsumexp on the last dim of `inputs`. 39 | reference: https://github.com/pytorch/pytorch/issues/2591 40 | Args: 41 | inputs: A Variable with any shape. 42 | keepdim: A boolean. 43 | mask: A mask variable of type float. It has the same shape as `inputs`. 44 | Returns: 45 | Equivalent of log(sum(exp(inputs), keepdim=keepdim)). 46 | """ 47 | 48 | if mask is not None: 49 | mask = 1. - mask 50 | max_offset = -1e7 * mask 51 | else: 52 | max_offset = 0. 53 | 54 | s, _ = torch.max(inputs + max_offset, dim=-1, keepdim=True) 55 | 56 | inputs_offset = inputs - s 57 | if mask is not None: 58 | inputs_offset.masked_fill_(mask.bool(), -float('inf')) # byte -> bool, pytorch version upgrade 59 | 60 | outputs = s + inputs_offset.exp().sum(dim=-1, keepdim=True).log() 61 | 62 | if not keepdim: 63 | outputs = outputs.squeeze(-1) 64 | return outputs 65 | 66 | 67 | def log_softmax(inputs, dim=-1, mask=None): 68 | if mask is not None: 69 | inputs.masked_fill_((1 - mask).bool(), -float('inf')) # byte -> bool, pytorch version upgrade 70 | 71 | return F.log_softmax(inputs, dim=dim) 72 | 73 | 74 | def length_array_to_mask_tensor(length_array, device=None): 75 | max_len = max(length_array) 76 | batch_size = len(length_array) 77 | 78 | mask = np.ones((batch_size, max_len), dtype=np.uint8) 79 | for i, seq_len in enumerate(length_array): 80 | mask[i][:seq_len] = 0 81 | 82 | mask = torch.tensor(mask, dtype=torch.uint8, device=device) 83 | 84 | return mask 85 | 86 | 87 | def pad_lists(indices, pad_id, return_mask=False): 88 | max_len = max(len(idx_list) for idx_list in indices) 89 | padded_indices = [] 90 | if return_mask: masks = [] 91 | for idx_list in indices: 92 | padded_indices.append(idx_list + [pad_id] * (max_len - len(idx_list))) 93 | if return_mask: 94 | masks.append([0] * len(idx_list) + [1] * (max_len - len(idx_list))) 95 | 96 | if return_mask: return padded_indices, masks 97 | return padded_indices 98 | 99 | 100 | def to_input_variable(sequences, vocab, device, append_boundary_sym=False, return_mask=False, pad_id=-1, batch_first=False): 101 | """ 102 | given a list of sequences, 103 | return a tensor of shape (max_sent_len, batch_size) 104 | """ 105 | if append_boundary_sym: 106 | sequences = [[''] + seq + [''] for seq in sequences] 107 | 108 | pad_id = pad_id if pad_id >= 0 else vocab[''] 109 | 110 | word_ids = word2id(sequences, vocab) 111 | if batch_first: 112 | result = pad_lists(word_ids, pad_id, return_mask=return_mask) 113 | if return_mask: sents_t, masks = result 114 | else: sents_t = result 115 | else: 116 | sents_t, masks = input_transpose(word_ids, pad_id) 117 | 118 | sents_var = torch.tensor(sents_t, dtype=torch.long, device=device) 119 | 120 | if return_mask: 121 | mask_var = torch.tensor(masks, dtype=torch.long, device=device) 122 | return sents_var, mask_var 123 | 124 | return sents_var 125 | 126 | 127 | def word2id(sents, vocab): 128 | if type(sents[0]) == list: 129 | return [[vocab[w] for w in s] for s in sents] 130 | else: 131 | return [vocab[w] for w in sents] 132 | 133 | 134 | def id2word(sents, vocab): 135 | if type(sents[0]) == list: 136 | return [[vocab.id2word[w] for w in s] for s in sents] 137 | else: 138 | return [vocab.id2word[w] for w in sents] 139 | 140 | 141 | def input_transpose(sents, pad_token): 142 | """ 143 | transform the input List[sequence] of size (batch_size, max_sent_len) 144 | into a list of size (max_sent_len, batch_size), with proper padding 145 | """ 146 | max_len = max(len(s) for s in sents) 147 | batch_size = len(sents) 148 | 149 | sents_t = [] 150 | masks = [] 151 | for i in range(max_len): 152 | sents_t.append([sents[k][i] if len(sents[k]) > i else pad_token for k in range(batch_size)]) 153 | masks.append([1 if len(sents[k]) > i else 0 for k in range(batch_size)]) 154 | 155 | return sents_t, masks 156 | 157 | 158 | def get_sort_map(lens_array): 159 | """sort input by length in descending order, 160 | return the sorted index and the mapping between old and new positions""" 161 | 162 | sorted_example_ids = sorted(list(range(len(lens_array))), key=lambda x: -lens_array[x]) 163 | 164 | example_old2new_pos_map = [-1] * len(lens_array) 165 | for new_pos, old_pos in enumerate(sorted_example_ids): 166 | example_old2new_pos_map[old_pos] = new_pos 167 | 168 | return sorted_example_ids, example_old2new_pos_map 169 | 170 | 171 | def batch_iter(examples, batch_size, shuffle=False, sort_func=None, return_sort_map=False): 172 | index_arr = np.arange(len(examples)) 173 | if shuffle: 174 | np.random.shuffle(index_arr) 175 | 176 | batch_num = int(np.ceil(len(examples) / float(batch_size))) 177 | for batch_id in range(batch_num): 178 | batch_ids = index_arr[batch_size * batch_id: batch_size * (batch_id + 1)] 179 | batch_examples = [examples[i] for i in batch_ids] 180 | # sort by the length of the change sequence in descending order 181 | if sort_func: 182 | sorted_examples_with_ids = sorted([(_idx, e) for _idx, e in enumerate(batch_examples)], key=lambda x: sort_func(x[1])) 183 | 184 | if return_sort_map and sort_func: 185 | sorted_example_ids = [x[0] for x in sorted_examples_with_ids] 186 | 187 | example_old2new_pos_map = [-1] * len(sorted_example_ids) 188 | for new_pos, old_pos in enumerate(sorted_example_ids): 189 | example_old2new_pos_map[old_pos] = new_pos 190 | 191 | sorted_examples = [x[1] for x in sorted_examples_with_ids] 192 | yield sorted_examples, sorted_example_ids, example_old2new_pos_map 193 | else: 194 | yield batch_examples 195 | 196 | 197 | def anonymize_unk_tokens(prev_code, updated_code, context, vocab): 198 | unk_name_map = dict() 199 | 200 | def __to_new_token_seq(tokens): 201 | for token in tokens: 202 | if token in unk_name_map: 203 | new_token_name = unk_name_map[token] 204 | else: 205 | if vocab.is_unk(token): 206 | new_token_name = 'UNK_%d' % len(unk_name_map) 207 | unk_name_map[token] = new_token_name 208 | else: 209 | new_token_name = token 210 | 211 | yield new_token_name 212 | 213 | new_prev_code = list(__to_new_token_seq(prev_code)) 214 | new_context = list(__to_new_token_seq(context)) 215 | new_updated_code = list(__to_new_token_seq(updated_code)) 216 | 217 | return new_prev_code, new_updated_code, new_context 218 | 219 | -------------------------------------------------------------------------------- /edit_model/pointer_net.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.utils 6 | from torch.autograd import Variable 7 | import torch.nn.functional as F 8 | 9 | 10 | class PointerNet(nn.Module): 11 | def __init__(self, query_vec_size, src_encoding_size): 12 | super(PointerNet, self).__init__() 13 | 14 | self.src_encoding_linear = nn.Linear(src_encoding_size, query_vec_size, bias=False) 15 | 16 | def forward(self, src_encodings, src_token_mask, query_vec, 17 | log=False, valid_masked_as_one=False, return_logits=False): 18 | """ 19 | :param src_encodings: Variable(batch_size, src_sent_len, hidden_size * 2) 20 | :param src_token_mask: Variable(batch_size, src_sent_len) 21 | :param query_vec: Variable(tgt_action_num, batch_size, query_vec_size) 22 | :return: Variable(tgt_action_num, batch_size, src_sent_len) 23 | """ 24 | 25 | # (batch_size, 1, src_sent_len, query_vec_size) 26 | src_trans = self.src_encoding_linear(src_encodings).unsqueeze(1) 27 | # (batch_size, tgt_action_num, query_vec_size, 1) 28 | q = query_vec.permute(1, 0, 2).unsqueeze(3) 29 | 30 | # (batch_size, tgt_action_num, src_sent_len) 31 | weights = torch.matmul(src_trans, q).squeeze(3) 32 | 33 | # (tgt_action_num, batch_size, src_sent_len) 34 | weights = weights.permute(1, 0, 2) 35 | 36 | if src_token_mask is not None: 37 | # (tgt_action_num, batch_size, src_sent_len) 38 | if len(src_token_mask.size()) == len(weights.size()) + 1: 39 | src_token_mask = src_token_mask.unsqueeze(0).expand_as(weights) 40 | 41 | if valid_masked_as_one: 42 | src_token_mask = 1 - src_token_mask 43 | 44 | weights.data.masked_fill_(src_token_mask.bool(), -float('inf')) # byte -> bool, pytorch version upgrade 45 | 46 | if return_logits: 47 | return weights 48 | 49 | if log: 50 | ptr_weights = F.log_softmax(weights, dim=-1) 51 | else: 52 | ptr_weights = F.softmax(weights, dim=-1) 53 | 54 | return ptr_weights 55 | -------------------------------------------------------------------------------- /edit_model/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from typing import Dict 3 | 4 | 5 | class cached_property: 6 | """ 7 | A property that is only computed once per instance and then replaces 8 | itself with an ordinary attribute. Deleting the attribute resets the 9 | property. 10 | 11 | Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 12 | """ 13 | 14 | def __init__(self, func): 15 | self.__doc__ = getattr(func, '__doc__') 16 | self.func = func 17 | 18 | def __get__(self, obj, cls): 19 | if obj is None: 20 | return self 21 | value = obj.__dict__[self.func.__name__] = self.func(obj) 22 | return value 23 | 24 | 25 | def get_method_args_dict(func, locals) -> Dict: 26 | arg_spec = inspect.getfullargspec(func) 27 | args = dict() 28 | 29 | for arg_name in filter(lambda x: x not in ('self'), arg_spec.args): 30 | arg_val = locals[arg_name] 31 | if arg_val is None or isinstance(arg_val, (list, dict, set, str, int, float, bool)): 32 | args[arg_name] = locals[arg_name] 33 | 34 | return args 35 | -------------------------------------------------------------------------------- /scripts/githubedits/test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate structural_edits 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | export PYTHONPATH=/scratch/yao.470/CMU_project/incremental_tree_edit:$PYTHONPATH 7 | 8 | config_file=$1 9 | 10 | work_dir=exp_githubedits_runs/FOLDER_OF_MODEL # TODO: replace FOLDER_OF_MODEL with your model dir 11 | 12 | echo use config file ${config_file} 13 | echo work dir=${work_dir} 14 | 15 | mkdir -p ${work_dir} 16 | 17 | # TODO: uncomment the test setting 18 | # beam search 19 | #test_file=source_data/githubedits/githubedits.test.jsonl 20 | #OMP_NUM_THREADS=1 python -m exp_githubedits decode_updated_data \ 21 | # --cuda \ 22 | # --beam_size=1 \ 23 | # --evaluate_ppl \ 24 | # ${work_dir}/model.bin \ 25 | # ${test_file} 2>${work_dir}/model.bin.${filename}_test.log # dev_debug, dev, test, train_debug 26 | 27 | # test ppl 28 | #OMP_NUM_THREADS=1 python -m exp_githubedits test_ppl \ 29 | # --cuda \ 30 | # --evaluate_ppl \ 31 | # ${work_dir}/model.bin \ 32 | # ${test_file} #2>${work_dir}/model.bin.decode_ppl.log 33 | 34 | 35 | ## csharp_fixer 36 | #test_file=source_data/githubedits/csharp_fixers.jsonl 37 | #scorer='default' 38 | #OMP_NUM_THREADS=1 python -m exp_githubedits eval_csharp_fixer \ 39 | # --cuda \ 40 | # --beam_size=1 \ 41 | # --scorer=${scorer} \ 42 | # ${work_dir}/model.bin \ 43 | # ${test_file} 2>${work_dir}/model.bin.${filename}_csharp_fixer_${scorer}.log 44 | 45 | 46 | ## beam search on csharp_fixer with gold inputs 47 | #test_file=source_data/githubedits/csharp_fixers.jsonl 48 | #OMP_NUM_THREADS=1 python -m exp_githubedits decode_updated_data \ 49 | # --cuda \ 50 | # --beam_size=1 \ 51 | # ${work_dir}/model.bin \ 52 | # ${test_file} 2>${work_dir}/model.bin.${filename}_csharp_fixer_gold.log 53 | 54 | 55 | # collect edit vecs 56 | ##test_file=source_data/githubedits/csharp_fixers.jsonl 57 | #test_file=source_data/githubedits/githubedits.test.jsonl 58 | #OMP_NUM_THREADS=1 python -m exp_githubedits collect_edit_vecs \ 59 | # --cuda \ 60 | # ${work_dir}/model.bin \ 61 | # ${test_file} 62 | 63 | -------------------------------------------------------------------------------- /scripts/githubedits/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source activate structural_edits 4 | export CUDA_VISIBLE_DEVICES=0 5 | 6 | export PYTHONPATH=/scratch/yao.470/CMU_project/incremental_tree_edit:$PYTHONPATH 7 | 8 | seed=0 # remember to change this! 9 | config_file=$1 10 | branch=$(git branch | sed -n -e 's/^\* \(.*\)/\1/p') 11 | commit=$(git rev-parse HEAD | cut -c 1-7) 12 | timestamp=`date "+%Y%m%d-%H%M%S"` 13 | work_dir=exp_githubedits_runs/$(basename ${config_file})_branch_${branch}_${commit}.seed${seed}.${timestamp} 14 | 15 | echo use config file ${config_file} 16 | echo work dir=${work_dir} 17 | 18 | mkdir -p ${work_dir} 19 | 20 | # TODO 1: uncomment the training setting 21 | # TODO 2: consider adding `--small_memory` to disable training data preprocessing 22 | #OMP_NUM_THREADS=1 python -u -m exp_githubedits train \ 23 | # --cuda \ 24 | # --seed=${seed} \ 25 | # --work_dir=${work_dir} \ 26 | # ${config_file} 2>${work_dir}/err.log 27 | 28 | #OMP_NUM_THREADS=1 python -u -m exp_githubedits imitation_learning \ 29 | # --cuda \ 30 | # --seed=${seed} \ 31 | # --work_dir=${work_dir} \ 32 | # ${config_file} 2>${work_dir}/err.log 33 | -------------------------------------------------------------------------------- /source_data/githubedits.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neulab/incremental_tree_edit/8651f2c75154bd776682726ea7e1d3da8a12924b/source_data/githubedits.tar.gz -------------------------------------------------------------------------------- /structural_edits.yml: -------------------------------------------------------------------------------- 1 | name: structural_edits 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - defaults 6 | dependencies: 7 | - astor=0.7.1=py37_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.1.1=0 10 | - certifi=2020.4.5.2=py37_0 11 | - click=7.1.2=py_0 12 | - cudatoolkit=10.0.130=0 13 | - intel-openmp=2020.1=217 14 | - joblib=0.15.1=py_0 15 | - libedit=3.1.20191231=h7b6447c_0 16 | - libffi=3.3=he6710b0_1 17 | - libgcc-ng=9.1.0=hdf63c60_0 18 | - libgfortran-ng=7.3.0=hdf63c60_0 19 | - libstdcxx-ng=9.1.0=hdf63c60_0 20 | - mkl=2019.4=243 21 | - mkl-service=2.3.0=py37he904b0f_0 22 | - mkl_fft=1.1.0=py37h23d657b_0 23 | - mkl_random=1.1.0=py37hd6b4f25_0 24 | - ncurses=6.2=he6710b0_1 25 | - ninja=1.9.0=py37hfd86e86_0 26 | - nltk=3.5=py_0 27 | - numpy=1.18.1=py37h4f9e942_0 28 | - numpy-base=1.18.1=py37hde5b4d6_1 29 | - openssl=1.1.1g=h7b6447c_0 30 | - pip=20.1.1=py37_1 31 | - python=3.7.3=h0371630_0 32 | - pytorch=1.4.0=py3.7_cuda10.0.130_cudnn7.6.3_0 33 | - readline=7.0=h7b6447c_5 34 | - regex=2020.5.14=py37h7b6447c_0 35 | - setuptools=47.3.0=py37_0 36 | - six=1.15.0=py_0 37 | - sqlite=3.32.2=h62c20be_0 38 | - tk=8.6.10=hbc83047_0 39 | - tqdm=4.46.1=py_0 40 | - wheel=0.34.2=py37_0 41 | - xz=5.2.5=h7b6447c_0 42 | - zlib=1.2.11=h7b6447c_3 43 | - pip: 44 | - absl-py==0.9.0 45 | - backcall==0.2.0 46 | - beautifulsoup4==4.9.1 47 | - bs4==0.0.1 48 | - cachetools==4.1.0 49 | - cffi==1.14.0 50 | - chardet==3.0.4 51 | - cycler==0.10.0 52 | - cython==0.29.20 53 | - decorator==4.4.2 54 | - docopt==0.6.2 55 | - editdistance==0.5.3 56 | - elasticsearch==7.8.0 57 | - gitdb==4.0.5 58 | - gitpython==3.1.3 59 | - google-auth==1.18.0 60 | - google-auth-oauthlib==0.4.1 61 | - grpcio==1.30.0 62 | - idna==2.9 63 | - importlib-metadata==1.6.1 64 | - ipython==7.16.1 65 | - ipython-genutils==0.2.0 66 | - jedi==0.17.1 67 | - kiwisolver==1.2.0 68 | - line-profiler==3.0.2 69 | - lxml==4.5.1 70 | - markdown==3.2.2 71 | - matplotlib==3.3.0 72 | - oauthlib==3.1.0 73 | - parso==0.7.0 74 | - pexpect==4.8.0 75 | - pickleshare==0.7.5 76 | - pillow==7.2.0 77 | - portalocker==1.7.0 78 | - prompt-toolkit==3.0.5 79 | - protobuf==3.12.2 80 | - ptyprocess==0.6.0 81 | - pyasn1==0.4.8 82 | - pyasn1-modules==0.2.8 83 | - pycparser==2.20 84 | - pygments==2.6.1 85 | - pyparsing==2.4.7 86 | - python-dateutil==2.8.1 87 | - requests==2.24.0 88 | - requests-oauthlib==1.3.0 89 | - rsa==4.6 90 | - sacrebleu==1.4.10 91 | - scikit-learn==0.23.2 92 | - scipy==1.5.0 93 | - sklearn==0.0 94 | - smmap==3.0.4 95 | - snakeviz==2.1.0 96 | - soupsieve==2.0.1 97 | - tensorboard==2.2.2 98 | - tensorboard-plugin-wit==1.6.0.post3 99 | - tensorboardx==2.0 100 | - threadpoolctl==2.1.0 101 | - tornado==6.0.4 102 | - traitlets==4.3.3 103 | - urllib3==1.25.9 104 | - wcwidth==0.2.5 105 | - werkzeug==1.0.1 106 | - xgboost==1.1.1 107 | - zipp==3.1.0 108 | prefix: /home/yao.470/anaconda2/envs/structural_edits 109 | 110 | -------------------------------------------------------------------------------- /trees/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 -------------------------------------------------------------------------------- /trees/edits.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from asdl.asdl import ASDLCompositeType, ASDLPrimitiveType 3 | from asdl.transition_system import ApplyRuleAction, GenTokenAction 4 | from asdl.asdl_ast import AbstractSyntaxTree, AbstractSyntaxNode, SyntaxToken, RealizedField, DummyReduce 5 | 6 | 7 | class Edit(object): 8 | pass 9 | 10 | 11 | class Delete(Edit): 12 | def __init__(self, field, value_idx, node, meta=None): 13 | self.field = field 14 | self.value_idx = value_idx 15 | self.node = node 16 | self.meta = meta 17 | 18 | def _apply_edit(self): 19 | assert self.field.as_value_list[self.value_idx] == self.node, \ 20 | "Delete: Node not found in Field (value idx %d)!" % self.value_idx 21 | 22 | edited_field = self.field.copy() 23 | edited_field.remove_w_idx(self.value_idx) 24 | return edited_field 25 | 26 | @property 27 | def output(self): 28 | return self._apply_edit() 29 | 30 | def __repr__(self): 31 | return 'Delete[%s, field %s, value idx %d]' % ( 32 | self.node.__repr__(), self.field.__repr__(), self.value_idx) 33 | 34 | 35 | class Add(Edit): 36 | def __init__(self, field, value_idx, action, value_buffer=None, meta=None): 37 | self.field = field 38 | self.value_idx = value_idx 39 | self.action = action 40 | self._value_buffer = value_buffer 41 | self.meta = meta 42 | 43 | def _apply_action(self): 44 | edited_field = self.field.copy() 45 | action = self.action 46 | 47 | if isinstance(edited_field.type, ASDLCompositeType) or \ 48 | (not isinstance(edited_field.type, ASDLPrimitiveType) and edited_field.type.is_composite): 49 | if isinstance(action, ApplyRuleAction): 50 | field_value = AbstractSyntaxNode(action.production) 51 | edited_field.add_value_w_idx(field_value, self.value_idx) 52 | # edited_field.set_open() # open the field once Add 53 | 54 | else: 55 | raise ValueError('Invalid action [%s] on field [%s]' % (action, edited_field)) 56 | else: # fill in a primitive field 57 | if isinstance(action, GenTokenAction): 58 | # only field of type string requires termination signal 59 | end_primitive = False 60 | if edited_field.type.name == 'string': 61 | if action.is_stop_signal(): 62 | assert self._value_buffer is not None and len(self._value_buffer) 63 | edited_field.add_value_w_idx( 64 | SyntaxToken(edited_field.type, ' '.join(self._value_buffer)), self.value_idx) 65 | end_primitive = True 66 | else: 67 | edited_field.add_value_w_idx( 68 | SyntaxToken(edited_field.type, action.token.value if isinstance(action.token, SyntaxToken) 69 | else action.token), self.value_idx) 70 | end_primitive = True 71 | 72 | # if not end_primitive: 73 | # edited_field.set_open() 74 | 75 | # if end_primitive and edited_field.cardinality in ('single', 'optional'): 76 | # edited_field.set_finish() 77 | 78 | else: 79 | raise ValueError('Can only invoke GenToken or Reduce actions on primitive fields') 80 | 81 | return edited_field 82 | 83 | @property 84 | def output(self): 85 | return self._apply_action() 86 | 87 | def __repr__(self): 88 | return 'Add[%s, %s, value idx %d]' % ( 89 | self.action.__repr__(), self.field.__repr__(), self.value_idx) 90 | 91 | 92 | class AddSubtree(Edit): 93 | def __init__(self, field, value_idx, node, meta=None): 94 | self.field = field 95 | self.value_idx = value_idx 96 | self.node = node 97 | self.meta = meta 98 | 99 | def _apply_edit(self): 100 | edited_field = self.field.copy() 101 | edited_field.add_value_w_idx(self.node.copy(), self.value_idx) 102 | return edited_field 103 | 104 | @property 105 | def output(self): 106 | return self._apply_edit() 107 | 108 | def __repr__(self): 109 | return 'AddSubtree[%s, field %s, value idx %d]' % ( 110 | self.node.__repr__(), self.field.__repr__(), self.value_idx) 111 | 112 | 113 | class Stop(Edit): 114 | def __init__(self, meta=None): 115 | self.meta = meta 116 | 117 | def __repr__(self): 118 | return 'StopEdit' 119 | -------------------------------------------------------------------------------- /trees/hypothesis.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from copy import deepcopy 3 | import numpy as np 4 | 5 | from .substitution_system import * 6 | from asdl.asdl import ASDLCompositeType, ASDLPrimitiveType 7 | from asdl.asdl_ast import AbstractSyntaxTree, DummyReduce, AbstractSyntaxNode 8 | from trees.utils import get_field_repr, find_by_id, stack_subtrees 9 | from trees.edits import Edit, Delete, Add, AddSubtree, Stop 10 | 11 | 12 | class Hypothesis(object): 13 | def __init__(self, init_tree_w_dummy_reduce: AbstractSyntaxTree, bool_copy_subtree=False, tree=None, 14 | memory=None, memory_type='all_init_joint', init_code_tokens=None, length_norm=False): 15 | self.init_tree_w_dummy_reduce = init_tree_w_dummy_reduce 16 | self.bool_copy_subtree = bool_copy_subtree 17 | assert memory_type in ('all_init_joint', 'all_init_distinct', 'deleted_distinct') 18 | self.memory_type = memory_type 19 | self.init_code_tokens = init_code_tokens 20 | self.length_norm = length_norm 21 | 22 | if tree is not None: 23 | self.tree = tree 24 | else: 25 | self.tree = init_tree_w_dummy_reduce.copy() 26 | 27 | if bool_copy_subtree and memory is None: 28 | if self.memory_type == 'all_init_joint': 29 | self.memory = stack_subtrees(self.init_tree_w_dummy_reduce.root_node) 30 | elif self.memory_type == 'all_init_distinct': 31 | self.memory = [] 32 | for node in stack_subtrees(self.init_tree_w_dummy_reduce.root_node): 33 | if node not in self.memory: 34 | self.memory.append(node) 35 | else: 36 | self.memory = [] 37 | else: 38 | self.memory = memory 39 | # self.set_tree_all_finish() # redundant? 40 | 41 | self.edits = [] 42 | self.score_per_edit = [] 43 | self.score = 0. 44 | 45 | self.repr2field = {} 46 | self.open_del_node_and_ids = [] # nodes available to delete 47 | self.open_add_fields = [] # fields open to add nodes 48 | self.restricted_frontier_fields = [] # fields (esp. with single cardinality) grammatically need to fill 49 | self.update_frontier_info() 50 | 51 | # record the current time step 52 | self.last_edit_field_node = None # trace the last edit 53 | self.t = 0 54 | self.stop_t = None 55 | 56 | def apply_edit(self, edit: Edit, score=0.0): 57 | if isinstance(edit, Stop): 58 | self.stop_t = self.t 59 | self.last_edit_field_node = None 60 | elif isinstance(edit, (Add, Delete, AddSubtree)): 61 | if isinstance(edit, AddSubtree): 62 | assert self.bool_copy_subtree and edit.node in self.memory 63 | 64 | field_repr = get_field_repr(edit.field) 65 | assert field_repr in self.repr2field, "Apply edit: Field not found in state!" 66 | old_field = self.repr2field[field_repr] 67 | 68 | field_idx = find_by_id(old_field.parent_node.fields, old_field) #old_field.parent_node.fields.index(old_field) 69 | assert field_idx != -1 70 | edited_field = edit.output 71 | old_field.parent_node.replace_child_w_idx(edited_field, field_idx) 72 | 73 | if self.bool_copy_subtree and self.memory_type == 'deleted_distinct' and isinstance(edit, Delete): 74 | for new_subtree in stack_subtrees(edit.node): 75 | if new_subtree not in self.memory: 76 | self.memory.append(new_subtree) # note this does not map to the original init tree 77 | 78 | self.tree.reindex_w_dummy_reduce() 79 | self.update_frontier_info() 80 | 81 | if isinstance(edit, (Add, AddSubtree)): 82 | self.last_edit_field_node = (edited_field, edited_field.as_value_list[edit.value_idx]) 83 | elif isinstance(edit, Delete): 84 | # self.last_edit_field_node = (edited_field, None) 85 | valid_edit_value_idx = edit.value_idx 86 | if len(edited_field.as_value_list) <= valid_edit_value_idx: 87 | valid_edit_value_idx = len(edited_field.as_value_list) - 1 88 | self.last_edit_field_node = (edited_field, edited_field.as_value_list[valid_edit_value_idx]) 89 | 90 | else: 91 | raise ValueError('Invalid edit!') 92 | 93 | self.t += 1 94 | self.edits.append(edit) 95 | 96 | self.score_per_edit.append(score) 97 | if self.length_norm: 98 | assert len(self.edits) == len(self.score_per_edit) 99 | self.score = np.average(self.score_per_edit) 100 | else: 101 | self.score = sum(self.score_per_edit) 102 | 103 | def copy_and_apply_edit(self, edit: Edit, score=0.0): 104 | new_hyp = self.copy() 105 | new_hyp.apply_edit(edit, score=score) 106 | 107 | return new_hyp 108 | 109 | def copy(self): 110 | new_hyp = Hypothesis(self.init_tree_w_dummy_reduce, 111 | bool_copy_subtree=self.bool_copy_subtree, 112 | tree=self.tree.copy(), 113 | memory=list(self.memory), # deepcopy(self.memory) # usually existing memory will not be modified 114 | memory_type=self.memory_type, 115 | init_code_tokens=self.init_code_tokens, 116 | length_norm=self.length_norm) 117 | 118 | new_hyp.edits = list(self.edits) 119 | new_hyp.score_per_edit = list(self.score_per_edit) 120 | new_hyp.score = self.score 121 | new_hyp.t = self.t 122 | new_hyp.last_edit_field_node = None 123 | 124 | if hasattr(self, 'meta'): 125 | new_hyp.meta = deepcopy(self.meta) 126 | 127 | new_hyp.update_frontier_info() 128 | 129 | return new_hyp 130 | 131 | def update_frontier_info(self): 132 | open_del_node_and_ids = [] 133 | open_add_fields = [] 134 | restricted_frontier_fields = [] 135 | repr2field = {} 136 | 137 | def _find_frontier_node_and_field(tree_node, field_repr_prefix): 138 | if tree_node: 139 | for field_idx, field in enumerate(tree_node.fields): 140 | tmp_field_repr_prefix = field_repr_prefix + "%s-%d" % (str(tree_node), field_idx) + "-SEP-" 141 | cur_field_repr = tmp_field_repr_prefix + str(field) 142 | repr2field[cur_field_repr] = field 143 | 144 | # if it's an intermediate node, check its children 145 | if (isinstance(field.type, ASDLCompositeType) or 146 | (not isinstance(field.type, ASDLPrimitiveType) and field.type.is_composite)) and \ 147 | field.value: 148 | if field.cardinality in ('single', 'optional'): iter_values = [field.value] 149 | else: iter_values = field.value 150 | 151 | for child_node_idx, child_node in enumerate(iter_values): 152 | if isinstance(child_node, DummyReduce): 153 | continue 154 | _find_frontier_node_and_field( 155 | child_node, tmp_field_repr_prefix + "%s-%d" % (str(field), child_node_idx) + '-SEP-') 156 | 157 | # now all its possible children are checked 158 | # fields must add node 159 | if field.cardinality == 'single' and (field.value is None or isinstance(field.value, DummyReduce)): 160 | restricted_frontier_fields.append(field) # break grammar 161 | 162 | # fields okay to delete node 163 | if (field.cardinality in ('single', 'optional') and field.value is not None and 164 | not isinstance(field.value, DummyReduce)) or \ 165 | (field.cardinality == 'multiple' and len(field.as_value_list)): 166 | open_del_node_and_ids.extend([(node, node.id) for node in field.as_value_list 167 | if not isinstance(node, DummyReduce)]) 168 | 169 | # fields okay to add node 170 | if (field.cardinality in ('single', 'optional') and 171 | (field.value is None or isinstance(field.value, DummyReduce))) or \ 172 | (field.cardinality == 'multiple'): 173 | open_add_fields.append(field) 174 | 175 | _find_frontier_node_and_field(self.tree.root_node, '') 176 | self.open_del_node_and_ids = open_del_node_and_ids 177 | self.open_add_fields = open_add_fields 178 | self.restricted_frontier_fields = restricted_frontier_fields 179 | self.repr2field = repr2field 180 | 181 | def set_tree_all_finish(self): 182 | def _set_field_finish(tree_node): 183 | for field in tree_node.fields: 184 | field.set_finish() 185 | 186 | if (isinstance(field.type, ASDLCompositeType) or 187 | (not isinstance(field.type, ASDLPrimitiveType) and field.type.is_composite)) and \ 188 | field.value: 189 | if field.cardinality in ('single', 'optional'): iter_values = [field.value] 190 | else: iter_values = field.value 191 | 192 | for child_node in iter_values: 193 | _set_field_finish(child_node) 194 | 195 | _set_field_finish(self.tree.root_node) 196 | 197 | @property 198 | def syntax_valid(self): 199 | return len(self.restricted_frontier_fields) == 0 200 | 201 | @property 202 | def stopped(self): 203 | return self.stop_t is not None 204 | -------------------------------------------------------------------------------- /trees/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | from asdl.asdl_ast import RealizedField, AbstractSyntaxTree, AbstractSyntaxNode, SyntaxToken, DummyReduce 3 | 4 | 5 | def find_by_id(candidate_list, candidate): 6 | for idx, candidate_i in enumerate(candidate_list): 7 | if id(candidate_i) == id(candidate): 8 | return idx 9 | 10 | return -1 # not found 11 | 12 | 13 | def get_field_repr(field): 14 | field_string = "" 15 | parent_node = field.parent_node 16 | cur_field = field 17 | while parent_node: 18 | # which field of parent_node 19 | field_idx = find_by_id(parent_node.fields, cur_field) 20 | field_string = "%s-%d" % (str(parent_node), field_idx) + "-SEP-" + field_string 21 | parent_field = parent_node.parent_field 22 | if parent_field: 23 | node_idx = find_by_id(parent_field.as_value_list, parent_node) 24 | field_string = "%s-%d" % (str(parent_field), node_idx) + "-SEP-" + field_string 25 | parent_node = parent_field.parent_node 26 | cur_field = parent_field 27 | else: 28 | break 29 | 30 | field_string += str(field) 31 | 32 | return field_string 33 | 34 | 35 | def get_field_node_queue(node): 36 | output = [] 37 | 38 | field = node.parent_field 39 | while field is not None: 40 | node_idx = find_by_id(field.as_value_list, node) 41 | assert node_idx != -1 42 | output.append((node_idx, node)) 43 | 44 | node = field.parent_node 45 | field_idx = find_by_id(node.fields, field) 46 | assert field_idx != -1 47 | output.append((field_idx, field)) 48 | 49 | field = node.parent_field 50 | 51 | output = output[::-1] 52 | return output 53 | 54 | 55 | def copy_tree_field(tree: AbstractSyntaxTree, field: RealizedField, bool_w_dummy_reduce=False): 56 | if bool_w_dummy_reduce: 57 | new_tree = tree.copy_and_reindex_w_dummy_reduce() 58 | else: 59 | new_tree = tree.copy_and_reindex_wo_dummy_reduce() 60 | 61 | root_to_field_trace = [] 62 | cur_field = field 63 | while cur_field: 64 | cur_parent_node = cur_field.parent_node 65 | cur_field_idx = find_by_id(cur_parent_node.fields, cur_field) 66 | assert cur_field_idx != -1 67 | root_to_field_trace.append(('field', cur_field_idx)) 68 | 69 | cur_parent_node_parent_field = cur_parent_node.parent_field 70 | if cur_parent_node_parent_field: 71 | cur_parent_node_idx = find_by_id(cur_parent_node_parent_field.as_value_list, cur_parent_node) 72 | assert cur_parent_node_idx != -1 73 | root_to_field_trace.append(('node', cur_parent_node_idx)) 74 | 75 | cur_field = cur_parent_node_parent_field 76 | 77 | pointer = new_tree.root_node 78 | while root_to_field_trace: 79 | trace = root_to_field_trace.pop() 80 | if trace[0] == 'field': 81 | assert isinstance(pointer, AbstractSyntaxNode) 82 | field_idx = trace[1] 83 | pointer = pointer.fields[field_idx] 84 | else: 85 | assert trace[0] == 'node' 86 | assert isinstance(pointer, RealizedField) 87 | node_idx = trace[1] 88 | pointer = pointer.as_value_list[node_idx] 89 | 90 | assert isinstance(pointer, RealizedField) 91 | new_field = pointer 92 | 93 | # assert new_tree == tree # not necessary since DummyReduce may have been inserted 94 | # assert new_field == field 95 | 96 | return new_tree, new_field 97 | 98 | 99 | def stack_subtrees(tree_node, bool_repr=False, bool_stack_syntax_token=False): 100 | if isinstance(tree_node, AbstractSyntaxNode): 101 | # safety check: need depth >= 2 102 | bool_has_child = False 103 | for field in tree_node.fields: 104 | if len(field.as_value_list): 105 | bool_has_child = True 106 | break 107 | if not bool_has_child: 108 | return [] 109 | 110 | if bool_repr: 111 | new_memory = [tree_node.to_string()] 112 | else: 113 | new_memory = [tree_node] 114 | 115 | for field in tree_node.fields: 116 | for val in field.as_value_list: 117 | new_memory.extend(stack_subtrees(val, bool_repr=bool_repr, 118 | bool_stack_syntax_token=bool_stack_syntax_token)) 119 | elif bool_stack_syntax_token: 120 | if bool_repr: 121 | new_memory = [str(tree_node).replace(' ', '-SPACE-')] 122 | else: 123 | new_memory = [tree_node] 124 | 125 | else: 126 | new_memory = [] 127 | 128 | return new_memory 129 | 130 | 131 | def get_productions_str(tree_node): 132 | productions = dict() 133 | 134 | if isinstance(tree_node, AbstractSyntaxNode): 135 | productions[str(tree_node)] = productions.get(str(tree_node), 0) + 1 136 | for field in tree_node.fields: 137 | for val in field.as_value_list: 138 | for k,v in get_productions_str(val).items(): 139 | productions[k] = productions.get(k, 0) + v 140 | elif not isinstance(tree_node, DummyReduce): # dummy nodes are excluded 141 | productions[str(tree_node)] = productions.get(str(tree_node), 0) + 1 142 | 143 | return productions 144 | 145 | 146 | def calculate_tree_prod_f1(tree_prod_pred, tree_prod_gold): 147 | all_preds = 0 148 | true_pos = 0 149 | for prod, count in tree_prod_pred.items(): 150 | all_preds += count 151 | if prod in tree_prod_gold: 152 | true_pos += min(count, tree_prod_gold[prod]) 153 | 154 | precision = true_pos * 1.0 / all_preds 155 | 156 | all_golds = sum([count for prod, count in tree_prod_gold.items()]) 157 | recall = true_pos * 1.0 / all_golds 158 | 159 | if true_pos == 0: 160 | f1 = 0 161 | else: 162 | f1 = (2 * precision * recall) / (precision + recall) 163 | return f1 164 | 165 | 166 | def get_sibling_ids(field, anchor_node, bool_rm_dummy=False): 167 | assert id(anchor_node.parent_field) == id(field) # has to be the same field instance 168 | 169 | left_sibling_ids, right_sibling_ids = [], [] 170 | 171 | parent_node = field.parent_node 172 | all_sibling_nodes = [node for field in parent_node.fields for node in field.as_value_list] 173 | bool_left = True 174 | for node in all_sibling_nodes: 175 | if node.id == anchor_node.id: # anchor node will be shifted right 176 | bool_left = False 177 | 178 | if bool_left: 179 | if bool_rm_dummy and isinstance(node, DummyReduce): 180 | continue 181 | left_sibling_ids.append(node.id) 182 | else: 183 | if bool_rm_dummy and isinstance(node, DummyReduce): 184 | continue 185 | right_sibling_ids.append(node.id) 186 | 187 | return left_sibling_ids, right_sibling_ids 188 | 189 | --------------------------------------------------------------------------------