├── .gitignore ├── LICENSE ├── README.md ├── c2nl ├── __init__.py ├── config.py ├── decoders │ ├── __init__.py │ ├── decoder.py │ ├── rnn_decoder.py │ ├── state.py │ └── transformer.py ├── encoders │ ├── __init__.py │ ├── encoder.py │ ├── rnn_encoder.py │ └── transformer.py ├── eval │ ├── __init__.py │ ├── bleu │ │ ├── __init__.py │ │ ├── bleu.py │ │ ├── bleu_scorer.py │ │ ├── google_bleu.py │ │ └── nltk_bleu.py │ ├── meteor │ │ ├── __init__.py │ │ ├── data │ │ │ └── paraphrase-en.gz │ │ ├── meteor-1.5.jar │ │ └── meteor.py │ └── rouge │ │ ├── __init__.py │ │ └── rouge.py ├── inputters │ ├── __init__.py │ ├── constants.py │ ├── dataset.py │ ├── timer.py │ ├── utils.py │ ├── vector.py │ └── vocabulary.py ├── models │ ├── __init__.py │ ├── seq2seq.py │ └── transformer.py ├── modules │ ├── __init__.py │ ├── char_embedding.py │ ├── copy_generator.py │ ├── embeddings.py │ ├── global_attention.py │ ├── highway.py │ ├── multi_head_attn.py │ ├── position_ffn.py │ └── util_class.py ├── objects │ ├── __init__.py │ ├── code.py │ └── summary.py ├── tokenizers │ ├── __init__.py │ ├── code_tokenizer.py │ ├── simple_tokenizer.py │ └── tokenizer.py ├── translator │ ├── __init__.py │ ├── beam.py │ ├── penalties.py │ ├── translation.py │ └── translator.py └── utils │ ├── __init__.py │ ├── copy_utils.py │ ├── logging.py │ └── misc.py ├── data ├── README.md ├── java │ ├── get_data.sh │ ├── get_stat.py │ └── sample.code └── python │ ├── get_data.sh │ ├── get_stat.py │ └── sample.code ├── main ├── __init__.py ├── model.py ├── test.py └── train.py ├── requirements.txt ├── scripts ├── generate.sh ├── java │ ├── rnn.sh │ └── transformer.sh └── python │ ├── rnn.sh │ └── transformer.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.pyc 3 | */__pycache__/* 4 | c2nl/*/__pycache__/* 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Wasi Ahmad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## A Transformer-based Approach for Source Code Summarization 2 | Official implementation of our ACL 2020 paper on Source Code Summarization. [[arxiv](https://arxiv.org/abs/2005.00653)] 3 | 4 | ### Installing C2NL 5 | 6 | You may consider installing the C2NL package. C2NL requires Linux and Python 3.6 or higher. It also requires installing PyTorch version 1.3. Its other dependencies are listed in requirements.txt. CUDA is strongly recommended for speed, but not necessary. 7 | 8 | Run the following commands to clone the repository and install C2NL: 9 | 10 | ``` 11 | git clone https://github.com/wasiahmad/NeuralCodeSum.git 12 | cd NeuralCodeSum; pip install -r requirements.txt; python setup.py develop 13 | ``` 14 | 15 | ### Training/Testing Models 16 | 17 | We provide a RNN-based sequence-to-sequence (Seq2Seq) model implementation along with our Transformer model. To perform training and evaluation, first go the scripts directory associated with the target dataset. 18 | 19 | ``` 20 | $ cd scripts/DATASET_NAME 21 | ``` 22 | 23 | Where, choices for DATASET_NAME are ["java", "python"]. 24 | 25 | To train/evaluate a model, run: 26 | 27 | ``` 28 | $ bash script_name.sh GPU_ID MODEL_NAME 29 | ``` 30 | 31 | For example, to train/evaluate the transformer model, run: 32 | 33 | ``` 34 | $ bash transformer.sh 0,1 code2jdoc 35 | ``` 36 | 37 | #### Generated log files 38 | 39 | While training and evaluating the models, a list of files are generated inside a `tmp` directory. The files are as follows. 40 | 41 | - **MODEL_NAME.mdl** 42 | - Model file containing the parameters of the best model. 43 | - **MODEL_NAME.mdl.checkpoint** 44 | - A model checkpoint, in case if we need to restart the training. 45 | - **MODEL_NAME.txt** 46 | - Log file for training. 47 | - **MODEL_NAME.json** 48 | - The predictions and gold references are dumped during validation. 49 | - **MODEL_NAME_test.txt** 50 | - Log file for evaluation (greedy). 51 | - **MODEL_NAME_test.json** 52 | - The predictions and gold references are dumped during evaluation (greedy). 53 | - **MODEL_NAME_beam.txt** 54 | - Log file for evaluation (beam). 55 | - **MODEL_NAME_beam.json** 56 | - The predictions and gold references are dumped during evaluation (beam). 57 | 58 | **[Structure of the JSON files]** Each line in a JSON file is a JSON object. An example is provided below. 59 | 60 | ```json 61 | { 62 | "id": 0, 63 | "code": "private int current Depth ( ) { try { Integer one Based = ( ( Integer ) DEPTH FIELD . get ( this ) ) ; return one Based - NUM ; } catch ( Illegal Access Exception e ) { throw new Assertion Error ( e ) ; } }", 64 | "predictions": [ 65 | "returns a 0 - based depth within the object graph of the current object being serialized ." 66 | ], 67 | "references": [ 68 | "returns a 0 - based depth within the object graph of the current object being serialized ." 69 | ], 70 | "bleu": 1, 71 | "rouge_l": 1 72 | } 73 | ``` 74 | 75 | #### Generating Summaries for Source Codes 76 | 77 | We may want to generate summaries for source codes using a trained model. And this can be done by running [generate.sh](https://github.com/wasiahmad/NeuralCodeSum/blob/master/scripts/generate.sh) script. The input source code file must be under `java` or `python` directory. We need to manually set the value of the [DATASET](https://github.com/wasiahmad/NeuralCodeSum/blob/master/scripts/generate.sh#L15) variable in the bash script. 78 | 79 | A sample Java and Python code file is provided at `[data/java/sample.code]` and `[data/python/sample.code]`. 80 | 81 | ``` 82 | $ cd scripts 83 | $ bash generate.sh 0 code2jdoc sample.code 84 | ``` 85 | 86 | The above command will generate `tmp/code2jdoc_beam.json` file that will contain the predicted summaries. 87 | 88 | #### Running experiments on CPU/GPU/Multi-GPU 89 | 90 | - If GPU_ID is set to -1, CPU will be used. 91 | - If GPU_ID is set to one specific number, only one GPU will be used. 92 | - If GPU_ID is set to multiple numbers (e.g., 0,1,2), then parallel computing will be used. 93 | 94 | ### Acknowledgement 95 | 96 | We borrowed and modified code from [DrQA](https://github.com/facebookresearch/DrQA), [OpenNMT](https://github.com/OpenNMT/OpenNMT-py). We would like to expresse our gratitdue for the authors of these repositeries. 97 | 98 | 99 | ### Citation 100 | 101 | ``` 102 | @inproceedings{ahmad2020summarization, 103 | author = {Ahmad, Wasi Uddin and Chakraborty, Saikat and Ray, Baishakhi and Chang, Kai-Wei}, 104 | booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics (ACL)}, 105 | title = {A Transformer-based Approach for Source Code Summarization}, 106 | year = {2020} 107 | } 108 | ``` 109 | 110 | -------------------------------------------------------------------------------- /c2nl/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | -------------------------------------------------------------------------------- /c2nl/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .decoder import * 4 | from .rnn_decoder import * 5 | from .state import * 6 | from .transformer import * 7 | -------------------------------------------------------------------------------- /c2nl/decoders/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from c2nl.utils.misc import aeq 5 | from c2nl.decoders.state import RNNDecoderState 6 | from c2nl.modules.global_attention import GlobalAttention 7 | 8 | 9 | class DecoderBase(nn.Module): 10 | """Abstract class for decoders. 11 | Args: 12 | attentional (bool): The decoder returns non-empty attention. 13 | """ 14 | 15 | def __init__(self, attentional=True): 16 | super(DecoderBase, self).__init__() 17 | self.attentional = attentional 18 | 19 | @classmethod 20 | def from_opt(cls, opt, embeddings): 21 | """Alternate constructor. 22 | Subclasses should override this method. 23 | """ 24 | 25 | raise NotImplementedError 26 | 27 | 28 | # many part of the codes are copied from OpenNMT-Py sources 29 | class RNNDecoderBase(nn.Module): 30 | """ 31 | Base recurrent attention-based decoder class. 32 | 33 | .. mermaid:: 34 | graph BT 35 | A[Input] 36 | subgraph RNN 37 | C[Pos 1] 38 | D[Pos 2] 39 | E[Pos N] 40 | end 41 | G[Decoder State] 42 | H[Decoder State] 43 | I[Outputs] 44 | F[Memory_Bank] 45 | A--emb-->C 46 | A--emb-->D 47 | A--emb-->E 48 | H-->C 49 | C-- attn --- F 50 | D-- attn --- F 51 | E-- attn --- F 52 | C-->I 53 | D-->I 54 | E-->I 55 | E-->G 56 | F---I 57 | 58 | Args: 59 | rnn_type (:obj:`str`): 60 | style of recurrent unit to use, one of [LSTM, GRU] 61 | bidirectional (bool) : use with a bidirectional encoder 62 | num_layers (int) : number of stacked layers 63 | hidden_size (int) : hidden size of each layer 64 | attn_type (str) : see :obj:`nqa.modules.GlobalAttention` 65 | dropout (float) : dropout value for :obj:`nn.Dropout` 66 | """ 67 | 68 | def __init__(self, 69 | rnn_type, 70 | input_size, 71 | bidirectional_encoder, 72 | num_layers, 73 | hidden_size, 74 | attn_type=None, 75 | coverage_attn=False, 76 | copy_attn=False, 77 | reuse_copy_attn=False, 78 | dropout=0.0): 79 | 80 | super(RNNDecoderBase, self).__init__() 81 | 82 | # Basic attributes. 83 | self.decoder_type = 'rnn' 84 | self.bidirectional_encoder = bidirectional_encoder 85 | self.num_layers = num_layers 86 | self.hidden_size = hidden_size 87 | self.dropout = nn.Dropout(dropout) 88 | 89 | # Build the RNN. 90 | kwargs = {'input_size': input_size, 91 | 'hidden_size': hidden_size, 92 | 'num_layers': num_layers, 93 | 'dropout': dropout, 94 | 'batch_first': True} 95 | self.rnn = getattr(nn, rnn_type)(**kwargs) 96 | 97 | # Set up the standard attention. 98 | self._coverage = coverage_attn 99 | self.attn = None 100 | if attn_type: 101 | self.attn = GlobalAttention( 102 | hidden_size, coverage=coverage_attn, 103 | attn_type=attn_type 104 | ) 105 | else: 106 | assert not self._coverage 107 | if copy_attn and reuse_copy_attn: 108 | raise RuntimeError('Attn is turned off, so reuse_copy_attn flag must be false') 109 | 110 | # Set up a separated copy attention layer, if needed. 111 | self._copy = copy_attn 112 | self._reuse_copy_attn = reuse_copy_attn 113 | self.copy_attn = None 114 | if copy_attn and not reuse_copy_attn: 115 | self.copy_attn = GlobalAttention( 116 | hidden_size, attn_type=attn_type 117 | ) 118 | 119 | def count_parameters(self): 120 | params = list(self.rnn.parameters()) 121 | if self.attn is not None: 122 | params = params + list(self.attn.parameters()) 123 | if self.copy_attn is not None: 124 | params = params + list(self.copy_attn.parameters()) 125 | return sum(p.numel() for p in params if p.requires_grad) 126 | 127 | def forward(self, tgt, memory_bank, state, memory_lengths=None): 128 | """ 129 | Args: 130 | tgt (`LongTensor`): sequences of padded tokens 131 | `[batch x tgt_len x nfeats]`. 132 | memory_bank (`FloatTensor`): vectors from the encoder 133 | `[batch x src_len x hidden]`. 134 | state (:obj:`onmt.models.DecoderState`): 135 | decoder state object to initialize the decoder 136 | memory_lengths (`LongTensor`): the padded source lengths 137 | `[batch]`. 138 | Returns: 139 | (`FloatTensor`,:obj:`onmt.Models.DecoderState`,`FloatTensor`): 140 | * decoder_outputs: output from the decoder (after attn) 141 | `[batch x tgt_len x hidden]`. 142 | * decoder_state: final hidden state from the decoder 143 | * attns: distribution over src at each tgt 144 | `[batch x tgt_len x src_len]`. 145 | """ 146 | # Check 147 | assert isinstance(state, RNNDecoderState) 148 | # tgt.size() returns tgt length and batch 149 | tgt_batch, _, _ = tgt.size() 150 | if self.attn is not None: 151 | memory_batch, _, _ = memory_bank.size() 152 | aeq(tgt_batch, memory_batch) 153 | # END 154 | 155 | # Run the forward pass of the RNN. 156 | decoder_final, decoder_outputs, attns = self._run_forward_pass( 157 | tgt, memory_bank, state, memory_lengths=memory_lengths) 158 | 159 | coverage = None 160 | if "coverage" in attns: 161 | coverage = attns["coverage"] 162 | # Update the state with the result. 163 | state.update_state(decoder_final, coverage) 164 | 165 | return decoder_outputs, state, attns 166 | 167 | def init_decoder_state(self, encoder_final): 168 | """ Init decoder state with last state of the encoder """ 169 | 170 | def _fix_enc_hidden(hidden): 171 | # The encoder hidden is (layers*directions) x batch x dim. 172 | # We need to convert it to layers x batch x (directions*dim). 173 | if self.bidirectional_encoder: 174 | hidden = torch.cat([hidden[0:hidden.size(0):2], 175 | hidden[1:hidden.size(0):2]], 2) 176 | return hidden 177 | 178 | if isinstance(encoder_final, tuple): # LSTM 179 | return RNNDecoderState(self.hidden_size, 180 | tuple([_fix_enc_hidden(enc_hid) 181 | for enc_hid in encoder_final])) 182 | else: # GRU 183 | return RNNDecoderState(self.hidden_size, 184 | _fix_enc_hidden(encoder_final)) 185 | -------------------------------------------------------------------------------- /c2nl/decoders/rnn_decoder.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/decoders/decoder.py 2 | import torch 3 | import torch.nn as nn 4 | 5 | from c2nl.decoders.decoder import RNNDecoderBase 6 | from c2nl.utils.misc import aeq 7 | 8 | 9 | class RNNDecoder(RNNDecoderBase): 10 | """ 11 | Standard fully batched RNN decoder with attention. 12 | Faster implementation, uses CuDNN for implementation. 13 | See :obj:`RNNDecoderBase` for options. 14 | Based around the approach from 15 | "Neural Machine Translation By Jointly Learning To Align and Translate" 16 | :cite:`Bahdanau2015` 17 | """ 18 | 19 | def _run_forward_pass(self, tgt, memory_bank, state, memory_lengths=None): 20 | """ 21 | Private helper for running the specific RNN forward pass. 22 | Must be overriden by all subclasses. 23 | Args: 24 | tgt (LongTensor): a sequence of input tokens tensors 25 | [batch x len x nfeats]. 26 | memory_bank (FloatTensor): output(tensor sequence) from the encoder 27 | RNN of size (batch x src_len x hidden_size). 28 | state (FloatTensor): hidden state from the encoder RNN for 29 | initializing the decoder. 30 | memory_lengths (LongTensor): the source memory_bank lengths. 31 | Returns: 32 | decoder_final (Tensor): final hidden state from the decoder. 33 | decoder_outputs (Tensor): output from the decoder (after attn) 34 | `[batch x tgt_len x hidden]`. 35 | attns (Tensor): distribution over src at each tgt 36 | `[batch x tgt_len x src_len]`. 37 | """ 38 | # Initialize local and return variables. 39 | attns = {} 40 | 41 | emb = tgt 42 | assert emb.dim() == 3 43 | 44 | coverage = state.coverage 45 | 46 | if isinstance(self.rnn, nn.GRU): 47 | rnn_output, decoder_final = self.rnn(emb, state.hidden[0]) 48 | else: 49 | rnn_output, decoder_final = self.rnn(emb, state.hidden) 50 | 51 | # Check 52 | tgt_batch, tgt_len, _ = tgt.size() 53 | output_batch, output_len, _ = rnn_output.size() 54 | aeq(tgt_len, output_len) 55 | aeq(tgt_batch, output_batch) 56 | # END 57 | 58 | # Calculate the attention. 59 | if self.attn is not None: 60 | decoder_outputs, p_attn, coverage_v = self.attn( 61 | rnn_output.contiguous(), 62 | memory_bank, 63 | memory_lengths=memory_lengths, 64 | coverage=coverage, 65 | softmax_weights=False 66 | ) 67 | attns["std"] = p_attn 68 | else: 69 | decoder_outputs = rnn_output.contiguous() 70 | 71 | # Update the coverage attention. 72 | if self._coverage: 73 | if coverage_v is None: 74 | coverage = coverage + p_attn \ 75 | if coverage is not None else p_attn 76 | else: 77 | coverage = coverage + coverage_v \ 78 | if coverage is not None else coverage_v 79 | attns["coverage"] = coverage 80 | 81 | decoder_outputs = self.dropout(decoder_outputs) 82 | # Run the forward pass of the copy attention layer. 83 | if self._copy and not self._reuse_copy_attn: 84 | _, copy_attn, _ = self.copy_attn(decoder_outputs, 85 | memory_bank, 86 | memory_lengths=memory_lengths, 87 | softmax_weights=False) 88 | attns["copy"] = copy_attn 89 | elif self._copy: 90 | attns["copy"] = attns["std"] 91 | 92 | return decoder_final, decoder_outputs, attns 93 | -------------------------------------------------------------------------------- /c2nl/decoders/state.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/decoders/decoder.py 2 | 3 | 4 | class DecoderState(object): 5 | """Interface for grouping together the current state of a recurrent 6 | decoder. In the simplest case just represents the hidden state of 7 | the model. But can also be used for implementing various forms of 8 | input_feeding and non-recurrent models. 9 | Modules need to implement this to utilize beam search decoding. 10 | """ 11 | 12 | def detach(self): 13 | """ Need to document this """ 14 | self.hidden = tuple([_.detach() for _ in self.hidden]) 15 | 16 | def beam_update(self, idx, positions, beam_size): 17 | """ Need to document this """ 18 | for e in self._all: 19 | sizes = e.size() 20 | br = sizes[1] 21 | if len(sizes) == 3: 22 | sent_states = e.view(sizes[0], beam_size, br // beam_size, 23 | sizes[2])[:, :, idx] 24 | else: 25 | sent_states = e.view(sizes[0], beam_size, 26 | br // beam_size, 27 | sizes[2], 28 | sizes[3])[:, :, idx] 29 | 30 | sent_states.data.copy_( 31 | sent_states.data.index_select(1, positions)) 32 | 33 | def map_batch_fn(self, fn): 34 | raise NotImplementedError() 35 | 36 | 37 | class RNNDecoderState(DecoderState): 38 | """ Base class for RNN decoder state """ 39 | 40 | def __init__(self, hidden_size, rnnstate): 41 | """ 42 | Args: 43 | hidden_size (int): the size of hidden layer of the decoder. 44 | rnnstate: final hidden state from the encoder. 45 | transformed to shape: layers x batch x (directions*dim). 46 | """ 47 | if not isinstance(rnnstate, tuple): 48 | self.hidden = (rnnstate,) 49 | else: 50 | self.hidden = rnnstate 51 | self.coverage = None 52 | 53 | @property 54 | def _all(self): 55 | return self.hidden 56 | 57 | def update_state(self, rnnstate, coverage): 58 | """ Update decoder state """ 59 | if not isinstance(rnnstate, tuple): 60 | self.hidden = (rnnstate,) 61 | else: 62 | self.hidden = rnnstate 63 | self.coverage = coverage 64 | 65 | def repeat_beam_size_times(self, beam_size): 66 | """ Repeat beam_size times along batch dimension. """ 67 | vars = [e.data.repeat(1, beam_size, 1) 68 | for e in self._all] 69 | self.hidden = tuple(vars) 70 | 71 | def map_batch_fn(self, fn): 72 | self.hidden = tuple(map(lambda x: fn(x, 1), self.hidden)) 73 | -------------------------------------------------------------------------------- /c2nl/decoders/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from c2nl.decoders.decoder import DecoderBase 9 | from c2nl.modules.multi_head_attn import MultiHeadedAttention 10 | from c2nl.modules.position_ffn import PositionwiseFeedForward 11 | from c2nl.utils.misc import sequence_mask 12 | from c2nl.modules.util_class import LayerNorm 13 | 14 | 15 | class TransformerDecoderLayer(nn.Module): 16 | """ 17 | Args: 18 | d_model (int): the dimension of keys/values/queries in 19 | :class:`MultiHeadedAttention`, also the input size of 20 | the first-layer of the :class:`PositionwiseFeedForward`. 21 | heads (int): the number of heads for MultiHeadedAttention. 22 | d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`. 23 | dropout (float): dropout probability. 24 | """ 25 | 26 | def __init__(self, 27 | d_model, 28 | heads, 29 | d_k, 30 | d_v, 31 | d_ff, 32 | dropout, 33 | max_relative_positions=0, 34 | coverage_attn=False): 35 | super(TransformerDecoderLayer, self).__init__() 36 | 37 | self.attention = MultiHeadedAttention( 38 | heads, d_model, d_k, d_v, dropout=dropout, 39 | max_relative_positions=max_relative_positions) 40 | self.layer_norm = LayerNorm(d_model) 41 | 42 | self.context_attn = MultiHeadedAttention( 43 | heads, d_model, d_k, d_v, dropout=dropout, 44 | coverage=coverage_attn) 45 | self.layer_norm_2 = LayerNorm(d_model) 46 | 47 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 48 | self.drop = nn.Dropout(dropout) 49 | 50 | def forward(self, 51 | inputs, 52 | memory_bank, 53 | src_pad_mask, 54 | tgt_pad_mask, 55 | layer_cache=None, 56 | step=None, 57 | coverage=None): 58 | """ 59 | Args: 60 | inputs (FloatTensor): ``(batch_size, 1, model_dim)`` 61 | memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)`` 62 | src_pad_mask (LongTensor): ``(batch_size, 1, src_len)`` 63 | tgt_pad_mask (LongTensor): ``(batch_size, 1, 1)`` 64 | Returns: 65 | (FloatTensor, FloatTensor): 66 | * output ``(batch_size, 1, model_dim)`` 67 | * attn ``(batch_size, 1, src_len)`` 68 | """ 69 | dec_mask = None 70 | if step is None: 71 | tgt_len = tgt_pad_mask.size(-1) 72 | future_mask = torch.ones( 73 | [tgt_len, tgt_len], 74 | device=tgt_pad_mask.device, 75 | dtype=torch.uint8) 76 | future_mask = future_mask.triu_(1).view(1, tgt_len, tgt_len) 77 | dec_mask = torch.gt(tgt_pad_mask + future_mask, 0) 78 | 79 | query, _, _ = self.attention(inputs, 80 | inputs, 81 | inputs, 82 | mask=dec_mask, 83 | layer_cache=layer_cache, 84 | attn_type="self") 85 | query_norm = self.layer_norm(self.drop(query) + inputs) 86 | 87 | mid, attn, coverage = self.context_attn(memory_bank, 88 | memory_bank, 89 | query_norm, 90 | mask=src_pad_mask, 91 | layer_cache=layer_cache, 92 | attn_type="context", 93 | step=step, 94 | coverage=coverage) 95 | mid_norm = self.layer_norm_2(self.drop(mid) + query_norm) 96 | 97 | output = self.feed_forward(mid_norm) 98 | return output, attn, coverage 99 | 100 | 101 | class TransformerDecoder(DecoderBase): 102 | """The Transformer decoder from "Attention is All You Need". 103 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` 104 | .. mermaid:: 105 | graph BT 106 | A[input] 107 | B[multi-head self-attn] 108 | BB[multi-head src-attn] 109 | C[feed forward] 110 | O[output] 111 | A --> B 112 | B --> BB 113 | BB --> C 114 | C --> O 115 | Args: 116 | num_layers (int): number of encoder layers. 117 | d_model (int): size of the model 118 | heads (int): number of heads 119 | d_ff (int): size of the inner FF layer 120 | copy_attn (bool): if using a separate copy attention 121 | dropout (float): dropout parameters 122 | embeddings (onmt.modules.Embeddings): 123 | embeddings to use, should have positional encodings 124 | """ 125 | 126 | def __init__(self, 127 | num_layers, 128 | d_model=512, 129 | heads=8, 130 | d_k=64, 131 | d_v=64, 132 | d_ff=2048, 133 | dropout=0.2, 134 | max_relative_positions=0, 135 | coverage_attn=False): 136 | super(TransformerDecoder, self).__init__() 137 | 138 | self.num_layers = num_layers 139 | if isinstance(max_relative_positions, int): 140 | max_relative_positions = [max_relative_positions] * self.num_layers 141 | assert len(max_relative_positions) == self.num_layers 142 | 143 | self._coverage = coverage_attn 144 | self.layer = nn.ModuleList( 145 | [TransformerDecoderLayer(d_model, 146 | heads, 147 | d_k, 148 | d_v, 149 | d_ff, 150 | dropout, 151 | max_relative_positions=max_relative_positions[i], 152 | coverage_attn=coverage_attn) 153 | for i in range(num_layers)]) 154 | 155 | def init_state(self, src_len, max_len): 156 | """Initialize decoder state.""" 157 | state = dict() 158 | state["src_len"] = src_len # [B] 159 | state["src_max_len"] = max_len # an integer 160 | state["cache"] = None 161 | return state 162 | 163 | def count_parameters(self): 164 | params = list(self.layer.parameters()) 165 | return sum(p.numel() for p in params if p.requires_grad) 166 | 167 | def forward(self, 168 | tgt_pad_mask, 169 | emb, 170 | memory_bank, 171 | state, 172 | step=None, 173 | layer_wise_coverage=None): 174 | if step == 0: 175 | self._init_cache(state) 176 | 177 | assert emb.dim() == 3 # batch x len x embedding_dim 178 | output = emb 179 | 180 | src_pad_mask = ~sequence_mask(state["src_len"], 181 | max_len=state["src_max_len"]).unsqueeze(1) 182 | tgt_pad_mask = tgt_pad_mask.unsqueeze(1) # [B, 1, T_tgt] 183 | 184 | new_layer_wise_coverage = [] 185 | representations = [] 186 | std_attentions = [] 187 | for i, layer in enumerate(self.layer): 188 | layer_cache = state["cache"]["layer_{}".format(i)] \ 189 | if step is not None else None 190 | mem_bank = memory_bank[i] if isinstance(memory_bank, list) else memory_bank 191 | output, attn, coverage = layer( 192 | output, 193 | mem_bank, 194 | src_pad_mask, 195 | tgt_pad_mask, 196 | layer_cache=layer_cache, 197 | step=step, 198 | coverage=None if layer_wise_coverage is None 199 | else layer_wise_coverage[i] 200 | ) 201 | representations.append(output) 202 | std_attentions.append(attn) 203 | new_layer_wise_coverage.append(coverage) 204 | 205 | attns = dict() 206 | attns["std"] = std_attentions[-1] 207 | attns["coverage"] = None 208 | if self._coverage: 209 | attns["coverage"] = new_layer_wise_coverage 210 | 211 | return representations, attns 212 | 213 | def _init_cache(self, state): 214 | state["cache"] = {} 215 | for i, layer in enumerate(self.layer): 216 | layer_cache = dict() 217 | layer_cache["memory_keys"] = None 218 | layer_cache["memory_values"] = None 219 | layer_cache["self_keys"] = None 220 | layer_cache["self_values"] = None 221 | state["cache"]["layer_{}".format(i)] = layer_cache 222 | -------------------------------------------------------------------------------- /c2nl/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .encoder import * 4 | from .rnn_encoder import * 5 | from .transformer import * 6 | -------------------------------------------------------------------------------- /c2nl/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/encoders/encoder.py 2 | """"Base class for encoders and generic multi encoders.""" 3 | 4 | from __future__ import division 5 | 6 | import torch.nn as nn 7 | from c2nl.utils.misc import aeq 8 | 9 | 10 | # from absqa.nqa.utils import aeq 11 | 12 | 13 | class EncoderBase(nn.Module): 14 | """ 15 | Base encoder class. Specifies the interface used by different encoder types 16 | and required by :obj:`onmt.Models.NMTModel`. 17 | .. mermaid:: 18 | graph BT 19 | A[Input] 20 | subgraph RNN 21 | C[Pos 1] 22 | D[Pos 2] 23 | E[Pos N] 24 | end 25 | F[Memory_Bank] 26 | G[Final] 27 | A-->C 28 | A-->D 29 | A-->E 30 | C-->F 31 | D-->F 32 | E-->F 33 | E-->G 34 | """ 35 | 36 | def _check_args(self, src, lengths=None, hidden=None): 37 | n_batch, _, _ = src.size() 38 | if lengths is not None: 39 | n_batch_, = lengths.size() 40 | aeq(n_batch, n_batch_) 41 | 42 | def forward(self, src, lengths=None): 43 | """ 44 | Args: 45 | src (:obj:`LongTensor`): 46 | padded sequences of sparse indices `[src_len x batch x nfeat]` 47 | lengths (:obj:`LongTensor`): length of each sequence `[batch]` 48 | Returns: 49 | (tuple of :obj:`FloatTensor`, :obj:`FloatTensor`): 50 | * final encoder state, used to initialize decoder 51 | * memory bank for attention, `[src_len x batch x hidden]` 52 | """ 53 | raise NotImplementedError 54 | -------------------------------------------------------------------------------- /c2nl/encoders/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/encoders/rnn_encoder.py 2 | """Define RNN-based encoders.""" 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from c2nl.encoders.encoder import EncoderBase 10 | from torch.nn.utils.rnn import pack_padded_sequence as pack 11 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 12 | 13 | 14 | class RNNEncoder(EncoderBase): 15 | """ A generic recurrent neural network encoder. 16 | Args: 17 | rnn_type (:obj:`str`): 18 | style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] 19 | bidirectional (bool) : use a bidirectional RNN 20 | num_layers (int) : number of stacked layers 21 | hidden_size (int) : hidden size of each layer 22 | dropout (float) : dropout value for :obj:`nn.Dropout` 23 | """ 24 | 25 | def __init__(self, 26 | rnn_type, 27 | input_size, 28 | bidirectional, 29 | num_layers, 30 | hidden_size, 31 | dropout=0.0, 32 | use_bridge=False, 33 | use_last=True): 34 | super(RNNEncoder, self).__init__() 35 | 36 | num_directions = 2 if bidirectional else 1 37 | assert hidden_size % num_directions == 0 38 | hidden_size = hidden_size // num_directions 39 | 40 | # Saves preferences for layer 41 | self.nlayers = num_layers 42 | self.use_last = use_last 43 | 44 | self.rnns = nn.ModuleList() 45 | for i in range(self.nlayers): 46 | input_size = input_size if i == 0 else hidden_size * num_directions 47 | kwargs = {'input_size': input_size, 48 | 'hidden_size': hidden_size, 49 | 'num_layers': 1, 50 | 'bidirectional': bidirectional, 51 | 'batch_first': True} 52 | rnn = getattr(nn, rnn_type)(**kwargs) 53 | self.rnns.append(rnn) 54 | 55 | self.dropout = nn.Dropout(dropout) 56 | # Initialize the bridge layer 57 | self.use_bridge = use_bridge 58 | if self.use_bridge: 59 | nl = 1 if self.use_last else num_layers 60 | self._initialize_bridge(rnn_type, hidden_size, nl) 61 | 62 | def count_parameters(self): 63 | params = list(self.rnns.parameters()) 64 | if self.use_bridge: 65 | params = params + list(self.bridge.parameters()) 66 | return sum(p.numel() for p in params if p.requires_grad) 67 | 68 | def forward(self, emb, lengths=None): 69 | "See :obj:`EncoderBase.forward()`" 70 | self._check_args(emb, lengths) 71 | 72 | packed_emb = emb 73 | if lengths is not None: 74 | # Lengths data is wrapped inside a Tensor. 75 | lengths, indices = torch.sort(lengths, 0, True) # Sort by length (keep idx) 76 | packed_emb = pack(packed_emb[indices], lengths.tolist(), batch_first=True) 77 | _, _indices = torch.sort(indices, 0) # Un-sort by length 78 | 79 | memory_bank, encoder_final = [], {'h_n': [], 'c_n': []} 80 | for i in range(self.nlayers): 81 | if i != 0: 82 | packed_emb = self.dropout(packed_emb) 83 | if lengths is not None: 84 | packed_emb = pack(packed_emb, lengths.tolist(), batch_first=True) 85 | 86 | packed_emb, states = self.rnns[i](packed_emb) 87 | if isinstance(states, tuple): 88 | h_n, c_n = states 89 | encoder_final['c_n'].append(c_n) 90 | else: 91 | h_n = states 92 | encoder_final['h_n'].append(h_n) 93 | 94 | packed_emb = unpack(packed_emb, batch_first=True)[0] if lengths is not None else packed_emb 95 | if not self.use_last or i == self.nlayers - 1: 96 | memory_bank += [packed_emb[_indices]] if lengths is not None else [packed_emb] 97 | 98 | assert len(encoder_final['h_n']) != 0 99 | if self.use_last: 100 | memory_bank = memory_bank[-1] 101 | if len(encoder_final['c_n']) == 0: 102 | encoder_final = encoder_final['h_n'][-1] 103 | else: 104 | encoder_final = encoder_final['h_n'][-1], encoder_final['c_n'][-1] 105 | else: 106 | memory_bank = torch.cat(memory_bank, dim=2) 107 | if len(encoder_final['c_n']) == 0: 108 | encoder_final = torch.cat(encoder_final['h_n'], dim=0) 109 | else: 110 | encoder_final = torch.cat(encoder_final['h_n'], dim=0), \ 111 | torch.cat(encoder_final['c_n'], dim=0) 112 | 113 | if self.use_bridge: 114 | encoder_final = self._bridge(encoder_final) 115 | 116 | # TODO: Temporary hack is adopted to compatible with DataParallel 117 | # reference: https://github.com/pytorch/pytorch/issues/1591 118 | if memory_bank.size(1) < emb.size(1): 119 | dummy_tensor = torch.zeros(memory_bank.size(0), 120 | emb.size(1) - memory_bank.size(1), 121 | memory_bank.size(2)).type_as(memory_bank) 122 | memory_bank = torch.cat([memory_bank, dummy_tensor], 1) 123 | 124 | return encoder_final, memory_bank 125 | 126 | def _initialize_bridge(self, 127 | rnn_type, 128 | hidden_size, 129 | num_layers): 130 | 131 | # LSTM has hidden and cell state, other only one 132 | number_of_states = 2 if rnn_type == "LSTM" else 1 133 | # Total number of states 134 | self.total_hidden_dim = hidden_size * num_layers 135 | 136 | # Build a linear layer for each 137 | self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim, 138 | self.total_hidden_dim, 139 | bias=True) 140 | for _ in range(number_of_states)]) 141 | 142 | def _bridge(self, hidden): 143 | """ 144 | Forward hidden state through bridge 145 | """ 146 | 147 | def bottle_hidden(linear, states): 148 | """ 149 | Transform from 3D to 2D, apply linear and return initial size 150 | """ 151 | size = states.size() 152 | result = linear(states.view(-1, self.total_hidden_dim)) 153 | return F.relu(result).view(size) 154 | 155 | if isinstance(hidden, tuple): # LSTM 156 | outs = tuple([bottle_hidden(layer, hidden[ix]) 157 | for ix, layer in enumerate(self.bridge)]) 158 | else: 159 | outs = bottle_hidden(self.bridge[0], hidden) 160 | return outs 161 | -------------------------------------------------------------------------------- /c2nl/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | from c2nl.modules.util_class import LayerNorm 8 | from c2nl.modules.multi_head_attn import MultiHeadedAttention 9 | from c2nl.modules.position_ffn import PositionwiseFeedForward 10 | from c2nl.encoders.encoder import EncoderBase 11 | from c2nl.utils.misc import sequence_mask 12 | 13 | 14 | class TransformerEncoderLayer(nn.Module): 15 | """ 16 | A single layer of the transformer encoder. 17 | Args: 18 | d_model (int): the dimension of keys/values/queries in 19 | MultiHeadedAttention, also the input size of 20 | the first-layer of the PositionwiseFeedForward. 21 | heads (int): the number of head for MultiHeadedAttention. 22 | d_ff (int): the second-layer of the PositionwiseFeedForward. 23 | dropout (float): dropout probability(0-1.0). 24 | """ 25 | 26 | def __init__(self, 27 | d_model, 28 | heads, 29 | d_ff, 30 | d_k, 31 | d_v, 32 | dropout, 33 | max_relative_positions=0, 34 | use_neg_dist=True): 35 | super(TransformerEncoderLayer, self).__init__() 36 | 37 | self.attention = MultiHeadedAttention(heads, 38 | d_model, 39 | d_k, 40 | d_v, 41 | dropout=dropout, 42 | max_relative_positions=max_relative_positions, 43 | use_neg_dist=use_neg_dist) 44 | 45 | self.dropout = nn.Dropout(dropout) 46 | self.layer_norm = LayerNorm(d_model) 47 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 48 | 49 | def forward(self, inputs, mask): 50 | """ 51 | Transformer Encoder Layer definition. 52 | Args: 53 | inputs (`FloatTensor`): `[batch_size x src_len x model_dim]` 54 | mask (`LongTensor`): `[batch_size x src_len x src_len]` 55 | Returns: 56 | (`FloatTensor`): 57 | * outputs `[batch_size x src_len x model_dim]` 58 | """ 59 | context, attn_per_head, _ = self.attention(inputs, inputs, inputs, 60 | mask=mask, attn_type="self") 61 | out = self.layer_norm(self.dropout(context) + inputs) 62 | return self.feed_forward(out), attn_per_head 63 | 64 | 65 | class TransformerEncoder(EncoderBase): 66 | """ 67 | The Transformer encoder from "Attention is All You Need". 68 | .. mermaid:: 69 | graph BT 70 | A[input] 71 | B[multi-head self-attn] 72 | C[feed forward] 73 | O[output] 74 | A --> B 75 | B --> C 76 | C --> O 77 | Args: 78 | num_layers (int): number of encoder layers 79 | d_model (int): size of the model 80 | heads (int): number of heads 81 | d_ff (int): size of the inner FF layer 82 | dropout (float): dropout parameters 83 | embeddings (:obj:`onmt.modules.Embeddings`): 84 | embeddings to use, should have positional encodings 85 | Returns: 86 | (`FloatTensor`, `FloatTensor`): 87 | * embeddings `[src_len x batch_size x model_dim]` 88 | * memory_bank `[src_len x batch_size x model_dim]` 89 | """ 90 | 91 | def __init__(self, 92 | num_layers, 93 | d_model=512, 94 | heads=8, 95 | d_k=64, 96 | d_v=64, 97 | d_ff=2048, 98 | dropout=0.2, 99 | max_relative_positions=0, 100 | use_neg_dist=True): 101 | super(TransformerEncoder, self).__init__() 102 | 103 | self.num_layers = num_layers 104 | if isinstance(max_relative_positions, int): 105 | max_relative_positions = [max_relative_positions] * self.num_layers 106 | assert len(max_relative_positions) == self.num_layers 107 | 108 | self.layer = nn.ModuleList( 109 | [TransformerEncoderLayer(d_model, 110 | heads, 111 | d_ff, 112 | d_k, 113 | d_v, 114 | dropout, 115 | max_relative_positions=max_relative_positions[i], 116 | use_neg_dist=use_neg_dist) 117 | for i in range(num_layers)]) 118 | 119 | def count_parameters(self): 120 | params = list(self.layer.parameters()) 121 | return sum(p.numel() for p in params if p.requires_grad) 122 | 123 | def forward(self, src, lengths=None): 124 | """ 125 | Args: 126 | src (`FloatTensor`): `[batch_size x src_len x model_dim]` 127 | lengths (`LongTensor`): length of each sequence `[batch]` 128 | Returns: 129 | (`FloatTensor`): 130 | * outputs `[batch_size x src_len x model_dim]` 131 | """ 132 | self._check_args(src, lengths) 133 | 134 | out = src 135 | mask = None if lengths is None else \ 136 | ~sequence_mask(lengths, out.shape[1]).unsqueeze(1) 137 | # Run the forward pass of every layer of the tranformer. 138 | representations = [] 139 | attention_scores = [] 140 | for i in range(self.num_layers): 141 | out, attn_per_head = self.layer[i](out, mask) 142 | representations.append(out) 143 | attention_scores.append(attn_per_head) 144 | 145 | return representations, attention_scores 146 | -------------------------------------------------------------------------------- /c2nl/eval/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from . import bleu 4 | from . import rouge 5 | from . import meteor 6 | -------------------------------------------------------------------------------- /c2nl/eval/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .bleu import * 4 | from .bleu_scorer import * 5 | from .nltk_bleu import * 6 | from .google_bleu import * 7 | -------------------------------------------------------------------------------- /c2nl/eval/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from c2nl.eval.bleu.bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res, verbose): 22 | assert (sorted(gts.keys()) == sorted(res.keys())) 23 | imgIds = list(gts.keys()) 24 | 25 | bleu_scorer = BleuScorer(n=self._n) 26 | for id in imgIds: 27 | hypo = res[id] 28 | ref = gts[id] 29 | 30 | # Sanity check. 31 | assert (type(hypo) is list) 32 | assert (len(hypo) == 1) 33 | assert (type(ref) is list) 34 | assert (len(ref) >= 1) 35 | 36 | bleu_scorer += (hypo[0], ref) 37 | 38 | # score, scores = bleu_scorer.compute_score(option='shortest') 39 | score, scores, bleu = bleu_scorer.compute_score(option='closest', verbose=verbose) 40 | # score, scores = bleu_scorer.compute_score(option='average', verbose=1) 41 | 42 | return score, scores, bleu 43 | 44 | def method(self): 45 | return "Bleu" 46 | -------------------------------------------------------------------------------- /c2nl/eval/bleu/bleu_scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # bleu_scorer.py 4 | # David Chiang 5 | 6 | # Copyright (c) 2004-2006 University of Maryland. All rights 7 | # reserved. Do not redistribute without permission from the 8 | # author. Not for commercial use. 9 | 10 | # Modified by: 11 | # Hao Fang 12 | # Tsung-Yi Lin 13 | 14 | '''Provides: 15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test(). 16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked(). 17 | ''' 18 | 19 | import copy 20 | import math 21 | from collections import defaultdict 22 | 23 | 24 | def precook(s, n=4, out=False): 25 | """Takes a string as input and returns an object that can be given to 26 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 27 | can take string arguments as well.""" 28 | words = s.split() 29 | counts = defaultdict(int) 30 | for k in range(1, n + 1): 31 | for i in range(len(words) - k + 1): 32 | ngram = tuple(words[i:i + k]) 33 | counts[ngram] += 1 34 | return (len(words), counts) 35 | 36 | 37 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average" 38 | '''Takes a list of reference sentences for a single segment 39 | and returns an object that encapsulates everything that BLEU 40 | needs to know about them.''' 41 | 42 | reflen = [] 43 | maxcounts = {} 44 | for ref in refs: 45 | rl, counts = precook(ref, n) 46 | reflen.append(rl) 47 | for (ngram, count) in counts.items(): 48 | maxcounts[ngram] = max(maxcounts.get(ngram, 0), count) 49 | 50 | # Calculate effective reference sentence length. 51 | if eff == "shortest": 52 | reflen = min(reflen) 53 | elif eff == "average": 54 | reflen = float(sum(reflen)) / len(reflen) 55 | 56 | ## lhuang: N.B.: leave reflen computaiton to the very end!! 57 | 58 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design) 59 | 60 | return (reflen, maxcounts) 61 | 62 | 63 | def cook_test(test, xxx_todo_changeme, eff=None, n=4): 64 | '''Takes a test sentence and returns an object that 65 | encapsulates everything that BLEU needs to know about it.''' 66 | (reflen, refmaxcounts) = xxx_todo_changeme 67 | testlen, counts = precook(test, n, True) 68 | 69 | result = {} 70 | 71 | # Calculate effective reference sentence length. 72 | 73 | if eff == "closest": 74 | result["reflen"] = min((abs(l - testlen), l) for l in reflen)[1] 75 | else: ## i.e., "average" or "shortest" or None 76 | result["reflen"] = reflen 77 | 78 | result["testlen"] = testlen 79 | 80 | result["guess"] = [max(0, testlen - k + 1) for k in range(1, n + 1)] 81 | 82 | result['correct'] = [0] * n 83 | for (ngram, count) in counts.items(): 84 | result["correct"][len(ngram) - 1] += min(refmaxcounts.get(ngram, 0), count) 85 | 86 | return result 87 | 88 | 89 | class BleuScorer(object): 90 | """Bleu scorer. 91 | """ 92 | 93 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen" 94 | 95 | # special_reflen is used in oracle (proportional effective ref len for a node). 96 | 97 | def copy(self): 98 | ''' copy the refs.''' 99 | new = BleuScorer(n=self.n) 100 | new.ctest = copy.copy(self.ctest) 101 | new.crefs = copy.copy(self.crefs) 102 | new._score = None 103 | return new 104 | 105 | def __init__(self, test=None, refs=None, n=4, special_reflen=None): 106 | ''' singular instance ''' 107 | 108 | self.n = n 109 | self.crefs = [] 110 | self.ctest = [] 111 | self.cook_append(test, refs) 112 | self.special_reflen = special_reflen 113 | 114 | def cook_append(self, test, refs): 115 | '''called by constructor and __iadd__ to avoid creating new instances.''' 116 | 117 | if refs is not None: 118 | self.crefs.append(cook_refs(refs)) 119 | if test is not None: 120 | cooked_test = cook_test(test, self.crefs[-1]) 121 | self.ctest.append(cooked_test) ## N.B.: -1 122 | else: 123 | self.ctest.append(None) # lens of crefs and ctest have to match 124 | 125 | self._score = None ## need to recompute 126 | 127 | def ratio(self, option=None): 128 | self.compute_score(option=option) 129 | return self._ratio 130 | 131 | def score_ratio(self, option=None): 132 | '''return (bleu, len_ratio) pair''' 133 | return (self.fscore(option=option), self.ratio(option=option)) 134 | 135 | def score_ratio_str(self, option=None): 136 | return "%.4f (%.2f)" % self.score_ratio(option) 137 | 138 | def reflen(self, option=None): 139 | self.compute_score(option=option) 140 | return self._reflen 141 | 142 | def testlen(self, option=None): 143 | self.compute_score(option=option) 144 | return self._testlen 145 | 146 | def retest(self, new_test): 147 | if type(new_test) is str: 148 | new_test = [new_test] 149 | assert len(new_test) == len(self.crefs), new_test 150 | self.ctest = [] 151 | for t, rs in zip(new_test, self.crefs): 152 | self.ctest.append(cook_test(t, rs)) 153 | self._score = None 154 | 155 | return self 156 | 157 | def rescore(self, new_test): 158 | ''' replace test(s) with new test(s), and returns the new score.''' 159 | 160 | return self.retest(new_test).compute_score() 161 | 162 | def size(self): 163 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 164 | return len(self.crefs) 165 | 166 | def __iadd__(self, other): 167 | '''add an instance (e.g., from another sentence).''' 168 | 169 | if type(other) is tuple: 170 | ## avoid creating new BleuScorer instances 171 | self.cook_append(other[0], other[1]) 172 | else: 173 | assert self.compatible(other), "incompatible BLEUs." 174 | self.ctest.extend(other.ctest) 175 | self.crefs.extend(other.crefs) 176 | self._score = None ## need to recompute 177 | 178 | return self 179 | 180 | def compatible(self, other): 181 | return isinstance(other, BleuScorer) and self.n == other.n 182 | 183 | def single_reflen(self, option="average"): 184 | return self._single_reflen(self.crefs[0][0], option) 185 | 186 | def _single_reflen(self, reflens, option=None, testlen=None): 187 | 188 | if option == "shortest": 189 | reflen = min(reflens) 190 | elif option == "average": 191 | reflen = float(sum(reflens)) / len(reflens) 192 | elif option == "closest": 193 | reflen = min((abs(l - testlen), l) for l in reflens)[1] 194 | else: 195 | assert False, "unsupported reflen option %s" % option 196 | 197 | return reflen 198 | 199 | def recompute_score(self, option=None, verbose=0): 200 | self._score = None 201 | return self.compute_score(option, verbose) 202 | 203 | def compute_score(self, option=None, verbose=0): 204 | n = self.n 205 | small = 1e-9 206 | tiny = 1e-15 ## so that if guess is 0 still return 0 207 | bleu_list = [[] for _ in range(n)] 208 | 209 | if self._score is not None: 210 | return self._score 211 | 212 | if option is None: 213 | option = "average" if len(self.crefs) == 1 else "closest" 214 | 215 | self._testlen = 0 216 | self._reflen = 0 217 | totalcomps = {'testlen': 0, 'reflen': 0, 'guess': [0] * n, 'correct': [0] * n} 218 | 219 | # for each sentence 220 | for comps in self.ctest: 221 | testlen = comps['testlen'] 222 | self._testlen += testlen 223 | 224 | if self.special_reflen is None: ## need computation 225 | reflen = self._single_reflen(comps['reflen'], option, testlen) 226 | else: 227 | reflen = self.special_reflen 228 | 229 | self._reflen += reflen 230 | 231 | for key in ['guess', 'correct']: 232 | for k in range(n): 233 | totalcomps[key][k] += comps[key][k] 234 | 235 | # append per image bleu score 236 | bleu = 1. 237 | for k in range(n): 238 | bleu *= (float(comps['correct'][k]) + tiny) \ 239 | / (float(comps['guess'][k]) + small) 240 | bleu_list[k].append(bleu ** (1. / (k + 1))) 241 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division 242 | if ratio < 1: 243 | for k in range(n): 244 | bleu_list[k][-1] *= math.exp(1 - 1 / ratio) 245 | 246 | if verbose > 1: 247 | print(comps, reflen) 248 | 249 | totalcomps['reflen'] = self._reflen 250 | totalcomps['testlen'] = self._testlen 251 | 252 | bleus = [] 253 | bleu = 1. 254 | for k in range(n): 255 | bleu *= float(totalcomps['correct'][k] + tiny) \ 256 | / (totalcomps['guess'][k] + small) 257 | bleus.append(bleu ** (1. / (k + 1))) 258 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division 259 | 260 | # modification by Wasi Ahmad 261 | if ratio > 1.0: 262 | bp = 1. 263 | else: 264 | bp = math.exp(1 - 1. / ratio) 265 | bleus = [bleus[k] * bp for k in range(n)] 266 | 267 | if verbose > 0: 268 | print(totalcomps) 269 | print("ratio:", ratio) 270 | 271 | if min(bleus) > 0: 272 | p_log_sum = sum((1. / n) * math.log(p) for p in bleus) 273 | geo_mean = math.exp(p_log_sum) 274 | else: 275 | geo_mean = 0 276 | 277 | bleu = geo_mean * bp 278 | self._score = bleus 279 | return self._score, bleu_list, bleu 280 | -------------------------------------------------------------------------------- /c2nl/eval/bleu/google_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | This module provides a Python implementation of BLEU and smooth-BLEU. 18 | Smooth BLEU is computed following the method outlined in the paper: 19 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 20 | evaluation metrics for machine translation. COLING 2004. 21 | """ 22 | 23 | import collections 24 | import math 25 | 26 | 27 | def _get_ngrams(segment, max_order): 28 | """Extracts all n-grams upto a given maximum order from an input segment. 29 | Args: 30 | segment: text segment from which n-grams will be extracted. 31 | max_order: maximum length in tokens of the n-grams returned by this 32 | methods. 33 | Returns: 34 | The Counter containing all n-grams upto max_order in segment 35 | with a count of how many times each n-gram occurred. 36 | """ 37 | ngram_counts = collections.Counter() 38 | for order in range(1, max_order + 1): 39 | for i in range(0, len(segment) - order + 1): 40 | ngram = tuple(segment[i:i + order]) 41 | ngram_counts[ngram] += 1 42 | return ngram_counts 43 | 44 | 45 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 46 | smooth=False): 47 | """Computes BLEU score of translated segments against one or more references. 48 | Args: 49 | reference_corpus: list of lists of references for each translation. Each 50 | reference should be tokenized into a list of tokens. 51 | translation_corpus: list of translations to score. Each translation 52 | should be tokenized into a list of tokens. 53 | max_order: Maximum n-gram order to use when computing BLEU score. 54 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 55 | Returns: 56 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 57 | precisions and brevity penalty. 58 | """ 59 | matches_by_order = [0] * max_order 60 | possible_matches_by_order = [0] * max_order 61 | reference_length = 0 62 | translation_length = 0 63 | for (references, translation) in zip(reference_corpus, 64 | translation_corpus): 65 | reference_length += min(len(r) for r in references) 66 | translation_length += len(translation) 67 | 68 | merged_ref_ngram_counts = collections.Counter() 69 | for reference in references: 70 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 71 | translation_ngram_counts = _get_ngrams(translation, max_order) 72 | overlap = translation_ngram_counts & merged_ref_ngram_counts 73 | for ngram in overlap: 74 | matches_by_order[len(ngram) - 1] += overlap[ngram] 75 | for order in range(1, max_order + 1): 76 | possible_matches = len(translation) - order + 1 77 | if possible_matches > 0: 78 | possible_matches_by_order[order - 1] += possible_matches 79 | 80 | precisions = [0] * max_order 81 | for i in range(0, max_order): 82 | if smooth: 83 | precisions[i] = ((matches_by_order[i] + 1.) / 84 | (possible_matches_by_order[i] + 1.)) 85 | else: 86 | if possible_matches_by_order[i] > 0: 87 | precisions[i] = (float(matches_by_order[i]) / 88 | possible_matches_by_order[i]) 89 | else: 90 | precisions[i] = 0.0 91 | 92 | if min(precisions) > 0: 93 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 94 | geo_mean = math.exp(p_log_sum) 95 | else: 96 | geo_mean = 0 97 | 98 | ratio = float(translation_length) / reference_length 99 | 100 | if ratio > 1.0: 101 | bp = 1. 102 | else: 103 | bp = math.exp(1 - 1. / ratio) 104 | 105 | bleu = geo_mean * bp 106 | 107 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 108 | 109 | 110 | def corpus_bleu(hypotheses, references): 111 | refs = [] 112 | hyps = [] 113 | count = 0 114 | total_score = 0.0 115 | 116 | assert (sorted(hypotheses.keys()) == sorted(references.keys())) 117 | Ids = list(hypotheses.keys()) 118 | ind_score = dict() 119 | 120 | for id in Ids: 121 | hyp = hypotheses[id][0].split() 122 | ref = [r.split() for r in references[id]] 123 | hyps.append(hyp) 124 | refs.append(ref) 125 | 126 | score = compute_bleu([ref], [hyp], smooth=True)[0] 127 | total_score += score 128 | count += 1 129 | ind_score[id] = score 130 | 131 | avg_score = total_score / count 132 | corpus_bleu = compute_bleu(refs, hyps, smooth=True)[0] 133 | return corpus_bleu, avg_score, ind_score 134 | -------------------------------------------------------------------------------- /c2nl/eval/bleu/nltk_bleu.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | from nltk.translate.bleu_score import SmoothingFunction 3 | 4 | 5 | def nltk_sentence_bleu(hypothesis, reference, order=4): 6 | cc = SmoothingFunction() 7 | return nltk.translate.bleu([reference], hypothesis, smoothing_function=cc.method4) 8 | 9 | 10 | def nltk_corpus_bleu(hypotheses, references, order=4): 11 | refs = [] 12 | hyps = [] 13 | count = 0 14 | total_score = 0.0 15 | 16 | cc = SmoothingFunction() 17 | 18 | assert (sorted(hypotheses.keys()) == sorted(references.keys())) 19 | Ids = list(hypotheses.keys()) 20 | ind_score = dict() 21 | 22 | for id in Ids: 23 | hyp = hypotheses[id][0].split() 24 | ref = [r.split() for r in references[id]] 25 | hyps.append(hyp) 26 | refs.append(ref) 27 | 28 | score = nltk.translate.bleu(ref, hyp, smoothing_function=cc.method4) 29 | total_score += score 30 | count += 1 31 | ind_score[id] = score 32 | 33 | avg_score = total_score / count 34 | corpus_bleu = nltk.translate.bleu_score.corpus_bleu(refs, hyps) 35 | return corpus_bleu, avg_score, ind_score 36 | -------------------------------------------------------------------------------- /c2nl/eval/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .meteor import * 4 | -------------------------------------------------------------------------------- /c2nl/eval/meteor/data/paraphrase-en.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wasiahmad/NeuralCodeSum/ffc43a415718d80aba4fe8372a438e0e492ddc6d/c2nl/eval/meteor/data/paraphrase-en.gz -------------------------------------------------------------------------------- /c2nl/eval/meteor/meteor-1.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wasiahmad/NeuralCodeSum/ffc43a415718d80aba4fe8372a438e0e492ddc6d/c2nl/eval/meteor/meteor-1.5.jar -------------------------------------------------------------------------------- /c2nl/eval/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | from __future__ import division 6 | 7 | import atexit 8 | import logging 9 | import os 10 | import subprocess 11 | import sys 12 | import threading 13 | 14 | import psutil 15 | 16 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed. 17 | METEOR_JAR = 'meteor-1.5.jar' 18 | 19 | 20 | def enc(s): 21 | return s.encode('utf-8') 22 | 23 | 24 | def dec(s): 25 | return s.decode('utf-8') 26 | 27 | 28 | class Meteor: 29 | 30 | def __init__(self): 31 | # Used to guarantee thread safety 32 | self.lock = threading.Lock() 33 | 34 | mem = '2G' 35 | mem_available_G = psutil.virtual_memory().available / 1E9 36 | if mem_available_G < 2: 37 | logging.warning("There is less than 2GB of available memory.\n" 38 | "Will try with limiting Meteor to 1GB of memory but this might cause issues.\n" 39 | "If you have problems using Meteor, " 40 | "then you can try to lower the `mem` variable in meteor.py") 41 | mem = '1G' 42 | 43 | meteor_cmd = ['java', '-jar', '-Xmx{}'.format(mem), METEOR_JAR, 44 | '-', '-', '-stdio', '-l', 'en', '-norm'] 45 | env = os.environ.copy() 46 | env['LC_ALL'] = "C" 47 | self.meteor_p = subprocess.Popen(meteor_cmd, 48 | cwd=os.path.dirname(os.path.abspath(__file__)), 49 | env=env, 50 | stdin=subprocess.PIPE, 51 | stdout=subprocess.PIPE, 52 | stderr=subprocess.PIPE) 53 | 54 | atexit.register(self.close) 55 | 56 | def close(self): 57 | with self.lock: 58 | if self.meteor_p: 59 | self.meteor_p.kill() 60 | self.meteor_p.wait() 61 | self.meteor_p = None 62 | # if the user calls close() manually, remove the 63 | # reference from atexit so the object can be garbage-collected. 64 | if atexit is not None and atexit.unregister is not None: 65 | atexit.unregister(self.close) 66 | 67 | def compute_score(self, gts, res): 68 | assert (gts.keys() == res.keys()) 69 | imgIds = gts.keys() 70 | scores = [] 71 | 72 | eval_line = 'EVAL' 73 | with self.lock: 74 | for i in imgIds: 75 | assert (len(res[i]) == 1) 76 | stat = self._stat(res[i][0], gts[i]) 77 | eval_line += ' ||| {}'.format(stat) 78 | 79 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 80 | self.meteor_p.stdin.flush() 81 | for i in range(0, len(imgIds)): 82 | v = self.meteor_p.stdout.readline() 83 | try: 84 | scores.append(float(dec(v.strip()))) 85 | except: 86 | sys.stderr.write("Error handling value: {}\n".format(v)) 87 | sys.stderr.write("Decoded value: {}\n".format(dec(v.strip()))) 88 | sys.stderr.write("eval_line: {}\n".format(eval_line)) 89 | # You can try uncommenting the next code line to show stderr from the Meteor JAR. 90 | # If the Meteor JAR is not writing to stderr, then the line will just hang. 91 | # sys.stderr.write("Error from Meteor:\n{}".format(self.meteor_p.stderr.read())) 92 | raise 93 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 94 | 95 | return score, scores 96 | 97 | def method(self): 98 | return "METEOR" 99 | 100 | def _stat(self, hypothesis_str, reference_list): 101 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 102 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 103 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 104 | self.meteor_p.stdin.write(enc(score_line)) 105 | self.meteor_p.stdin.write(enc('\n')) 106 | self.meteor_p.stdin.flush() 107 | return dec(self.meteor_p.stdout.readline()).strip() 108 | 109 | def _score(self, hypothesis_str, reference_list): 110 | with self.lock: 111 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words 112 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ') 113 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str)) 114 | self.meteor_p.stdin.write(enc('{}\n'.format(score_line))) 115 | self.meteor_p.stdin.flush() 116 | stats = dec(self.meteor_p.stdout.readline()).strip() 117 | eval_line = 'EVAL ||| {}'.format(stats) 118 | # EVAL ||| stats 119 | self.meteor_p.stdin.write(enc('{}\n'.format(eval_line))) 120 | self.meteor_p.stdin.flush() 121 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 122 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice 123 | # thanks for Andrej for pointing this out 124 | score = float(dec(self.meteor_p.stdout.readline()).strip()) 125 | return score 126 | 127 | def __del__(self): 128 | self.close() 129 | -------------------------------------------------------------------------------- /c2nl/eval/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .rouge import * 4 | -------------------------------------------------------------------------------- /c2nl/eval/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 20 | """ 21 | if len(string) < len(sub): 22 | sub, string = string, sub 23 | 24 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 25 | 26 | for j in range(1, len(sub) + 1): 27 | for i in range(1, len(string) + 1): 28 | if string[i - 1] == sub[j - 1]: 29 | lengths[i][j] = lengths[i - 1][j - 1] + 1 30 | else: 31 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 32 | 33 | return lengths[len(string)][len(sub)] 34 | 35 | 36 | class Rouge(): 37 | ''' 38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 39 | ''' 40 | 41 | def __init__(self): 42 | # vrama91: updated the value below based on discussion with Hovey 43 | self.beta = 1.2 44 | 45 | def calc_score(self, candidate, refs): 46 | """ 47 | Compute ROUGE-L score given one candidate and references for an image 48 | :param candidate: str : candidate sentence to be evaluated 49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 50 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 51 | """ 52 | assert (len(candidate) == 1) 53 | assert (len(refs) > 0) 54 | prec = [] 55 | rec = [] 56 | 57 | # split into tokens 58 | token_c = candidate[0].split(" ") 59 | 60 | for reference in refs: 61 | # split into tokens 62 | token_r = reference.split(" ") 63 | # compute the longest common subsequence 64 | lcs = my_lcs(token_r, token_c) 65 | prec.append(lcs / float(len(token_c))) 66 | rec.append(lcs / float(len(token_r))) 67 | 68 | prec_max = max(prec) 69 | rec_max = max(rec) 70 | 71 | if prec_max != 0 and rec_max != 0: 72 | score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max) 73 | else: 74 | score = 0.0 75 | return score 76 | 77 | def compute_score(self, gts, res): 78 | """ 79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 80 | Invoked by evaluate_captions.py 81 | :param gts: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values 82 | :param res: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values 83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 84 | """ 85 | assert (sorted(gts.keys()) == sorted(res.keys())) 86 | imgIds = list(gts.keys()) 87 | 88 | score = dict() 89 | for id in imgIds: 90 | hypo = res[id] 91 | ref = gts[id] 92 | 93 | # Sanity check. 94 | assert (type(hypo) is list) 95 | assert (len(hypo) == 1) 96 | assert (type(ref) is list) 97 | assert (len(ref) > 0) 98 | 99 | score[id] = self.calc_score(hypo, ref) 100 | 101 | average_score = np.mean(np.array(list(score.values()))) 102 | return average_score, score 103 | 104 | def method(self): 105 | return "Rouge" 106 | -------------------------------------------------------------------------------- /c2nl/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .dataset import * 4 | from .vector import * 5 | from .utils import * 6 | from .constants import * 7 | from .vocabulary import * 8 | -------------------------------------------------------------------------------- /c2nl/inputters/constants.py: -------------------------------------------------------------------------------- 1 | PAD = 0 2 | UNK = 1 3 | BOS = 2 4 | EOS = 3 5 | 6 | PAD_WORD = '' 7 | UNK_WORD = '' 8 | BOS_WORD = '' 9 | EOS_WORD = '' 10 | 11 | TOKEN_TYPE_MAP = { 12 | # Java 13 | '': 0, 14 | '': 1, 15 | 'other': 2, 16 | 'var': 3, 17 | 'method': 4, 18 | # Python 19 | 's': 5, 20 | 'None': 6, 21 | 'value': 7, 22 | 'asname': 8, 23 | 'n': 9, 24 | 'level': 10, 25 | 'is_async': 11, 26 | 'arg': 12, 27 | 'attr': 13, 28 | 'id': 14, 29 | 'name': 15, 30 | 'module': 16 31 | } 32 | 33 | AST_TYPE_MAP = { 34 | '': 0, 35 | 'N': 1, 36 | 'T': 2 37 | } 38 | 39 | DATA_LANG_MAP = { 40 | 'java': 'java', 41 | 'python': 'python' 42 | } 43 | 44 | LANG_ID_MAP = { 45 | 'java': 0, 46 | 'python': 1, 47 | 'c#': 2 48 | } 49 | -------------------------------------------------------------------------------- /c2nl/inputters/dataset.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/facebookresearch/DrQA/blob/master/drqa/reader/data.py 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torch.utils.data.sampler import Sampler 5 | 6 | from c2nl.inputters.vector import vectorize 7 | 8 | 9 | # ------------------------------------------------------------------------------ 10 | # PyTorch dataset class for SQuAD (and SQuAD-like) data. 11 | # ------------------------------------------------------------------------------ 12 | 13 | 14 | class CommentDataset(Dataset): 15 | def __init__(self, examples, model): 16 | self.model = model 17 | self.examples = examples 18 | 19 | def __len__(self): 20 | return len(self.examples) 21 | 22 | def __getitem__(self, index): 23 | return vectorize(self.examples[index], self.model) 24 | 25 | def lengths(self): 26 | return [(len(ex['code'].tokens), len(ex['summary'].tokens)) 27 | for ex in self.examples] 28 | 29 | 30 | # ------------------------------------------------------------------------------ 31 | # PyTorch sampler returning batched of sorted lengths (by doc and question). 32 | # ------------------------------------------------------------------------------ 33 | 34 | 35 | class SortedBatchSampler(Sampler): 36 | def __init__(self, lengths, batch_size, shuffle=True): 37 | self.lengths = lengths 38 | self.batch_size = batch_size 39 | self.shuffle = shuffle 40 | 41 | def __iter__(self): 42 | lengths = np.array( 43 | [(-l[0], -l[1], np.random.random()) for l in self.lengths], 44 | dtype=[('l1', np.int_), ('l2', np.int_), ('rand', np.float_)] 45 | ) 46 | indices = np.argsort(lengths, order=('l1', 'l2', 'rand')) 47 | batches = [indices[i:i + self.batch_size] 48 | for i in range(0, len(indices), self.batch_size)] 49 | if self.shuffle: 50 | np.random.shuffle(batches) 51 | return iter([i for batch in batches for i in batch]) 52 | 53 | def __len__(self): 54 | return len(self.lengths) 55 | -------------------------------------------------------------------------------- /c2nl/inputters/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | # ------------------------------------------------------------------------------ 5 | # Utility classes 6 | # ------------------------------------------------------------------------------ 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value.""" 11 | 12 | def __init__(self): 13 | self.reset() 14 | 15 | def reset(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | class Timer(object): 29 | """Computes elapsed time.""" 30 | 31 | def __init__(self): 32 | self.running = True 33 | self.total = 0 34 | self.start = time.time() 35 | 36 | def reset(self): 37 | self.running = True 38 | self.total = 0 39 | self.start = time.time() 40 | return self 41 | 42 | def resume(self): 43 | if not self.running: 44 | self.running = True 45 | self.start = time.time() 46 | return self 47 | 48 | def stop(self): 49 | if self.running: 50 | self.running = False 51 | self.total += time.time() - self.start 52 | return self 53 | 54 | def time(self): 55 | if self.running: 56 | return self.total + time.time() - self.start 57 | return self.total 58 | -------------------------------------------------------------------------------- /c2nl/inputters/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import string 4 | from collections import Counter 5 | from tqdm import tqdm 6 | 7 | from c2nl.objects import Code, Summary 8 | from c2nl.inputters.vocabulary import Vocabulary, UnicodeCharsVocabulary 9 | from c2nl.inputters.constants import BOS_WORD, EOS_WORD, PAD_WORD, \ 10 | UNK_WORD, TOKEN_TYPE_MAP, AST_TYPE_MAP, DATA_LANG_MAP, LANG_ID_MAP 11 | from c2nl.utils.misc import count_file_lines 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def is_number(n): 17 | try: 18 | float(n) 19 | except ValueError: 20 | return False 21 | return True 22 | 23 | 24 | def generate_random_string(N=8): 25 | return ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(N)) 26 | 27 | 28 | # ------------------------------------------------------------------------------ 29 | # Data loading 30 | # ------------------------------------------------------------------------------ 31 | 32 | def process_examples(lang_id, 33 | source, 34 | source_tag, 35 | target, 36 | max_src_len, 37 | max_tgt_len, 38 | code_tag_type, 39 | uncase=False, 40 | test_split=True): 41 | code_tokens = source.split() 42 | code_type = [] 43 | if source_tag is not None: 44 | code_type = source_tag.split() 45 | if len(code_tokens) != len(code_type): 46 | return None 47 | 48 | code_tokens = code_tokens[:max_src_len] 49 | code_type = code_type[:max_src_len] 50 | if len(code_tokens) == 0: 51 | return None 52 | 53 | TAG_TYPE_MAP = TOKEN_TYPE_MAP if \ 54 | code_tag_type == 'subtoken' else AST_TYPE_MAP 55 | code = Code() 56 | code.text = source 57 | code.language = lang_id 58 | code.tokens = code_tokens 59 | code.type = [TAG_TYPE_MAP.get(ct, 1) for ct in code_type] 60 | if code_tag_type != 'subtoken': 61 | code.mask = [1 if ct == 'N' else 0 for ct in code_type] 62 | 63 | if target is not None: 64 | summ = target.lower() if uncase else target 65 | summ_tokens = summ.split() 66 | if not test_split: 67 | summ_tokens = summ_tokens[:max_tgt_len] 68 | if len(summ_tokens) == 0: 69 | return None 70 | summary = Summary() 71 | summary.text = ' '.join(summ_tokens) 72 | summary.tokens = summ_tokens 73 | summary.prepend_token(BOS_WORD) 74 | summary.append_token(EOS_WORD) 75 | else: 76 | summary = None 77 | 78 | example = dict() 79 | example['code'] = code 80 | example['summary'] = summary 81 | return example 82 | 83 | 84 | def load_data(args, filenames, max_examples=-1, dataset_name='java', 85 | test_split=False): 86 | """Load examples from preprocessed file. One example per line, JSON encoded.""" 87 | 88 | with open(filenames['src']) as f: 89 | sources = [line.strip() for line in 90 | tqdm(f, total=count_file_lines(filenames['src']))] 91 | 92 | if filenames['tgt'] is not None: 93 | with open(filenames['tgt']) as f: 94 | targets = [line.strip() for line in 95 | tqdm(f, total=count_file_lines(filenames['tgt']))] 96 | else: 97 | targets = [None] * len(sources) 98 | 99 | if filenames['src_tag'] is not None: 100 | with open(filenames['src_tag']) as f: 101 | source_tags = [line.strip() for line in 102 | tqdm(f, total=count_file_lines(filenames['src_tag']))] 103 | else: 104 | source_tags = [None] * len(sources) 105 | 106 | assert len(sources) == len(source_tags) == len(targets) 107 | 108 | examples = [] 109 | for src, src_tag, tgt in tqdm(zip(sources, source_tags, targets), 110 | total=len(sources)): 111 | if dataset_name in ['java', 'python']: 112 | _ex = process_examples(LANG_ID_MAP[DATA_LANG_MAP[dataset_name]], 113 | src, 114 | src_tag, 115 | tgt, 116 | args.max_src_len, 117 | args.max_tgt_len, 118 | args.code_tag_type, 119 | uncase=args.uncase, 120 | test_split=test_split) 121 | if _ex is not None: 122 | examples.append(_ex) 123 | 124 | if max_examples != -1 and len(examples) > max_examples: 125 | break 126 | 127 | return examples 128 | 129 | 130 | # ------------------------------------------------------------------------------ 131 | # Dictionary building 132 | # ------------------------------------------------------------------------------ 133 | 134 | 135 | def index_embedding_words(embedding_file): 136 | """Put all the words in embedding_file into a set.""" 137 | words = set() 138 | with open(embedding_file) as f: 139 | for line in tqdm(f, total=count_file_lines(embedding_file)): 140 | w = Vocabulary.normalize(line.rstrip().split(' ')[0]) 141 | words.add(w) 142 | 143 | words.update([BOS_WORD, EOS_WORD, PAD_WORD, UNK_WORD]) 144 | return words 145 | 146 | 147 | def load_words(args, examples, fields, dict_size=None): 148 | """Iterate and index all the words in examples (documents + questions).""" 149 | 150 | def _insert(iterable): 151 | words = [] 152 | for w in iterable: 153 | w = Vocabulary.normalize(w) 154 | words.append(w) 155 | word_count.update(words) 156 | 157 | word_count = Counter() 158 | for ex in tqdm(examples): 159 | for field in fields: 160 | _insert(ex[field].tokens) 161 | 162 | # -2 to reserve spots for PAD and UNK token 163 | dict_size = dict_size - 2 if dict_size and dict_size > 2 else dict_size 164 | most_common = word_count.most_common(dict_size) 165 | words = set(word for word, _ in most_common) 166 | return words 167 | 168 | 169 | def build_word_dict(args, examples, fields, dict_size=None, 170 | no_special_token=False): 171 | """Return a dictionary from question and document words in 172 | provided examples. 173 | """ 174 | word_dict = Vocabulary(no_special_token) 175 | for w in load_words(args, examples, fields, dict_size): 176 | word_dict.add(w) 177 | return word_dict 178 | 179 | 180 | def build_word_and_char_dict(args, examples, fields, dict_size=None, 181 | no_special_token=False): 182 | """Return a dictionary from question and document words in 183 | provided examples. 184 | """ 185 | words = load_words(args, examples, fields, dict_size) 186 | dictioanry = UnicodeCharsVocabulary(words, 187 | args.max_characters_per_token, 188 | no_special_token) 189 | return dictioanry 190 | 191 | 192 | def top_summary_words(args, examples, word_dict): 193 | """Count and return the most common question words in provided examples.""" 194 | word_count = Counter() 195 | for ex in examples: 196 | for w in ex['summary'].tokens: 197 | w = Vocabulary.normalize(w) 198 | if w in word_dict: 199 | word_count.update([w]) 200 | return word_count.most_common(args.tune_partial) 201 | -------------------------------------------------------------------------------- /c2nl/inputters/vector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def vectorize(ex, model): 5 | """Vectorize a single example.""" 6 | src_dict = model.src_dict 7 | tgt_dict = model.tgt_dict 8 | 9 | code, summary = ex['code'], ex['summary'] 10 | vectorized_ex = dict() 11 | vectorized_ex['id'] = code.id 12 | vectorized_ex['language'] = code.language 13 | 14 | vectorized_ex['code'] = code.text 15 | vectorized_ex['code_tokens'] = code.tokens 16 | vectorized_ex['code_char_rep'] = None 17 | vectorized_ex['code_type_rep'] = None 18 | vectorized_ex['code_mask_rep'] = None 19 | vectorized_ex['use_code_mask'] = False 20 | 21 | vectorized_ex['code_word_rep'] = torch.LongTensor(code.vectorize(word_dict=src_dict)) 22 | if model.args.use_src_char: 23 | vectorized_ex['code_char_rep'] = torch.LongTensor(code.vectorize(word_dict=src_dict, _type='char')) 24 | if model.args.use_code_type: 25 | vectorized_ex['code_type_rep'] = torch.LongTensor(code.type) 26 | if code.mask: 27 | vectorized_ex['code_mask_rep'] = torch.LongTensor(code.mask) 28 | vectorized_ex['use_code_mask'] = True 29 | 30 | vectorized_ex['summ'] = None 31 | vectorized_ex['summ_tokens'] = None 32 | vectorized_ex['stype'] = None 33 | vectorized_ex['summ_word_rep'] = None 34 | vectorized_ex['summ_char_rep'] = None 35 | vectorized_ex['target'] = None 36 | 37 | if summary is not None: 38 | vectorized_ex['summ'] = summary.text 39 | vectorized_ex['summ_tokens'] = summary.tokens 40 | vectorized_ex['stype'] = summary.type 41 | vectorized_ex['summ_word_rep'] = torch.LongTensor(summary.vectorize(word_dict=tgt_dict)) 42 | if model.args.use_tgt_char: 43 | vectorized_ex['summ_char_rep'] = torch.LongTensor(summary.vectorize(word_dict=tgt_dict, _type='char')) 44 | # target is only used to compute loss during training 45 | vectorized_ex['target'] = torch.LongTensor(summary.vectorize(tgt_dict)) 46 | 47 | vectorized_ex['src_vocab'] = code.src_vocab 48 | vectorized_ex['use_src_word'] = model.args.use_src_word 49 | vectorized_ex['use_tgt_word'] = model.args.use_tgt_word 50 | vectorized_ex['use_src_char'] = model.args.use_src_char 51 | vectorized_ex['use_tgt_char'] = model.args.use_tgt_char 52 | vectorized_ex['use_code_type'] = model.args.use_code_type 53 | 54 | return vectorized_ex 55 | 56 | 57 | def batchify(batch): 58 | """Gather a batch of individual examples into one batch.""" 59 | 60 | # batch is a list of vectorized examples 61 | batch_size = len(batch) 62 | use_src_word = batch[0]['use_src_word'] 63 | use_tgt_word = batch[0]['use_tgt_word'] 64 | use_src_char = batch[0]['use_src_char'] 65 | use_tgt_char = batch[0]['use_tgt_char'] 66 | use_code_type = batch[0]['use_code_type'] 67 | use_code_mask = batch[0]['use_code_mask'] 68 | ids = [ex['id'] for ex in batch] 69 | language = [ex['language'] for ex in batch] 70 | 71 | # --------- Prepare Code tensors --------- 72 | code_words = [ex['code_word_rep'] for ex in batch] 73 | code_chars = [ex['code_char_rep'] for ex in batch] 74 | code_type = [ex['code_type_rep'] for ex in batch] 75 | code_mask = [ex['code_mask_rep'] for ex in batch] 76 | max_code_len = max([d.size(0) for d in code_words]) 77 | if use_src_char: 78 | max_char_in_code_token = code_chars[0].size(1) 79 | 80 | # Batch Code Representations 81 | code_len_rep = torch.zeros(batch_size, dtype=torch.long) 82 | code_word_rep = torch.zeros(batch_size, max_code_len, dtype=torch.long) \ 83 | if use_src_word else None 84 | code_type_rep = torch.zeros(batch_size, max_code_len, dtype=torch.long) \ 85 | if use_code_type else None 86 | code_mask_rep = torch.zeros(batch_size, max_code_len, dtype=torch.long) \ 87 | if use_code_mask else None 88 | code_char_rep = torch.zeros(batch_size, max_code_len, max_char_in_code_token, dtype=torch.long) \ 89 | if use_src_char else None 90 | 91 | source_maps = [] 92 | src_vocabs = [] 93 | for i in range(batch_size): 94 | code_len_rep[i] = code_words[i].size(0) 95 | if use_src_word: 96 | code_word_rep[i, :code_words[i].size(0)].copy_(code_words[i]) 97 | if use_code_type: 98 | code_type_rep[i, :code_type[i].size(0)].copy_(code_type[i]) 99 | if use_code_mask: 100 | code_mask_rep[i, :code_mask[i].size(0)].copy_(code_mask[i]) 101 | if use_src_char: 102 | code_char_rep[i, :code_chars[i].size(0), :].copy_(code_chars[i]) 103 | # 104 | context = batch[i]['code_tokens'] 105 | vocab = batch[i]['src_vocab'] 106 | src_vocabs.append(vocab) 107 | # Mapping source tokens to indices in the dynamic dict. 108 | src_map = torch.LongTensor([vocab[w] for w in context]) 109 | source_maps.append(src_map) 110 | 111 | # --------- Prepare Summary tensors --------- 112 | no_summary = batch[0]['summ_word_rep'] is None 113 | if no_summary: 114 | summ_len_rep = None 115 | summ_word_rep = None 116 | summ_char_rep = None 117 | tgt_tensor = None 118 | alignments = None 119 | else: 120 | summ_words = [ex['summ_word_rep'] for ex in batch] 121 | summ_chars = [ex['summ_char_rep'] for ex in batch] 122 | max_sum_len = max([q.size(0) for q in summ_words]) 123 | if use_tgt_char: 124 | max_char_in_summ_token = summ_chars[0].size(1) 125 | 126 | summ_len_rep = torch.zeros(batch_size, dtype=torch.long) 127 | summ_word_rep = torch.zeros(batch_size, max_sum_len, dtype=torch.long) \ 128 | if use_tgt_word else None 129 | summ_char_rep = torch.zeros(batch_size, max_sum_len, max_char_in_summ_token, dtype=torch.long) \ 130 | if use_tgt_char else None 131 | 132 | max_tgt_length = max([ex['target'].size(0) for ex in batch]) 133 | tgt_tensor = torch.zeros(batch_size, max_tgt_length, dtype=torch.long) 134 | alignments = [] 135 | for i in range(batch_size): 136 | summ_len_rep[i] = summ_words[i].size(0) 137 | if use_tgt_word: 138 | summ_word_rep[i, :summ_words[i].size(0)].copy_(summ_words[i]) 139 | if use_tgt_char: 140 | summ_char_rep[i, :summ_chars[i].size(0), :].copy_(summ_chars[i]) 141 | # 142 | tgt_len = batch[i]['target'].size(0) 143 | tgt_tensor[i, :tgt_len].copy_(batch[i]['target']) 144 | target = batch[i]['summ_tokens'] 145 | align_mask = torch.LongTensor([src_vocabs[i][w] for w in target]) 146 | alignments.append(align_mask) 147 | 148 | return { 149 | 'ids': ids, 150 | 'language': language, 151 | 'batch_size': batch_size, 152 | 'code_word_rep': code_word_rep, 153 | 'code_char_rep': code_char_rep, 154 | 'code_type_rep': code_type_rep, 155 | 'code_mask_rep': code_mask_rep, 156 | 'code_len': code_len_rep, 157 | 'summ_word_rep': summ_word_rep, 158 | 'summ_char_rep': summ_char_rep, 159 | 'summ_len': summ_len_rep, 160 | 'tgt_seq': tgt_tensor, 161 | 'code_text': [ex['code'] for ex in batch], 162 | 'code_tokens': [ex['code_tokens'] for ex in batch], 163 | 'summ_text': [ex['summ'] for ex in batch], 164 | 'summ_tokens': [ex['summ_tokens'] for ex in batch], 165 | 'src_vocab': src_vocabs, 166 | 'src_map': source_maps, 167 | 'alignment': alignments, 168 | 'stype': [ex['stype'] for ex in batch] 169 | } 170 | -------------------------------------------------------------------------------- /c2nl/inputters/vocabulary.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/facebookresearch/DrQA/blob/master/drqa/reader/data.py 2 | import unicodedata 3 | import numpy as np 4 | from c2nl.inputters.constants import PAD, PAD_WORD, UNK, UNK_WORD, \ 5 | BOS, BOS_WORD, EOS, EOS_WORD 6 | 7 | 8 | class Vocabulary(object): 9 | def __init__(self, no_special_token=False): 10 | if no_special_token: 11 | self.tok2ind = {PAD_WORD: PAD, 12 | UNK_WORD: UNK} 13 | self.ind2tok = {PAD: PAD_WORD, 14 | UNK: UNK_WORD} 15 | else: 16 | self.tok2ind = {PAD_WORD: PAD, 17 | UNK_WORD: UNK, 18 | BOS_WORD: BOS, 19 | EOS_WORD: EOS} 20 | self.ind2tok = {PAD: PAD_WORD, 21 | UNK: UNK_WORD, 22 | BOS: BOS_WORD, 23 | EOS: EOS_WORD} 24 | 25 | @staticmethod 26 | def normalize(token): 27 | return unicodedata.normalize('NFD', token) 28 | 29 | def __len__(self): 30 | return len(self.tok2ind) 31 | 32 | def __iter__(self): 33 | return iter(self.tok2ind) 34 | 35 | def __contains__(self, key): 36 | if type(key) == int: 37 | return key in self.ind2tok 38 | elif type(key) == str: 39 | return self.normalize(key) in self.tok2ind 40 | 41 | def __getitem__(self, key): 42 | if type(key) == int: 43 | return self.ind2tok.get(key, UNK_WORD) 44 | elif type(key) == str: 45 | return self.tok2ind.get(self.normalize(key), 46 | self.tok2ind.get(UNK_WORD)) 47 | else: 48 | raise RuntimeError('Invalid key type.') 49 | 50 | def __setitem__(self, key, item): 51 | if type(key) == int and type(item) == str: 52 | self.ind2tok[key] = item 53 | elif type(key) == str and type(item) == int: 54 | self.tok2ind[key] = item 55 | else: 56 | raise RuntimeError('Invalid (key, item) types.') 57 | 58 | def add(self, token): 59 | token = self.normalize(token) 60 | if token not in self.tok2ind: 61 | index = len(self.tok2ind) 62 | self.tok2ind[token] = index 63 | self.ind2tok[index] = token 64 | 65 | def add_tokens(self, token_list): 66 | assert isinstance(token_list, list) 67 | for token in token_list: 68 | self.add(token) 69 | 70 | def tokens(self): 71 | """Get dictionary tokens. 72 | Return all the words indexed by this dictionary, except for special 73 | tokens. 74 | """ 75 | tokens = [k for k in self.tok2ind.keys() 76 | if k not in {PAD_WORD, UNK_WORD}] 77 | return tokens 78 | 79 | def remove(self, key): 80 | if key in self.tok2ind: 81 | ind = self.tok2ind[key] 82 | del self.tok2ind[key] 83 | del self.ind2tok[ind] 84 | return True 85 | return False 86 | 87 | 88 | class UnicodeCharsVocabulary(Vocabulary): 89 | """Vocabulary containing character-level and word level information. 90 | Has a word vocabulary that is used to lookup word ids and 91 | a character id that is used to map words to arrays of character ids. 92 | The character ids are defined by ord(c) for c in word.encode('utf-8') 93 | This limits the total number of possible char ids to 256. 94 | To this we add 5 additional special ids: begin sentence, end sentence, 95 | begin word, end word and padding. 96 | """ 97 | 98 | def __init__(self, words, max_word_length, 99 | no_special_token): 100 | super(UnicodeCharsVocabulary, self).__init__(no_special_token) 101 | self._max_word_length = max_word_length 102 | 103 | # char ids 0-255 come from utf-8 encoding bytes 104 | # assign 256-300 to special chars 105 | self.bow_char = 256 # 106 | self.eow_char = 257 # 107 | self.pad_char = 258 # 108 | 109 | for w in words: 110 | self.add(w) 111 | num_words = len(self.ind2tok) 112 | 113 | self._word_char_ids = np.zeros([num_words, max_word_length], 114 | dtype=np.int32) 115 | 116 | for i, word in self.ind2tok.items(): 117 | self._word_char_ids[i] = self._convert_word_to_char_ids(word) 118 | 119 | @property 120 | def word_char_ids(self): 121 | return self._word_char_ids 122 | 123 | @property 124 | def max_word_length(self): 125 | return self._max_word_length 126 | 127 | def _convert_word_to_char_ids(self, word): 128 | code = np.zeros([self.max_word_length], dtype=np.int32) 129 | code[:] = self.pad_char 130 | 131 | word_encoded = word.encode('utf-8', 'ignore')[:(self.max_word_length - 2)] 132 | code[0] = self.bow_char 133 | for k, chr_id in enumerate(word_encoded, start=1): 134 | code[k] = chr_id 135 | code[k + 1] = self.eow_char 136 | 137 | return code 138 | 139 | def word_to_char_ids(self, word): 140 | if word in self.tok2ind: 141 | return self._word_char_ids[self.tok2ind[word]] 142 | else: 143 | return self._convert_word_to_char_ids(word) 144 | 145 | def encode_chars(self, sentence, split=True): 146 | ''' 147 | Encode the sentence as a white space delimited string of tokens. 148 | ''' 149 | if split: 150 | chars_ids = [self.word_to_char_ids(cur_word) 151 | for cur_word in sentence.split()] 152 | else: 153 | chars_ids = [self.word_to_char_ids(cur_word) 154 | for cur_word in sentence] 155 | 156 | return chars_ids 157 | -------------------------------------------------------------------------------- /c2nl/models/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .seq2seq import * 4 | from .transformer import * 5 | -------------------------------------------------------------------------------- /c2nl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .char_embedding import * 4 | from .copy_generator import * 5 | from .embeddings import * 6 | from .global_attention import * 7 | from .highway import * 8 | from .multi_head_attn import * 9 | from .position_ffn import * 10 | from .util_class import * 11 | -------------------------------------------------------------------------------- /c2nl/modules/char_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharEmbedding(nn.Module): 6 | """Embeds words based on character embeddings using CNN.""" 7 | 8 | def __init__(self, vocab_size, emsize, filter_size, nfilters): 9 | super(CharEmbedding, self).__init__() 10 | self.embedding = nn.Embedding(vocab_size, emsize) 11 | self.convolution = nn.ModuleList([nn.Conv1d(emsize, int(num_filter), int(k)) 12 | for (k, num_filter) in zip(filter_size, nfilters)]) 13 | 14 | def forward(self, inputs): 15 | """ 16 | Embed words from character embeddings using CNN. 17 | Parameters 18 | -------------------- 19 | inputs -- 3d tensor (N,sentence_len,word_len) 20 | Returns 21 | -------------------- 22 | loss -- total loss over the input mini-batch (N,sentence_len,char_embed_size) 23 | """ 24 | # step1: embed the characters 25 | char_emb = self.embedding(inputs.view(-1, inputs.size(2))) # (N*sentence_len,word_len,char_emb_size) 26 | 27 | # step2: apply convolution to form word embeddings 28 | char_emb = char_emb.transpose(1, 2) # (N*sentence_len,char_emb_size,word_len) 29 | output = [] 30 | for conv in self.convolution: 31 | cnn_out = conv(char_emb).transpose(1, 2) # (N*sentence_len,word_len-filter_size,num_filters) 32 | cnn_out = torch.max(cnn_out, 1)[0] # (N*sentence_len,num_filters) 33 | output.append(cnn_out.view(*inputs.size()[:2], -1)) # appended (N,sentence_len,num_filters) 34 | 35 | output = torch.cat(output, 2) 36 | return output 37 | -------------------------------------------------------------------------------- /c2nl/modules/copy_generator.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/copy_generator.py 2 | """ Generator module """ 3 | import torch.nn as nn 4 | import torch 5 | 6 | from c2nl.inputters import constants 7 | from c2nl.utils.misc import aeq 8 | 9 | 10 | class CopyGenerator(nn.Module): 11 | """Generator module that additionally considers copying 12 | words directly from the source. 13 | The main idea is that we have an extended "dynamic dictionary". 14 | It contains `|tgt_dict|` words plus an arbitrary number of 15 | additional words introduced by the source sentence. 16 | For each source sentence we have a `src_map` that maps 17 | each source word to an index in `tgt_dict` if it known, or 18 | else to an extra word. 19 | The copy generator is an extended version of the standard 20 | generator that computes three values. 21 | * :math:`p_{softmax}` the standard softmax over `tgt_dict` 22 | * :math:`p(z)` the probability of copying a word from 23 | the source 24 | * :math:`p_{copy}` the probility of copying a particular word. 25 | taken from the attention distribution directly. 26 | The model returns a distribution over the extend dictionary, 27 | computed as 28 | :math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` 29 | .. mermaid:: 30 | graph BT 31 | A[input] 32 | S[src_map] 33 | B[softmax] 34 | BB[switch] 35 | C[attn] 36 | D[copy] 37 | O[output] 38 | A --> B 39 | A --> BB 40 | S --> D 41 | C --> D 42 | D --> O 43 | B --> O 44 | BB --> O 45 | Args: 46 | input_size (int): size of input representation 47 | tgt_dict (Vocab): output target dictionary 48 | """ 49 | 50 | def __init__(self, input_size, tgt_dict, generator, eps=1e-20): 51 | super(CopyGenerator, self).__init__() 52 | self.linear = generator 53 | self.linear_copy = nn.Linear(input_size, 1) 54 | self.tgt_dict = tgt_dict 55 | self.softmax = nn.Softmax(dim=-1) 56 | self.sigmoid = nn.Sigmoid() 57 | self.eps = eps 58 | 59 | def forward(self, hidden, attn, src_map): 60 | """ 61 | Compute a distribution over the target dictionary 62 | extended by the dynamic dictionary implied by compying 63 | source words. 64 | Args: 65 | hidden (`FloatTensor`): hidden outputs `[batch, tlen, input_size]` 66 | attn (`FloatTensor`): attn for each `[batch, tlen, slen]` 67 | src_map (`FloatTensor`): 68 | A sparse indicator matrix mapping each source word to 69 | its index in the "extended" vocab containing. 70 | `[batch, src_len, extra_words]` 71 | """ 72 | # CHECKS 73 | batch, tlen, _ = hidden.size() 74 | batch_, tlen_, slen = attn.size() 75 | batch, slen_, cvocab = src_map.size() 76 | aeq(tlen, tlen_) 77 | aeq(slen, slen_) 78 | 79 | # Original probabilities. 80 | logits = self.linear(hidden) 81 | logits[:, :, self.tgt_dict[constants.PAD_WORD]] = -self.eps 82 | prob = self.softmax(logits) 83 | 84 | # Probability of copying p(z=1) batch. 85 | p_copy = self.sigmoid(self.linear_copy(hidden)) 86 | # Probibility of not copying: p_{word}(w) * (1 - p(z)) 87 | out_prob = torch.mul(prob, 1 - p_copy.expand_as(prob)) 88 | mul_attn = torch.mul(attn, p_copy.expand_as(attn)) 89 | copy_prob = torch.bmm(mul_attn, src_map) # `[batch, tlen, extra_words]` 90 | return torch.cat([out_prob, copy_prob], 2) 91 | 92 | 93 | class CopyGeneratorCriterion(object): 94 | """ Copy generator criterion """ 95 | 96 | def __init__(self, vocab_size, force_copy, eps=1e-20): 97 | self.force_copy = force_copy 98 | self.eps = eps 99 | self.offset = vocab_size 100 | 101 | def __call__(self, scores, align, target): 102 | # CHECKS 103 | batch, tlen, _ = scores.size() 104 | _, _tlen = target.size() 105 | aeq(tlen, _tlen) 106 | _, _tlen = align.size() 107 | aeq(tlen, _tlen) 108 | 109 | align = align.view(-1) 110 | target = target.view(-1) 111 | scores = scores.view(-1, scores.size(2)) 112 | 113 | # Compute unks in align and target for readability 114 | align_unk = align.eq(constants.UNK).float() 115 | align_not_unk = align.ne(constants.UNK).float() 116 | target_unk = target.eq(constants.UNK).float() 117 | target_not_unk = target.ne(constants.UNK).float() 118 | 119 | # Copy probability of tokens in source 120 | out = scores.gather(1, align.view(-1, 1) + self.offset).view(-1) 121 | # Set scores for unk to 0 and add eps 122 | out = out.mul(align_not_unk) + self.eps 123 | # Get scores for tokens in target 124 | tmp = scores.gather(1, target.view(-1, 1)).view(-1) 125 | 126 | # Regular prob (no unks and unks that can't be copied) 127 | if not self.force_copy: 128 | # Add score for non-unks in target 129 | out = out + tmp.mul(target_not_unk) 130 | # Add score for when word is unk in both align and tgt 131 | out = out + tmp.mul(align_unk).mul(target_unk) 132 | else: 133 | # Forced copy. Add only probability for not-copied tokens 134 | out = out + tmp.mul(align_unk) 135 | 136 | loss = -out.log() 137 | return loss 138 | -------------------------------------------------------------------------------- /c2nl/modules/global_attention.py: -------------------------------------------------------------------------------- 1 | # copied from https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/global_attention.py 2 | 3 | """" Global attention modules (Luong / Bahdanau) """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from c2nl.utils.misc import aeq, sequence_mask 8 | 9 | 10 | # This class is mainly used by decoder.py for RNNs but also 11 | # by the CNN / transformer decoder when copy attention is used 12 | # CNN has its own attention mechanism ConvMultiStepAttention 13 | # Transformer has its own MultiHeadedAttention 14 | 15 | class GlobalAttention(nn.Module): 16 | """ 17 | Global attention takes a matrix and a query vector. It 18 | then computes a parameterized convex combination of the matrix 19 | based on the input query. 20 | Constructs a unit mapping a query `q` of size `dim` 21 | and a source matrix `H` of size `n x dim`, to an output 22 | of size `dim`. 23 | .. mermaid:: 24 | graph BT 25 | A[Query] 26 | subgraph RNN 27 | C[H 1] 28 | D[H 2] 29 | E[H N] 30 | end 31 | F[Attn] 32 | G[Output] 33 | A --> F 34 | C --> F 35 | D --> F 36 | E --> F 37 | C -.-> G 38 | D -.-> G 39 | E -.-> G 40 | F --> G 41 | All models compute the output as 42 | :math:`c = sum_{j=1}^{SeqLength} a_j H_j` where 43 | :math:`a_j` is the softmax of a score function. 44 | Then then apply a projection layer to [q, c]. 45 | However they 46 | differ on how they compute the attention score. 47 | * Luong Attention (dot, general): 48 | * dot: :math:`score(H_j,q) = H_j^T q` 49 | * general: :math:`score(H_j, q) = H_j^T W_a q` 50 | * Bahdanau Attention (mlp): 51 | * :math:`score(H_j, q) = v_a^T tanh(W_a q + U_a h_j)` 52 | Args: 53 | dim (int): dimensionality of query and key 54 | coverage (bool): use coverage term 55 | attn_type (str): type of attention to use, options [dot,general,mlp] 56 | """ 57 | 58 | def __init__(self, dim, coverage=False, attn_type="dot"): 59 | super(GlobalAttention, self).__init__() 60 | 61 | self.dim = dim 62 | self.attn_type = attn_type 63 | assert (self.attn_type in ["dot", "general", "mlp"]), ( 64 | "Please select a valid attention type.") 65 | 66 | if self.attn_type == "general": 67 | self.linear_in = nn.Linear(dim, dim, bias=False) 68 | elif self.attn_type == "mlp": 69 | self.linear_context = nn.Linear(dim, dim, bias=False) 70 | self.linear_query = nn.Linear(dim, dim, bias=True) 71 | self.v = nn.Linear(dim, 1, bias=False) 72 | # mlp wants it with bias 73 | out_bias = self.attn_type == "mlp" 74 | self.linear_out = nn.Linear(dim * 2, dim, bias=out_bias) 75 | 76 | self.softmax = nn.Softmax(dim=-1) 77 | self.tanh = nn.Tanh() 78 | self._coverage = coverage 79 | 80 | def score(self, h_t, h_s): 81 | """ 82 | Args: 83 | h_t (`FloatTensor`): sequence of queries `[batch x tgt_len x dim]` 84 | h_s (`FloatTensor`): sequence of sources `[batch x src_len x dim]` 85 | Returns: 86 | :obj:`FloatTensor`: 87 | raw attention scores (unnormalized) for each src index 88 | `[batch x tgt_len x src_len]` 89 | """ 90 | # Check input sizes 91 | src_batch, src_len, src_dim = h_s.size() 92 | tgt_batch, tgt_len, tgt_dim = h_t.size() 93 | aeq(src_batch, tgt_batch) 94 | aeq(src_dim, tgt_dim) 95 | aeq(self.dim, src_dim) 96 | 97 | if self.attn_type in ["general", "dot"]: 98 | if self.attn_type == "general": 99 | h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) 100 | h_t_ = self.linear_in(h_t_) 101 | h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) 102 | h_s_ = h_s.transpose(1, 2) 103 | # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) 104 | return torch.bmm(h_t, h_s_) 105 | else: 106 | dim = self.dim 107 | wq = self.linear_query(h_t.view(-1, dim)) 108 | wq = wq.view(tgt_batch, tgt_len, 1, dim) 109 | wq = wq.expand(tgt_batch, tgt_len, src_len, dim) 110 | 111 | uh = self.linear_context(h_s.contiguous().view(-1, dim)) 112 | uh = uh.view(src_batch, 1, src_len, dim) 113 | uh = uh.expand(src_batch, tgt_len, src_len, dim) 114 | 115 | # (batch, t_len, s_len, d) 116 | wquh = self.tanh(wq + uh) 117 | 118 | return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len) 119 | 120 | def forward(self, source, memory_bank, memory_lengths=None, 121 | coverage=None, softmax_weights=True): 122 | """ 123 | Args: 124 | input (`FloatTensor`): query vectors `[batch x tgt_len x dim]` 125 | memory_bank (`FloatTensor`): source vectors `[batch x src_len x dim]` 126 | memory_lengths (`LongTensor`): the source context lengths `[batch]` 127 | coverage (`FloatTensor`): None (not supported yet) 128 | Returns: 129 | (`FloatTensor`, `FloatTensor`): 130 | * Computed vector `[batch x tgt_len x dim]` 131 | * Attention distribtutions for each query 132 | `[batch x tgt_len x src_len]` 133 | """ 134 | 135 | # one step input 136 | assert source.dim() == 3 137 | one_step = True if source.size(1) == 1 else False 138 | 139 | batch, source_l, dim = memory_bank.size() 140 | batch_, target_l, dim_ = source.size() 141 | aeq(batch, batch_) 142 | aeq(dim, dim_) 143 | aeq(self.dim, dim) 144 | 145 | # compute attention scores, as in Luong et al. 146 | align = self.score(source, memory_bank) 147 | 148 | if memory_lengths is not None: 149 | mask = sequence_mask(memory_lengths, max_len=align.size(-1)) 150 | mask = mask.unsqueeze(1) # Make it broadcastable. 151 | align.data.masked_fill_(~mask, -float('inf')) 152 | 153 | # We adopt coverage attn described in Paulus et al., 2018 154 | # REF: https://arxiv.org/abs/1705.04304 155 | if self._coverage: 156 | maxes = torch.max(align, 2, keepdim=True)[0] 157 | exp_score = torch.exp(align - maxes) 158 | 159 | if one_step: 160 | if coverage is None: 161 | # t = 1 in Eq(3) from Paulus et al., 2018 162 | unnormalized_score = exp_score 163 | else: 164 | # t = otherwise in Eq(3) from Paulus et al., 2018 165 | assert coverage.dim() == 3 # B x 1 x slen 166 | unnormalized_score = exp_score.div(coverage + 1e-20) 167 | else: 168 | multiplier = torch.tril(torch.ones(target_l - 1, target_l - 1)) 169 | multiplier = multiplier.unsqueeze(0).expand(batch, *multiplier.size()) 170 | multiplier = multiplier.cuda() if align.is_cuda else multiplier 171 | 172 | penalty = torch.bmm(multiplier, exp_score[:, :-1, :]) # B x tlen-1 x slen 173 | no_penalty = torch.ones_like(penalty[:, -1, :]) # B x slen 174 | penalty = torch.cat([no_penalty.unsqueeze(1), penalty], dim=1) # B x tlen x slen 175 | assert exp_score.size() == penalty.size() 176 | unnormalized_score = exp_score.div(penalty + 1e-20) 177 | 178 | # Eq.(4) from Paulus et al., 2018 179 | align_vectors = unnormalized_score.div(unnormalized_score.sum(2, keepdim=True)) 180 | 181 | # Softmax to normalize attention weights 182 | else: 183 | align_vectors = self.softmax(align) 184 | 185 | # each context vector c_t is the weighted average 186 | # over all the source hidden states 187 | c = torch.bmm(align_vectors, memory_bank) 188 | 189 | # concatenate 190 | concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) 191 | attn_h = self.linear_out(concat_c).view(batch, target_l, dim) 192 | if self.attn_type in ["general", "dot"]: 193 | attn_h = self.tanh(attn_h) 194 | 195 | # Check output sizes 196 | batch_, target_l_, dim_ = attn_h.size() 197 | aeq(target_l, target_l_) 198 | aeq(batch, batch_) 199 | aeq(dim, dim_) 200 | batch_, target_l_, source_l_ = align_vectors.size() 201 | aeq(target_l, target_l_) 202 | aeq(batch, batch_) 203 | aeq(source_l, source_l_) 204 | 205 | covrage_vector = None 206 | if self._coverage and one_step: 207 | covrage_vector = exp_score # B x 1 x slen 208 | 209 | if softmax_weights: 210 | return attn_h, align_vectors, covrage_vector 211 | else: 212 | return attn_h, align, covrage_vector 213 | -------------------------------------------------------------------------------- /c2nl/modules/highway.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as f 3 | 4 | 5 | # reference: https://github.com/allenai/allennlp/blob/master/allennlp/modules/highway.py 6 | class Highway(nn.Module): 7 | """ 8 | A `Highway layer `_ does a gated combination of a linear 9 | transformation and a non-linear transformation of its input. :math:`y = g * x + (1 - g) * 10 | f(A(x))`, where :math:`A` is a linear transformation, :math:`f` is an element-wise 11 | non-linearity, and :math:`g` is an element-wise gate, computed as :math:`sigmoid(B(x))`. 12 | This module will apply a fixed number of highway layers to its input, returning the final 13 | result. 14 | Parameters 15 | ---------- 16 | input_dim : ``int`` 17 | The dimensionality of :math:`x`. We assume the input has shape ``(batch_size, 18 | input_dim)``. 19 | num_layers : ``int``, optional (default=``1``) 20 | The number of highway layers to apply to the input. 21 | activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``f.relu``) 22 | The non-linearity to use in the highway layers. 23 | """ 24 | 25 | def __init__(self, input_dim, num_layers=1, activation=f.relu): 26 | super(Highway, self).__init__() 27 | self._input_dim = input_dim 28 | self._layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)]) 29 | self._activation = activation 30 | for layer in self._layers: 31 | # We should bias the highway layer to just carry its input forward. We do that by 32 | # setting the bias on `B(x)` to be positive, because that means `g` will be biased to 33 | # be high, to we will carry the input forward. The bias on `B(x)` is the second half 34 | # of the bias vector in each Linear layer. 35 | layer.bias[input_dim:].data.fill_(1) 36 | 37 | def forward(self, inputs): 38 | current_input = inputs 39 | for layer in self._layers: 40 | projected_input = layer(current_input) 41 | linear_part = current_input 42 | nonlinear_part, gate = projected_input.chunk(2, dim=-1) 43 | nonlinear_part = self._activation(nonlinear_part) 44 | gate = f.sigmoid(gate) 45 | current_input = gate * linear_part + (1 - gate) * nonlinear_part 46 | return current_input 47 | -------------------------------------------------------------------------------- /c2nl/modules/multi_head_attn.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/multi_headed_attn.py 2 | """ Multi-Head Attention module """ 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | from c2nl.utils.misc import generate_relative_positions_matrix, \ 7 | relative_matmul 8 | 9 | 10 | class MultiHeadedAttention(nn.Module): 11 | """ 12 | Multi-Head Attention module from 13 | "Attention is All You Need" 14 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17`. 15 | Similar to standard `dot` attention but uses 16 | multiple attention distributions simulataneously 17 | to select relevant items. 18 | .. mermaid:: 19 | graph BT 20 | A[key] 21 | B[value] 22 | C[query] 23 | O[output] 24 | subgraph Attn 25 | D[Attn 1] 26 | E[Attn 2] 27 | F[Attn N] 28 | end 29 | A --> D 30 | C --> D 31 | A --> E 32 | C --> E 33 | A --> F 34 | C --> F 35 | D --> O 36 | E --> O 37 | F --> O 38 | B --> O 39 | Also includes several additional tricks. 40 | Args: 41 | head_count (int): number of parallel heads 42 | model_dim (int): the dimension of keys/values/queries, 43 | must be divisible by head_count 44 | dropout (float): dropout parameter 45 | """ 46 | 47 | def __init__(self, head_count, model_dim, d_k, d_v, dropout=0.1, 48 | max_relative_positions=0, use_neg_dist=True, coverage=False): 49 | super(MultiHeadedAttention, self).__init__() 50 | 51 | self.head_count = head_count 52 | self.model_dim = model_dim 53 | self.d_k = d_k 54 | self.d_v = d_v 55 | 56 | self.key = nn.Linear(model_dim, head_count * self.d_k) 57 | self.query = nn.Linear(model_dim, head_count * self.d_k) 58 | self.value = nn.Linear(model_dim, head_count * self.d_v) 59 | 60 | self.softmax = nn.Softmax(dim=-1) 61 | self.dropout = nn.Dropout(dropout) 62 | self.output = nn.Linear(self.head_count * d_v, model_dim) 63 | self._coverage = coverage 64 | 65 | self.max_relative_positions = max_relative_positions 66 | self.use_neg_dist = use_neg_dist 67 | 68 | if max_relative_positions > 0: 69 | vocab_size = max_relative_positions * 2 + 1 \ 70 | if self.use_neg_dist else max_relative_positions + 1 71 | self.relative_positions_embeddings_k = nn.Embedding( 72 | vocab_size, self.d_k) 73 | self.relative_positions_embeddings_v = nn.Embedding( 74 | vocab_size, self.d_v) 75 | 76 | def forward(self, key, value, query, mask=None, layer_cache=None, 77 | attn_type=None, step=None, coverage=None): 78 | """ 79 | Compute the context vector and the attention vectors. 80 | Args: 81 | key (FloatTensor): set of `key_len` 82 | key vectors ``(batch, key_len, dim)`` 83 | value (FloatTensor): set of `key_len` 84 | value vectors ``(batch, key_len, dim)`` 85 | query (FloatTensor): set of `query_len` 86 | query vectors ``(batch, query_len, dim)`` 87 | mask: binary mask 1/0 indicating which keys have 88 | zero / non-zero attention ``(batch, query_len, key_len)`` 89 | Returns: 90 | (FloatTensor, FloatTensor): 91 | * output context vectors ``(batch, query_len, dim)`` 92 | * one of the attention vectors ``(batch, query_len, key_len)`` 93 | """ 94 | 95 | # CHECKS 96 | # batch, k_len, d = key.size() 97 | # batch_, k_len_, d_ = value.size() 98 | # aeq(batch, batch_) 99 | # aeq(k_len, k_len_) 100 | # aeq(d, d_) 101 | # batch_, q_len, d_ = query.size() 102 | # aeq(batch, batch_) 103 | # aeq(d, d_) 104 | # aeq(self.model_dim % 8, 0) 105 | # if mask is not None: 106 | # batch_, q_len_, k_len_ = mask.size() 107 | # aeq(batch_, batch) 108 | # aeq(k_len_, k_len) 109 | # aeq(q_len_ == q_len) 110 | # END CHECKS 111 | 112 | batch_size = key.size(0) 113 | head_count = self.head_count 114 | key_len = key.size(1) 115 | query_len = query.size(1) 116 | use_gpu = key.is_cuda 117 | 118 | def shape(x, dim): 119 | """ projection """ 120 | return x.view(batch_size, -1, head_count, dim).transpose(1, 2) 121 | 122 | def unshape(x, dim): 123 | """ compute context """ 124 | return x.transpose(1, 2).contiguous().view(batch_size, -1, head_count * dim) 125 | 126 | # 1) Project key, value, and query. 127 | if layer_cache is not None: 128 | if attn_type == "self": 129 | # 1) Project key, value, and query. 130 | key = shape(self.key(key), self.d_k) 131 | value = shape(self.value(value), self.d_v) 132 | query = shape(self.query(query), self.d_k) 133 | 134 | if layer_cache["self_keys"] is not None: 135 | key = torch.cat( 136 | (layer_cache["self_keys"], key), 137 | dim=2) 138 | if layer_cache["self_values"] is not None: 139 | value = torch.cat( 140 | (layer_cache["self_values"], value), 141 | dim=2) 142 | layer_cache["self_keys"] = key 143 | layer_cache["self_values"] = value 144 | 145 | elif attn_type == "context": 146 | query = shape(self.query(query), self.d_k) 147 | if layer_cache["memory_keys"] is None: 148 | key = shape(self.key(key), self.d_k) 149 | value = shape(self.value(value), self.d_v) 150 | else: 151 | key, value = layer_cache["memory_keys"], \ 152 | layer_cache["memory_values"] 153 | layer_cache["memory_keys"] = key 154 | layer_cache["memory_values"] = value 155 | else: 156 | key = shape(self.key(key), self.d_k) 157 | value = shape(self.value(value), self.d_v) 158 | query = shape(self.query(query), self.d_k) 159 | 160 | if self.max_relative_positions > 0 and attn_type == "self": 161 | key_len = key.size(2) 162 | # 1 or key_len x key_len 163 | relative_positions_matrix = generate_relative_positions_matrix( 164 | key_len, self.max_relative_positions, self.use_neg_dist, 165 | cache=True if layer_cache is not None else False) 166 | # 1 or key_len x key_len x dim_per_head 167 | relations_keys = self.relative_positions_embeddings_k( 168 | relative_positions_matrix.to(key.device)) 169 | # 1 or key_len x key_len x dim_per_head 170 | relations_values = self.relative_positions_embeddings_v( 171 | relative_positions_matrix.to(key.device)) 172 | 173 | key_len = key.size(2) 174 | query_len = query.size(2) 175 | 176 | # 2) Calculate and scale scores. 177 | query = query / math.sqrt(self.d_k) 178 | # batch x num_heads x query_len x key_len 179 | query_key = torch.matmul(query, key.transpose(2, 3)) 180 | 181 | if self.max_relative_positions > 0 and attn_type == "self": 182 | scores = query_key + relative_matmul(query, relations_keys, True) 183 | else: 184 | scores = query_key 185 | scores = scores.float() 186 | 187 | if mask is not None: 188 | mask = mask.unsqueeze(1) # [B, 1, 1, T_values] 189 | scores = scores.masked_fill(mask, -1e18) 190 | 191 | # ---------------------------- 192 | # We adopt coverage attn described in Paulus et al., 2018 193 | # REF: https://arxiv.org/abs/1705.04304 194 | exp_score = None 195 | if self._coverage and attn_type == 'context': 196 | # batch x num_heads x query_len x 1 197 | maxes = torch.max(scores, 3, keepdim=True)[0] 198 | # batch x num_heads x query_len x key_len 199 | exp_score = torch.exp(scores - maxes) 200 | 201 | if step is not None: # indicates inference mode (one-step at a time) 202 | if coverage is None: 203 | # t = 1 in Eq(3) from Paulus et al., 2018 204 | unnormalized_score = exp_score 205 | else: 206 | # t = otherwise in Eq(3) from Paulus et al., 2018 207 | assert coverage.dim() == 4 # B x num_heads x 1 x key_len 208 | unnormalized_score = exp_score.div(coverage + 1e-20) 209 | else: 210 | multiplier = torch.tril(torch.ones(query_len - 1, query_len - 1)) 211 | # batch x num_heads x query_len-1 x query_len-1 212 | multiplier = multiplier.unsqueeze(0).unsqueeze(0). \ 213 | expand(batch_size, head_count, *multiplier.size()) 214 | multiplier = multiplier.cuda() if scores.is_cuda else multiplier 215 | 216 | # B x num_heads x query_len-1 x key_len 217 | penalty = torch.matmul(multiplier, exp_score[:, :, :-1, :]) 218 | # B x num_heads x key_len 219 | no_penalty = torch.ones_like(penalty[:, :, -1, :]) 220 | # B x num_heads x query_len x key_len 221 | penalty = torch.cat([no_penalty.unsqueeze(2), penalty], dim=2) 222 | assert exp_score.size() == penalty.size() 223 | unnormalized_score = exp_score.div(penalty + 1e-20) 224 | 225 | # Eq.(4) from Paulus et al., 2018 226 | attn = unnormalized_score.div(unnormalized_score.sum(3, keepdim=True)) 227 | 228 | # Softmax to normalize attention weights 229 | else: 230 | # 3) Apply attention dropout and compute context vectors. 231 | attn = self.softmax(scores).to(query.dtype) 232 | 233 | # ---------------------------- 234 | 235 | # 3) Apply attention dropout and compute context vectors. 236 | # attn = self.softmax(scores).to(query.dtype) 237 | drop_attn = self.dropout(attn) 238 | 239 | context_original = torch.matmul(drop_attn, value) 240 | 241 | if self.max_relative_positions > 0 and attn_type == "self": 242 | context = unshape(context_original 243 | + relative_matmul(drop_attn, 244 | relations_values, 245 | False), 246 | self.d_v) 247 | else: 248 | context = unshape(context_original, self.d_v) 249 | 250 | final_output = self.output(context) 251 | # CHECK 252 | # batch_, q_len_, d_ = output.size() 253 | # aeq(q_len, q_len_) 254 | # aeq(batch, batch_) 255 | # aeq(d, d_) 256 | 257 | # a list of size num_heads containing tensors 258 | # of shape `batch x query_len x key_len` 259 | attn_per_head = [attn.squeeze(1) 260 | for attn in attn.chunk(head_count, dim=1)] 261 | 262 | covrage_vector = None 263 | if (self._coverage and attn_type == 'context') and step is not None: 264 | covrage_vector = exp_score # B x num_heads x 1 x key_len 265 | 266 | return final_output, attn_per_head, covrage_vector 267 | 268 | def update_dropout(self, dropout): 269 | self.dropout.p = dropout 270 | -------------------------------------------------------------------------------- /c2nl/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Position feed-forward network from "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | from c2nl.modules.util_class import LayerNorm 7 | 8 | 9 | class PositionwiseFeedForward(nn.Module): 10 | """ A two-layer Feed-Forward-Network with residual layer norm. 11 | Args: 12 | d_model (int): the size of input for the first-layer of the FFN. 13 | d_ff (int): the hidden layer size of the second-layer 14 | of the FNN. 15 | dropout (float): dropout probability(0-1.0). 16 | """ 17 | 18 | def __init__(self, d_model, d_ff, dropout=0.1): 19 | super(PositionwiseFeedForward, self).__init__() 20 | self.intermediate = nn.Linear(d_model, d_ff) 21 | self.output = nn.Linear(d_ff, d_model) 22 | self.layer_norm = LayerNorm(d_model) 23 | self.dropout_1 = nn.Dropout(dropout) 24 | self.relu = nn.ReLU() 25 | self.dropout_2 = nn.Dropout(dropout) 26 | 27 | def forward(self, x): 28 | """ 29 | Layer definition. 30 | Args: 31 | input: [ batch_size, input_len, model_dim ] 32 | Returns: 33 | output: [ batch_size, input_len, model_dim ] 34 | """ 35 | inter = self.dropout_1(self.relu(self.intermediate(self.layer_norm(x)))) 36 | output = self.dropout_2(self.output(inter)) 37 | return output + x 38 | -------------------------------------------------------------------------------- /c2nl/modules/util_class.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/util_class.py 2 | 3 | """ Misc classes """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class LayerNorm(nn.Module): 9 | """ 10 | Layer Normalization class 11 | """ 12 | 13 | def __init__(self, features, eps=1e-6): 14 | super(LayerNorm, self).__init__() 15 | self.weight = nn.Parameter(torch.ones(features)) 16 | self.bias = nn.Parameter(torch.zeros(features)) 17 | self.eps = eps 18 | 19 | def forward(self, x): 20 | mean = x.mean(-1, keepdim=True) 21 | std = x.std(-1, keepdim=True) 22 | return self.weight * (x - mean) / (std + self.eps) + self.bias 23 | 24 | 25 | # At the moment this class is only used by embeddings.Embeddings look-up tables 26 | class Elementwise(nn.ModuleList): 27 | """ 28 | A simple network container. 29 | Parameters are a list of modules. 30 | Inputs are a 3d Tensor whose last dimension is the same length 31 | as the list. 32 | Outputs are the result of applying modules to inputs elementwise. 33 | An optional merge parameter allows the outputs to be reduced to a 34 | single Tensor. 35 | """ 36 | 37 | def __init__(self, merge=None, *args): 38 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 39 | self.merge = merge 40 | super(Elementwise, self).__init__(*args) 41 | 42 | def forward(self, inputs): 43 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 44 | assert len(self) == len(inputs_) 45 | outputs = [f(x) for f, x in zip(self, inputs_)] 46 | if self.merge == 'first': 47 | return outputs[0] 48 | elif self.merge == 'concat' or self.merge == 'mlp': 49 | return torch.cat(outputs, 2) 50 | elif self.merge == 'sum': 51 | return sum(outputs) 52 | else: 53 | return outputs 54 | -------------------------------------------------------------------------------- /c2nl/objects/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .summary import * 4 | from .code import * 5 | -------------------------------------------------------------------------------- /c2nl/objects/code.py: -------------------------------------------------------------------------------- 1 | from c2nl.inputters.vocabulary import Vocabulary, BOS_WORD, EOS_WORD 2 | 3 | 4 | class Code(object): 5 | """ 6 | Code containing annotated text, original text, selection label and 7 | all the extractive spans that can be an answer for the associated question. 8 | """ 9 | 10 | def __init__(self, _id=None): 11 | self._id = _id 12 | self._language = None 13 | self._text = None 14 | self._tokens = [] 15 | self._type = [] 16 | self._mask = [] 17 | self.src_vocab = None # required for Copy Attention 18 | 19 | @property 20 | def id(self) -> str: 21 | return self._id 22 | 23 | @property 24 | def language(self) -> str: 25 | return self._language 26 | 27 | @language.setter 28 | def language(self, param: str) -> None: 29 | self._language = param 30 | 31 | @property 32 | def text(self) -> str: 33 | return self._text 34 | 35 | @text.setter 36 | def text(self, param: str) -> None: 37 | self._text = param 38 | 39 | @property 40 | def type(self) -> list: 41 | return self._type 42 | 43 | @type.setter 44 | def type(self, param: list) -> None: 45 | assert isinstance(param, list) 46 | self._type = param 47 | 48 | @property 49 | def mask(self) -> list: 50 | return self._mask 51 | 52 | @mask.setter 53 | def mask(self, param: list) -> None: 54 | assert isinstance(param, list) 55 | self._mask = param 56 | 57 | @property 58 | def tokens(self) -> list: 59 | return self._tokens 60 | 61 | @tokens.setter 62 | def tokens(self, param: list) -> None: 63 | assert isinstance(param, list) 64 | self._tokens = param 65 | self.form_src_vocab() 66 | 67 | def form_src_vocab(self) -> None: 68 | self.src_vocab = Vocabulary() 69 | assert self.src_vocab.remove(BOS_WORD) 70 | assert self.src_vocab.remove(EOS_WORD) 71 | self.src_vocab.add_tokens(self.tokens) 72 | 73 | def vectorize(self, word_dict, _type='word') -> list: 74 | if _type == 'word': 75 | return [word_dict[w] for w in self.tokens] 76 | elif _type == 'char': 77 | return [word_dict.word_to_char_ids(w).tolist() for w in self.tokens] 78 | else: 79 | assert False 80 | -------------------------------------------------------------------------------- /c2nl/objects/summary.py: -------------------------------------------------------------------------------- 1 | from c2nl.inputters.vocabulary import EOS_WORD, BOS_WORD 2 | 3 | 4 | class Summary(object): 5 | """ 6 | Summary containing annotated text, original text, a list of 7 | candidate documents, answers and well formed answers. 8 | """ 9 | 10 | def __init__(self, _id=None): 11 | self._id = _id 12 | self._text = None 13 | self._tokens = [] 14 | self._type = None # summary, comment etc 15 | 16 | @property 17 | def id(self) -> str: 18 | return self._id 19 | 20 | @property 21 | def text(self) -> str: 22 | return self._text 23 | 24 | @text.setter 25 | def text(self, param: str) -> None: 26 | self._text = param 27 | 28 | @property 29 | def tokens(self) -> list: 30 | return self._tokens 31 | 32 | @tokens.setter 33 | def tokens(self, param: list) -> None: 34 | assert isinstance(param, list) 35 | self._tokens = param 36 | 37 | def append_token(self, tok=EOS_WORD): 38 | assert isinstance(tok, str) 39 | self._tokens.append(tok) 40 | 41 | def prepend_token(self, tok=BOS_WORD): 42 | assert isinstance(tok, str) 43 | self._tokens.insert(0, tok) 44 | 45 | @property 46 | def type(self) -> str: 47 | return self._type 48 | 49 | @type.setter 50 | def type(self, param: str) -> None: 51 | assert isinstance(param, str) 52 | self._type = param 53 | 54 | def vectorize(self, word_dict, _type='word') -> list: 55 | if _type == 'word': 56 | return [word_dict[w] for w in self.tokens] 57 | elif _type == 'char': 58 | return [word_dict.word_to_char_ids(w).tolist() for w in self.tokens] 59 | else: 60 | assert False 61 | -------------------------------------------------------------------------------- /c2nl/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .tokenizer import * 4 | from .code_tokenizer import * 5 | from .simple_tokenizer import * 6 | -------------------------------------------------------------------------------- /c2nl/tokenizers/code_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Author : Saikat Chakraborty (saikatc@cs.columbia.edu) 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | """Basic tokenizer that splits text into alpha-numeric tokens and 6 | non-whitespace tokens. 7 | """ 8 | 9 | import logging 10 | from .tokenizer import Tokens, Tokenizer 11 | import re 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def tokenize_with_camel_case(token): 17 | matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', token) 18 | return [m.group(0) for m in matches] 19 | 20 | 21 | def tokenize_with_snake_case(token): 22 | return token.split('_') 23 | 24 | 25 | class CodeTokenizer(Tokenizer): 26 | def __init__(self, camel_case=True, snake_case=True, **kwargs): 27 | """ 28 | Args: 29 | camel_case: Boolean denoting whether CamelCase split is desired 30 | snake_case: Boolean denoting whether snake_case split is desired 31 | annotators: None or empty set (only tokenizes). 32 | """ 33 | super(CodeTokenizer, self).__init__() 34 | self.snake_case = snake_case 35 | self.camel_case = camel_case 36 | assert self.snake_case or self.camel_case, \ 37 | 'To call CodeIdentifierTokenizer at least one of camel_case or ' \ 38 | 'snake_case flag has to be turned on in the initializer' 39 | if len(kwargs.get('annotators', {})) > 0: 40 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 41 | (type(self).__name__, kwargs.get('annotators'))) 42 | self.annotators = set() 43 | 44 | def tokenize(self, text): 45 | tokens = text.split() 46 | snake_case_tokenized = [] 47 | if self.snake_case: 48 | for token in tokens: 49 | snake_case_tokenized.extend(tokenize_with_snake_case(token)) 50 | else: 51 | snake_case_tokenized = tokens 52 | camel_case_tokenized = [] 53 | if self.camel_case: 54 | for token in snake_case_tokenized: 55 | camel_case_tokenized.extend(tokenize_with_camel_case(token)) 56 | else: 57 | camel_case_tokenized = snake_case_tokenized 58 | data = [] 59 | for token in camel_case_tokenized: 60 | data.append((token, token, token)) 61 | 62 | return Tokens(data, self.annotators) 63 | -------------------------------------------------------------------------------- /c2nl/tokenizers/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Basic tokenizer that splits text into alpha-numeric tokens and 8 | non-whitespace tokens. 9 | """ 10 | 11 | import regex 12 | import logging 13 | from .tokenizer import Tokens, Tokenizer 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class SimpleTokenizer(Tokenizer): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self, **kwargs): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | if len(kwargs.get('annotators', {})) > 0: 32 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 33 | (type(self).__name__, kwargs.get('annotators'))) 34 | self.annotators = set() 35 | 36 | def tokenize(self, text): 37 | data = [] 38 | matches = [m for m in self._regexp.finditer(text)] 39 | for i in range(len(matches)): 40 | # Get text 41 | token = matches[i].group() 42 | 43 | # Get whitespace 44 | span = matches[i].span() 45 | start_ws = span[0] 46 | if i + 1 < len(matches): 47 | end_ws = matches[i + 1].span()[0] 48 | else: 49 | end_ws = span[1] 50 | 51 | # Format data 52 | data.append(( 53 | token, 54 | text[start_ws: end_ws], 55 | span, 56 | )) 57 | return Tokens(data, self.annotators) 58 | -------------------------------------------------------------------------------- /c2nl/tokenizers/tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright 2017-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | """Base tokenizer/tokens classes and utilities.""" 8 | 9 | import copy 10 | 11 | 12 | class Tokens(object): 13 | """A class to represent a list of tokenized text.""" 14 | TEXT = 0 15 | TEXT_WS = 1 16 | SPAN = 2 17 | POS = 3 18 | LEMMA = 4 19 | NER = 5 20 | 21 | def __init__(self, data, annotators, opts=None): 22 | self.data = data 23 | self.annotators = annotators 24 | self.opts = opts or {} 25 | 26 | def __len__(self): 27 | """The number of tokens.""" 28 | return len(self.data) 29 | 30 | def slice(self, i=None, j=None): 31 | """Return a view of the list of tokens from [i, j).""" 32 | new_tokens = copy.copy(self) 33 | new_tokens.data = self.data[i: j] 34 | return new_tokens 35 | 36 | def untokenize(self): 37 | """Returns the original text (with whitespace reinserted).""" 38 | return ''.join([t[self.TEXT_WS] for t in self.data]).strip() 39 | 40 | def words(self, uncased=False): 41 | """Returns a list of the text of each token 42 | Args: 43 | uncased: lower cases text 44 | """ 45 | if uncased: 46 | return [t[self.TEXT].lower() for t in self.data] 47 | else: 48 | return [t[self.TEXT] for t in self.data] 49 | 50 | def offsets(self): 51 | """Returns a list of [start, end) character offsets of each token.""" 52 | return [t[self.SPAN] for t in self.data] 53 | 54 | def pos(self): 55 | """Returns a list of part-of-speech tags of each token. 56 | Returns None if this annotation was not included. 57 | """ 58 | if 'pos' not in self.annotators: 59 | return None 60 | return [t[self.POS] for t in self.data] 61 | 62 | def lemmas(self): 63 | """Returns a list of the lemmatized text of each token. 64 | Returns None if this annotation was not included. 65 | """ 66 | if 'lemma' not in self.annotators: 67 | return None 68 | return [t[self.LEMMA] for t in self.data] 69 | 70 | def entities(self): 71 | """Returns a list of named-entity-recognition tags of each token. 72 | Returns None if this annotation was not included. 73 | """ 74 | if 'ner' not in self.annotators: 75 | return None 76 | return [t[self.NER] for t in self.data] 77 | 78 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 79 | """Returns a list of all ngrams from length 1 to n. 80 | Args: 81 | n: upper limit of ngram length 82 | uncased: lower cases text 83 | filter_fn: user function that takes in an ngram list and returns 84 | True or False to keep or not keep the ngram 85 | as_string: return the ngram as a string vs list 86 | """ 87 | 88 | def _skip(gram): 89 | if not filter_fn: 90 | return False 91 | return filter_fn(gram) 92 | 93 | words = self.words(uncased) 94 | ngrams = [(s, e + 1) 95 | for s in range(len(words)) 96 | for e in range(s, min(s + n, len(words))) 97 | if not _skip(words[s:e + 1])] 98 | 99 | # Concatenate into strings 100 | if as_strings: 101 | ngrams = ['{}'.format(' '.join(words[s:e])) for (s, e) in ngrams] 102 | 103 | return ngrams 104 | 105 | def entity_groups(self): 106 | """Group consecutive entity tokens with the same NER tag.""" 107 | entities = self.entities() 108 | if not entities: 109 | return None 110 | non_ent = self.opts.get('non_ent', 'O') 111 | groups = [] 112 | idx = 0 113 | while idx < len(entities): 114 | ner_tag = entities[idx] 115 | # Check for entity tag 116 | if ner_tag != non_ent: 117 | # Chomp the sequence 118 | start = idx 119 | while idx < len(entities) and entities[idx] == ner_tag: 120 | idx += 1 121 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 122 | else: 123 | idx += 1 124 | return groups 125 | 126 | 127 | class Tokenizer(object): 128 | """Base tokenizer class. 129 | Tokenizers implement tokenize, which should return a Tokens class. 130 | """ 131 | 132 | def tokenize(self, text): 133 | raise NotImplementedError 134 | 135 | def shutdown(self): 136 | pass 137 | 138 | def __del__(self): 139 | self.shutdown() 140 | -------------------------------------------------------------------------------- /c2nl/translator/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .beam import * 4 | from .penalties import * 5 | from .translator import * 6 | from .translation import * 7 | -------------------------------------------------------------------------------- /c2nl/translator/beam.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import warnings 5 | from c2nl.translator import penalties 6 | 7 | 8 | class Beam(object): 9 | """ 10 | Class for managing the internals of the beam search process. 11 | Takes care of beams, back pointers, and scores. 12 | Args: 13 | size (int): beam size 14 | pad, bos, eos (int): indices of padding, beginning, and ending. 15 | n_best (int): nbest size to use 16 | cuda (bool): use gpu 17 | global_scorer (:obj:`GlobalScorer`) 18 | """ 19 | 20 | def __init__(self, size, pad, bos, eos, 21 | n_best=1, cuda=False, 22 | global_scorer=None, 23 | min_length=0, 24 | stepwise_penalty=False, 25 | block_ngram_repeat=0, 26 | exclusion_tokens=set()): 27 | 28 | self.size = size 29 | self.tt = torch.cuda if cuda else torch 30 | 31 | # The score for each translation on the beam. 32 | self.scores = self.tt.FloatTensor(size).zero_() 33 | self.all_scores = [] 34 | 35 | # The backpointers at each time-step. 36 | self.prev_ks = [] 37 | 38 | # The outputs at each time-step. 39 | self.next_ys = [self.tt.LongTensor(size) 40 | .fill_(pad)] 41 | self.next_ys[0][0] = bos 42 | 43 | # Has EOS topped the beam yet. 44 | self._eos = eos 45 | self.eos_top = False 46 | 47 | # The attentions (matrix) for each time. 48 | self.attn = [] 49 | 50 | # Time and k pair for finished. 51 | self.finished = [] 52 | self.n_best = n_best 53 | 54 | # Information for global scoring. 55 | self.global_scorer = global_scorer 56 | self.global_state = {} 57 | 58 | # Minimum prediction length 59 | self.min_length = min_length 60 | 61 | # Apply Penalty at every step 62 | self.stepwise_penalty = stepwise_penalty 63 | self.block_ngram_repeat = block_ngram_repeat 64 | self.exclusion_tokens = exclusion_tokens 65 | 66 | def get_current_state(self): 67 | "Get the outputs for the current timestep." 68 | return self.next_ys[-1] 69 | 70 | def get_current_origin(self): 71 | "Get the backpointers for the current timestep." 72 | return self.prev_ks[-1] 73 | 74 | def advance(self, word_probs, attn_out): 75 | """ 76 | Given prob over words for every last beam `wordLk` and attention 77 | `attn_out`: Compute and update the beam search. 78 | Parameters: 79 | * `word_probs`- probs of advancing from the last step (K x words) 80 | * `attn_out`- attention at the last step 81 | Returns: True if beam search is complete. 82 | """ 83 | num_words = word_probs.size(1) 84 | if self.stepwise_penalty: 85 | self.global_scorer.update_score(self, attn_out) 86 | # force the output to be longer than self.min_length 87 | cur_len = len(self.next_ys) 88 | if cur_len < self.min_length: 89 | for k in range(len(word_probs)): 90 | word_probs[k][self._eos] = -1e20 91 | # Sum the previous scores. 92 | if len(self.prev_ks) > 0: 93 | beam_scores = word_probs + \ 94 | self.scores.unsqueeze(1).expand_as(word_probs) 95 | # Don't let EOS have children. 96 | for i in range(self.next_ys[-1].size(0)): 97 | if self.next_ys[-1][i] == self._eos: 98 | beam_scores[i] = -1e20 99 | 100 | # Block ngram repeats 101 | if self.block_ngram_repeat > 0: 102 | le = len(self.next_ys) 103 | for j in range(self.next_ys[-1].size(0)): 104 | hyp, _ = self.get_hyp(le - 1, j) 105 | ngrams = set() 106 | fail = False 107 | gram = [] 108 | for i in range(le - 1): 109 | # Last n tokens, n = block_ngram_repeat 110 | gram = (gram + [hyp[i]])[-self.block_ngram_repeat:] 111 | # Skip the blocking if it is in the exclusion list 112 | if set(gram) & self.exclusion_tokens: 113 | continue 114 | if tuple(gram) in ngrams: 115 | fail = True 116 | ngrams.add(tuple(gram)) 117 | if fail: 118 | beam_scores[j] = -1e20 119 | else: 120 | beam_scores = word_probs[0] 121 | 122 | flat_beam_scores = beam_scores.view(-1) 123 | best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, 124 | True, True) 125 | 126 | self.all_scores.append(self.scores) 127 | self.scores = best_scores 128 | 129 | # best_scores_id is flattened beam x word array, so calculate which 130 | # word and beam each score came from 131 | prev_k = best_scores_id / num_words 132 | self.prev_ks.append(prev_k) 133 | self.next_ys.append((best_scores_id - prev_k * num_words)) 134 | self.attn.append(attn_out.index_select(0, prev_k)) 135 | self.global_scorer.update_global_state(self) 136 | 137 | for i in range(self.next_ys[-1].size(0)): 138 | if self.next_ys[-1][i] == self._eos: 139 | global_scores = self.global_scorer.score(self, self.scores) 140 | s = global_scores[i] 141 | self.finished.append((s, len(self.next_ys) - 1, i)) 142 | 143 | # End condition is when top-of-beam is EOS and no global score. 144 | if self.next_ys[-1][0] == self._eos: 145 | self.all_scores.append(self.scores) 146 | self.eos_top = True 147 | 148 | @property 149 | def done(self): 150 | return self.eos_top and len(self.finished) >= self.n_best 151 | 152 | def sort_finished(self, minimum=None): 153 | if minimum is not None: 154 | i = 0 155 | # Add from beam until we have minimum outputs. 156 | while len(self.finished) < minimum: 157 | global_scores = self.global_scorer.score(self, self.scores) 158 | s = global_scores[i] 159 | self.finished.append((s, len(self.next_ys) - 1, i)) 160 | i += 1 161 | 162 | self.finished.sort(key=lambda a: -a[0]) 163 | scores = [sc for sc, _, _ in self.finished] 164 | ks = [(t, k) for _, t, k in self.finished] 165 | return scores, ks 166 | 167 | def get_hyp(self, timestep, k): 168 | """ 169 | Walk back to construct the full hypothesis. 170 | """ 171 | hyp, attn = [], [] 172 | for j in range(len(self.prev_ks[:timestep]) - 1, -1, -1): 173 | hyp.append(self.next_ys[j + 1][k]) 174 | attn.append(self.attn[j][k]) 175 | k = self.prev_ks[j][k] 176 | return hyp[::-1], torch.stack(attn[::-1]) 177 | 178 | 179 | class GNMTGlobalScorer(object): 180 | """ 181 | NMT re-ranking score from 182 | "Google's Neural Machine Translation System" :cite:`wu2016google` 183 | Args: 184 | alpha (float): length parameter 185 | beta (float): coverage parameter 186 | """ 187 | 188 | def __init__(self, alpha, beta, cov_penalty, length_penalty): 189 | self._validate(alpha, beta, length_penalty, cov_penalty) 190 | self.alpha = alpha 191 | self.beta = beta 192 | penalty_builder = penalties.PenaltyBuilder(cov_penalty, 193 | length_penalty) 194 | self.has_cov_pen = penalty_builder.has_cov_pen 195 | # Term will be subtracted from probability 196 | self.cov_penalty = penalty_builder.coverage_penalty 197 | 198 | self.has_len_pen = penalty_builder.has_len_pen 199 | # Probability will be divided by this 200 | self.length_penalty = penalty_builder.length_penalty 201 | 202 | @classmethod 203 | def _validate(cls, alpha, beta, length_penalty, coverage_penalty): 204 | # these warnings indicate that either the alpha/beta 205 | # forces a penalty to be a no-op, or a penalty is a no-op but 206 | # the alpha/beta would suggest otherwise. 207 | if length_penalty is None or length_penalty == "none": 208 | if alpha != 0: 209 | warnings.warn("Non-default `alpha` with no length penalty. " 210 | "`alpha` has no effect.") 211 | else: 212 | # using some length penalty 213 | if length_penalty == "wu" and alpha == 0.: 214 | warnings.warn("Using length penalty Wu with alpha==0 " 215 | "is equivalent to using length penalty none.") 216 | if coverage_penalty is None or coverage_penalty == "none": 217 | if beta != 0: 218 | warnings.warn("Non-default `beta` with no coverage penalty. " 219 | "`beta` has no effect.") 220 | else: 221 | # using some coverage penalty 222 | if beta == 0.: 223 | warnings.warn("Non-default coverage penalty with beta==0 " 224 | "is equivalent to using coverage penalty none.") 225 | 226 | def score(self, beam, logprobs): 227 | """Rescore a prediction based on penalty functions.""" 228 | len_pen = self.length_penalty(len(beam.next_ys), self.alpha) 229 | normalized_probs = logprobs / len_pen 230 | if not beam.stepwise_penalty: 231 | penalty = self.cov_penalty(beam.global_state["coverage"], 232 | self.beta) 233 | normalized_probs -= penalty 234 | 235 | return normalized_probs 236 | 237 | def update_score(self, beam, attn): 238 | """Update scores of a Beam that is not finished.""" 239 | if "prev_penalty" in beam.global_state.keys(): 240 | beam.scores.add_(beam.global_state["prev_penalty"]) 241 | penalty = self.cov_penalty(beam.global_state["coverage"] + attn, 242 | self.beta) 243 | beam.scores.sub_(penalty) 244 | 245 | def update_global_state(self, beam): 246 | """Keeps the coverage vector as sum of attentions.""" 247 | if len(beam.prev_ks) == 1: 248 | beam.global_state["prev_penalty"] = beam.scores.clone().fill_(0.0) 249 | beam.global_state["coverage"] = beam.attn[-1] 250 | self.cov_total = beam.attn[-1].sum(1) 251 | else: 252 | self.cov_total += torch.min(beam.attn[-1], 253 | beam.global_state['coverage']).sum(1) 254 | beam.global_state["coverage"] = beam.global_state["coverage"] \ 255 | .index_select(0, beam.prev_ks[-1]).add(beam.attn[-1]) 256 | 257 | prev_penalty = self.cov_penalty(beam.global_state["coverage"], 258 | self.beta) 259 | beam.global_state["prev_penalty"] = prev_penalty 260 | -------------------------------------------------------------------------------- /c2nl/translator/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """Returns the Length and Coverage Penalty function for Beam Search. 7 | Args: 8 | length_pen (str): option name of length pen 9 | cov_pen (str): option name of cov pen 10 | Attributes: 11 | has_cov_pen (bool): Whether coverage penalty is None (applying it 12 | is a no-op). Note that the converse isn't true. Setting beta 13 | to 0 should force coverage length to be a no-op. 14 | has_len_pen (bool): Whether length penalty is None (applying it 15 | is a no-op). Note that the converse isn't true. Setting alpha 16 | to 1 should force length penalty to be a no-op. 17 | coverage_penalty (callable[[FloatTensor, float], FloatTensor]): 18 | Calculates the coverage penalty. 19 | length_penalty (callable[[int, float], float]): Calculates 20 | the length penalty. 21 | """ 22 | 23 | def __init__(self, cov_pen, length_pen): 24 | self.has_cov_pen = not self._pen_is_none(cov_pen) 25 | self.coverage_penalty = self._coverage_penalty(cov_pen) 26 | self.has_len_pen = not self._pen_is_none(length_pen) 27 | self.length_penalty = self._length_penalty(length_pen) 28 | 29 | @staticmethod 30 | def _pen_is_none(pen): 31 | return pen == "none" or pen is None 32 | 33 | def _coverage_penalty(self, cov_pen): 34 | if cov_pen == "wu": 35 | return self.coverage_wu 36 | elif cov_pen == "summary": 37 | return self.coverage_summary 38 | elif self._pen_is_none(cov_pen): 39 | return self.coverage_none 40 | else: 41 | raise NotImplementedError("No '{:s}' coverage penalty.".format( 42 | cov_pen)) 43 | 44 | def _length_penalty(self, length_pen): 45 | if length_pen == "wu": 46 | return self.length_wu 47 | elif length_pen == "avg": 48 | return self.length_average 49 | elif self._pen_is_none(length_pen): 50 | return self.length_none 51 | else: 52 | raise NotImplementedError("No '{:s}' length penalty.".format( 53 | length_pen)) 54 | 55 | # Below are all the different penalty terms implemented so far. 56 | # Subtract coverage penalty from topk log probs. 57 | # Divide topk log probs by length penalty. 58 | 59 | def coverage_wu(self, cov, beta=0.): 60 | """GNMT coverage re-ranking score. 61 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 62 | ``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is 63 | probably ``batch_size x beam_size`` but could be several 64 | dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention, 65 | then the ``seq_len`` axis probably sums to (almost) 1. 66 | """ 67 | 68 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) 69 | return beta * penalty 70 | 71 | def coverage_summary(self, cov, beta=0.): 72 | """Our summary penalty.""" 73 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1) 74 | penalty -= cov.size(-1) 75 | return beta * penalty 76 | 77 | def coverage_none(self, cov, beta=0.): 78 | """Returns zero as penalty""" 79 | none = torch.zeros((1,), device=cov.device, 80 | dtype=torch.float) 81 | if cov.dim() == 3: 82 | none = none.unsqueeze(0) 83 | return none 84 | 85 | def length_wu(self, cur_len, alpha=0.): 86 | """GNMT length re-ranking score. 87 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 88 | """ 89 | 90 | return ((5 + cur_len) / 6.0) ** alpha 91 | 92 | def length_average(self, cur_len, alpha=0.): 93 | """Returns the current sequence length.""" 94 | return cur_len 95 | 96 | def length_none(self, cur_len, alpha=0.): 97 | """Returns unmodified scores.""" 98 | return 1.0 99 | -------------------------------------------------------------------------------- /c2nl/translator/translation.py: -------------------------------------------------------------------------------- 1 | # https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/translation.py 2 | """ Translation main class """ 3 | from __future__ import division, unicode_literals 4 | from __future__ import print_function 5 | 6 | from c2nl.inputters import constants 7 | 8 | 9 | class TranslationBuilder(object): 10 | """ 11 | Build a word-based translation from the batch output 12 | of translator and the underlying dictionaries. 13 | Replacement based on "Addressing the Rare Word 14 | Problem in Neural Machine Translation" :cite:`Luong2015b` 15 | Args: 16 | data (DataSet): 17 | tgt_vocab : Vocabulary 18 | n_best (int): number of translations produced 19 | replace_unk (bool): replace unknown words using attention 20 | """ 21 | 22 | def __init__(self, tgt_vocab, n_best=1, replace_unk=False): 23 | self.tgt_vocab = tgt_vocab 24 | self.n_best = n_best 25 | self.replace_unk = replace_unk 26 | 27 | def _build_target_tokens(self, src_vocab, src_raw, pred, attn): 28 | tokens = [] 29 | for tok in pred: 30 | tok = tok if isinstance(tok, int) \ 31 | else tok.item() 32 | if tok == constants.BOS: 33 | continue 34 | if tok == constants.EOS: 35 | break 36 | 37 | if tok < len(self.tgt_vocab): 38 | tokens.append(self.tgt_vocab[tok]) 39 | else: 40 | tokens.append(src_vocab[tok - len(self.tgt_vocab)]) 41 | 42 | if self.replace_unk and (attn is not None): 43 | for i in range(len(tokens)): 44 | if tokens[i] == constants.UNK_WORD: 45 | _, max_index = attn[i].max(0) 46 | tokens[i] = src_raw[max_index.item()] 47 | return tokens 48 | 49 | def from_batch(self, translation_batch, src_raw, targets, src_vocabs): 50 | batch_size = len(translation_batch["predictions"]) 51 | preds = translation_batch["predictions"] 52 | pred_score = translation_batch["scores"] 53 | attn = translation_batch["attention"] 54 | 55 | translations = [] 56 | for b in range(batch_size): 57 | src_vocab = src_vocabs[b] if src_vocabs else None 58 | pred_sents = [self._build_target_tokens( 59 | src_vocab, src_raw[b], 60 | preds[b][n], attn[b][n]) 61 | for n in range(self.n_best)] 62 | translation = Translation(targets[b], pred_sents, 63 | attn[b], pred_score[b]) 64 | translations.append(translation) 65 | 66 | return translations 67 | 68 | 69 | class Translation(object): 70 | """ 71 | Container for a translated sentence. 72 | Attributes: 73 | target ([str]): list of targets 74 | pred_sents ([[str]]): words from the n-best translations 75 | pred_scores ([[float]]): log-probs of n-best translations 76 | attns ([`FloatTensor`]) : attention dist for each translation 77 | """ 78 | 79 | def __init__(self, targets, pred_sents, attn, pred_scores): 80 | self.targets = targets 81 | self.pred_sents = pred_sents 82 | self.attns = attn 83 | self.pred_scores = pred_scores 84 | 85 | def log(self, sent_number): 86 | """ 87 | Log translation. 88 | """ 89 | output = '\nTARGET {}: {}\n'.format(sent_number, '\t'.join(self.targets)) 90 | 91 | best_pred = self.pred_sents[0] 92 | best_score = self.pred_scores[0] 93 | pred_sent = ' '.join(best_pred) 94 | output += 'PRED {}: {}\n'.format(sent_number, pred_sent) 95 | output += "PRED SCORE: {:.4f}\n".format(best_score) 96 | 97 | if len(self.pred_sents) > 1: 98 | output += '\nBEST HYP:\n' 99 | for score, sent in zip(self.pred_scores, self.pred_sents): 100 | output += "[{:.4f}] {}\n".format(score, sent) 101 | 102 | return output 103 | -------------------------------------------------------------------------------- /c2nl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | 3 | from .logging import * 4 | from .misc import * 5 | -------------------------------------------------------------------------------- /c2nl/utils/copy_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from c2nl.inputters import constants 3 | 4 | 5 | def collapse_copy_scores(tgt_dict, src_vocabs): 6 | """ 7 | Given scores from an expanded dictionary 8 | corresponding to a batch, sums together copies, 9 | with a dictionary word when it is ambiguous. 10 | """ 11 | offset = len(tgt_dict) 12 | blank_arr, fill_arr = [], [] 13 | for b in range(len(src_vocabs)): 14 | blank = [] 15 | fill = [] 16 | src_vocab = src_vocabs[b] 17 | # Starting from 2 to ignore PAD and UNK token 18 | for i in range(2, len(src_vocab)): 19 | sw = src_vocab[i] 20 | ti = tgt_dict[sw] 21 | if ti != constants.UNK: 22 | blank.append(offset + i) 23 | fill.append(ti) 24 | 25 | blank_arr.append(blank) 26 | fill_arr.append(fill) 27 | 28 | return blank_arr, fill_arr 29 | 30 | 31 | def make_src_map(data): 32 | """ ? """ 33 | src_size = max([t.size(0) for t in data]) 34 | src_vocab_size = max([t.max() for t in data]) + 1 35 | alignment = torch.zeros(len(data), src_size, src_vocab_size) 36 | for i, sent in enumerate(data): 37 | for j, t in enumerate(sent): 38 | alignment[i, j, t] = 1 39 | return alignment 40 | 41 | 42 | def align(data): 43 | """ ? """ 44 | tgt_size = max([t.size(0) for t in data]) 45 | alignment = torch.zeros(len(data), tgt_size).long() 46 | for i, sent in enumerate(data): 47 | alignment[i, :sent.size(0)] = sent 48 | return alignment 49 | 50 | 51 | def replace_unknown(prediction, attn, src_raw): 52 | """ ? 53 | attn: tgt_len x src_len 54 | """ 55 | tokens = prediction.split() 56 | for i in range(len(tokens)): 57 | if tokens[i] == constants.UNK_WORD: 58 | _, max_index = attn[i].max(0) 59 | tokens[i] = src_raw[max_index.item()] 60 | return ' '.join(tokens) 61 | -------------------------------------------------------------------------------- /c2nl/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | if log_file and log_file != '': 15 | file_handler = logging.FileHandler(log_file) 16 | file_handler.setFormatter(log_format) 17 | logger.addHandler(file_handler) 18 | 19 | console_handler = logging.StreamHandler() 20 | console_handler.setFormatter(log_format) 21 | logger.addHandler(console_handler) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /c2nl/utils/misc.py: -------------------------------------------------------------------------------- 1 | # src: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/utils/misc.py 2 | # -*- coding: utf-8 -*- 3 | 4 | import string 5 | import torch 6 | import subprocess 7 | from nltk.stem import PorterStemmer 8 | from c2nl.inputters import constants 9 | 10 | ps = PorterStemmer() 11 | 12 | 13 | def normalize_string(s, dostem=False): 14 | """Lower text and remove punctuation, and extra whitespace.""" 15 | 16 | def white_space_fix(text): 17 | return ' '.join(text.split()) 18 | 19 | def remove_punc(text): 20 | exclude = set(string.punctuation) 21 | return ''.join(ch for ch in text if ch not in exclude) 22 | 23 | def lower(text): 24 | return text.lower() 25 | 26 | def stem(text): 27 | if not dostem: 28 | return text 29 | return ' '.join([ps.stem(w) for w in text.split()]) 30 | 31 | return stem(white_space_fix(remove_punc(lower(s)))) 32 | 33 | 34 | def aeq(*args): 35 | """ 36 | Assert all arguments have the same value 37 | """ 38 | arguments = (arg for arg in args) 39 | first = next(arguments) 40 | assert all(arg == first for arg in arguments), \ 41 | "Not all arguments have the same value: " + str(args) 42 | 43 | 44 | def validate(sequence): 45 | seq_wo_punc = sequence.translate(str.maketrans('', '', string.punctuation)) 46 | return len(seq_wo_punc.strip()) > 0 47 | 48 | 49 | def tens2sen(t, word_dict=None, src_vocabs=None): 50 | sentences = [] 51 | # loop over the batch elements 52 | for idx, s in enumerate(t): 53 | sentence = [] 54 | for wt in s: 55 | word = wt if isinstance(wt, int) \ 56 | else wt.item() 57 | if word in [constants.BOS]: 58 | continue 59 | if word in [constants.EOS]: 60 | break 61 | if word_dict and word < len(word_dict): 62 | sentence += [word_dict[word]] 63 | elif src_vocabs: 64 | word = word - len(word_dict) 65 | sentence += [src_vocabs[idx][word]] 66 | else: 67 | sentence += [str(word)] 68 | 69 | if len(sentence) == 0: 70 | # NOTE: just a trick not to score empty sentence 71 | # this has no consequence 72 | sentence = [str(constants.PAD)] 73 | 74 | sentence = ' '.join(sentence) 75 | # if not validate(sentence): 76 | # sentence = str(constants.PAD) 77 | sentences += [sentence] 78 | return sentences 79 | 80 | 81 | def sequence_mask(lengths, max_len=None): 82 | """ 83 | Creates a boolean mask from sequence lengths. 84 | :param lengths: 1d tensor [batch_size] 85 | :param max_len: int 86 | """ 87 | batch_size = lengths.numel() 88 | max_len = max_len or lengths.max() 89 | return (torch.arange(0, max_len, device=lengths.device) # (0 for pad positions) 90 | .type_as(lengths) 91 | .repeat(batch_size, 1) 92 | .lt(lengths.unsqueeze(1))) 93 | 94 | 95 | def tile(x, count, dim=0): 96 | """ 97 | Tiles x on dimension dim count times. 98 | """ 99 | perm = list(range(len(x.size()))) 100 | if dim != 0: 101 | perm[0], perm[dim] = perm[dim], perm[0] 102 | x = x.permute(perm).contiguous() 103 | out_size = list(x.size()) 104 | out_size[0] *= count 105 | batch = x.size(0) 106 | x = x.view(batch, -1) \ 107 | .transpose(0, 1) \ 108 | .repeat(count, 1) \ 109 | .transpose(0, 1) \ 110 | .contiguous() \ 111 | .view(*out_size) 112 | if dim != 0: 113 | x = x.permute(perm).contiguous() 114 | return x 115 | 116 | 117 | def use_gpu(opt): 118 | """ 119 | Creates a boolean if gpu used 120 | """ 121 | return (hasattr(opt, 'gpuid') and len(opt.gpuid) > 0) or \ 122 | (hasattr(opt, 'gpu') and opt.gpu > -1) 123 | 124 | 125 | def generate_relative_positions_matrix(length, 126 | max_relative_positions, 127 | use_neg_dist, 128 | cache=False): 129 | """Generate the clipped relative positions matrix 130 | for a given length and maximum relative positions""" 131 | if cache: 132 | distance_mat = torch.arange(-length + 1, 1, 1).unsqueeze(0) 133 | else: 134 | range_vec = torch.arange(length) 135 | range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) 136 | distance_mat = range_mat - range_mat.transpose(0, 1) 137 | 138 | distance_mat_clipped = torch.clamp(distance_mat, 139 | min=-max_relative_positions, 140 | max=max_relative_positions) 141 | 142 | # Shift values to be >= 0 143 | if use_neg_dist: 144 | final_mat = distance_mat_clipped + max_relative_positions 145 | else: 146 | final_mat = torch.abs(distance_mat_clipped) 147 | 148 | return final_mat 149 | 150 | 151 | def relative_matmul(x, z, transpose): 152 | """Helper function for relative positions attention.""" 153 | batch_size = x.shape[0] 154 | heads = x.shape[1] 155 | length = x.shape[2] 156 | x_t = x.permute(2, 0, 1, 3) 157 | x_t_r = x_t.reshape(length, heads * batch_size, -1) 158 | if transpose: 159 | z_t = z.transpose(1, 2) 160 | x_tz_matmul = torch.matmul(x_t_r, z_t) 161 | else: 162 | x_tz_matmul = torch.matmul(x_t_r, z) 163 | x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1) 164 | x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) 165 | return x_tz_matmul_r_t 166 | 167 | 168 | def count_file_lines(file_path): 169 | """ 170 | Counts the number of lines in a file using wc utility. 171 | :param file_path: path to file 172 | :return: int, no of lines 173 | """ 174 | num = subprocess.check_output(['wc', '-l', file_path]) 175 | num = num.decode('utf-8').split(' ') 176 | return int(num[0]) 177 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Dataset: python-method 2 | 3 | - Paper: https://arxiv.org/abs/1707.02275 4 | - Data source: https://github.com/EdinburghNLP/code-docstring-corpus 5 | 6 | Run the `get_data.sh` script inside the `python` directory. Once finished, we will see a pretty table summarizing the data statistics. 7 | 8 | ``` 9 | +------------------------+---------+--------+--------+---------+ 10 | | Attribute | Train | Valid | Test | Fullset | 11 | +------------------------+---------+--------+--------+---------+ 12 | | Records | 55538 | 18505 | 18502 | 92545 | 13 | | Function Tokens | 2670849 | 887331 | 882013 | 4440193 | 14 | | Javadoc Tokens | 525524 | 175429 | 176673 | 877626 | 15 | | Unique Function Tokens | 159968 | 73862 | 73766 | 307596 | 16 | | Unique Javadoc Tokens | 27197 | 14462 | 14530 | 56189 | 17 | | Avg. Function Length | 48.09 | 47.95 | 47.67 | 47.98 | 18 | | Avg. Javadoc Length | 9.46 | 9.48 | 9.55 | 9.48 | 19 | +------------------------+---------+--------+--------+---------+ 20 | ``` 21 | 22 | **Acknowledgement**: We thank the authors of [Bolin et al., 2019](https://arxiv.org/abs/1910.05923) for sharing the preprocessed python dataset that we used in our experiments. 23 | 24 | ### Dataset: tlcodesum 25 | 26 | - Paper: https://xin-xia.github.io/publication/ijcai18.pdf 27 | - Data source: https://github.com/xing-hu/TL-CodeSum 28 | 29 | Run the `get_data.sh` script inside the `java` directory. Once finished, we will see a pretty table summarizing the data statistics. 30 | 31 | ``` 32 | +------------------------+---------+---------+---------+----------+ 33 | | Attribute | Train | Valid | Test | Fullset | 34 | +------------------------+---------+---------+---------+----------+ 35 | | Records | 69708 | 8714 | 8714 | 87136 | 36 | | Function Tokens | 8371911 | 1042466 | 1055733 | 10470110 | 37 | | Javadoc Tokens | 1235295 | 155876 | 153407 | 1544578 | 38 | | Unique Function Tokens | 36202 | 15317 | 15131 | 66650 | 39 | | Unique Javadoc Tokens | 28047 | 9555 | 9293 | 46895 | 40 | | Avg. Function Length | 120.10 | 119.63 | 121.15 | 120.16 | 41 | | Avg. Javadoc Length | 17.72 | 17.89 | 17.60 | 17.73 | 42 | +------------------------+---------+---------+---------+----------+ 43 | ``` 44 | 45 | ### Direct Data Download 46 | 47 | You can directly download our experiment dataset from [here](https://drive.google.com/drive/folders/1Mx0xEPZfQzb5h0z753XV-JgoWUuxiuKZ?usp=sharing). 48 | -------------------------------------------------------------------------------- /data/java/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Downloading TL-CodeSum dataset" 4 | FILE=java.zip 5 | if [[ -f "$FILE" ]]; then 6 | echo "$FILE exists, skipping download" 7 | else 8 | # https://drive.google.com/open?id=13o4MiELiQoomlly2TCTpbtGee_HdQZxl 9 | fileid="13o4MiELiQoomlly2TCTpbtGee_HdQZxl" 10 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 11 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${FILE} 12 | rm ./cookie 13 | unzip ${FILE} && rm ${FILE} 14 | fi 15 | 16 | echo "Aggregating statistics of the dataset" 17 | python get_stat.py 18 | -------------------------------------------------------------------------------- /data/java/get_stat.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from tqdm import tqdm 3 | from prettytable import PrettyTable 4 | 5 | 6 | def count_file_lines(file_path): 7 | """ 8 | Counts the number of lines in a file using wc utility. 9 | :param file_path: path to file 10 | :return: int, no of lines 11 | """ 12 | num = subprocess.check_output(['wc', '-l', file_path]) 13 | num = num.decode('utf-8').split(' ') 14 | return int(num[0]) 15 | 16 | 17 | def main(): 18 | records = {'train': 0, 'dev': 0, 'test': 0} 19 | function_tokens = {'train': 0, 'dev': 0, 'test': 0} 20 | javadoc_tokens = {'train': 0, 'dev': 0, 'test': 0} 21 | unique_function_tokens = {'train': set(), 'dev': set(), 'test': set()} 22 | unique_javadoc_tokens = {'train': set(), 'dev': set(), 'test': set()} 23 | 24 | attribute_list = ["Records", "Function Tokens", "Javadoc Tokens", 25 | "Unique Function Tokens", "Unique Javadoc Tokens"] 26 | 27 | def read_data(split): 28 | source = '%s/code.original_subtoken' % split 29 | target = '%s/javadoc.original' % split 30 | with open(source) as f1, open(target) as f2: 31 | for src, tgt in tqdm(zip(f1, f2), 32 | total=count_file_lines(source)): 33 | func_tokens = src.strip().split() 34 | comm_tokens = tgt.strip().split() 35 | records[split] += 1 36 | function_tokens[split] += len(func_tokens) 37 | javadoc_tokens[split] += len(comm_tokens) 38 | unique_function_tokens[split].update(func_tokens) 39 | unique_javadoc_tokens[split].update(comm_tokens) 40 | 41 | read_data('train') 42 | read_data('dev') 43 | read_data('test') 44 | 45 | table = PrettyTable() 46 | table.field_names = ["Attribute", "Train", "Valid", "Test", "Fullset"] 47 | table.align["Attribute"] = "l" 48 | table.align["Train"] = "r" 49 | table.align["Valid"] = "r" 50 | table.align["Test"] = "r" 51 | table.align["Fullset"] = "r" 52 | for attr in attribute_list: 53 | var = eval('_'.join(attr.lower().split())) 54 | val1 = len(var['train']) if isinstance(var['train'], set) else var['train'] 55 | val2 = len(var['dev']) if isinstance(var['dev'], set) else var['dev'] 56 | val3 = len(var['test']) if isinstance(var['test'], set) else var['test'] 57 | fullset = val1 + val2 + val3 58 | table.add_row([attr, val1, val2, val3, fullset]) 59 | 60 | avg = (function_tokens['train'] + function_tokens['dev'] + function_tokens['test']) / ( 61 | records['train'] + records['dev'] + records['test']) 62 | table.add_row([ 63 | 'Avg. Function Length', 64 | '%.2f' % (function_tokens['train'] / records['train']), 65 | '%.2f' % (function_tokens['dev'] / records['dev']), 66 | '%.2f' % (function_tokens['test'] / records['test']), 67 | '%.2f' % avg 68 | ]) 69 | avg = (javadoc_tokens['train'] + javadoc_tokens['dev'] + javadoc_tokens['test']) / ( 70 | records['train'] + records['dev'] + records['test']) 71 | table.add_row([ 72 | 'Avg. Javadoc Length', 73 | '%.2f' % (javadoc_tokens['train'] / records['train']), 74 | '%.2f' % (javadoc_tokens['dev'] / records['dev']), 75 | '%.2f' % (javadoc_tokens['test'] / records['test']), 76 | '%.2f' % avg 77 | ]) 78 | print(table) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /data/java/sample.code: -------------------------------------------------------------------------------- 1 | private int current Depth ( ) { try { Integer one Based = ( ( Integer ) DEPTH FIELD . get ( this ) ) ; return one Based - NUM ; } catch ( Illegal Access Exception e ) { throw new Assertion Error ( e ) ; } } 2 | protected boolean [ ] dataset Integrity ( boolean nominal Predictor , boolean numeric Predictor , boolean string Predictor , boolean date Predictor , boolean relational Predictor , boolean multi Instance , int class Type , boolean predictor Missing , boolean class Missing ) { print ( STRING ) ; print Attribute Summary ( nominal Predictor , numeric Predictor , string Predictor , date Predictor , relational Predictor , multi Instance , class Type ) ; print ( STRING ) ; int num Train = get Num Instances ( ) , num Classes = NUM , missing Level = NUM ; boolean [ ] result = new boolean [ NUM ] ; Instances train = null ; Kernel kernel = null ; try { train = make Test Dataset ( NUM , num Train , nominal Predictor ? get Num Nominal ( ) : NUM , numeric Predictor ? get Num Numeric ( ) : NUM , string Predictor ? get Num String ( ) : NUM , date Predictor ? get Num Date ( ) : NUM , relational Predictor ? get Num Relational ( ) : NUM , num Classes , class Type , multi Instance ) ; if ( missing Level > NUM ) { add Missing ( train , missing Level , predictor Missing , class Missing ) ; } kernel = Kernel . make Copies ( get Kernel ( ) , NUM ) [ NUM ] ; } catch ( Exception ex ) { throw new Error ( STRING + ex . get Message ( ) ) ; } try { Instances train Copy = new Instances ( train ) ; kernel . build Kernel ( train Copy ) ; compare Datasets ( train , train Copy ) ; println ( STRING ) ; result [ NUM ] = BOOL ; } catch ( Exception ex ) { println ( STRING ) ; result [ NUM ] = BOOL ; if ( m Debug ) { println ( STRING ) ; print ( STRING ) ; println ( STRING + ex . get Message ( ) + STRING ) ; println ( STRING ) ; println ( STRING + train . to String ( ) + STRING ) ; } } return result ; } 3 | public static int union Size ( long [ ] x , long [ ] y ) { final int lx = x . length , ly = y . length ; final int min = ( lx < ly ) ? lx : ly ; int i = NUM , res = NUM ; for ( ; i < min ; i ++ ) { res += Long . bit Count ( x [ i ] | y [ i ] ) ; } for ( ; i < lx ; i ++ ) { res += Long . bit Count ( x [ i ] ) ; } for ( ; i < ly ; i ++ ) { res += Long . bit Count ( y [ i ] ) ; } return res ; } 4 | public void test Reverse Order 4 ( ) throws Exception { UUID id = UUID . random UUID ( ) ; Grid Cache Adapter < String , String > cache = grid . internal Cache ( ) ; Grid Cache Context < String , String > ctx = cache . context ( ) ; Grid Cache Test Entry Ex entry 1 = new Grid Cache Test Entry Ex ( ctx , STRING ) ; Grid Cache Test Entry Ex entry 2 = new Grid Cache Test Entry Ex ( ctx , STRING ) ; Grid Cache Version ver 1 = version ( NUM ) ; Grid Cache Version ver 2 = version ( NUM ) ; Grid Cache Version ver 3 = version ( NUM ) ; Grid Cache Mvcc Candidate v3 k 1 = entry 1 . add Local ( NUM , ver 3 , NUM , BOOL , BOOL ) ; Grid Cache Mvcc Candidate v3 k 2 = entry 2 . add Local ( NUM , ver 3 , NUM , BOOL , BOOL ) ; link Candidates ( ctx , v3 k 1 , v3 k 2 ) ; entry 1 . ready Local ( ver 3 ) ; check Local ( v3 k 1 , ver 3 , BOOL , BOOL , BOOL ) ; check Local ( v3 k 2 , ver 3 , BOOL , BOOL , BOOL ) ; Grid Cache Mvcc Candidate v1 k 1 = entry 1 . add Local ( NUM , ver 1 , NUM , BOOL , BOOL ) ; Grid Cache Mvcc Candidate v1 k 2 = entry 2 . add Local ( NUM , ver 1 , NUM , BOOL , BOOL ) ; link Candidates ( ctx , v1 k 1 , v1 k 2 ) ; entry 1 . ready Local ( ver 1 ) ; entry 2 . ready Local ( ver 1 ) ; check Local ( v3 k 1 , ver 3 , BOOL , BOOL , BOOL ) ; check Local ( v3 k 2 , ver 3 , BOOL , BOOL , BOOL ) ; check Local ( v1 k 1 , ver 1 , BOOL , BOOL , BOOL ) ; check Local ( v1 k 2 , ver 1 , BOOL , BOOL , BOOL ) ; Grid Cache Mvcc Candidate v2 k 2 = entry 2 . add Remote ( id , NUM , ver 2 , NUM , BOOL , BOOL ) ; check Remote ( v2 k 2 , ver 2 , BOOL , BOOL ) ; entry 2 . ready Local ( v3 k 2 ) ; check Local ( v3 k 1 , ver 3 , BOOL , BOOL , BOOL ) ; check Local ( v3 k 2 , ver 3 , BOOL , BOOL , BOOL ) ; } 5 | @ Override public void closing OK ( ) { List < Add User Fields . Attribute Spec > specs = new Array List < Add User Fields . Attribute Spec > ( ) ; for ( int i = NUM ; i < m list Model . size ( ) ; i ++ ) { Add User Fields . Attribute Spec a = ( Add User Fields . Attribute Spec ) m list Model . element At ( i ) ; specs . add ( a ) ; } if ( m modify L != null ) { m modify L . set Modified Status ( Add User Fields Customizer . this , BOOL ) ; } m filter . set Attribute Specs ( specs ) ; } 6 | public String to String ( ) { String Buffer text = new String Buffer ( ) ; if ( ( m class Attribute == null ) ) { return STRING ; } try { text . append ( STRING ) ; text . append ( STRING + m kernel . to String ( ) + STRING ) ; for ( int i = NUM ; i < m class Attribute . num Values ( ) ; i ++ ) { for ( int j = i + NUM ; j < m class Attribute . num Values ( ) ; j ++ ) { text . append ( STRING + m class Attribute . value ( i ) + STRING + m class Attribute . value ( j ) + STRING ) ; text . append ( m classifiers [ i ] [ j ] ) ; if ( m fit Logistic Models ) { text . append ( STRING ) ; if ( m classifiers [ i ] [ j ] . m logistic == null ) { text . append ( STRING ) ; } else { text . append ( m classifiers [ i ] [ j ] . m logistic ) ; } } text . append ( STRING ) ; } } } catch ( Exception e ) { return STRING ; } return text . to String ( ) ; } 7 | public final Sector union ( Sector that ) { if ( that == null ) return this ; Angle min Lat = this . min Latitude ; Angle max Lat = this . max Latitude ; Angle min Lon = this . min Longitude ; Angle max Lon = this . max Longitude ; if ( that . min Latitude . degrees < this . min Latitude . degrees ) min Lat = that . min Latitude ; if ( that . max Latitude . degrees > this . max Latitude . degrees ) max Lat = that . max Latitude ; if ( that . min Longitude . degrees < this . min Longitude . degrees ) min Lon = that . min Longitude ; if ( that . max Longitude . degrees > this . max Longitude . degrees ) max Lon = that . max Longitude ; return new Sector ( min Lat , max Lat , min Lon , max Lon ) ; } 8 | private static boolean is Double Equal ( double value , double value To Compare ) { return ( Math . abs ( value - value To Compare ) < NUM ) ; } 9 | public void test Int Value Pos ( ) { String a = STRING ; Big Decimal a Number = new Big Decimal ( a ) ; int result = - NUM ; assert True ( STRING , a Number . int Value ( ) == result ) ; } 10 | private void walk ( File directory , int depth , Collection < T > results ) throws IO Exception { check If Cancelled ( directory , depth , results ) ; if ( handle Directory ( directory , depth , results ) ) { handle Directory Start ( directory , depth , results ) ; int child Depth = depth + NUM ; if ( depth Limit < NUM || child Depth <= depth Limit ) { check If Cancelled ( directory , depth , results ) ; File [ ] child Files = filter == null ? directory . list Files ( ) : directory . list Files ( filter ) ; child Files = filter Directory Contents ( directory , depth , child Files ) ; if ( child Files == null ) { handle Restricted ( directory , child Depth , results ) ; } else { for ( File child File : child Files ) { if ( child File . is Directory ( ) ) { walk ( child File , child Depth , results ) ; } else { check If Cancelled ( child File , child Depth , results ) ; handle File ( child File , child Depth , results ) ; check If Cancelled ( child File , child Depth , results ) ; } } } } handle Directory End ( directory , depth , results ) ; } check If Cancelled ( directory , depth , results ) ; } 11 | public static Object [ ] ordinal Array ( Tuple Set tuples , String field ) { return ordinal Array ( tuples , field , Default Literal Comparator . get Instance ( ) ) ; } 12 | public void test Divide Exception Invalid RM ( ) { String a = STRING ; int a Scale = NUM ; String b = STRING ; int b Scale = NUM ; Big Decimal a Number = new Big Decimal ( new Big Integer ( a ) , a Scale ) ; Big Decimal b Number = new Big Decimal ( new Big Integer ( b ) , b Scale ) ; try { a Number . divide ( b Number , NUM ) ; fail ( STRING ) ; } catch ( Illegal Argument Exception e ) { assert Equals ( STRING , STRING , e . get Message ( ) ) ; } } 13 | @ Override public void dataset Changed ( Dataset Change Event event ) { super . dataset Changed ( event ) ; if ( this . subplots == null ) { return ; } XY Dataset dataset = null ; if ( event . get Dataset ( ) instanceof XY Dataset ) { dataset = ( XY Dataset ) event . get Dataset ( ) ; } for ( XY Plot subplot : this . subplots ) { if ( subplot . index Of ( dataset ) >= NUM ) { subplot . configure Range Axes ( ) ; } } } 14 | public boolean on Schedule As Library ( Config config , Config runtime , I Scheduler scheduler , Packing Plan packing ) { boolean ret = BOOL ; try { scheduler . initialize ( config , runtime ) ; ret = scheduler . on Schedule ( packing ) ; if ( ret ) { ret = Scheduler Utils . set Lib Scheduler Location ( runtime , scheduler , BOOL ) ; } else { LOG . severe ( STRING ) ; } } finally { scheduler . close ( ) ; } return ret ; } 15 | public static boolean is String Type ( Type t ) { return t . equals ( Ref Type . v ( STRING ) ) ; } 16 | public Entry update Or Create Source ( User user , String id , String url , String title , Long mod Time , Long create Time , boolean is Admin , Errors errors ) { if ( user == null ) { Errors . add ( errors , error Messages . error User Is Null ( ) ) ; return null ; } if ( url == null ) { Errors . add ( errors , error Messages . error Url Is Null ( ) ) ; return null ; } Entry source = get Entry By User Id And Url ( user . get Id ( ) , url ) ; if ( source == null ) { if ( url . is Empty ( ) ) { Errors . add ( errors , error Messages . error Url Is Empty ( ) ) ; return null ; } if ( title == null ) { Errors . add ( errors , error Messages . error Title Is Null ( ) ) ; return null ; } if ( title . is Empty ( ) ) { Errors . add ( errors , error Messages . error Title Is Empty ( ) ) ; return null ; } if ( mod Time == null ) { Errors . add ( errors , error Messages . error Mod Time Is Null ( ) ) ; return null ; } if ( create Time == null ) { Errors . add ( errors , error Messages . error Create Time Is Null ( ) ) ; return null ; } if ( id != null && ! id Generator . is Id Well Formed ( id ) ) { Errors . add ( errors , error Messages . error Id Is Invalid ( ) ) ; return null ; } if ( create Time . long Value ( ) > mod Time . long Value ( ) ) { mod Time = create Time ; } if ( url != null ) { url = clean Up Text ( url ) ; } if ( title != null ) { title = clean Up Text ( title ) ; } source = new Entry ( ) ; source . set Db ( db ) ; if ( id == null ) { id = id Generator . get Another Id ( ) ; } source . set Id ( id ) ; source . set Source Url ( url ) ; source . set Source Title ( title ) ; source . set Create Time ( create Time ) ; source . set Type ( Constants . source ) ; source . set User Id ( user . get Id ( ) ) ; db . persist Entry ( source ) ; } else if ( ! can User Modify Entry ( user , source , is Admin ) ) { Errors . add ( errors , error Messages . error User Is Not Entitled To Modify The Source ( ) ) ; return null ; } source . set Mod Time ( mod Time ) ; return source ; } 17 | private void validate Sql Statement ( String sql , int jdbc Statement Index ) { Assert . is True ( String Utils . is Not Blank ( sql ) , STRING + jdbc Statement Index + STRING ) ; } 18 | public static Long [ ] values Of ( long [ ] array ) { Long [ ] dest = new Long [ array . length ] ; for ( int i = NUM ; i < array . length ; i ++ ) { dest [ i ] = Long . value Of ( array [ i ] ) ; } return dest ; } 19 | @ Override public boolean is Trace Enabled ( ) { return logger . is Loggable ( Level . FINEST ) ; } 20 | private static void use Missile ( Player player ) { Stackable Item projectiles Item = null ; if ( player . get Range Weapon ( ) != null ) { projectiles Item = player . get Ammunition ( ) ; } if ( projectiles Item == null ) { projectiles Item = player . get Missile If Not Holding Other Weapon ( ) ; } if ( projectiles Item != null ) { projectiles Item . remove One ( ) ; } } 21 | -------------------------------------------------------------------------------- /data/python/get_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo "Downloading python-method dataset" 4 | FILE=python.zip 5 | if [[ -f "$FILE" ]]; then 6 | echo "$FILE exists, skipping download" 7 | else 8 | # https://drive.google.com/open?id=1XPE1txk9VI0aOT_TdqbAeI58Q8puKVl2 9 | fileid="1XPE1txk9VI0aOT_TdqbAeI58Q8puKVl2" 10 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 11 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${FILE} 12 | rm ./cookie 13 | unzip ${FILE} && rm ${FILE} 14 | fi 15 | 16 | echo "Aggregating statistics of the dataset" 17 | python get_stat.py 18 | -------------------------------------------------------------------------------- /data/python/get_stat.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from tqdm import tqdm 3 | from prettytable import PrettyTable 4 | 5 | 6 | def count_file_lines(file_path): 7 | """ 8 | Counts the number of lines in a file using wc utility. 9 | :param file_path: path to file 10 | :return: int, no of lines 11 | """ 12 | num = subprocess.check_output(['wc', '-l', file_path]) 13 | num = num.decode('utf-8').split(' ') 14 | return int(num[0]) 15 | 16 | 17 | def main(): 18 | records = {'train': 0, 'dev': 0, 'test': 0} 19 | function_tokens = {'train': 0, 'dev': 0, 'test': 0} 20 | javadoc_tokens = {'train': 0, 'dev': 0, 'test': 0} 21 | unique_function_tokens = {'train': set(), 'dev': set(), 'test': set()} 22 | unique_javadoc_tokens = {'train': set(), 'dev': set(), 'test': set()} 23 | 24 | attribute_list = ["Records", "Function Tokens", "Javadoc Tokens", 25 | "Unique Function Tokens", "Unique Javadoc Tokens"] 26 | 27 | def read_data(split): 28 | source = '%s/code.original_subtoken' % split 29 | target = '%s/javadoc.original' % split 30 | with open(source) as f1, open(target) as f2: 31 | for src, tgt in tqdm(zip(f1, f2), 32 | total=count_file_lines(source)): 33 | func_tokens = src.strip().split() 34 | comm_tokens = tgt.strip().split() 35 | records[split] += 1 36 | function_tokens[split] += len(func_tokens) 37 | javadoc_tokens[split] += len(comm_tokens) 38 | unique_function_tokens[split].update(func_tokens) 39 | unique_javadoc_tokens[split].update(comm_tokens) 40 | 41 | read_data('train') 42 | read_data('dev') 43 | read_data('test') 44 | 45 | table = PrettyTable() 46 | table.field_names = ["Attribute", "Train", "Valid", "Test", "Fullset"] 47 | table.align["Attribute"] = "l" 48 | table.align["Train"] = "r" 49 | table.align["Valid"] = "r" 50 | table.align["Test"] = "r" 51 | table.align["Fullset"] = "r" 52 | for attr in attribute_list: 53 | var = eval('_'.join(attr.lower().split())) 54 | val1 = len(var['train']) if isinstance(var['train'], set) else var['train'] 55 | val2 = len(var['dev']) if isinstance(var['dev'], set) else var['dev'] 56 | val3 = len(var['test']) if isinstance(var['test'], set) else var['test'] 57 | fullset = val1 + val2 + val3 58 | table.add_row([attr, val1, val2, val3, fullset]) 59 | 60 | avg = (function_tokens['train'] + function_tokens['dev'] + function_tokens['test']) / ( 61 | records['train'] + records['dev'] + records['test']) 62 | table.add_row([ 63 | 'Avg. Function Length', 64 | '%.2f' % (function_tokens['train'] / records['train']), 65 | '%.2f' % (function_tokens['dev'] / records['dev']), 66 | '%.2f' % (function_tokens['test'] / records['test']), 67 | '%.2f' % avg 68 | ]) 69 | avg = (javadoc_tokens['train'] + javadoc_tokens['dev'] + javadoc_tokens['test']) / ( 70 | records['train'] + records['dev'] + records['test']) 71 | table.add_row([ 72 | 'Avg. Javadoc Length', 73 | '%.2f' % (javadoc_tokens['train'] / records['train']), 74 | '%.2f' % (javadoc_tokens['dev'] / records['dev']), 75 | '%.2f' % (javadoc_tokens['test'] / records['test']), 76 | '%.2f' % avg 77 | ]) 78 | print(table) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /data/python/sample.code: -------------------------------------------------------------------------------- 1 | def resource patch context data dict check access 'resource patch' context data dict show context {'model' context['model'] 'session' context['session'] 'user' context['user'] 'auth user obj' context['auth user obj']}resource dict get action 'resource show' show context {'id' get or bust data dict 'id' } patched dict resource dict patched update data dict return update resource update context patched 2 | def pyramid laplacian image max layer -1 downscale 2 sigma None order 1 mode 'reflect' cval 0 check factor downscale image img as float image if sigma is None sigma 2 * downscale / 6 0 layer 0rows image shape[ 0 ]cols image shape[ 1 ]smoothed image smooth image sigma mode cval yield image - smoothed image while layer max layer layer + 1out rows math ceil rows / float downscale out cols math ceil cols / float downscale resized image resize smoothed image out rows out cols order order mode mode cval cval smoothed image smooth resized image sigma mode cval prev rows rowsprev cols colsrows resized image shape[ 0 ]cols resized image shape[ 1 ]if prev rows rows and prev cols cols break yield resized image - smoothed image 3 | def get fun fun with get serv ret None commit True as cur sql 'SELEC Ts id s jid s full ret\n FRO Msalt returnss\n JOIN SELECTMAX `jid` asjid\nfromsalt returns GROUPB Yfun id max\n O Ns jid max jid\n WHER Es fun %s\n'cur execute sql fun data cur fetchall ret {}if data for minion full ret in data ret[minion] full retreturn ret 4 | def get svc avail path return AVAIL SVR DIRS 5 | def store temp file filedata filename path None filename get filename from path filename filename filename[ 100 ]options Config if path target path pathelse tmp path options cuckoo get 'tmppath' '/tmp' target path os path join tmp path 'cuckoo-tmp' if not os path exists target path os mkdir target path tmp dir tempfile mkdtemp prefix 'upload ' dir target path tmp file path os path join tmp dir filename with open tmp file path 'wb' as tmp file if hasattr filedata 'read' chunk filedata read 1024 while chunk tmp file write chunk chunk filedata read 1024 else tmp file write filedata return tmp file path 6 | def create Target Dirs if not os path isdir paths POCSUITE OUTPUT PATH try if not os path isdir paths POCSUITE OUTPUT PATH os makedirs paths POCSUITE OUTPUT PATH 493 warn Msg "using'%s'astheoutputdirectory" % paths POCSUITE OUTPUT PATH logger log CUSTOM LOGGING WARNING warn Msg except OS Error IO Error as ex try temp Dir tempfile mkdtemp prefix 'pocsuiteoutput' except Exception as err Msg "unabletowritetothetemporarydirectory '%s' " % err Msg + ' Pleasemakesurethatyourdiskisnotfulland'err Msg + 'thatyouhavesufficientwritepermissionsto'err Msg + 'createtemporaryfilesand/ordirectories'raise Pocsuite System Exception err Msg warn Msg 'unabletocreateregularoutputdirectory'warn Msg + "'%s' %s " % paths POCSUITE OUTPUT PATH get Unicode ex warn Msg + " Usingtemporarydirectory'%s'instead" % get Unicode temp Dir logger log CUSTOM LOGGING WARNING warn Msg paths POCUSITE OUTPUT PATH temp Dir 7 | def unhex s bits 0for c in s c bytes c if '0 ' < c < '9 ' i ord '0 ' elif 'a' < c < 'f' i ord 'a' - 10 elif 'A' < c < 'F' i ord 'A' - 10 else assert False 'non-hexdigit' + repr c bits bits * 16 + ord c - i return bits 8 | def setwindowposition folder alias x y finder getfinder args {}attrs {}aeobj 0 aetypes Object Specifier want aetypes Type 'cfol' form 'alis' seld folder alias fr None aeobj 1 aetypes Object Specifier want aetypes Type 'prop' form 'prop' seld aetypes Type 'cwnd' fr aeobj 0 aeobj 2 aetypes Object Specifier want aetypes Type 'prop' form 'prop' seld aetypes Type 'posn' fr aeobj 1 args['----'] aeobj 2args['data'] [x y] reply args attrs finder send 'core' 'setd' args attrs if args has key 'errn' raise Error aetools decodeerror args if args has key '----' return args['----'] 9 | def walk top topdown True followlinks False names os listdir top dirs nondirs [] [] for name in names if path isdir path join top name dirs append name else nondirs append name if topdown yield top dirs nondirs for name in dirs fullpath path join top name if followlinks or not path islink fullpath for x in walk fullpath topdown followlinks yield x if not topdown yield top dirs nondirs 10 | def url filename url match upload title re match url if match return match group 'filename' else return url 11 | @testing requires testing data@requires mnedef test other volume source spaces tempdir Temp Dir temp name op join tempdir 'temp-src fif' run subprocess ['mne volume source space' '--grid' '7 0' '--src' temp name '--mri' fname mri] src read source spaces temp name src new setup volume source space None pos 7 0 mri fname mri subjects dir subjects dir compare source spaces src src new mode 'approx' assert true 'volume shape' in repr src del srcdel src newassert raises Value Error setup volume source space 'sample' temp name pos 7 0 sphere [1 0 1 0] mri fname mri subjects dir subjects dir run subprocess ['mne volume source space' '--grid' '7 0' '--src' temp name] assert raises Value Error read source spaces temp name 12 | def get preamble latex preamble rc Params get u'pgf preamble' u'' if type latex preamble list latex preamble u'\n' join latex preamble return latex preamble 13 | def sdm spoly f g O K phantom None if not f or not g return sdm zero LM 1 sdm LM f LM 2 sdm LM g if LM 1 [ 0 ] LM 2 [ 0 ] return sdm zero LM 1 LM 1 [ 1 ]LM 2 LM 2 [ 1 ]lcm monomial lcm LM 1 LM 2 m1 monomial div lcm LM 1 m2 monomial div lcm LM 2 c K quo - sdm LC f K sdm LC g K r1 sdm add sdm mul term f m1 K one O K sdm mul term g m2 c O K O K if phantom is None return r1 r 2 sdm add sdm mul term phantom[ 0 ] m1 K one O K sdm mul term phantom[ 1 ] m2 c O K O K return r1 r2 14 | def upgrade migrate engine meta Meta Data meta bind migrate enginevolume type projects Table 'volume type projects' meta autoload True if migrate engine name 'postgresql' sql 'ALTERTABL Evolume type projects ALTERCOLUM Ndeleted' + 'TYPEINTEGERUSIN Gdeleted integer' migrate engine execute sql else volume type projects c deleted alter Integer 15 | def analyze modules project task handle taskhandle Null Task Handle resources project get python files job set task handle create jobset ' Analyzing Modules' len resources for resource in resources job set started job resource path analyze module project resource job set finished job 16 | def get sw login version return '-' join get sw version strip build num True split '-' [1 -2 ] 17 | def Set Help Menu Other Help main Menu global help ID Mapif help ID Map is None help ID Map {}cmd ID win 32 ui ID HELP OTHE Rexclude List [' Main Python Documentation' ' Pythonwin Reference']first List List All Help Files exclude Fnames []for desc fname in first List if desc in exclude List exclude Fnames append fname help Descs []for desc fname in first List if fname not in exclude Fnames help ID Map[cmd ID] desc fname win 32 ui Get Main Frame Hook Command Handle Help Other Command cmd ID cmd ID cmd ID + 1 help Menu main Menu Get Sub Menu main Menu Get Menu Item Count - 1 other Help Menu Pos 2other Menu help Menu Get Sub Menu other Help Menu Pos while other Menu Get Menu Item Count other Menu Delete Menu 0 win 32 con MF BYPOSITION if help ID Map for id desc fname in help ID Map iteritems other Menu Append Menu win 32 con MF ENABLED win 32 con MF STRING id desc else help Menu Enable Menu Item other Help Menu Pos win 32 con MF BYPOSITION win 32 con MF GRAYED 18 | def to location code falcon HTTP 302 raise falcon http status HTTP Status code {'location' location} 19 | def select command corrected commands try selector Command Selector corrected commands except No Rule Matched logs failed ' Nofucksgiven' returnif not settings require confirmation logs show corrected command selector value return selector valuelogs confirm text selector value for action in read actions if action const ACTION SELECT sys stderr write '\n' return selector valueelif action const ACTION ABORT logs failed '\n Aborted' returnelif action const ACTION PREVIOUS selector previous logs confirm text selector value elif action const ACTION NEXT selector next logs confirm text selector value 20 | def partial project endog exog x1 x2 endog exog params np linalg pinv x2 dot x1 predicted x2 dot params residual x1 - predicted res Bunch params params fittedvalues predicted resid residual return res 21 | -------------------------------------------------------------------------------- /main/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'wasi' 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | nltk 4 | prettytable 5 | torch>=1.3.0 6 | -------------------------------------------------------------------------------- /scripts/generate.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function make_dir () { 4 | if [[ ! -d "$1" ]]; then 5 | mkdir $1 6 | fi 7 | } 8 | 9 | SRC_DIR=.. 10 | DATA_DIR=${SRC_DIR}/data 11 | MODEL_DIR=${SRC_DIR}/tmp 12 | 13 | make_dir $MODEL_DIR 14 | 15 | DATASET=java 16 | 17 | function generate () { 18 | 19 | echo "============Generating (Beam)============" 20 | 21 | RGPU=$1 22 | MODEL_NAME=$2 23 | 24 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/test.py \ 25 | --only_generate True \ 26 | --data_workers 5 \ 27 | --dataset_name $DATASET \ 28 | --data_dir ${DATA_DIR}/ \ 29 | --model_dir $MODEL_DIR \ 30 | --model_name $MODEL_NAME \ 31 | --dev_src $3 \ 32 | --uncase True \ 33 | --max_examples -1 \ 34 | --max_src_len 150 \ 35 | --max_tgt_len 50 \ 36 | --test_batch_size 64 \ 37 | --beam_size 4 \ 38 | --n_best 1 \ 39 | --block_ngram_repeat 3 \ 40 | --stepwise_penalty False \ 41 | --coverage_penalty none \ 42 | --length_penalty none \ 43 | --beta 0 \ 44 | --gamma 0 \ 45 | --replace_unk 46 | 47 | } 48 | 49 | # run: bash generate.sh 0 MODEL_NAME source_code_filename 50 | generate $1 $2 $3 51 | -------------------------------------------------------------------------------- /scripts/java/rnn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function make_dir () { 4 | if [[ ! -d "$1" ]]; then 5 | mkdir $1 6 | fi 7 | } 8 | 9 | SRC_DIR=../.. 10 | DATA_DIR=${SRC_DIR}/data 11 | MODEL_DIR=${SRC_DIR}/tmp 12 | 13 | make_dir $MODEL_DIR 14 | 15 | DATASET=java 16 | CODE_EXTENSION=original_subtoken 17 | JAVADOC_EXTENSION=original 18 | 19 | 20 | function train () { 21 | 22 | echo "============TRAINING============" 23 | 24 | RGPU=$1 25 | MODEL_NAME=$2 26 | 27 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 28 | --data_workers 5 \ 29 | --dataset_name $DATASET \ 30 | --data_dir ${DATA_DIR}/ \ 31 | --model_dir $MODEL_DIR \ 32 | --model_name $MODEL_NAME \ 33 | --train_src train/code.${CODE_EXTENSION} \ 34 | --train_tgt train/javadoc.${JAVADOC_EXTENSION} \ 35 | --dev_src dev/code.${CODE_EXTENSION} \ 36 | --dev_tgt dev/javadoc.${JAVADOC_EXTENSION} \ 37 | --uncase True \ 38 | --use_src_word True \ 39 | --use_src_char False \ 40 | --use_tgt_word True \ 41 | --use_tgt_char False \ 42 | --max_src_len 150 \ 43 | --max_tgt_len 50 \ 44 | --emsize 512 \ 45 | --fix_embeddings False \ 46 | --src_vocab_size 50000 \ 47 | --tgt_vocab_size 30000 \ 48 | --share_decoder_embeddings True \ 49 | --conditional_decoding False \ 50 | --max_examples -1 \ 51 | --batch_size 32 \ 52 | --test_batch_size 64 \ 53 | --num_epochs 200 \ 54 | --model_type rnn \ 55 | --nhid 512 \ 56 | --nlayers 2 \ 57 | --use_all_enc_layers False \ 58 | --dropout_rnn 0.2 \ 59 | --dropout_emb 0.2 \ 60 | --dropout 0.2 \ 61 | --copy_attn True \ 62 | --reuse_copy_attn True \ 63 | --early_stop 20 \ 64 | --optimizer adam \ 65 | --learning_rate 0.002 \ 66 | --lr_decay 0.99 \ 67 | --grad_clipping 5.0 \ 68 | --valid_metric bleu \ 69 | --checkpoint True 70 | 71 | } 72 | 73 | function test () { 74 | 75 | echo "============TESTING============" 76 | 77 | RGPU=$1 78 | MODEL_NAME=$2 79 | 80 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 81 | --only_test True \ 82 | --data_workers 5 \ 83 | --dataset_name $DATASET \ 84 | --data_dir ${DATA_DIR}/ \ 85 | --model_dir $MODEL_DIR \ 86 | --model_name $MODEL_NAME \ 87 | --dev_src test/code.${CODE_EXTENSION} \ 88 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 89 | --uncase True \ 90 | --max_src_len 150 \ 91 | --max_tgt_len 50 \ 92 | --max_examples -1 \ 93 | --test_batch_size 64 94 | 95 | } 96 | 97 | function beam_search () { 98 | 99 | echo "============Beam Search TESTING============" 100 | 101 | RGPU=$1 102 | MODEL_NAME=$2 103 | 104 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/test.py \ 105 | --data_workers 5 \ 106 | --dataset_name $DATASET \ 107 | --data_dir ${DATA_DIR}/ \ 108 | --model_dir $MODEL_DIR \ 109 | --model_name $MODEL_NAME \ 110 | --dev_src test/code.${CODE_EXTENSION} \ 111 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 112 | --uncase True \ 113 | --max_examples -1 \ 114 | --max_src_len 150 \ 115 | --max_tgt_len 50 \ 116 | --test_batch_size 64 \ 117 | --beam_size 4 \ 118 | --n_best 1 \ 119 | --block_ngram_repeat 3 \ 120 | --stepwise_penalty False \ 121 | --coverage_penalty none \ 122 | --length_penalty none \ 123 | --beta 0 \ 124 | --gamma 0 \ 125 | --replace_unk 126 | 127 | } 128 | 129 | train $1 $2 130 | test $1 $2 131 | beam_search $1 $2 132 | -------------------------------------------------------------------------------- /scripts/java/transformer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function make_dir () { 4 | if [[ ! -d "$1" ]]; then 5 | mkdir $1 6 | fi 7 | } 8 | 9 | SRC_DIR=../.. 10 | DATA_DIR=${SRC_DIR}/data 11 | MODEL_DIR=${SRC_DIR}/tmp 12 | 13 | make_dir $MODEL_DIR 14 | 15 | DATASET=java 16 | CODE_EXTENSION=original_subtoken 17 | JAVADOC_EXTENSION=original 18 | 19 | 20 | function train () { 21 | 22 | echo "============TRAINING============" 23 | 24 | RGPU=$1 25 | MODEL_NAME=$2 26 | 27 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 28 | --data_workers 5 \ 29 | --dataset_name $DATASET \ 30 | --data_dir ${DATA_DIR}/ \ 31 | --model_dir $MODEL_DIR \ 32 | --model_name $MODEL_NAME \ 33 | --train_src train/code.${CODE_EXTENSION} \ 34 | --train_tgt train/javadoc.${JAVADOC_EXTENSION} \ 35 | --dev_src dev/code.${CODE_EXTENSION} \ 36 | --dev_tgt dev/javadoc.${JAVADOC_EXTENSION} \ 37 | --uncase True \ 38 | --use_src_word True \ 39 | --use_src_char False \ 40 | --use_tgt_word True \ 41 | --use_tgt_char False \ 42 | --max_src_len 150 \ 43 | --max_tgt_len 50 \ 44 | --emsize 512 \ 45 | --fix_embeddings False \ 46 | --src_vocab_size 50000 \ 47 | --tgt_vocab_size 30000 \ 48 | --share_decoder_embeddings True \ 49 | --max_examples -1 \ 50 | --batch_size 32 \ 51 | --test_batch_size 64 \ 52 | --num_epochs 200 \ 53 | --model_type transformer \ 54 | --num_head 8 \ 55 | --d_k 64 \ 56 | --d_v 64 \ 57 | --d_ff 2048 \ 58 | --src_pos_emb False \ 59 | --tgt_pos_emb True \ 60 | --max_relative_pos 32 \ 61 | --use_neg_dist True \ 62 | --nlayers 6 \ 63 | --trans_drop 0.2 \ 64 | --dropout_emb 0.2 \ 65 | --dropout 0.2 \ 66 | --copy_attn True \ 67 | --early_stop 20 \ 68 | --warmup_steps 2000 \ 69 | --optimizer adam \ 70 | --learning_rate 0.0001 \ 71 | --lr_decay 0.99 \ 72 | --valid_metric bleu \ 73 | --checkpoint True \ 74 | --split_decoder False 75 | } 76 | 77 | function test () { 78 | 79 | echo "============TESTING============" 80 | 81 | RGPU=$1 82 | MODEL_NAME=$2 83 | 84 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 85 | --only_test True \ 86 | --data_workers 5 \ 87 | --dataset_name $DATASET \ 88 | --data_dir ${DATA_DIR}/ \ 89 | --model_dir $MODEL_DIR \ 90 | --model_name $MODEL_NAME \ 91 | --dev_src test/code.${CODE_EXTENSION} \ 92 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 93 | --uncase True \ 94 | --max_src_len 150 \ 95 | --max_tgt_len 50 \ 96 | --max_examples -1 \ 97 | --test_batch_size 64 98 | 99 | } 100 | 101 | function beam_search () { 102 | 103 | echo "============Beam Search TESTING============" 104 | 105 | RGPU=$1 106 | MODEL_NAME=$2 107 | 108 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/test.py \ 109 | --data_workers 5 \ 110 | --dataset_name $DATASET \ 111 | --data_dir ${DATA_DIR}/ \ 112 | --model_dir $MODEL_DIR \ 113 | --model_name $MODEL_NAME \ 114 | --dev_src test/code.${CODE_EXTENSION} \ 115 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 116 | --uncase True \ 117 | --max_examples -1 \ 118 | --max_src_len 150 \ 119 | --max_tgt_len 50 \ 120 | --test_batch_size 64 \ 121 | --beam_size 4 \ 122 | --n_best 1 \ 123 | --block_ngram_repeat 3 \ 124 | --stepwise_penalty False \ 125 | --coverage_penalty none \ 126 | --length_penalty none \ 127 | --beta 0 \ 128 | --gamma 0 \ 129 | --replace_unk 130 | 131 | } 132 | 133 | train $1 $2 134 | test $1 $2 135 | beam_search $1 $2 136 | -------------------------------------------------------------------------------- /scripts/python/rnn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function make_dir () { 4 | if [[ ! -d "$1" ]]; then 5 | mkdir $1 6 | fi 7 | } 8 | 9 | SRC_DIR=../.. 10 | DATA_DIR=${SRC_DIR}/data 11 | MODEL_DIR=${SRC_DIR}/tmp 12 | 13 | make_dir $MODEL_DIR 14 | 15 | DATASET=python 16 | CODE_EXTENSION=original_subtoken 17 | JAVADOC_EXTENSION=original 18 | 19 | 20 | function train () { 21 | 22 | echo "============TRAINING============" 23 | 24 | RGPU=$1 25 | MODEL_NAME=$2 26 | 27 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 28 | --data_workers 5 \ 29 | --dataset_name $DATASET \ 30 | --data_dir ${DATA_DIR}/ \ 31 | --model_dir $MODEL_DIR \ 32 | --model_name $MODEL_NAME \ 33 | --train_src train/code.${CODE_EXTENSION} \ 34 | --train_tgt train/javadoc.${JAVADOC_EXTENSION} \ 35 | --dev_src dev/code.${CODE_EXTENSION} \ 36 | --dev_tgt dev/javadoc.${JAVADOC_EXTENSION} \ 37 | --uncase True \ 38 | --use_src_word True \ 39 | --use_src_char False \ 40 | --use_tgt_word True \ 41 | --use_tgt_char False \ 42 | --max_src_len 400 \ 43 | --max_tgt_len 30 \ 44 | --emsize 512 \ 45 | --fix_embeddings False \ 46 | --src_vocab_size 50000 \ 47 | --tgt_vocab_size 30000 \ 48 | --share_decoder_embeddings True \ 49 | --conditional_decoding False \ 50 | --max_examples -1 \ 51 | --batch_size 32 \ 52 | --test_batch_size 64 \ 53 | --num_epochs 200 \ 54 | --model_type rnn \ 55 | --nhid 512 \ 56 | --nlayers 2 \ 57 | --dropout_rnn 0.2 \ 58 | --dropout_emb 0.2 \ 59 | --dropout 0.2 \ 60 | --copy_attn True \ 61 | --reuse_copy_attn True \ 62 | --early_stop 20 \ 63 | --optimizer adam \ 64 | --learning_rate 0.002 \ 65 | --lr_decay 0.99 \ 66 | --grad_clipping 5.0 \ 67 | --valid_metric bleu \ 68 | --checkpoint True 69 | 70 | } 71 | 72 | function test () { 73 | 74 | echo "============TESTING============" 75 | 76 | RGPU=$1 77 | MODEL_NAME=$2 78 | 79 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 80 | --only_test True \ 81 | --data_workers 5 \ 82 | --dataset_name $DATASET \ 83 | --data_dir ${DATA_DIR}/ \ 84 | --model_dir $MODEL_DIR \ 85 | --model_name $MODEL_NAME \ 86 | --dev_src test/code.${CODE_EXTENSION} \ 87 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 88 | --uncase True \ 89 | --max_src_len 400 \ 90 | --max_tgt_len 30 \ 91 | --max_examples -1 \ 92 | --test_batch_size 64 93 | 94 | } 95 | 96 | function beam_search () { 97 | 98 | echo "============Beam Search TESTING============" 99 | 100 | RGPU=$1 101 | MODEL_NAME=$2 102 | 103 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/test.py \ 104 | --data_workers 5 \ 105 | --dataset_name $DATASET \ 106 | --data_dir ${DATA_DIR}/ \ 107 | --model_dir $MODEL_DIR \ 108 | --model_name $MODEL_NAME \ 109 | --dev_src test/code.${CODE_EXTENSION} \ 110 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 111 | --uncase True \ 112 | --max_examples -1 \ 113 | --max_src_len 400 \ 114 | --max_tgt_len 30 \ 115 | --test_batch_size 64 \ 116 | --beam_size 4 \ 117 | --n_best 1 \ 118 | --block_ngram_repeat 3 \ 119 | --stepwise_penalty False \ 120 | --coverage_penalty none \ 121 | --length_penalty none \ 122 | --beta 0 \ 123 | --gamma 0 \ 124 | --replace_unk 125 | 126 | } 127 | 128 | train $1 $2 129 | test $1 $2 130 | beam_search $1 $2 131 | -------------------------------------------------------------------------------- /scripts/python/transformer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function make_dir () { 4 | if [[ ! -d "$1" ]]; then 5 | mkdir $1 6 | fi 7 | } 8 | 9 | SRC_DIR=../.. 10 | DATA_DIR=${SRC_DIR}/data 11 | MODEL_DIR=${SRC_DIR}/tmp 12 | 13 | make_dir $MODEL_DIR 14 | 15 | DATASET=python 16 | CODE_EXTENSION=original_subtoken 17 | JAVADOC_EXTENSION=original 18 | 19 | 20 | function train () { 21 | 22 | echo "============TRAINING============" 23 | 24 | RGPU=$1 25 | MODEL_NAME=$2 26 | 27 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 28 | --data_workers 5 \ 29 | --dataset_name $DATASET \ 30 | --data_dir ${DATA_DIR}/ \ 31 | --model_dir $MODEL_DIR \ 32 | --model_name $MODEL_NAME \ 33 | --train_src train/code.${CODE_EXTENSION} \ 34 | --train_tgt train/javadoc.${JAVADOC_EXTENSION} \ 35 | --dev_src dev/code.${CODE_EXTENSION} \ 36 | --dev_tgt dev/javadoc.${JAVADOC_EXTENSION} \ 37 | --uncase True \ 38 | --use_src_word True \ 39 | --use_src_char False \ 40 | --use_tgt_word True \ 41 | --use_tgt_char False \ 42 | --max_src_len 400 \ 43 | --max_tgt_len 30 \ 44 | --emsize 512 \ 45 | --fix_embeddings False \ 46 | --src_vocab_size 50000 \ 47 | --tgt_vocab_size 30000 \ 48 | --share_decoder_embeddings True \ 49 | --max_examples -1 \ 50 | --batch_size 32 \ 51 | --test_batch_size 64 \ 52 | --num_epochs 200 \ 53 | --model_type transformer \ 54 | --num_head 8 \ 55 | --d_k 64 \ 56 | --d_v 64 \ 57 | --d_ff 2048 \ 58 | --src_pos_emb False \ 59 | --tgt_pos_emb True \ 60 | --max_relative_pos 32 \ 61 | --use_neg_dist True \ 62 | --nlayers 6 \ 63 | --trans_drop 0.2 \ 64 | --dropout_emb 0.2 \ 65 | --dropout 0.2 \ 66 | --copy_attn True \ 67 | --early_stop 20 \ 68 | --warmup_steps 0 \ 69 | --optimizer adam \ 70 | --learning_rate 0.0001 \ 71 | --lr_decay 0.99 \ 72 | --valid_metric bleu \ 73 | --checkpoint True \ 74 | --split_decoder False 75 | } 76 | 77 | 78 | function test () { 79 | 80 | echo "============TESTING============" 81 | 82 | RGPU=$1 83 | MODEL_NAME=$2 84 | 85 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/train.py \ 86 | --only_test True \ 87 | --data_workers 5 \ 88 | --dataset_name $DATASET \ 89 | --data_dir ${DATA_DIR}/ \ 90 | --model_dir $MODEL_DIR \ 91 | --model_name $MODEL_NAME \ 92 | --dev_src test/code.${CODE_EXTENSION} \ 93 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 94 | --uncase True \ 95 | --max_src_len 400 \ 96 | --max_tgt_len 30 \ 97 | --max_examples -1 \ 98 | --test_batch_size 64 99 | 100 | } 101 | 102 | function beam_search () { 103 | 104 | echo "============Beam Search TESTING============" 105 | 106 | RGPU=$1 107 | MODEL_NAME=$2 108 | 109 | PYTHONPATH=$SRC_DIR CUDA_VISIBLE_DEVICES=$RGPU python -W ignore ${SRC_DIR}/main/test.py \ 110 | --data_workers 5 \ 111 | --dataset_name $DATASET \ 112 | --data_dir ${DATA_DIR}/ \ 113 | --model_dir $MODEL_DIR \ 114 | --model_name $MODEL_NAME \ 115 | --dev_src test/code.${CODE_EXTENSION} \ 116 | --dev_tgt test/javadoc.${JAVADOC_EXTENSION} \ 117 | --uncase True \ 118 | --max_examples -1 \ 119 | --max_src_len 400 \ 120 | --max_tgt_len 30 \ 121 | --test_batch_size 64 \ 122 | --beam_size 4 \ 123 | --n_best 1 \ 124 | --block_ngram_repeat 3 \ 125 | --stepwise_penalty False \ 126 | --coverage_penalty none \ 127 | --length_penalty none \ 128 | --beta 0 \ 129 | --gamma 0 \ 130 | --replace_unk 131 | 132 | } 133 | 134 | train $1 $2 135 | test $1 $2 136 | beam_search $1 $2 137 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from setuptools import setup, find_packages 7 | 8 | with open('README.md') as f: 9 | readme = f.read() 10 | 11 | with open('LICENSE') as f: 12 | license = f.read() 13 | 14 | with open('requirements.txt') as f: 15 | reqs = f.read() 16 | 17 | setup( 18 | name='c2nl', 19 | version='0.1.0', 20 | description='Code to Natural Language Generation', 21 | long_description=readme, 22 | license=license, 23 | python_requires='>=3.6', 24 | packages=find_packages(exclude=('data')), 25 | install_requires=reqs.strip().split('\n'), 26 | ) 27 | --------------------------------------------------------------------------------