├── .gitignore ├── Bi-selective Encoding ├── LICENSE.md ├── onmt │ ├── __init__.py │ ├── decoders │ │ ├── __init__.py │ │ ├── cnn_decoder.py │ │ ├── decoder.py │ │ ├── ensemble.py │ │ └── transformer.py │ ├── encoders │ │ ├── __init__.py │ │ ├── audio_encoder.py │ │ ├── biset.py │ │ ├── cnn_encoder.py │ │ ├── encoder.py │ │ ├── image_encoder.py │ │ ├── mean_encoder.py │ │ └── transformer.py │ ├── inputters │ │ ├── __init__.py │ │ ├── audio_dataset.py │ │ ├── dataset_base.py │ │ ├── image_dataset.py │ │ ├── inputter.py │ │ └── text_dataset.py │ ├── model_builder.py │ ├── models │ │ ├── __init__.py │ │ ├── model.py │ │ ├── model_saver.py │ │ ├── sru.py │ │ └── stacked_rnn.py │ ├── modules │ │ ├── __init__.py │ │ ├── average_attn.py │ │ ├── conv_multi_step_attention.py │ │ ├── copy_generator.py │ │ ├── embeddings.py │ │ ├── gate.py │ │ ├── global_attention.py │ │ ├── multi_headed_attn.py │ │ ├── position_ffn.py │ │ ├── sparse_activations.py │ │ ├── sparse_losses.py │ │ ├── structured_attention.py │ │ ├── util_class.py │ │ └── weight_norm.py │ ├── opts.py │ ├── tests │ │ ├── __init__.py │ │ ├── output_hyp.txt │ │ ├── pull_request_chk.sh │ │ ├── rebuild_test_models.sh │ │ ├── test_attention.py │ │ ├── test_model.pt │ │ ├── test_model2.pt │ │ ├── test_models.py │ │ ├── test_models.sh │ │ ├── test_preprocess.py │ │ └── test_simple.py │ ├── train_multi.py │ ├── train_single.py │ ├── trainer.py │ ├── translate │ │ ├── __init__.py │ │ ├── beam.py │ │ ├── penalties.py │ │ ├── translation.py │ │ ├── translation_server.py │ │ └── translator.py │ └── utils │ │ ├── __init__.py │ │ ├── cnn_factory.py │ │ ├── distributed.py │ │ ├── logging.py │ │ ├── loss.py │ │ ├── misc.py │ │ ├── optimizers.py │ │ ├── report_manager.py │ │ ├── rnn_factory.py │ │ └── statistics.py ├── preprocess.py ├── requirements.opt.txt ├── requirements.txt ├── server.py ├── setup.py ├── tools │ ├── README.md │ ├── apply_bpe.py │ ├── average_models.py │ ├── bpe_pipeline.sh │ ├── detokenize.perl │ ├── embeddings_to_torch.py │ ├── extract_embeddings.py │ ├── learn_bpe.py │ ├── multi-bleu-detok.perl │ ├── multi-bleu.perl │ ├── nonbreaking_prefixes │ │ ├── README.txt │ │ ├── nonbreaking_prefix.ca │ │ ├── nonbreaking_prefix.cs │ │ ├── nonbreaking_prefix.de │ │ ├── nonbreaking_prefix.el │ │ ├── nonbreaking_prefix.en │ │ ├── nonbreaking_prefix.es │ │ ├── nonbreaking_prefix.fi │ │ ├── nonbreaking_prefix.fr │ │ ├── nonbreaking_prefix.ga │ │ ├── nonbreaking_prefix.hu │ │ ├── nonbreaking_prefix.is │ │ ├── nonbreaking_prefix.it │ │ ├── nonbreaking_prefix.lt │ │ ├── nonbreaking_prefix.lv │ │ ├── nonbreaking_prefix.nl │ │ ├── nonbreaking_prefix.pl │ │ ├── nonbreaking_prefix.ro │ │ ├── nonbreaking_prefix.ru │ │ ├── nonbreaking_prefix.sk │ │ ├── nonbreaking_prefix.sl │ │ ├── nonbreaking_prefix.sv │ │ ├── nonbreaking_prefix.ta │ │ ├── nonbreaking_prefix.yue │ │ └── nonbreaking_prefix.zh │ ├── release_model.py │ ├── test_rouge.py │ └── tokenizer.perl ├── train.py └── translate.py ├── FastRerank ├── PyRouge │ ├── 1775.txt │ ├── LICENSE │ ├── README.md │ ├── Rouge │ │ ├── Rouge.py │ │ ├── Rouge_current.py │ │ └── __init__.py │ ├── compute.py │ ├── dev.out.102 │ ├── input-word.txt │ ├── ref-word.txt │ ├── ref.txt │ ├── some.1.txt │ ├── task1_ref0.txt │ └── valid.title.filter.txt ├── config.py ├── main.py ├── model.py └── preprocess.py ├── LICENSE ├── README.md └── Retrieve └── src └── com └── wk └── lucene ├── Constants.java ├── Indexer.java └── Searcher.java /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | -------------------------------------------------------------------------------- /Bi-selective Encoding/LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017-Present OpenNMT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/__init__.py: -------------------------------------------------------------------------------- 1 | """ Main entry point of the ONMT library """ 2 | from __future__ import division, print_function 3 | 4 | import onmt.inputters 5 | import onmt.encoders 6 | import onmt.decoders 7 | import onmt.models 8 | import onmt.utils 9 | import onmt.modules 10 | from onmt.trainer import Trainer 11 | import sys 12 | import onmt.utils.optimizers 13 | onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer 14 | sys.modules["onmt.Optim"] = onmt.utils.optimizers 15 | 16 | # For Flake 17 | __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, 18 | onmt.utils, onmt.modules, "Trainer"] 19 | 20 | __version__ = "0.2.0" 21 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining decoders.""" 2 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/decoders/cnn_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of the CNN Decoder part of 3 | "Convolutional Sequence to Sequence Learning" 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | import onmt.modules 9 | from onmt.decoders.decoder import DecoderState 10 | from onmt.utils.misc import aeq 11 | from onmt.utils.cnn_factory import shape_transform, GatedConv 12 | 13 | SCALE_WEIGHT = 0.5 ** 0.5 14 | 15 | 16 | class CNNDecoder(nn.Module): 17 | """ 18 | Decoder built on CNN, based on :cite:`DBLP:journals/corr/GehringAGYD17`. 19 | 20 | 21 | Consists of residual convolutional layers, with ConvMultiStepAttention. 22 | """ 23 | 24 | def __init__(self, num_layers, hidden_size, attn_type, 25 | copy_attn, cnn_kernel_width, dropout, embeddings): 26 | super(CNNDecoder, self).__init__() 27 | 28 | # Basic attributes. 29 | self.decoder_type = 'cnn' 30 | self.num_layers = num_layers 31 | self.hidden_size = hidden_size 32 | self.cnn_kernel_width = cnn_kernel_width 33 | self.embeddings = embeddings 34 | self.dropout = dropout 35 | 36 | # Build the CNN. 37 | input_size = self.embeddings.embedding_size 38 | self.linear = nn.Linear(input_size, self.hidden_size) 39 | self.conv_layers = nn.ModuleList() 40 | for _ in range(self.num_layers): 41 | self.conv_layers.append( 42 | GatedConv(self.hidden_size, self.cnn_kernel_width, 43 | self.dropout, True)) 44 | 45 | self.attn_layers = nn.ModuleList() 46 | for _ in range(self.num_layers): 47 | self.attn_layers.append( 48 | onmt.modules.ConvMultiStepAttention(self.hidden_size)) 49 | 50 | # CNNDecoder has its own attention mechanism. 51 | # Set up a separated copy attention layer, if needed. 52 | self._copy = False 53 | if copy_attn: 54 | self.copy_attn = onmt.modules.GlobalAttention( 55 | hidden_size, attn_type=attn_type) 56 | self._copy = True 57 | 58 | def forward(self, tgt, memory_bank, state, memory_lengths=None, step=None): 59 | """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" 60 | # NOTE: memory_lengths is only here for compatibility reasons 61 | # with onmt.modules.RNNDecoderBase.forward() 62 | # CHECKS 63 | assert isinstance(state, CNNDecoderState) 64 | _, tgt_batch, _ = tgt.size() 65 | _, contxt_batch, _ = memory_bank.size() 66 | aeq(tgt_batch, contxt_batch) 67 | # END CHECKS 68 | 69 | if state.previous_input is not None: 70 | tgt = torch.cat([state.previous_input, tgt], 0) 71 | 72 | # Initialize return variables. 73 | outputs = [] 74 | attns = {"std": []} 75 | assert not self._copy, "Copy mechanism not yet tested in conv2conv" 76 | if self._copy: 77 | attns["copy"] = [] 78 | 79 | emb = self.embeddings(tgt) 80 | assert emb.dim() == 3 # len x batch x embedding_dim 81 | 82 | tgt_emb = emb.transpose(0, 1).contiguous() 83 | # The output of CNNEncoder. 84 | src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() 85 | # The combination of output of CNNEncoder and source embeddings. 86 | src_memory_bank_c = state.init_src.transpose(0, 1).contiguous() 87 | 88 | # Run the forward pass of the CNNDecoder. 89 | emb_reshape = tgt_emb.contiguous().view( 90 | tgt_emb.size(0) * tgt_emb.size(1), -1) 91 | linear_out = self.linear(emb_reshape) 92 | x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) 93 | x = shape_transform(x) 94 | 95 | pad = torch.zeros(x.size(0), x.size(1), 96 | self.cnn_kernel_width - 1, 1) 97 | 98 | pad = pad.type_as(x) 99 | base_target_emb = x 100 | 101 | for conv, attention in zip(self.conv_layers, self.attn_layers): 102 | new_target_input = torch.cat([pad, x], 2) 103 | out = conv(new_target_input) 104 | c, attn = attention(base_target_emb, out, 105 | src_memory_bank_t, src_memory_bank_c) 106 | x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT 107 | output = x.squeeze(3).transpose(1, 2) 108 | 109 | # Process the result and update the attentions. 110 | outputs = output.transpose(0, 1).contiguous() 111 | if state.previous_input is not None: 112 | outputs = outputs[state.previous_input.size(0):] 113 | attn = attn[:, state.previous_input.size(0):].squeeze() 114 | attn = torch.stack([attn]) 115 | attns["std"] = attn 116 | if self._copy: 117 | attns["copy"] = attn 118 | 119 | # Update the state. 120 | state.update_state(tgt) 121 | 122 | return outputs, state, attns 123 | 124 | def init_decoder_state(self, _, memory_bank, enc_hidden, with_cache=False): 125 | """ 126 | Init decoder state. 127 | """ 128 | return CNNDecoderState(memory_bank, enc_hidden) 129 | 130 | 131 | class CNNDecoderState(DecoderState): 132 | """ 133 | Init CNN decoder state. 134 | """ 135 | 136 | def __init__(self, memory_bank, enc_hidden): 137 | self.init_src = (memory_bank + enc_hidden) * SCALE_WEIGHT 138 | self.previous_input = None 139 | 140 | @property 141 | def _all(self): 142 | """ 143 | Contains attributes that need to be updated in self.beam_update(). 144 | """ 145 | return (self.previous_input,) 146 | 147 | def detach(self): 148 | self.previous_input = self.previous_input.detach() 149 | 150 | def update_state(self, new_input): 151 | """ Called for every decoder forward pass. """ 152 | self.previous_input = new_input 153 | 154 | def repeat_beam_size_times(self, beam_size): 155 | """ Repeat beam_size times along batch dimension. """ 156 | self.init_src = self.init_src.data.repeat(1, beam_size, 1) 157 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/decoders/ensemble.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ensemble decoding. 3 | 4 | Decodes using multiple models simultaneously, 5 | combining their prediction distributions by averaging. 6 | All models in the ensemble must share a target vocabulary. 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from onmt.decoders.decoder import DecoderState 13 | from onmt.encoders.encoder import EncoderBase 14 | from onmt.models import NMTModel 15 | import onmt.model_builder 16 | 17 | 18 | class EnsembleDecoderState(DecoderState): 19 | """ Dummy DecoderState that wraps a tuple of real DecoderStates """ 20 | def __init__(self, model_decoder_states): 21 | self.model_decoder_states = tuple(model_decoder_states) 22 | 23 | def beam_update(self, idx, positions, beam_size): 24 | for model_state in self.model_decoder_states: 25 | model_state.beam_update(idx, positions, beam_size) 26 | 27 | def repeat_beam_size_times(self, beam_size): 28 | """ Repeat beam_size times along batch dimension. """ 29 | for model_state in self.model_decoder_states: 30 | model_state.repeat_beam_size_times(beam_size) 31 | 32 | def __getitem__(self, index): 33 | return self.model_decoder_states[index] 34 | 35 | 36 | class EnsembleDecoderOutput(object): 37 | """ Wrapper around multiple decoder final hidden states """ 38 | def __init__(self, model_outputs): 39 | self.model_outputs = tuple(model_outputs) 40 | 41 | def squeeze(self, dim=None): 42 | """ 43 | Delegate squeeze to avoid modifying 44 | :obj:`Translator.translate_batch()` 45 | """ 46 | return EnsembleDecoderOutput([ 47 | x.squeeze(dim) for x in self.model_outputs]) 48 | 49 | def __getitem__(self, index): 50 | return self.model_outputs[index] 51 | 52 | 53 | class EnsembleEncoder(EncoderBase): 54 | """ Dummy Encoder that delegates to individual real Encoders """ 55 | def __init__(self, model_encoders): 56 | super(EnsembleEncoder, self).__init__() 57 | self.model_encoders = nn.ModuleList(list(model_encoders)) 58 | 59 | def forward(self, src, lengths=None): 60 | enc_hidden, memory_bank = zip(*[ 61 | model_encoder.forward(src, lengths) 62 | for model_encoder in self.model_encoders]) 63 | return enc_hidden, memory_bank 64 | 65 | 66 | class EnsembleDecoder(nn.Module): 67 | """ Dummy Decoder that delegates to individual real Decoders """ 68 | def __init__(self, model_decoders): 69 | super(EnsembleDecoder, self).__init__() 70 | self.model_decoders = nn.ModuleList(list(model_decoders)) 71 | 72 | def forward(self, tgt, memory_bank, state, memory_lengths=None, 73 | step=None): 74 | """ See :obj:`RNNDecoderBase.forward()` """ 75 | # Memory_lengths is a single tensor shared between all models. 76 | # This assumption will not hold if Translator is modified 77 | # to calculate memory_lengths as something other than the length 78 | # of the input. 79 | outputs, states, attns = zip(*[ 80 | model_decoder.forward( 81 | tgt, memory_bank[i], state[i], memory_lengths, step=step) 82 | for (i, model_decoder) 83 | in enumerate(self.model_decoders)]) 84 | mean_attns = self.combine_attns(attns) 85 | return (EnsembleDecoderOutput(outputs), 86 | EnsembleDecoderState(states), 87 | mean_attns) 88 | 89 | def combine_attns(self, attns): 90 | result = {} 91 | for key in attns[0].keys(): 92 | result[key] = torch.stack([attn[key] for attn in attns]).mean(0) 93 | return result 94 | 95 | def init_decoder_state(self, src, memory_bank, enc_hidden): 96 | """ See :obj:`RNNDecoderBase.init_decoder_state()` """ 97 | return EnsembleDecoderState( 98 | [model_decoder.init_decoder_state(src, 99 | memory_bank[i], 100 | enc_hidden[i]) 101 | for (i, model_decoder) in enumerate(self.model_decoders)]) 102 | 103 | 104 | class EnsembleGenerator(nn.Module): 105 | """ 106 | Dummy Generator that delegates to individual real Generators, 107 | and then averages the resulting target distributions. 108 | """ 109 | def __init__(self, model_generators): 110 | self.model_generators = tuple(model_generators) 111 | super(EnsembleGenerator, self).__init__() 112 | 113 | def forward(self, hidden): 114 | """ 115 | Compute a distribution over the target dictionary 116 | by averaging distributions from models in the ensemble. 117 | All models in the ensemble must share a target vocabulary. 118 | """ 119 | distributions = [model_generator.forward(hidden[i]) 120 | for (i, model_generator) 121 | in enumerate(self.model_generators)] 122 | return torch.stack(distributions).mean(0) 123 | 124 | 125 | class EnsembleModel(NMTModel): 126 | """ Dummy NMTModel wrapping individual real NMTModels """ 127 | def __init__(self, models): 128 | encoder = EnsembleEncoder(model.encoder for model in models) 129 | decoder = EnsembleDecoder(model.decoder for model in models) 130 | super(EnsembleModel, self).__init__(encoder, decoder) 131 | self.generator = EnsembleGenerator(model.generator for model in models) 132 | self.models = nn.ModuleList(models) 133 | 134 | 135 | def load_test_model(opt, dummy_opt): 136 | """ Read in multiple models for ensemble """ 137 | shared_fields = None 138 | shared_model_opt = None 139 | models = [] 140 | for model_path in opt.models: 141 | fields, model, model_opt = \ 142 | onmt.model_builder.load_test_model(opt, 143 | dummy_opt, 144 | model_path=model_path) 145 | if shared_fields is None: 146 | shared_fields = fields 147 | else: 148 | for key, field in fields.items(): 149 | if field is not None and 'vocab' in field.__dict__: 150 | assert field.vocab.stoi == shared_fields[key].vocab.stoi, \ 151 | 'Ensemble models must use the same preprocessed data' 152 | models.append(model) 153 | if shared_model_opt is None: 154 | shared_model_opt = model_opt 155 | ensemble_model = EnsembleModel(models) 156 | return shared_fields, ensemble_model, shared_model_opt 157 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining encoders.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.encoders.transformer import TransformerEncoder 4 | from onmt.encoders.biset import RNNEncoder 5 | from onmt.encoders.cnn_encoder import CNNEncoder 6 | from onmt.encoders.mean_encoder import MeanEncoder 7 | 8 | __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", 9 | "MeanEncoder"] 10 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/audio_encoder.py: -------------------------------------------------------------------------------- 1 | """ Audio encoder """ 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class AudioEncoder(nn.Module): 8 | """ 9 | A simple encoder convolutional -> recurrent neural network for 10 | audio input. 11 | 12 | Args: 13 | num_layers (int): number of encoder layers. 14 | bidirectional (bool): bidirectional encoder. 15 | rnn_size (int): size of hidden states of the rnn. 16 | dropout (float): dropout probablity. 17 | sample_rate (float): input spec 18 | window_size (int): input spec 19 | 20 | """ 21 | 22 | def __init__(self, num_layers, bidirectional, rnn_size, dropout, 23 | sample_rate, window_size): 24 | super(AudioEncoder, self).__init__() 25 | self.num_layers = num_layers 26 | self.num_directions = 2 if bidirectional else 1 27 | self.hidden_size = rnn_size 28 | 29 | self.layer1 = nn.Conv2d(1, 32, kernel_size=(41, 11), 30 | padding=(0, 10), stride=(2, 2)) 31 | self.batch_norm1 = nn.BatchNorm2d(32) 32 | self.layer2 = nn.Conv2d(32, 32, kernel_size=(21, 11), 33 | padding=(0, 0), stride=(2, 1)) 34 | self.batch_norm2 = nn.BatchNorm2d(32) 35 | 36 | input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 37 | input_size = int(math.floor(input_size - 41) / 2 + 1) 38 | input_size = int(math.floor(input_size - 21) / 2 + 1) 39 | input_size *= 32 40 | self.rnn = nn.LSTM(input_size, rnn_size, 41 | num_layers=num_layers, 42 | dropout=dropout, 43 | bidirectional=bidirectional) 44 | 45 | def load_pretrained_vectors(self, opt): 46 | """ Pass in needed options only when modify function definition.""" 47 | pass 48 | 49 | def forward(self, src, lengths=None): 50 | "See :obj:`onmt.encoders.encoder.EncoderBase.forward()`" 51 | # (batch_size, 1, nfft, t) 52 | # layer 1 53 | src = self.batch_norm1(self.layer1(src[:, :, :, :])) 54 | 55 | # (batch_size, 32, nfft/2, t/2) 56 | src = F.hardtanh(src, 0, 20, inplace=True) 57 | 58 | # (batch_size, 32, nfft/2/2, t/2) 59 | # layer 2 60 | src = self.batch_norm2(self.layer2(src)) 61 | 62 | # (batch_size, 32, nfft/2/2, t/2) 63 | src = F.hardtanh(src, 0, 20, inplace=True) 64 | 65 | batch_size = src.size(0) 66 | length = src.size(3) 67 | src = src.view(batch_size, -1, length) 68 | src = src.transpose(0, 2).transpose(1, 2) 69 | 70 | output, hidden = self.rnn(src) 71 | 72 | return hidden, output 73 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch.nn as nn 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | from onmt.utils.cnn_factory import shape_transform, StackedCNN 8 | 9 | SCALE_WEIGHT = 0.5 ** 0.5 10 | 11 | 12 | class CNNEncoder(EncoderBase): 13 | """ 14 | Encoder built on CNN based on 15 | :cite:`DBLP:journals/corr/GehringAGYD17`. 16 | """ 17 | 18 | def __init__(self, num_layers, hidden_size, 19 | cnn_kernel_width, dropout, embeddings): 20 | super(CNNEncoder, self).__init__() 21 | 22 | self.embeddings = embeddings 23 | input_size = embeddings.embedding_size 24 | self.linear = nn.Linear(input_size, hidden_size) 25 | self.cnn = StackedCNN(num_layers, hidden_size, 26 | cnn_kernel_width, dropout) 27 | 28 | def forward(self, input, lengths=None, hidden=None): 29 | """ See :obj:`onmt.modules.EncoderBase.forward()`""" 30 | self._check_args(input, lengths, hidden) 31 | 32 | emb = self.embeddings(input) 33 | # s_len, batch, emb_dim = emb.size() 34 | 35 | emb = emb.transpose(0, 1).contiguous() 36 | emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) 37 | emb_remap = self.linear(emb_reshape) 38 | emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) 39 | emb_remap = shape_transform(emb_remap) 40 | out = self.cnn(emb_remap) 41 | 42 | return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ 43 | out.squeeze(3).transpose(0, 1).contiguous() 44 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for encoders and generic multi encoders.""" 2 | 3 | from __future__ import division 4 | 5 | import torch.nn as nn 6 | 7 | from onmt.utils.misc import aeq 8 | 9 | 10 | class EncoderBase(nn.Module): 11 | """ 12 | Base encoder class. Specifies the interface used by different encoder types 13 | and required by :obj:`onmt.Models.NMTModel`. 14 | 15 | .. mermaid:: 16 | 17 | graph BT 18 | A[Input] 19 | subgraph RNN 20 | C[Pos 1] 21 | D[Pos 2] 22 | E[Pos N] 23 | end 24 | F[Memory_Bank] 25 | G[Final] 26 | A-->C 27 | A-->D 28 | A-->E 29 | C-->F 30 | D-->F 31 | E-->F 32 | E-->G 33 | """ 34 | 35 | def _check_args(self, src, lengths=None, hidden=None): 36 | _, n_batch, _ = src.size() 37 | if lengths is not None: 38 | n_batch_, = lengths.size() 39 | aeq(n_batch, n_batch_) 40 | 41 | def forward(self, src,template, lengths=None,template_lengths=None): 42 | """ 43 | Args: 44 | src (:obj:`LongTensor`): 45 | padded sequences of sparse indices `[src_len x batch x nfeat]` 46 | lengths (:obj:`LongTensor`): length of each sequence `[batch]` 47 | 48 | 49 | Returns: 50 | (tuple of :obj:`FloatTensor`, :obj:`FloatTensor`): 51 | * final encoder state, used to initialize decoder 52 | * memory bank for attention, `[src_len x batch x hidden]` 53 | """ 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/image_encoder.py: -------------------------------------------------------------------------------- 1 | """ Image Encoder """ 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | 7 | class ImageEncoder(nn.Module): 8 | """ 9 | A simple encoder convolutional -> recurrent neural network for 10 | image src. 11 | 12 | Args: 13 | num_layers (int): number of encoder layers. 14 | bidirectional (bool): bidirectional encoder. 15 | rnn_size (int): size of hidden states of the rnn. 16 | dropout (float): dropout probablity. 17 | """ 18 | 19 | def __init__(self, num_layers, bidirectional, rnn_size, dropout): 20 | super(ImageEncoder, self).__init__() 21 | self.num_layers = num_layers 22 | self.num_directions = 2 if bidirectional else 1 23 | self.hidden_size = rnn_size 24 | 25 | self.layer1 = nn.Conv2d(3, 64, kernel_size=(3, 3), 26 | padding=(1, 1), stride=(1, 1)) 27 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3), 28 | padding=(1, 1), stride=(1, 1)) 29 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3), 30 | padding=(1, 1), stride=(1, 1)) 31 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3), 32 | padding=(1, 1), stride=(1, 1)) 33 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3), 34 | padding=(1, 1), stride=(1, 1)) 35 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3), 36 | padding=(1, 1), stride=(1, 1)) 37 | 38 | self.batch_norm1 = nn.BatchNorm2d(256) 39 | self.batch_norm2 = nn.BatchNorm2d(512) 40 | self.batch_norm3 = nn.BatchNorm2d(512) 41 | 42 | src_size = 512 43 | self.rnn = nn.LSTM(src_size, rnn_size, 44 | num_layers=num_layers, 45 | dropout=dropout, 46 | bidirectional=bidirectional) 47 | self.pos_lut = nn.Embedding(1000, src_size) 48 | 49 | def load_pretrained_vectors(self, opt): 50 | """ Pass in needed options only when modify function definition.""" 51 | pass 52 | 53 | def forward(self, src, lengths=None): 54 | "See :obj:`onmt.encoders.encoder.EncoderBase.forward()`" 55 | 56 | batch_size = src.size(0) 57 | # (batch_size, 64, imgH, imgW) 58 | # layer 1 59 | src = F.relu(self.layer1(src[:, :, :, :]-0.5), True) 60 | 61 | # (batch_size, 64, imgH/2, imgW/2) 62 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 63 | 64 | # (batch_size, 128, imgH/2, imgW/2) 65 | # layer 2 66 | src = F.relu(self.layer2(src), True) 67 | 68 | # (batch_size, 128, imgH/2/2, imgW/2/2) 69 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 70 | 71 | # (batch_size, 256, imgH/2/2, imgW/2/2) 72 | # layer 3 73 | # batch norm 1 74 | src = F.relu(self.batch_norm1(self.layer3(src)), True) 75 | 76 | # (batch_size, 256, imgH/2/2, imgW/2/2) 77 | # layer4 78 | src = F.relu(self.layer4(src), True) 79 | 80 | # (batch_size, 256, imgH/2/2/2, imgW/2/2) 81 | src = F.max_pool2d(src, kernel_size=(1, 2), stride=(1, 2)) 82 | 83 | # (batch_size, 512, imgH/2/2/2, imgW/2/2) 84 | # layer 5 85 | # batch norm 2 86 | src = F.relu(self.batch_norm2(self.layer5(src)), True) 87 | 88 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 89 | src = F.max_pool2d(src, kernel_size=(2, 1), stride=(2, 1)) 90 | 91 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 92 | src = F.relu(self.batch_norm3(self.layer6(src)), True) 93 | 94 | # # (batch_size, 512, H, W) 95 | all_outputs = [] 96 | for row in range(src.size(2)): 97 | inp = src[:, :, row, :].transpose(0, 2)\ 98 | .transpose(1, 2) 99 | row_vec = torch.Tensor(batch_size).type_as(inp.data)\ 100 | .long().fill_(row) 101 | pos_emb = self.pos_lut(row_vec) 102 | with_pos = torch.cat( 103 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0) 104 | outputs, hidden_t = self.rnn(with_pos) 105 | all_outputs.append(outputs) 106 | out = torch.cat(all_outputs, 0) 107 | 108 | return hidden_t, out 109 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/mean_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from __future__ import division 3 | 4 | from onmt.encoders.encoder import EncoderBase 5 | 6 | 7 | class MeanEncoder(EncoderBase): 8 | """A trivial non-recurrent encoder. Simply applies mean pooling. 9 | 10 | Args: 11 | num_layers (int): number of replicated layers 12 | embeddings (:obj:`onmt.modules.Embeddings`): embedding module to use 13 | """ 14 | 15 | def __init__(self, num_layers, embeddings): 16 | super(MeanEncoder, self).__init__() 17 | self.num_layers = num_layers 18 | self.embeddings = embeddings 19 | 20 | def forward(self, src, lengths=None): 21 | "See :obj:`EncoderBase.forward()`" 22 | self._check_args(src, lengths) 23 | 24 | emb = self.embeddings(src) 25 | _, batch, emb_dim = emb.size() 26 | mean = emb.mean(0).expand(self.num_layers, batch, emb_dim) 27 | memory_bank = emb 28 | encoder_final = (mean, mean) 29 | return encoder_final, memory_bank 30 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | import onmt 8 | from onmt.encoders.encoder import EncoderBase 9 | # from onmt.utils.misc import aeq 10 | from onmt.modules.position_ffn import PositionwiseFeedForward 11 | 12 | 13 | class TransformerEncoderLayer(nn.Module): 14 | """ 15 | A single layer of the transformer encoder. 16 | 17 | Args: 18 | d_model (int): the dimension of keys/values/queries in 19 | MultiHeadedAttention, also the input size of 20 | the first-layer of the PositionwiseFeedForward. 21 | heads (int): the number of head for MultiHeadedAttention. 22 | d_ff (int): the second-layer of the PositionwiseFeedForward. 23 | dropout (float): dropout probability(0-1.0). 24 | """ 25 | 26 | def __init__(self, d_model, heads, d_ff, dropout): 27 | super(TransformerEncoderLayer, self).__init__() 28 | 29 | self.self_attn = onmt.modules.MultiHeadedAttention( 30 | heads, d_model, dropout=dropout) 31 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 32 | self.layer_norm = onmt.modules.LayerNorm(d_model) 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | def forward(self, inputs, mask): 36 | """ 37 | Transformer Encoder Layer definition. 38 | 39 | Args: 40 | inputs (`FloatTensor`): `[batch_size x src_len x model_dim]` 41 | mask (`LongTensor`): `[batch_size x src_len x src_len]` 42 | 43 | Returns: 44 | (`FloatTensor`): 45 | 46 | * outputs `[batch_size x src_len x model_dim]` 47 | """ 48 | input_norm = self.layer_norm(inputs) 49 | context, _ = self.self_attn(input_norm, input_norm, input_norm, 50 | mask=mask) 51 | out = self.dropout(context) + inputs 52 | return self.feed_forward(out) 53 | 54 | 55 | class TransformerEncoder(EncoderBase): 56 | """ 57 | The Transformer encoder from "Attention is All You Need". 58 | 59 | 60 | .. mermaid:: 61 | 62 | graph BT 63 | A[input] 64 | B[multi-head self-attn] 65 | C[feed forward] 66 | O[output] 67 | A --> B 68 | B --> C 69 | C --> O 70 | 71 | Args: 72 | num_layers (int): number of encoder layers 73 | d_model (int): size of the model 74 | heads (int): number of heads 75 | d_ff (int): size of the inner FF layer 76 | dropout (float): dropout parameters 77 | embeddings (:obj:`onmt.modules.Embeddings`): 78 | embeddings to use, should have positional encodings 79 | 80 | Returns: 81 | (`FloatTensor`, `FloatTensor`): 82 | 83 | * embeddings `[src_len x batch_size x model_dim]` 84 | * memory_bank `[src_len x batch_size x model_dim]` 85 | """ 86 | 87 | def __init__(self, num_layers, d_model, heads, d_ff, 88 | dropout, embeddings): 89 | super(TransformerEncoder, self).__init__() 90 | 91 | self.num_layers = num_layers 92 | self.embeddings = embeddings 93 | self.transformer = nn.ModuleList( 94 | [TransformerEncoderLayer(d_model, heads, d_ff, dropout) 95 | for _ in range(num_layers)]) 96 | self.layer_norm = onmt.modules.LayerNorm(d_model) 97 | 98 | def forward(self, src, lengths=None): 99 | """ See :obj:`EncoderBase.forward()`""" 100 | self._check_args(src, lengths) 101 | 102 | emb = self.embeddings(src) 103 | 104 | out = emb.transpose(0, 1).contiguous() 105 | words = src[:, :, 0].transpose(0, 1) 106 | w_batch, w_len = words.size() 107 | padding_idx = self.embeddings.word_padding_idx 108 | mask = words.data.eq(padding_idx).unsqueeze(1) \ 109 | .expand(w_batch, w_len, w_len) 110 | # Run the forward pass of every layer of the tranformer. 111 | for i in range(self.num_layers): 112 | out = self.transformer[i](out, mask) 113 | out = self.layer_norm(out) 114 | 115 | return emb, out.transpose(0, 1).contiguous() 116 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining inputters. 2 | 3 | Inputters implement the logic of transforming raw data to vectorized inputs, 4 | e.g., from a line of text to a sequence of embeddings. 5 | """ 6 | from onmt.inputters.inputter import collect_feature_vocabs, make_features, \ 7 | collect_features, get_num_features, \ 8 | load_fields_from_vocab, get_fields, \ 9 | save_fields_to_vocab, build_dataset, \ 10 | build_vocab, merge_vocabs, OrderedIterator 11 | from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, \ 12 | EOS_WORD, UNK 13 | from onmt.inputters.text_dataset import TextDataset, ShardedTextCorpusIterator 14 | from onmt.inputters.image_dataset import ImageDataset 15 | from onmt.inputters.audio_dataset import AudioDataset 16 | 17 | 18 | __all__ = ['PAD_WORD', 'BOS_WORD', 'EOS_WORD', 'UNK', 'DatasetBase', 19 | 'collect_feature_vocabs', 'make_features', 20 | 'collect_features', 'get_num_features', 21 | 'load_fields_from_vocab', 'get_fields', 22 | 'save_fields_to_vocab', 'build_dataset', 23 | 'build_vocab', 'merge_vocabs', 'OrderedIterator', 24 | 'TextDataset', 'ImageDataset', 'AudioDataset', 25 | 'ShardedTextCorpusIterator'] 26 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/inputters/dataset_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Base dataset class and constants 4 | """ 5 | from itertools import chain 6 | import torchtext 7 | 8 | import onmt 9 | 10 | PAD_WORD = '' 11 | UNK_WORD = '' 12 | UNK = 0 13 | BOS_WORD = '' 14 | EOS_WORD = '' 15 | 16 | 17 | class DatasetBase(torchtext.data.Dataset): 18 | """ 19 | A dataset basically supports iteration over all the examples 20 | it contains. We currently have 3 datasets inheriting this base 21 | for 3 types of corpus respectively: "text", "img", "audio". 22 | 23 | Internally it initializes an `torchtext.data.Dataset` object with 24 | the following attributes: 25 | 26 | `examples`: a sequence of `torchtext.data.Example` objects. 27 | `fields`: a dictionary associating str keys with `torchtext.data.Field` 28 | objects, and not necessarily having the same keys as the input fields. 29 | """ 30 | 31 | def __getstate__(self): 32 | return self.__dict__ 33 | 34 | def __setstate__(self, _d): 35 | self.__dict__.update(_d) 36 | 37 | def __reduce_ex__(self, proto): 38 | "This is a hack. Something is broken with torch pickle." 39 | return super(DatasetBase, self).__reduce_ex__() 40 | 41 | def load_fields(self, vocab_dict): 42 | """ Load fields from vocab.pt, and set the `fields` attribute. 43 | 44 | Args: 45 | vocab_dict (dict): a dict of loaded vocab from vocab.pt file. 46 | """ 47 | fields = onmt.inputters.inputter.load_fields_from_vocab( 48 | vocab_dict.items(), self.data_type) 49 | self.fields = dict([(k, f) for (k, f) in fields.items() 50 | if k in self.examples[0].__dict__]) 51 | 52 | @staticmethod 53 | def extract_text_features(tokens): 54 | """ 55 | Args: 56 | tokens: A list of tokens, where each token consists of a word, 57 | optionally followed by u"│"-delimited features. 58 | Returns: 59 | A sequence of words, a sequence of features, and num of features. 60 | """ 61 | if not tokens: 62 | return [], [], -1 63 | 64 | specials = [PAD_WORD, UNK_WORD, BOS_WORD, EOS_WORD] 65 | words = [] 66 | features = [] 67 | n_feats = None 68 | for token in tokens: 69 | split_token = token.split(u"│") 70 | assert all([special != split_token[0] for special in specials]), \ 71 | "Dataset cannot contain Special Tokens" 72 | 73 | if split_token[0]: 74 | words += [split_token[0]] 75 | features += [split_token[1:]] 76 | 77 | if n_feats is None: 78 | n_feats = len(split_token) 79 | else: 80 | assert len(split_token) == n_feats, \ 81 | "all words must have the same number of features" 82 | features = list(zip(*features)) 83 | return tuple(words), features, n_feats - 1 84 | 85 | # Below are helper functions for intra-class use only. 86 | 87 | def _join_dicts(self, *args): 88 | """ 89 | Args: 90 | dictionaries with disjoint keys. 91 | 92 | Returns: 93 | a single dictionary that has the union of these keys. 94 | """ 95 | return dict(chain(*[d.items() for d in args])) 96 | 97 | def _peek(self, seq): 98 | """ 99 | Args: 100 | seq: an iterator. 101 | 102 | Returns: 103 | the first thing returned by calling next() on the iterator 104 | and an iterator created by re-chaining that value to the beginning 105 | of the iterator. 106 | """ 107 | first = next(seq) 108 | return first, chain([first], seq) 109 | 110 | def _construct_example_fromlist(self, data, fields): 111 | """ 112 | Args: 113 | data: the data to be set as the value of the attributes of 114 | the to-be-created `Example`, associating with respective 115 | `Field` objects with same key. 116 | fields: a dict of `torchtext.data.Field` objects. The keys 117 | are attributes of the to-be-created `Example`. 118 | 119 | Returns: 120 | the created `Example` object. 121 | """ 122 | ex = torchtext.data.Example() 123 | for (name, field), val in zip(fields, data): 124 | if field is not None: 125 | setattr(ex, name, field.preprocess(val)) 126 | else: 127 | setattr(ex, name, val) 128 | return ex 129 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining models.""" 2 | from onmt.models.model_saver import build_model_saver, ModelSaver 3 | from onmt.models.model import NMTModel 4 | 5 | __all__ = ["build_model_saver", "ModelSaver", 6 | "NMTModel", "check_sru_requirement"] 7 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/models/model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch.nn as nn 3 | 4 | 5 | class NMTModel(nn.Module): 6 | """ 7 | Core trainable object in OpenNMT. Implements a trainable interface 8 | for a simple, generic encoder + decoder model. 9 | 10 | Args: 11 | encoder (:obj:`EncoderBase`): an encoder object 12 | decoder (:obj:`RNNDecoderBase`): a decoder object 13 | multi 0: 41 | self.checkpoint_queue = deque([], maxlen=keep_checkpoint) 42 | 43 | def maybe_save(self, step): 44 | """ 45 | Main entry point for model saver 46 | It wraps the `_save` method with checks and apply `keep_checkpoint` 47 | related logic 48 | """ 49 | if self.keep_checkpoint == 0: 50 | return 51 | 52 | if step % self.save_checkpoint_steps != 0: 53 | return 54 | 55 | chkpt, chkpt_name = self._save(step) 56 | 57 | if self.keep_checkpoint > 0: 58 | if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: 59 | todel = self.checkpoint_queue.popleft() 60 | self._rm_checkpoint(todel) 61 | self.checkpoint_queue.append(chkpt_name) 62 | 63 | def _save(self, step): 64 | """ Save a resumable checkpoint. 65 | 66 | Args: 67 | step (int): step number 68 | 69 | Returns: 70 | checkpoint: the saved object 71 | checkpoint_name: name (or path) of the saved checkpoint 72 | """ 73 | raise NotImplementedError() 74 | 75 | def _rm_checkpoint(self, name): 76 | """ 77 | Remove a checkpoint 78 | 79 | Args: 80 | name(str): name that indentifies the checkpoint 81 | (it may be a filepath) 82 | """ 83 | raise NotImplementedError() 84 | 85 | 86 | class ModelSaver(ModelSaverBase): 87 | """ 88 | Simple model saver to filesystem 89 | """ 90 | 91 | def __init__(self, base_path, model, model_opt, fields, optim, 92 | save_checkpoint_steps, keep_checkpoint=0): 93 | super(ModelSaver, self).__init__( 94 | base_path, model, model_opt, fields, optim, 95 | save_checkpoint_steps, keep_checkpoint) 96 | 97 | def _save(self, step): 98 | real_model = (self.model.module 99 | if isinstance(self.model, nn.DataParallel) 100 | else self.model) 101 | real_generator = (real_model.generator.module 102 | if isinstance(real_model.generator, nn.DataParallel) 103 | else real_model.generator) 104 | 105 | model_state_dict = real_model.state_dict() 106 | model_state_dict = {k: v for k, v in model_state_dict.items() 107 | if 'generator' not in k} 108 | generator_state_dict = real_generator.state_dict() 109 | checkpoint = { 110 | 'model': model_state_dict, 111 | 'generator': generator_state_dict, 112 | 'vocab': onmt.inputters.save_fields_to_vocab(self.fields), 113 | 'opt': self.model_opt, 114 | 'optim': self.optim, 115 | } 116 | 117 | logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) 118 | checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) 119 | torch.save(checkpoint, checkpoint_path) 120 | return checkpoint, checkpoint_path 121 | 122 | def _rm_checkpoint(self, name): 123 | os.remove(name) 124 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/models/stacked_rnn.py: -------------------------------------------------------------------------------- 1 | """ Implementation of ONMT RNN for Input Feeding Decoding """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class StackedLSTM(nn.Module): 7 | """ 8 | Our own implementation of stacked LSTM. 9 | Needed for the decoder, because we do input feeding. 10 | """ 11 | 12 | def __init__(self, num_layers, input_size, rnn_size, dropout): 13 | super(StackedLSTM, self).__init__() 14 | self.dropout = nn.Dropout(dropout) 15 | self.num_layers = num_layers 16 | self.layers = nn.ModuleList() 17 | 18 | for _ in range(num_layers): 19 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 20 | input_size = rnn_size 21 | 22 | def forward(self, input_feed, hidden): 23 | h_0, c_0 = hidden 24 | h_1, c_1 = [], [] 25 | for i, layer in enumerate(self.layers): 26 | h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) 27 | input_feed = h_1_i 28 | if i + 1 != self.num_layers: 29 | input_feed = self.dropout(input_feed) 30 | h_1 += [h_1_i] 31 | c_1 += [c_1_i] 32 | 33 | h_1 = torch.stack(h_1) 34 | c_1 = torch.stack(c_1) 35 | 36 | return input_feed, (h_1, c_1) 37 | 38 | 39 | class StackedGRU(nn.Module): 40 | """ 41 | Our own implementation of stacked GRU. 42 | Needed for the decoder, because we do input feeding. 43 | """ 44 | 45 | def __init__(self, num_layers, input_size, rnn_size, dropout): 46 | super(StackedGRU, self).__init__() 47 | self.dropout = nn.Dropout(dropout) 48 | self.num_layers = num_layers 49 | self.layers = nn.ModuleList() 50 | 51 | for _ in range(num_layers): 52 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 53 | input_size = rnn_size 54 | 55 | def forward(self, input_feed, hidden): 56 | h_1 = [] 57 | for i, layer in enumerate(self.layers): 58 | h_1_i = layer(input_feed, hidden[0][i]) 59 | input_feed = h_1_i 60 | if i + 1 != self.num_layers: 61 | input_feed = self.dropout(input_feed) 62 | h_1 += [h_1_i] 63 | 64 | h_1 = torch.stack(h_1) 65 | return input_feed, (h_1,) 66 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ Attention and normalization modules """ 2 | from onmt.modules.util_class import LayerNorm, Elementwise 3 | from onmt.modules.gate import context_gate_factory, ContextGate 4 | from onmt.modules.global_attention import GlobalAttention 5 | from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention 6 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLossCompute 7 | from onmt.modules.multi_headed_attn import MultiHeadedAttention 8 | from onmt.modules.embeddings import Embeddings, PositionalEncoding 9 | from onmt.modules.weight_norm import WeightNormConv2d 10 | from onmt.modules.average_attn import AverageAttention 11 | 12 | __all__ = ["LayerNorm", "Elementwise", "context_gate_factory", "ContextGate", 13 | "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", 14 | "CopyGeneratorLossCompute", "MultiHeadedAttention", "Embeddings", 15 | "PositionalEncoding", "WeightNormConv2d", "AverageAttention"] 16 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/average_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ Average Attention module """ 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules.position_ffn import PositionwiseFeedForward 8 | 9 | 10 | class AverageAttention(nn.Module): 11 | """ 12 | Average Attention module from 13 | "Accelerating Neural Transformer via an Average Attention Network" 14 | :cite:`https://arxiv.org/abs/1805.00631`. 15 | 16 | Args: 17 | model_dim (int): the dimension of keys/values/queries, 18 | must be divisible by head_count 19 | dropout (float): dropout parameter 20 | """ 21 | 22 | def __init__(self, model_dim, dropout=0.1): 23 | self.model_dim = model_dim 24 | 25 | super(AverageAttention, self).__init__() 26 | 27 | self.average_layer = PositionwiseFeedForward(model_dim, model_dim, 28 | dropout) 29 | self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) 30 | 31 | def cumulative_average_mask(self, batch_size, inputs_len): 32 | """ 33 | Builds the mask to compute the cumulative average as described in 34 | https://arxiv.org/abs/1805.00631 -- Figure 3 35 | 36 | Args: 37 | batch_size (int): batch size 38 | inputs_len (int): length of the inputs 39 | 40 | Returns: 41 | (`FloatTensor`): 42 | 43 | * A Tensor of shape `[batch_size x input_len x input_len]` 44 | """ 45 | 46 | triangle = torch.tril(torch.ones(inputs_len, inputs_len)) 47 | weights = torch.ones(1, inputs_len) / torch.arange( 48 | 1, inputs_len + 1, dtype=torch.float) 49 | mask = triangle * weights.transpose(0, 1) 50 | 51 | return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) 52 | 53 | def cumulative_average(self, inputs, mask_or_step, 54 | layer_cache=None, step=None): 55 | """ 56 | Computes the cumulative average as described in 57 | https://arxiv.org/abs/1805.00631 -- Equations (1) (5) (6) 58 | 59 | Args: 60 | inputs (`FloatTensor`): sequence to average 61 | `[batch_size x input_len x dimension]` 62 | mask_or_step: if cache is set, this is assumed 63 | to be the current step of the 64 | dynamic decoding. Otherwise, it is the mask matrix 65 | used to compute the cumulative average. 66 | cache: a dictionary containing the cumulative average 67 | of the previous step. 68 | """ 69 | if layer_cache is not None: 70 | step = mask_or_step 71 | device = inputs.device 72 | average_attention = (inputs + step * 73 | layer_cache["prev_g"].to(device)) / (step + 1) 74 | layer_cache["prev_g"] = average_attention 75 | return average_attention 76 | else: 77 | mask = mask_or_step 78 | return torch.matmul(mask, inputs) 79 | 80 | def forward(self, inputs, mask=None, layer_cache=None, step=None): 81 | """ 82 | Args: 83 | inputs (`FloatTensor`): `[batch_size x input_len x model_dim]` 84 | 85 | Returns: 86 | (`FloatTensor`, `FloatTensor`): 87 | 88 | * gating_outputs `[batch_size x 1 x model_dim]` 89 | * average_outputs average attention `[batch_size x 1 x model_dim]` 90 | """ 91 | batch_size = inputs.size(0) 92 | inputs_len = inputs.size(1) 93 | 94 | device = inputs.device 95 | average_outputs = self.cumulative_average( 96 | inputs, self.cumulative_average_mask(batch_size, 97 | inputs_len).to(device).float() 98 | if layer_cache is None else step, layer_cache=layer_cache) 99 | average_outputs = self.average_layer(average_outputs) 100 | gating_outputs = self.gating_layer(torch.cat((inputs, 101 | average_outputs), -1)) 102 | input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) 103 | gating_outputs = torch.sigmoid(input_gate) * inputs + \ 104 | torch.sigmoid(forget_gate) * average_outputs 105 | 106 | return gating_outputs, average_outputs 107 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/conv_multi_step_attention.py: -------------------------------------------------------------------------------- 1 | """ Multi Step Attention for CNN """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | SCALE_WEIGHT = 0.5 ** 0.5 9 | 10 | 11 | def seq_linear(linear, x): 12 | """ linear transform for 3-d tensor """ 13 | batch, hidden_size, length, _ = x.size() 14 | h = linear(torch.transpose(x, 1, 2).contiguous().view( 15 | batch * length, hidden_size)) 16 | return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) 17 | 18 | 19 | class ConvMultiStepAttention(nn.Module): 20 | """ 21 | 22 | Conv attention takes a key matrix, a value matrix and a query vector. 23 | Attention weight is calculated by key matrix with the query vector 24 | and sum on the value matrix. And the same operation is applied 25 | in each decode conv layer. 26 | 27 | """ 28 | 29 | def __init__(self, input_size): 30 | super(ConvMultiStepAttention, self).__init__() 31 | self.linear_in = nn.Linear(input_size, input_size) 32 | self.mask = None 33 | 34 | def apply_mask(self, mask): 35 | """ Apply mask """ 36 | self.mask = mask 37 | 38 | def forward(self, base_target_emb, input_from_dec, encoder_out_top, 39 | encoder_out_combine): 40 | """ 41 | Args: 42 | base_target_emb: target emb tensor 43 | input: output of decode conv 44 | encoder_out_t: the key matrix for calculation of attetion weight, 45 | which is the top output of encode conv 46 | encoder_out_combine: 47 | the value matrix for the attention-weighted sum, 48 | which is the combination of base emb and top output of encode 49 | 50 | """ 51 | # checks 52 | # batch, channel, height, width = base_target_emb.size() 53 | batch, _, height, _ = base_target_emb.size() 54 | # batch_, channel_, height_, width_ = input_from_dec.size() 55 | batch_, _, height_, _ = input_from_dec.size() 56 | aeq(batch, batch_) 57 | aeq(height, height_) 58 | 59 | # enc_batch, enc_channel, enc_height = encoder_out_top.size() 60 | enc_batch, _, enc_height = encoder_out_top.size() 61 | # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() 62 | enc_batch_, _, enc_height_ = encoder_out_combine.size() 63 | 64 | aeq(enc_batch, enc_batch_) 65 | aeq(enc_height, enc_height_) 66 | 67 | preatt = seq_linear(self.linear_in, input_from_dec) 68 | target = (base_target_emb + preatt) * SCALE_WEIGHT 69 | target = torch.squeeze(target, 3) 70 | target = torch.transpose(target, 1, 2) 71 | pre_attn = torch.bmm(target, encoder_out_top) 72 | 73 | if self.mask is not None: 74 | pre_attn.data.masked_fill_(self.mask, -float('inf')) 75 | 76 | pre_attn = pre_attn.transpose(0, 2) 77 | attn = F.softmax(pre_attn, dim=-1) 78 | attn = attn.transpose(0, 2).contiguous() 79 | context_output = torch.bmm( 80 | attn, torch.transpose(encoder_out_combine, 1, 2)) 81 | context_output = torch.transpose( 82 | torch.unsqueeze(context_output, 3), 1, 2) 83 | return context_output, attn 84 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/gate.py: -------------------------------------------------------------------------------- 1 | """ ContextGate module """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def context_gate_factory(gate_type, embeddings_size, decoder_size, 7 | attention_size, output_size): 8 | """Returns the correct ContextGate class""" 9 | 10 | gate_types = {'source': SourceContextGate, 11 | 'target': TargetContextGate, 12 | 'both': BothContextGate} 13 | 14 | assert gate_type in gate_types, "Not valid ContextGate type: {0}".format( 15 | gate_type) 16 | return gate_types[gate_type](embeddings_size, decoder_size, attention_size, 17 | output_size) 18 | 19 | 20 | class ContextGate(nn.Module): 21 | """ 22 | Context gate is a decoder module that takes as input the previous word 23 | embedding, the current decoder state and the attention state, and 24 | produces a gate. 25 | The gate can be used to select the input from the target side context 26 | (decoder state), from the source context (attention state) or both. 27 | """ 28 | 29 | def __init__(self, embeddings_size, decoder_size, 30 | attention_size, output_size): 31 | super(ContextGate, self).__init__() 32 | input_size = embeddings_size + decoder_size + attention_size 33 | self.gate = nn.Linear(input_size, output_size, bias=True) 34 | self.sig = nn.Sigmoid() 35 | self.source_proj = nn.Linear(attention_size, output_size) 36 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 37 | output_size) 38 | 39 | def forward(self, prev_emb, dec_state, attn_state): 40 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 41 | z = self.sig(self.gate(input_tensor)) 42 | proj_source = self.source_proj(attn_state) 43 | proj_target = self.target_proj( 44 | torch.cat((prev_emb, dec_state), dim=1)) 45 | return z, proj_source, proj_target 46 | 47 | 48 | class SourceContextGate(nn.Module): 49 | """Apply the context gate only to the source context""" 50 | 51 | def __init__(self, embeddings_size, decoder_size, 52 | attention_size, output_size): 53 | super(SourceContextGate, self).__init__() 54 | self.context_gate = ContextGate(embeddings_size, decoder_size, 55 | attention_size, output_size) 56 | self.tanh = nn.Tanh() 57 | 58 | def forward(self, prev_emb, dec_state, attn_state): 59 | z, source, target = self.context_gate( 60 | prev_emb, dec_state, attn_state) 61 | return self.tanh(target + z * source) 62 | 63 | 64 | class TargetContextGate(nn.Module): 65 | """Apply the context gate only to the target context""" 66 | 67 | def __init__(self, embeddings_size, decoder_size, 68 | attention_size, output_size): 69 | super(TargetContextGate, self).__init__() 70 | self.context_gate = ContextGate(embeddings_size, decoder_size, 71 | attention_size, output_size) 72 | self.tanh = nn.Tanh() 73 | 74 | def forward(self, prev_emb, dec_state, attn_state): 75 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 76 | return self.tanh(z * target + source) 77 | 78 | 79 | class BothContextGate(nn.Module): 80 | """Apply the context gate to both contexts""" 81 | 82 | def __init__(self, embeddings_size, decoder_size, 83 | attention_size, output_size): 84 | super(BothContextGate, self).__init__() 85 | self.context_gate = ContextGate(embeddings_size, decoder_size, 86 | attention_size, output_size) 87 | self.tanh = nn.Tanh() 88 | 89 | def forward(self, prev_emb, dec_state, attn_state): 90 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 91 | return self.tanh((1. - z) * target + z * source) 92 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Position feed-forward network from "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | import onmt 8 | 9 | 10 | class PositionwiseFeedForward(nn.Module): 11 | """ A two-layer Feed-Forward-Network with residual layer norm. 12 | 13 | Args: 14 | d_model (int): the size of input for the first-layer of the FFN. 15 | d_ff (int): the hidden layer size of the second-layer 16 | of the FNN. 17 | dropout (float): dropout probability(0-1.0). 18 | """ 19 | 20 | def __init__(self, d_model, d_ff, dropout=0.1): 21 | super(PositionwiseFeedForward, self).__init__() 22 | self.w_1 = nn.Linear(d_model, d_ff) 23 | self.w_2 = nn.Linear(d_ff, d_model) 24 | self.layer_norm = onmt.modules.LayerNorm(d_model) 25 | self.dropout_1 = nn.Dropout(dropout) 26 | self.relu = nn.ReLU() 27 | self.dropout_2 = nn.Dropout(dropout) 28 | 29 | def forward(self, x): 30 | """ 31 | Layer definition. 32 | 33 | Args: 34 | input: [ batch_size, input_len, model_dim ] 35 | 36 | 37 | Returns: 38 | output: [ batch_size, input_len, model_dim ] 39 | """ 40 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 41 | output = self.dropout_2(self.w_2(inter)) 42 | return output + x 43 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/sparse_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of sparsemax (Martins & Astudillo, 2016). See 3 | https://arxiv.org/pdf/1602.02068 for detailed description. 4 | """ 5 | 6 | import torch 7 | from torch.autograd import Function 8 | import torch.nn as nn 9 | 10 | 11 | def threshold_and_support(z, dim=0): 12 | """ 13 | z: any dimension 14 | dim: dimension along which to apply the sparsemax 15 | """ 16 | sorted_z, _ = torch.sort(z, descending=True, dim=dim) 17 | z_sum = sorted_z.cumsum(dim) - 1 # sort of a misnomer 18 | k = torch.arange(1, sorted_z.size(dim) + 1, device=z.device).float().view( 19 | torch.Size([-1] + [1] * (z.dim() - 1)) 20 | ).transpose(0, dim) 21 | support = k * sorted_z > z_sum 22 | 23 | k_z_indices = support.sum(dim=dim).unsqueeze(dim) 24 | k_z = k_z_indices.float() 25 | tau_z = z_sum.gather(dim, k_z_indices - 1) / k_z 26 | return tau_z, k_z 27 | 28 | 29 | class SparsemaxFunction(Function): 30 | 31 | @staticmethod 32 | def forward(ctx, input, dim=0): 33 | """ 34 | input (FloatTensor): any shape 35 | returns (FloatTensor): same shape with sparsemax computed on given dim 36 | """ 37 | ctx.dim = dim 38 | tau_z, k_z = threshold_and_support(input, dim=dim) 39 | output = torch.clamp(input - tau_z, min=0) 40 | ctx.save_for_backward(k_z, output) 41 | return output 42 | 43 | @staticmethod 44 | def backward(ctx, grad_output): 45 | k_z, output = ctx.saved_tensors 46 | dim = ctx.dim 47 | grad_input = grad_output.clone() 48 | grad_input[output == 0] = 0 49 | 50 | v_hat = (grad_input.sum(dim=dim) / k_z.squeeze()).unsqueeze(dim) 51 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 52 | return grad_input, None 53 | 54 | 55 | sparsemax = SparsemaxFunction.apply 56 | 57 | 58 | class Sparsemax(nn.Module): 59 | 60 | def __init__(self, dim=0): 61 | self.dim = dim 62 | super(Sparsemax, self).__init__() 63 | 64 | def forward(self, input): 65 | return sparsemax(input, self.dim) 66 | 67 | 68 | class LogSparsemax(nn.Module): 69 | 70 | def __init__(self, dim=0): 71 | self.dim = dim 72 | super(LogSparsemax, self).__init__() 73 | 74 | def forward(self, input): 75 | return torch.log(sparsemax(input, self.dim)) 76 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/sparse_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from onmt.modules.sparse_activations import threshold_and_support 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class SparsemaxLossFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, input, target): 12 | """ 13 | input (FloatTensor): n x num_classes 14 | target (LongTensor): n, the indices of the target classes 15 | """ 16 | input_batch, classes = input.size() 17 | target_batch = target.size(0) 18 | aeq(input_batch, target_batch) 19 | 20 | z_k = input.gather(1, target.unsqueeze(1)).squeeze() 21 | tau_z, support_size = threshold_and_support(input, dim=1) 22 | support = input > tau_z 23 | x = torch.where( 24 | support, input**2 - tau_z**2, 25 | torch.tensor(0.0, device=input.device) 26 | ).sum(dim=1) 27 | ctx.save_for_backward(input, target, tau_z) 28 | # clamping necessary because of numerical errors: loss should be lower 29 | # bounded by zero, but negative values near zero are possible without 30 | # the clamp 31 | return torch.clamp(x / 2 - z_k + 0.5, min=0.0) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | input, target, tau_z = ctx.saved_tensors 36 | sparsemax_out = torch.clamp(input - tau_z, min=0) 37 | delta = torch.zeros_like(sparsemax_out) 38 | delta.scatter_(1, target.unsqueeze(1), 1) 39 | return sparsemax_out - delta, None 40 | 41 | 42 | sparsemax_loss = SparsemaxLossFunction.apply 43 | 44 | 45 | class SparsemaxLoss(nn.Module): 46 | """ 47 | An implementation of sparsemax loss, first proposed in "From Softmax to 48 | Sparsemax: A Sparse Model of Attention and Multi-Label Classification" 49 | (Martins & Astudillo, 2016: https://arxiv.org/pdf/1602.02068). If using 50 | a sparse output layer, it is not possible to use negative log likelihood 51 | because the loss is infinite in the case the target is assigned zero 52 | probability. Inputs to SparsemaxLoss are arbitrary dense real-valued 53 | vectors (like in nn.CrossEntropyLoss), not probability vectors (like in 54 | nn.NLLLoss). 55 | """ 56 | 57 | def __init__(self, weight=None, ignore_index=-100, 58 | reduce=True, size_average=True): 59 | self.weight = weight 60 | self.ignore_index = ignore_index 61 | self.reduce = reduce 62 | self.size_average = size_average 63 | super(SparsemaxLoss, self).__init__() 64 | 65 | def forward(self, input, target): 66 | loss = sparsemax_loss(input, target) 67 | if self.ignore_index >= 0: 68 | ignored_positions = target == self.ignore_index 69 | size = float((target.size(0) - ignored_positions.sum()).item()) 70 | loss.masked_fill_(ignored_positions, 0.0) 71 | else: 72 | size = float(target.size(0)) 73 | if self.reduce: 74 | loss = loss.sum() 75 | if self.size_average: 76 | loss = loss / size 77 | return loss 78 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/structured_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.cuda 4 | from onmt.utils.logging import init_logger 5 | 6 | 7 | class MatrixTree(nn.Module): 8 | """Implementation of the matrix-tree theorem for computing marginals 9 | of non-projective dependency parsing. This attention layer is used 10 | in the paper "Learning Structured Text Representations." 11 | 12 | 13 | :cite:`DBLP:journals/corr/LiuL17d` 14 | """ 15 | 16 | def __init__(self, eps=1e-5): 17 | self.eps = eps 18 | super(MatrixTree, self).__init__() 19 | 20 | def forward(self, input): 21 | laplacian = input.exp() + self.eps 22 | output = input.clone() 23 | for b in range(input.size(0)): 24 | lap = laplacian[b].masked_fill( 25 | torch.eye(input.size(1)).cuda().ne(0), 0) 26 | lap = -lap + torch.diag(lap.sum(0)) 27 | # store roots on diagonal 28 | lap[0] = input[b].diag().exp() 29 | inv_laplacian = lap.inverse() 30 | 31 | factor = inv_laplacian.diag().unsqueeze(1)\ 32 | .expand_as(input[b]).transpose(0, 1) 33 | term1 = input[b].exp().mul(factor).clone() 34 | term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() 35 | term1[:, 0] = 0 36 | term2[0] = 0 37 | output[b] = term1 - term2 38 | roots_output = input[b].diag().exp().mul( 39 | inv_laplacian.transpose(0, 1)[0]) 40 | output[b] = output[b] + torch.diag(roots_output) 41 | return output 42 | 43 | 44 | if __name__ == "__main__": 45 | logger = init_logger('StructuredAttention.log') 46 | dtree = MatrixTree() 47 | q = torch.rand(1, 5, 5).cuda() 48 | marg = dtree.forward(q) 49 | logger.info(marg.sum(1)) 50 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/modules/util_class.py: -------------------------------------------------------------------------------- 1 | """ Misc classes """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LayerNorm(nn.Module): 7 | """ 8 | Layer Normalization class 9 | """ 10 | 11 | def __init__(self, features, eps=1e-6): 12 | super(LayerNorm, self).__init__() 13 | self.a_2 = nn.Parameter(torch.ones(features)) 14 | self.b_2 = nn.Parameter(torch.zeros(features)) 15 | self.eps = eps 16 | 17 | def forward(self, x): 18 | mean = x.mean(-1, keepdim=True) 19 | std = x.std(-1, keepdim=True) 20 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 21 | 22 | 23 | # At the moment this class is only used by embeddings.Embeddings look-up tables 24 | class Elementwise(nn.ModuleList): 25 | """ 26 | A simple network container. 27 | Parameters are a list of modules. 28 | Inputs are a 3d Tensor whose last dimension is the same length 29 | as the list. 30 | Outputs are the result of applying modules to inputs elementwise. 31 | An optional merge parameter allows the outputs to be reduced to a 32 | single Tensor. 33 | """ 34 | 35 | def __init__(self, merge=None, *args): 36 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 37 | self.merge = merge 38 | super(Elementwise, self).__init__(*args) 39 | 40 | def forward(self, inputs): 41 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 42 | assert len(self) == len(inputs_) 43 | outputs = [f(x) for f, x in zip(self, inputs_)] 44 | if self.merge == 'first': 45 | return outputs[0] 46 | elif self.merge == 'concat' or self.merge == 'mlp': 47 | return torch.cat(outputs, 2) 48 | elif self.merge == 'sum': 49 | return sum(outputs) 50 | else: 51 | return outputs 52 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InitialBug/BiSET/a697a3c61014281bbd83cd37ede29b1263c8832f/Bi-selective Encoding/onmt/tests/__init__.py -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/rebuild_test_models.sh: -------------------------------------------------------------------------------- 1 | # # Retrain the models used for CI. 2 | # # Should be done rarely, indicates a major breaking change. 3 | my_python=python 4 | 5 | ############### TEST regular RNN choose either -rnn_type LSTM / GRU / SRU and set input_feed 0 for SRU 6 | if true; then 7 | rm data/*.pt 8 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 9 | 10 | $my_python train.py -data data/data -save_model tmp -gpuid 0 -rnn_size 256 -word_vec_size 256 -layers 1 -train_steps 10000 -optim adam -learning_rate 0.001 -rnn_type LSTM -input_feed 0 11 | #-truncated_decoder 5 12 | #-label_smoothing 0.1 13 | 14 | mv tmp*e10.pt onmt/tests/test_model.pt 15 | rm tmp*.pt 16 | fi 17 | # 18 | # 19 | ############### TEST CNN 20 | if false; then 21 | rm data/*.pt 22 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 23 | 24 | $my_python train.py -data data/data -save_model /tmp/tmp -gpuid 0 -rnn_size 256 -word_vec_size 256 -layers 2 -train_steps 10000 -optim adam -learning_rate 0.001 -encoder_type cnn -decoder_type cnn 25 | 26 | 27 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 28 | 29 | rm /tmp/tmp*.pt 30 | fi 31 | # 32 | ################# MORPH DATA 33 | if true; then 34 | rm data/morph/*.pt 35 | $my_python preprocess.py -train_src data/morph/src.train -train_tgt data/morph/tgt.train -valid_src data/morph/src.valid -valid_tgt data/morph/tgt.valid -save_data data/morph/data 36 | 37 | $my_python train.py -data data/morph/data -save_model tmp -gpuid 0 -rnn_size 400 -word_vec_size 100 -layers 1 -train_steps 8000 -optim adam -learning_rate 0.001 38 | 39 | 40 | mv tmp*e8.pt onmt/tests/test_model2.pt 41 | 42 | rm tmp*.pt 43 | fi 44 | ############### TEST TRANSFORMER 45 | if false; then 46 | rm data/*.pt 47 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 -share_vocab 48 | 49 | 50 | $my_python train.py -data data/data -save_model /tmp/tmp -batch_type tokens -batch_size 1024 -accum_count 4 \ 51 | -layers 4 -rnn_size 256 -word_vec_size 256 -encoder_type transformer -decoder_type transformer -share_embedding \ 52 | -train_steps 10000 -gpuid 0 -max_generator_batches 4 -dropout 0.1 -normalization tokens \ 53 | -max_grad_norm 0 -optim adam -decay_method noam -learning_rate 2 -label_smoothing 0.1 \ 54 | -position_encoding -param_init 0 -warmup_steps 100 -param_init_glorot -adam_beta2 0.998 55 | # 56 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 57 | rm /tmp/tmp*.pt 58 | fi 59 | # 60 | if false; then 61 | $my_python translate.py -gpu 0 -model onmt/tests/test_model.pt \ 62 | -src data/src-val.txt -output onmt/tests/output_hyp.txt -beam 5 -batch_size 16 63 | 64 | fi 65 | 66 | 67 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/test_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Here come the tests for attention types and their compatibility 3 | """ 4 | import unittest 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | import onmt 9 | 10 | 11 | class TestAttention(unittest.TestCase): 12 | 13 | def test_masked_global_attention(self): 14 | 15 | source_lengths = torch.IntTensor([7, 3, 5, 2]) 16 | # illegal_weights_mask = torch.ByteTensor([ 17 | # [0, 0, 0, 0, 0, 0, 0], 18 | # [0, 0, 0, 1, 1, 1, 1], 19 | # [0, 0, 0, 0, 0, 1, 1], 20 | # [0, 0, 1, 1, 1, 1, 1]]) 21 | 22 | batch_size = source_lengths.size(0) 23 | dim = 20 24 | 25 | memory_bank = Variable(torch.randn(batch_size, 26 | source_lengths.max(), dim)) 27 | hidden = Variable(torch.randn(batch_size, dim)) 28 | 29 | attn = onmt.modules.GlobalAttention(dim) 30 | 31 | _, alignments = attn(hidden, memory_bank, 32 | memory_lengths=source_lengths) 33 | # TODO: fix for pytorch 0.3 34 | # illegal_weights = alignments.masked_select(illegal_weights_mask) 35 | 36 | # self.assertEqual(0.0, illegal_weights.data.sum()) 37 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/test_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InitialBug/BiSET/a697a3c61014281bbd83cd37ede29b1263c8832f/Bi-selective Encoding/onmt/tests/test_model.pt -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/test_model2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InitialBug/BiSET/a697a3c61014281bbd83cd37ede29b1263c8832f/Bi-selective Encoding/onmt/tests/test_model2.pt -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/test_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import copy 7 | import unittest 8 | import glob 9 | import os 10 | import codecs 11 | from collections import Counter 12 | 13 | import torchtext 14 | 15 | import onmt 16 | import onmt.inputters 17 | import onmt.opts 18 | import preprocess 19 | 20 | 21 | parser = argparse.ArgumentParser(description='preprocess.py') 22 | onmt.opts.preprocess_opts(parser) 23 | 24 | SAVE_DATA_PREFIX = 'data/test_preprocess' 25 | 26 | default_opts = [ 27 | '-data_type', 'text', 28 | '-train_src', 'data/src-train.txt', 29 | '-train_tgt', 'data/tgt-train.txt', 30 | '-valid_src', 'data/src-val.txt', 31 | '-valid_tgt', 'data/tgt-val.txt', 32 | '-save_data', SAVE_DATA_PREFIX 33 | ] 34 | 35 | opt = parser.parse_known_args(default_opts)[0] 36 | 37 | 38 | class TestData(unittest.TestCase): 39 | def __init__(self, *args, **kwargs): 40 | super(TestData, self).__init__(*args, **kwargs) 41 | self.opt = opt 42 | 43 | def dataset_build(self, opt): 44 | fields = onmt.inputters.get_fields("text", 0, 0) 45 | 46 | if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0: 47 | with codecs.open(opt.src_vocab, 'w', 'utf-8') as f: 48 | f.write('a\nb\nc\nd\ne\nf\n') 49 | if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0: 50 | with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f: 51 | f.write('a\nb\nc\nd\ne\nf\n') 52 | 53 | train_data_files = preprocess.build_save_dataset('train', fields, opt) 54 | 55 | preprocess.build_save_vocab(train_data_files, fields, opt) 56 | 57 | preprocess.build_save_dataset('valid', fields, opt) 58 | 59 | # Remove the generated *pt files. 60 | for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): 61 | os.remove(pt) 62 | if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab): 63 | os.remove(opt.src_vocab) 64 | if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab): 65 | os.remove(opt.tgt_vocab) 66 | 67 | def test_merge_vocab(self): 68 | va = torchtext.vocab.Vocab(Counter('abbccc')) 69 | vb = torchtext.vocab.Vocab(Counter('eeabbcccf')) 70 | 71 | merged = onmt.inputters.merge_vocabs([va, vb], 2) 72 | 73 | self.assertEqual(Counter({'c': 6, 'b': 4, 'a': 2, 'e': 2, 'f': 1}), 74 | merged.freqs) 75 | # 4 specicials + 2 words (since we pass 2 to merge_vocabs) 76 | self.assertEqual(6, len(merged.itos)) 77 | self.assertTrue('b' in merged.itos) 78 | 79 | 80 | def _add_test(param_setting, methodname): 81 | """ 82 | Adds a Test to TestData according to settings 83 | 84 | Args: 85 | param_setting: list of tuples of (param, setting) 86 | methodname: name of the method that gets called 87 | """ 88 | 89 | def test_method(self): 90 | if param_setting: 91 | opt = copy.deepcopy(self.opt) 92 | for param, setting in param_setting: 93 | setattr(opt, param, setting) 94 | else: 95 | opt = self.opt 96 | getattr(self, methodname)(opt) 97 | if param_setting: 98 | name = 'test_' + methodname + "_" + "_".join( 99 | str(param_setting).split()) 100 | else: 101 | name = 'test_' + methodname + '_standard' 102 | setattr(TestData, name, test_method) 103 | test_method.__name__ = name 104 | 105 | 106 | test_databuild = [[], 107 | [('src_vocab_size', 1), 108 | ('tgt_vocab_size', 1)], 109 | [('src_vocab_size', 10000), 110 | ('tgt_vocab_size', 10000)], 111 | [('src_seq_length', 1)], 112 | [('src_seq_length', 5000)], 113 | [('src_seq_length_trunc', 1)], 114 | [('src_seq_length_trunc', 5000)], 115 | [('tgt_seq_length', 1)], 116 | [('tgt_seq_length', 5000)], 117 | [('tgt_seq_length_trunc', 1)], 118 | [('tgt_seq_length_trunc', 5000)], 119 | [('shuffle', 0)], 120 | [('lower', True)], 121 | [('dynamic_dict', True)], 122 | [('share_vocab', True)], 123 | [('dynamic_dict', True), 124 | ('share_vocab', True)], 125 | [('dynamic_dict', True), 126 | ('max_shard_size', 500000)], 127 | [('src_vocab', '/tmp/src_vocab.txt'), 128 | ('tgt_vocab', '/tmp/tgt_vocab.txt')], 129 | ] 130 | 131 | for p in test_databuild: 132 | _add_test(p, 'dataset_build') 133 | 134 | # Test image preprocessing 135 | for p in copy.deepcopy(test_databuild): 136 | p.append(('data_type', 'img')) 137 | p.append(('src_dir', '/tmp/im2text/images')) 138 | p.append(('train_src', '/tmp/im2text/src-train-head.txt')) 139 | p.append(('train_tgt', '/tmp/im2text/tgt-train-head.txt')) 140 | p.append(('valid_src', '/tmp/im2text/src-val-head.txt')) 141 | p.append(('valid_tgt', '/tmp/im2text/tgt-val-head.txt')) 142 | _add_test(p, 'dataset_build') 143 | 144 | # Test audio preprocessing 145 | for p in copy.deepcopy(test_databuild): 146 | p.append(('data_type', 'audio')) 147 | p.append(('src_dir', '/tmp/speech/an4_dataset')) 148 | p.append(('train_src', '/tmp/speech/src-train-head.txt')) 149 | p.append(('train_tgt', '/tmp/speech/tgt-train-head.txt')) 150 | p.append(('valid_src', '/tmp/speech/src-val-head.txt')) 151 | p.append(('valid_tgt', '/tmp/speech/tgt-val-head.txt')) 152 | p.append(('sample_rate', 16000)) 153 | p.append(('window_size', 0.04)) 154 | p.append(('window_stride', 0.02)) 155 | p.append(('window', 'hamming')) 156 | _add_test(p, 'dataset_build') 157 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/tests/test_simple.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | 3 | 4 | def test_load(): 5 | onmt 6 | pass 7 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/train_multi.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Multi-GPU training 4 | """ 5 | import argparse 6 | import os 7 | import signal 8 | import torch 9 | 10 | import onmt.opts as opts 11 | import onmt.utils.distributed 12 | 13 | from onmt.utils.logging import logger 14 | from onmt.train_single import main as single_main 15 | 16 | 17 | def main(opt): 18 | """ Spawns 1 process per GPU """ 19 | nb_gpu = len(opt.gpuid) 20 | mp = torch.multiprocessing.get_context('spawn') 21 | 22 | # Create a thread to listen for errors in the child processes. 23 | error_queue = mp.SimpleQueue() 24 | error_handler = ErrorHandler(error_queue) 25 | 26 | # Train with multiprocessing. 27 | procs = [] 28 | for i in range(nb_gpu): 29 | opt.gpu_rank = i 30 | opt.device_id = i 31 | 32 | procs.append(mp.Process(target=run, args=( 33 | opt, error_queue, ), daemon=True)) 34 | procs[i].start() 35 | logger.info(" Starting process pid: %d " % procs[i].pid) 36 | error_handler.add_child(procs[i].pid) 37 | for p in procs: 38 | p.join() 39 | 40 | 41 | def run(opt, error_queue): 42 | """ run process """ 43 | try: 44 | opt.gpu_rank = onmt.utils.distributed.multi_init(opt) 45 | single_main(opt) 46 | except KeyboardInterrupt: 47 | pass # killed by parent, do nothing 48 | except Exception: 49 | # propagate exception to parent process, keeping original traceback 50 | import traceback 51 | error_queue.put((opt.gpu_rank, traceback.format_exc())) 52 | 53 | 54 | class ErrorHandler(object): 55 | """A class that listens for exceptions in children processes and propagates 56 | the tracebacks to the parent process.""" 57 | 58 | def __init__(self, error_queue): 59 | """ init error handler """ 60 | import signal 61 | import threading 62 | self.error_queue = error_queue 63 | self.children_pids = [] 64 | self.error_thread = threading.Thread( 65 | target=self.error_listener, daemon=True) 66 | self.error_thread.start() 67 | signal.signal(signal.SIGUSR1, self.signal_handler) 68 | 69 | def add_child(self, pid): 70 | """ error handler """ 71 | self.children_pids.append(pid) 72 | 73 | def error_listener(self): 74 | """ error listener """ 75 | (rank, original_trace) = self.error_queue.get() 76 | self.error_queue.put((rank, original_trace)) 77 | os.kill(os.getpid(), signal.SIGUSR1) 78 | 79 | def signal_handler(self, signalnum, stackframe): 80 | """ signal handler """ 81 | for pid in self.children_pids: 82 | os.kill(pid, signal.SIGINT) # kill children processes 83 | (rank, original_trace) = self.error_queue.get() 84 | msg = """\n\n-- Tracebacks above this line can probably 85 | be ignored --\n\n""" 86 | msg += original_trace 87 | raise Exception(msg) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser( 92 | description='train.py', 93 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 94 | 95 | opts.add_md_help_argument(parser) 96 | opts.model_opts(parser) 97 | opts.train_opts(parser) 98 | 99 | opt = parser.parse_args() 100 | main(opt) 101 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/train_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Training on a single process 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | import os 9 | import random 10 | import torch 11 | 12 | import onmt.opts as opts 13 | 14 | from onmt.inputters.inputter import build_dataset_iter, lazily_load_dataset, \ 15 | _load_fields, _collect_report_features 16 | from onmt.model_builder import build_model 17 | from onmt.utils.optimizers import build_optim 18 | from onmt.trainer import build_trainer 19 | from onmt.models import build_model_saver 20 | from onmt.utils.logging import init_logger, logger 21 | 22 | 23 | def _check_save_model_path(opt): 24 | save_model_path = os.path.abspath(opt.save_model) 25 | model_dirname = os.path.dirname(save_model_path) 26 | if not os.path.exists(model_dirname): 27 | os.makedirs(model_dirname) 28 | 29 | 30 | def _tally_parameters(model): 31 | n_params = sum([p.nelement() for p in model.parameters()]) 32 | enc = 0 33 | dec = 0 34 | for name, param in model.named_parameters(): 35 | if 'encoder' in name: 36 | enc += param.nelement() 37 | elif 'decoder' or 'generator' in name: 38 | dec += param.nelement() 39 | return n_params, enc, dec 40 | 41 | 42 | def training_opt_postprocessing(opt): 43 | if opt.word_vec_size != -1: 44 | opt.src_word_vec_size = opt.word_vec_size 45 | opt.tgt_word_vec_size = opt.word_vec_size 46 | 47 | if opt.layers != -1: 48 | opt.enc_layers = opt.layers 49 | opt.dec_layers = opt.layers 50 | 51 | opt.brnn = (opt.encoder_type == "brnn") 52 | 53 | if opt.rnn_type == "SRU" and not opt.gpuid: 54 | raise AssertionError("Using SRU requires -gpuid set.") 55 | 56 | if torch.cuda.is_available() and not opt.gpuid: 57 | logger.info("WARNING: You have a CUDA device, should run with -gpuid") 58 | 59 | if opt.seed > 0: 60 | torch.manual_seed(opt.seed) 61 | # this one is needed for torchtext random call (shuffled iterator) 62 | # in multi gpu it ensures datasets are read in the same order 63 | random.seed(opt.seed) 64 | # some cudnn methods can be random even after fixing the seed 65 | # unless you tell it to be deterministic 66 | torch.backends.cudnn.deterministic = True 67 | 68 | if opt.gpuid: 69 | torch.cuda.set_device(opt.device_id) 70 | if opt.seed > 0: 71 | # These ensure same initialization in multi gpu mode 72 | torch.cuda.manual_seed(opt.seed) 73 | 74 | return opt 75 | 76 | 77 | def main(opt): 78 | opt = training_opt_postprocessing(opt) 79 | init_logger(opt.log_file) 80 | # Load checkpoint if we resume from a previous training. 81 | if opt.train_from: 82 | logger.info('Loading checkpoint from %s' % opt.train_from) 83 | checkpoint = torch.load(opt.train_from, 84 | map_location=lambda storage, loc: storage) 85 | model_opt = checkpoint['opt'] 86 | else: 87 | checkpoint = None 88 | model_opt = opt 89 | 90 | # Peek the first dataset to determine the data_type. 91 | # (All datasets have the same data_type). 92 | first_dataset = next(lazily_load_dataset("train", opt)) 93 | data_type = first_dataset.data_type 94 | 95 | # Load fields generated from preprocess phase. 96 | fields = _load_fields(first_dataset, data_type, opt, checkpoint) 97 | 98 | # Report src/tgt features. 99 | 100 | src_features, tgt_features = _collect_report_features(fields) 101 | for j, feat in enumerate(src_features): 102 | logger.info(' * src feature %d size = %d' 103 | % (j, len(fields[feat].vocab))) 104 | for j, feat in enumerate(tgt_features): 105 | logger.info(' * tgt feature %d size = %d' 106 | % (j, len(fields[feat].vocab))) 107 | 108 | # Build model. 109 | model = build_model(model_opt, opt, fields, checkpoint) 110 | n_params, enc, dec = _tally_parameters(model) 111 | logger.info('encoder: %d' % enc) 112 | logger.info('decoder: %d' % dec) 113 | logger.info('* number of parameters: %d' % n_params) 114 | _check_save_model_path(opt) 115 | 116 | # Build optimizer. 117 | optim = build_optim(model, opt, checkpoint) 118 | 119 | # Build model saver 120 | model_saver = build_model_saver(model_opt, opt, model, fields, optim) 121 | 122 | trainer = build_trainer( 123 | opt, model, fields, optim, data_type, model_saver=model_saver) 124 | 125 | def train_iter_fct(): return build_dataset_iter( 126 | lazily_load_dataset("train", opt), fields, opt) 127 | 128 | def valid_iter_fct(): return build_dataset_iter( 129 | lazily_load_dataset("valid", opt), fields, opt, is_train=False) 130 | 131 | # Do training. 132 | trainer.train(train_iter_fct, valid_iter_fct, opt.train_steps, 133 | opt.valid_steps) 134 | 135 | if opt.tensorboard: 136 | trainer.report_manager.tensorboard_writer.close() 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser( 141 | description='train.py', 142 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 143 | 144 | opts.add_md_help_argument(parser) 145 | opts.model_opts(parser) 146 | opts.train_opts(parser) 147 | 148 | opt = parser.parse_args() 149 | main(opt) 150 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/translate/__init__.py: -------------------------------------------------------------------------------- 1 | """ Modules for translation """ 2 | from onmt.translate.translator import Translator 3 | from onmt.translate.translation import Translation, TranslationBuilder 4 | from onmt.translate.beam import Beam, GNMTGlobalScorer 5 | from onmt.translate.penalties import PenaltyBuilder 6 | from onmt.translate.translation_server import TranslationServer, \ 7 | ServerModelError 8 | 9 | __all__ = ['Translator', 'Translation', 'Beam', 10 | 'GNMTGlobalScorer', 'TranslationBuilder', 11 | 'PenaltyBuilder', 'TranslationServer', 'ServerModelError'] 12 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """ 7 | Returns the Length and Coverage Penalty function for Beam Search. 8 | 9 | Args: 10 | length_pen (str): option name of length pen 11 | cov_pen (str): option name of cov pen 12 | """ 13 | 14 | def __init__(self, cov_pen, length_pen): 15 | self.length_pen = length_pen 16 | self.cov_pen = cov_pen 17 | 18 | def coverage_penalty(self): 19 | if self.cov_pen == "wu": 20 | return self.coverage_wu 21 | elif self.cov_pen == "summary": 22 | return self.coverage_summary 23 | else: 24 | return self.coverage_none 25 | 26 | def length_penalty(self): 27 | if self.length_pen == "wu": 28 | return self.length_wu 29 | elif self.length_pen == "avg": 30 | return self.length_average 31 | else: 32 | return self.length_none 33 | 34 | """ 35 | Below are all the different penalty terms implemented so far 36 | """ 37 | 38 | def coverage_wu(self, beam, cov, beta=0.): 39 | """ 40 | NMT coverage re-ranking score from 41 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 42 | """ 43 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(1) 44 | return beta * penalty 45 | 46 | def coverage_summary(self, beam, cov, beta=0.): 47 | """ 48 | Our summary penalty. 49 | """ 50 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(1) 51 | penalty -= cov.size(1) 52 | return beta * penalty 53 | 54 | def coverage_none(self, beam, cov, beta=0.): 55 | """ 56 | returns zero as penalty 57 | """ 58 | return beam.scores.clone().fill_(0.0) 59 | 60 | def length_wu(self, beam, logprobs, alpha=0.): 61 | """ 62 | NMT length re-ranking score from 63 | "Google's Neural Machine Translation System" :cite:`wu2016google`. 64 | """ 65 | 66 | modifier = (((5 + len(beam.next_ys)) ** alpha) / 67 | ((5 + 1) ** alpha)) 68 | return (logprobs / modifier) 69 | 70 | def length_average(self, beam, logprobs, alpha=0.): 71 | """ 72 | Returns the average probability of tokens in a sequence. 73 | """ 74 | return logprobs / len(beam.next_ys) 75 | 76 | def length_none(self, beam, logprobs, alpha=0., beta=0.): 77 | """ 78 | Returns unmodified scores. 79 | """ 80 | return logprobs 81 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/translate/translation.py: -------------------------------------------------------------------------------- 1 | """ Translation main class """ 2 | from __future__ import division, unicode_literals 3 | from __future__ import print_function 4 | 5 | import torch 6 | import onmt.inputters as inputters 7 | 8 | 9 | class TranslationBuilder(object): 10 | """ 11 | Build a word-based translation from the batch output 12 | of translator and the underlying dictionaries. 13 | 14 | Replacement based on "Addressing the Rare Word 15 | Problem in Neural Machine Translation" :cite:`Luong2015b` 16 | 17 | Args: 18 | data (DataSet): 19 | fields (dict of Fields): data fields 20 | n_best (int): number of translations produced 21 | replace_unk (bool): replace unknown words using attention 22 | has_tgt (bool): will the batch have gold targets 23 | """ 24 | 25 | def __init__(self, data, fields, n_best=1, replace_unk=False, 26 | has_tgt=False): 27 | self.data = data 28 | self.fields = fields 29 | self.n_best = n_best 30 | self.replace_unk = replace_unk 31 | self.has_tgt = has_tgt 32 | 33 | def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): 34 | vocab = self.fields["tgt"].vocab 35 | tokens = [] 36 | for tok in pred: 37 | if tok < len(vocab): 38 | tokens.append(vocab.itos[tok]) 39 | else: 40 | tokens.append(src_vocab.itos[tok - len(vocab)]) 41 | if tokens[-1] == inputters.EOS_WORD: 42 | tokens = tokens[:-1] 43 | break 44 | if self.replace_unk and (attn is not None) and (src is not None): 45 | for i in range(len(tokens)): 46 | if tokens[i] == vocab.itos[inputters.UNK]: 47 | _, max_index = attn[i].max(0) 48 | tokens[i] = src_raw[max_index.item()] 49 | return tokens 50 | 51 | def from_batch(self, translation_batch): 52 | batch = translation_batch["batch"] 53 | assert(len(translation_batch["gold_score"]) == 54 | len(translation_batch["predictions"])) 55 | batch_size = batch.batch_size 56 | 57 | preds, pred_score, attn, gold_score, indices = list(zip( 58 | *sorted(zip(translation_batch["predictions"], 59 | translation_batch["scores"], 60 | translation_batch["attention"], 61 | translation_batch["gold_score"], 62 | batch.indices.data), 63 | key=lambda x: x[-1]))) 64 | 65 | # Sorting 66 | inds, perm = torch.sort(batch.indices.data) 67 | data_type = self.data.data_type 68 | if data_type == 'text': 69 | src = batch.src[0].data.index_select(1, perm) 70 | else: 71 | src = None 72 | 73 | if self.has_tgt: 74 | tgt = batch.tgt.data.index_select(1, perm) 75 | else: 76 | tgt = None 77 | 78 | translations = [] 79 | for b in range(batch_size): 80 | if data_type == 'text': 81 | src_vocab = self.data.src_vocabs[inds[b]] \ 82 | if self.data.src_vocabs else None 83 | src_raw = self.data.examples[inds[b]].src 84 | 85 | else: 86 | src_vocab = None 87 | src_raw = None 88 | pred_sents = [self._build_target_tokens( 89 | src[:, b] if src is not None else None, 90 | src_vocab, src_raw, 91 | preds[b][n], attn[b][n]) 92 | for n in range(self.n_best)] 93 | gold_sent = None 94 | if tgt is not None: 95 | gold_sent = self._build_target_tokens( 96 | src[:, b] if src is not None else None, 97 | src_vocab, src_raw, 98 | tgt[1:, b] if tgt is not None else None, None) 99 | 100 | translation = Translation(src[:, b] if src is not None else None, 101 | src_raw, pred_sents, 102 | attn[b], pred_score[b], gold_sent, 103 | gold_score[b]) 104 | translations.append(translation) 105 | 106 | return translations 107 | 108 | 109 | class Translation(object): 110 | """ 111 | Container for a translated sentence. 112 | 113 | Attributes: 114 | src (`LongTensor`): src word ids 115 | src_raw ([str]): raw src words 116 | 117 | pred_sents ([[str]]): words from the n-best translations 118 | pred_scores ([[float]]): log-probs of n-best translations 119 | attns ([`FloatTensor`]) : attention dist for each translation 120 | gold_sent ([str]): words from gold translation 121 | gold_score ([float]): log-prob of gold translation 122 | 123 | """ 124 | 125 | def __init__(self, src, src_raw, pred_sents, 126 | attn, pred_scores, tgt_sent, gold_score): 127 | self.src = src 128 | self.src_raw = src_raw 129 | self.pred_sents = pred_sents 130 | self.attns = attn 131 | self.pred_scores = pred_scores 132 | self.gold_sent = tgt_sent 133 | self.gold_score = gold_score 134 | 135 | def log(self, sent_number): 136 | """ 137 | Log translation. 138 | """ 139 | 140 | output = '\nSENT {}: {}\n'.format(sent_number, self.src_raw) 141 | 142 | best_pred = self.pred_sents[0] 143 | best_score = self.pred_scores[0] 144 | pred_sent = ' '.join(best_pred) 145 | output += 'PRED {}: {}\n'.format(sent_number, pred_sent) 146 | output += "PRED SCORE: {:.4f}\n".format(best_score) 147 | 148 | if self.gold_sent is not None: 149 | tgt_sent = ' '.join(self.gold_sent) 150 | output += 'GOLD {}: {}\n'.format(sent_number, tgt_sent) 151 | output += ("GOLD SCORE: {:.4f}\n".format(self.gold_score)) 152 | if len(self.pred_sents) > 1: 153 | output += '\nBEST HYP:\n' 154 | for score, sent in zip(self.pred_scores, self.pred_sents): 155 | output += "[{:.4f}] {}\n".format(score, sent) 156 | 157 | return output 158 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining various utilities.""" 2 | from onmt.utils.misc import aeq, use_gpu 3 | from onmt.utils.report_manager import ReportMgr, build_report_manager 4 | from onmt.utils.statistics import Statistics 5 | from onmt.utils.optimizers import build_optim, MultipleOptimizer, \ 6 | Optimizer 7 | 8 | __all__ = ["aeq", "use_gpu", "ReportMgr", 9 | "build_report_manager", "Statistics", 10 | "build_optim", "MultipleOptimizer", "Optimizer"] 11 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/cnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | import torch.nn.functional as F 8 | 9 | import onmt.modules 10 | 11 | SCALE_WEIGHT = 0.5 ** 0.5 12 | 13 | 14 | def shape_transform(x): 15 | """ Tranform the size of the tensors to fit for conv input. """ 16 | return torch.unsqueeze(torch.transpose(x, 1, 2), 3) 17 | 18 | 19 | class GatedConv(nn.Module): 20 | """ Gated convolution for CNN class """ 21 | 22 | def __init__(self, input_size, width=3, dropout=0.2, nopad=False): 23 | super(GatedConv, self).__init__() 24 | self.conv = onmt.modules.WeightNormConv2d( 25 | input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), 26 | padding=(width // 2 * (1 - nopad), 0)) 27 | init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout))**0.5) 28 | self.dropout = nn.Dropout(dropout) 29 | 30 | def forward(self, x_var): 31 | x_var = self.dropout(x_var) 32 | x_var = self.conv(x_var) 33 | out, gate = x_var.split(int(x_var.size(1) / 2), 1) 34 | out = out * F.sigmoid(gate) 35 | return out 36 | 37 | 38 | class StackedCNN(nn.Module): 39 | """ Stacked CNN class """ 40 | 41 | def __init__(self, num_layers, input_size, cnn_kernel_width=3, 42 | dropout=0.2): 43 | super(StackedCNN, self).__init__() 44 | self.dropout = dropout 45 | self.num_layers = num_layers 46 | self.layers = nn.ModuleList() 47 | for _ in range(num_layers): 48 | self.layers.append( 49 | GatedConv(input_size, cnn_kernel_width, dropout)) 50 | 51 | def forward(self, x): 52 | for conv in self.layers: 53 | x = x + conv(x) 54 | x *= SCALE_WEIGHT 55 | return x 56 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | import torch.distributed 12 | 13 | from onmt.utils.logging import logger 14 | 15 | 16 | def is_master(opt): 17 | return opt.gpu_rank == 0 18 | 19 | 20 | def multi_init(opt): 21 | if len(opt.gpuid) == 1: 22 | raise ValueError('Cannot initialize multiprocess with one gpu only') 23 | dist_init_method = 'tcp://localhost:10000' 24 | dist_world_size = len(opt.gpuid) 25 | torch.distributed.init_process_group( 26 | backend=opt.gpu_backend, init_method=dist_init_method, 27 | world_size=dist_world_size, rank=opt.gpu_rank) 28 | opt.gpu_rank = torch.distributed.get_rank() 29 | if not is_master(opt): 30 | logger.disabled = True 31 | 32 | return opt.gpu_rank 33 | 34 | 35 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 36 | buffer_size=10485760): 37 | """All-reduce and rescale tensors in chunks of the specified size. 38 | 39 | Args: 40 | tensors: list of Tensors to all-reduce 41 | rescale_denom: denominator for rescaling summed Tensors 42 | buffer_size: all-reduce chunk size in bytes 43 | """ 44 | # buffer size in bytes, determine equiv. # of elements based on data type 45 | buffer_t = tensors[0].new( 46 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 47 | buffer = [] 48 | 49 | def all_reduce_buffer(): 50 | # copy tensors into buffer_t 51 | offset = 0 52 | for t in buffer: 53 | numel = t.numel() 54 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 55 | offset += numel 56 | 57 | # all-reduce and rescale 58 | torch.distributed.all_reduce(buffer_t[:offset]) 59 | buffer_t.div_(rescale_denom) 60 | 61 | # copy all-reduced buffer back into tensors 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 66 | offset += numel 67 | 68 | filled = 0 69 | for t in tensors: 70 | sz = t.numel() * t.element_size() 71 | if sz > buffer_size: 72 | # tensor is bigger than buffer, all-reduce and rescale directly 73 | torch.distributed.all_reduce(t) 74 | t.div_(rescale_denom) 75 | elif filled + sz > buffer_size: 76 | # buffer is full, all-reduce and replace buffer with grad 77 | all_reduce_buffer() 78 | buffer = [t] 79 | filled = sz 80 | else: 81 | # add tensor to buffer 82 | buffer.append(t) 83 | filled += sz 84 | 85 | if len(buffer) > 0: 86 | all_reduce_buffer() 87 | 88 | 89 | def all_gather_list(data, max_size=4096): 90 | """Gathers arbitrary data from all nodes into a list.""" 91 | world_size = torch.distributed.get_world_size() 92 | if not hasattr(all_gather_list, '_in_buffer') or \ 93 | max_size != all_gather_list._in_buffer.size(): 94 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 95 | all_gather_list._out_buffers = [ 96 | torch.cuda.ByteTensor(max_size) 97 | for i in range(world_size) 98 | ] 99 | in_buffer = all_gather_list._in_buffer 100 | out_buffers = all_gather_list._out_buffers 101 | 102 | enc = pickle.dumps(data) 103 | enc_size = len(enc) 104 | if enc_size + 2 > max_size: 105 | raise ValueError( 106 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 107 | assert max_size < 255*256 108 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 109 | in_buffer[1] = enc_size % 255 110 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 111 | 112 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 113 | 114 | results = [] 115 | for i in range(world_size): 116 | out_buffer = out_buffers[i] 117 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 118 | 119 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 120 | result = pickle.loads(bytes_list) 121 | results.append(result) 122 | return results 123 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file) 20 | file_handler.setFormatter(log_format) 21 | logger.addHandler(file_handler) 22 | 23 | return logger 24 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | 6 | def aeq(*args): 7 | """ 8 | Assert all arguments have the same value 9 | """ 10 | arguments = (arg for arg in args) 11 | first = next(arguments) 12 | assert all(arg == first for arg in arguments), \ 13 | "Not all arguments have the same value: " + str(args) 14 | 15 | 16 | def sequence_mask(lengths, max_len=None): 17 | """ 18 | Creates a boolean mask from sequence lengths. 19 | """ 20 | batch_size = lengths.numel() 21 | max_len = max_len or lengths.max() 22 | return (torch.arange(0, max_len) 23 | .type_as(lengths) 24 | .repeat(batch_size, 1) 25 | .lt(lengths.unsqueeze(1))) 26 | 27 | 28 | def tile(x, count, dim=0): 29 | """ 30 | Tiles x on dimension dim count times. 31 | """ 32 | perm = list(range(len(x.size()))) 33 | if dim != 0: 34 | perm[0], perm[dim] = perm[dim], perm[0] 35 | x = x.permute(perm).contiguous() 36 | out_size = list(x.size()) 37 | out_size[0] *= count 38 | batch = x.size(0) 39 | x = x.view(batch, -1) \ 40 | .transpose(0, 1) \ 41 | .repeat(count, 1) \ 42 | .transpose(0, 1) \ 43 | .contiguous() \ 44 | .view(*out_size) 45 | if dim != 0: 46 | x = x.permute(perm).contiguous() 47 | return x 48 | 49 | 50 | def use_gpu(opt): 51 | """ 52 | Creates a boolean if gpu used 53 | """ 54 | return (hasattr(opt, 'gpuid') and len(opt.gpuid) > 0) or \ 55 | (hasattr(opt, 'gpu') and opt.gpu > -1) 56 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/report_manager.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | import time 4 | from datetime import datetime 5 | 6 | import onmt 7 | 8 | from onmt.utils.logging import logger 9 | 10 | 11 | def build_report_manager(opt): 12 | if opt.tensorboard: 13 | from tensorboardX import SummaryWriter 14 | writer = SummaryWriter(opt.tensorboard_log_dir 15 | + datetime.now().strftime("/%b-%d_%H-%M-%S"), 16 | comment="Unmt") 17 | else: 18 | writer = None 19 | 20 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 21 | tensorboard_writer=writer) 22 | return report_mgr 23 | 24 | 25 | class ReportMgrBase(object): 26 | """ 27 | Report Manager Base class 28 | Inherited classes should override: 29 | * `_report_training` 30 | * `_report_step` 31 | """ 32 | 33 | def __init__(self, report_every, start_time=-1.): 34 | """ 35 | Args: 36 | report_every(int): Report status every this many sentences 37 | start_time(float): manually set report start time. Negative values 38 | means that you will need to set it later or use `start()` 39 | """ 40 | self.report_every = report_every 41 | self.progress_step = 0 42 | self.start_time = start_time 43 | 44 | def start(self): 45 | self.start_time = time.time() 46 | 47 | def log(self, *args, **kwargs): 48 | logger.info(*args, **kwargs) 49 | 50 | def report_training(self, step, num_steps, learning_rate, 51 | report_stats, multigpu=False): 52 | """ 53 | This is the user-defined batch-level traing progress 54 | report function. 55 | 56 | Args: 57 | step(int): current step count. 58 | num_steps(int): total number of batches. 59 | learning_rate(float): current learning rate. 60 | report_stats(Statistics): old Statistics instance. 61 | Returns: 62 | report_stats(Statistics): updated Statistics instance. 63 | """ 64 | if self.start_time < 0: 65 | raise ValueError("""ReportMgr needs to be started 66 | (set 'start_time' or use 'start()'""") 67 | 68 | if multigpu: 69 | report_stats = onmt.utils.Statistics.all_gather_stats(report_stats) 70 | 71 | if step % self.report_every == 0: 72 | self._report_training( 73 | step, num_steps, learning_rate, report_stats) 74 | self.progress_step += 1 75 | return onmt.utils.Statistics() 76 | 77 | def _report_training(self, *args, **kwargs): 78 | """ To be overridden """ 79 | raise NotImplementedError() 80 | 81 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 82 | """ 83 | Report stats of a step 84 | 85 | Args: 86 | train_stats(Statistics): training stats 87 | valid_stats(Statistics): validation stats 88 | lr(float): current learning rate 89 | """ 90 | self._report_step( 91 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 92 | 93 | def _report_step(self, *args, **kwargs): 94 | raise NotImplementedError() 95 | 96 | 97 | class ReportMgr(ReportMgrBase): 98 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 99 | """ 100 | A report manager that writes statistics on standard output as well as 101 | (optionally) TensorBoard 102 | 103 | Args: 104 | report_every(int): Report status every this many sentences 105 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 106 | The TensorBoard Summary writer to use or None 107 | """ 108 | super(ReportMgr, self).__init__(report_every, start_time) 109 | self.tensorboard_writer = tensorboard_writer 110 | 111 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 112 | if self.tensorboard_writer is not None: 113 | stats.log_tensorboard( 114 | prefix, self.tensorboard_writer, learning_rate, step) 115 | 116 | def _report_training(self, step, num_steps, learning_rate, 117 | report_stats): 118 | """ 119 | See base class method `ReportMgrBase.report_training`. 120 | """ 121 | report_stats.output(step, num_steps, 122 | learning_rate, self.start_time) 123 | 124 | # Log the progress using the number of batches on the x-axis. 125 | self.maybe_log_tensorboard(report_stats, 126 | "progress", 127 | learning_rate, 128 | self.progress_step) 129 | report_stats = onmt.utils.Statistics() 130 | 131 | return report_stats 132 | 133 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 134 | """ 135 | See base class method `ReportMgrBase.report_step`. 136 | """ 137 | if train_stats is not None: 138 | self.log('Train perplexity: %g' % train_stats.ppl()) 139 | self.log('Train accuracy: %g' % train_stats.accuracy()) 140 | 141 | self.maybe_log_tensorboard(train_stats, 142 | "train", 143 | lr, 144 | step) 145 | 146 | if valid_stats is not None: 147 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 148 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 149 | 150 | self.maybe_log_tensorboard(valid_stats, 151 | "valid", 152 | lr, 153 | step) 154 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/rnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN tools 3 | """ 4 | from __future__ import division 5 | 6 | import torch.nn as nn 7 | import onmt.models 8 | 9 | 10 | def rnn_factory(rnn_type, **kwargs): 11 | """ rnn factory, Use pytorch version when available. """ 12 | no_pack_padded_seq = False 13 | if rnn_type == "SRU": 14 | # SRU doesn't support PackedSequence. 15 | no_pack_padded_seq = True 16 | rnn = onmt.models.sru.SRU(**kwargs) 17 | else: 18 | rnn = getattr(nn, rnn_type)(**kwargs) 19 | return rnn, no_pack_padded_seq 20 | -------------------------------------------------------------------------------- /Bi-selective Encoding/onmt/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | import time 4 | import math 5 | import sys 6 | 7 | from torch.distributed import get_rank 8 | from onmt.utils.distributed import all_gather_list 9 | from onmt.utils.logging import logger 10 | 11 | 12 | class Statistics(object): 13 | """ 14 | Accumulator for loss statistics. 15 | Currently calculates: 16 | 17 | * accuracy 18 | * perplexity 19 | * elapsed time 20 | """ 21 | 22 | def __init__(self, loss=0, n_words=0, n_correct=0): 23 | self.loss = loss 24 | self.n_words = n_words 25 | self.n_correct = n_correct 26 | self.n_src_words = 0 27 | self.start_time = time.time() 28 | 29 | @staticmethod 30 | def all_gather_stats(stat, max_size=4096): 31 | """ 32 | Gather a `Statistics` object accross multiple process/nodes 33 | 34 | Args: 35 | stat(:obj:Statistics): the statistics object to gather 36 | accross all processes/nodes 37 | max_size(int): max buffer size to use 38 | 39 | Returns: 40 | `Statistics`, the update stats object 41 | """ 42 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 43 | return stats[0] 44 | 45 | @staticmethod 46 | def all_gather_stats_list(stat_list, max_size=4096): 47 | """ 48 | Gather a `Statistics` list accross all processes/nodes 49 | 50 | Args: 51 | stat_list(list([`Statistics`])): list of statistics objects to 52 | gather accross all processes/nodes 53 | max_size(int): max buffer size to use 54 | 55 | Returns: 56 | our_stats(list([`Statistics`])): list of updated stats 57 | """ 58 | # Get a list of world_size lists with len(stat_list) Statistics objects 59 | all_stats = all_gather_list(stat_list, max_size=max_size) 60 | 61 | our_rank = get_rank() 62 | our_stats = all_stats[our_rank] 63 | for other_rank, stats in enumerate(all_stats): 64 | if other_rank == our_rank: 65 | continue 66 | for i, stat in enumerate(stats): 67 | our_stats[i].update(stat, update_n_src_words=True) 68 | return our_stats 69 | 70 | def update(self, stat, update_n_src_words=False): 71 | """ 72 | Update statistics by suming values with another `Statistics` object 73 | 74 | Args: 75 | stat: another statistic object 76 | update_n_src_words(bool): whether to update (sum) `n_src_words` 77 | or not 78 | 79 | """ 80 | self.loss += stat.loss 81 | self.n_words += stat.n_words 82 | self.n_correct += stat.n_correct 83 | 84 | if update_n_src_words: 85 | self.n_src_words += stat.n_src_words 86 | 87 | def accuracy(self): 88 | """ compute accuracy """ 89 | return 100 * (self.n_correct / self.n_words) 90 | 91 | def xent(self): 92 | """ compute cross entropy """ 93 | return self.loss / self.n_words 94 | 95 | def ppl(self): 96 | """ compute perplexity """ 97 | return math.exp(min(self.loss / self.n_words, 100)) 98 | 99 | def elapsed_time(self): 100 | """ compute elapsed time """ 101 | return time.time() - self.start_time 102 | 103 | def output(self, step, num_steps, learning_rate, start): 104 | """Write out statistics to stdout. 105 | 106 | Args: 107 | step (int): current step 108 | n_batch (int): total batches 109 | start (int): start time of step. 110 | """ 111 | t = self.elapsed_time() 112 | logger.info( 113 | ("Step %2d/%5d; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + 114 | "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") 115 | % (step, num_steps, 116 | self.accuracy(), 117 | self.ppl(), 118 | self.xent(), 119 | learning_rate, 120 | self.n_src_words / (t + 1e-5), 121 | self.n_words / (t + 1e-5), 122 | time.time() - start)) 123 | sys.stdout.flush() 124 | 125 | def log_tensorboard(self, prefix, writer, learning_rate, step): 126 | """ display statistics to tensorboard """ 127 | t = self.elapsed_time() 128 | writer.add_scalar(prefix + "/xent", self.xent(), step) 129 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 130 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 131 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 132 | writer.add_scalar(prefix + "/lr", learning_rate, step) 133 | -------------------------------------------------------------------------------- /Bi-selective Encoding/requirements.opt.txt: -------------------------------------------------------------------------------- 1 | cffi 2 | torchvision==0.2.1 3 | joblib==0.11 4 | librosa==0.6.0 5 | Pillow 6 | git+https://github.com/pytorch/audio 7 | pyrouge 8 | -------------------------------------------------------------------------------- /Bi-selective Encoding/requirements.txt: -------------------------------------------------------------------------------- 1 | six 2 | tqdm 3 | torch>=0.4.0 4 | git+https://github.com/pytorch/text 5 | future 6 | -------------------------------------------------------------------------------- /Bi-selective Encoding/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | 4 | from flask import Flask, jsonify, request 5 | from onmt.translate import TranslationServer, ServerModelError 6 | 7 | STATUS_OK = "ok" 8 | STATUS_ERROR = "error" 9 | 10 | 11 | def start(config_file, 12 | url_root="./translator", 13 | host="0.0.0.0", 14 | port=5000, 15 | debug=True): 16 | def prefix_route(route_function, prefix='', mask='{0}{1}'): 17 | def newroute(route, *args, **kwargs): 18 | return route_function(mask.format(prefix, route), *args, **kwargs) 19 | return newroute 20 | 21 | app = Flask(__name__) 22 | app.route = prefix_route(app.route, url_root) 23 | translation_server = TranslationServer() 24 | translation_server.start(config_file) 25 | 26 | @app.route('/models', methods=['GET']) 27 | def get_models(): 28 | out = translation_server.list_models() 29 | return jsonify(out) 30 | 31 | @app.route('/clone_model/', methods=['POST']) 32 | def clone_model(model_id): 33 | out = {} 34 | data = request.get_json(force=True) 35 | timeout = -1 36 | if 'timeout' in data: 37 | timeout = data['timeout'] 38 | del data['timeout'] 39 | 40 | opt = data.get('opt', None) 41 | try: 42 | model_id, load_time = translation_server.clone_model( 43 | model_id, opt, timeout) 44 | except ServerModelError as e: 45 | out['status'] = STATUS_ERROR 46 | out['error'] = str(e) 47 | else: 48 | out['status'] = STATUS_OK 49 | out['model_id'] = model_id 50 | out['load_time'] = load_time 51 | 52 | return jsonify(out) 53 | 54 | @app.route('/unload_model/', methods=['GET']) 55 | def unload_model(model_id): 56 | out = {"model_id": model_id} 57 | 58 | try: 59 | translation_server.unload_model(model_id) 60 | out['status'] = STATUS_OK 61 | except Exception as e: 62 | out['status'] = STATUS_ERROR 63 | out['error'] = str(e) 64 | 65 | return jsonify(out) 66 | 67 | @app.route('/translate', methods=['POST']) 68 | def translate(): 69 | inputs = request.get_json(force=True) 70 | out = {} 71 | try: 72 | translation, scores, n_best, times = translation_server.run(inputs) 73 | assert len(translation) == len(inputs) 74 | assert len(scores) == len(inputs) 75 | 76 | out = [[{"src": inputs[i]['src'], "tgt": translation[i], 77 | "n_best": n_best, 78 | "pred_score": scores[i]} 79 | for i in range(len(translation))]] 80 | except ServerModelError as e: 81 | out['error'] = str(e) 82 | out['status'] = STATUS_ERROR 83 | 84 | return jsonify(out) 85 | 86 | @app.route('/to_cpu/', methods=['GET']) 87 | def to_cpu(model_id): 88 | out = {'model_id': model_id} 89 | translation_server.models[model_id].to_cpu() 90 | 91 | out['status'] = STATUS_OK 92 | return jsonify(out) 93 | 94 | @app.route('/to_gpu/', methods=['GET']) 95 | def to_gpu(model_id): 96 | out = {'model_id': model_id} 97 | translation_server.models[model_id].to_gpu() 98 | 99 | out['status'] = STATUS_OK 100 | return jsonify(out) 101 | 102 | app.run(debug=debug, host=host, port=port, use_reloader=False, 103 | threaded=True) 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser(description="OpenNMT-py REST Server") 108 | parser.add_argument("--ip", type=str, default="0.0.0.0") 109 | parser.add_argument("--port", type=int, default="5000") 110 | parser.add_argument("--url_root", type=str, default="/translator") 111 | parser.add_argument("--debug", "-d", action="store_true") 112 | parser.add_argument("--config", "-c", type=str, 113 | default="./available_models/conf.json") 114 | args = parser.parse_args() 115 | start(args.config, url_root=args.url_root, host=args.ip, port=args.port, 116 | debug=args.debug) 117 | -------------------------------------------------------------------------------- /Bi-selective Encoding/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | setup(name='OpenNMT-py', 6 | description='A python implementation of OpenNMT', 7 | version='0.2.1', 8 | 9 | packages=['onmt', 'onmt.encoders', 'onmt.modules', 'onmt.tests', 10 | 'onmt.translate', 'onmt.decoders', 'onmt.inputters', 11 | 'onmt.models', 'onmt.utils']) 12 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/README.md: -------------------------------------------------------------------------------- 1 | This directly contains scripts and tools adopted from other open source projects such as Apache Joshua and Moses Decoder. 2 | 3 | TODO: credit the authors and resolve license issues (if any) 4 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/average_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import torch 4 | 5 | 6 | def average_models(model_files): 7 | vocab = None 8 | opt = None 9 | avg_model = None 10 | avg_generator = None 11 | 12 | for i, model_file in enumerate(model_files): 13 | m = torch.load(model_file) 14 | model_weights = m['model'] 15 | generator_weights = m['generator'] 16 | 17 | if i == 0: 18 | vocab, opt = m['vocab'], m['opt'] 19 | avg_model = model_weights 20 | avg_generator = generator_weights 21 | else: 22 | for (k, v) in avg_model.items(): 23 | avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1) 24 | 25 | for (k, v) in avg_generator.items(): 26 | avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) 27 | 28 | final = {"vocab": vocab, "opt": opt, "optim": None, 29 | "generator": avg_generator, "model": avg_model} 30 | return final 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description="") 35 | parser.add_argument("-models", "-m", nargs="+", required=True, 36 | help="List of models") 37 | parser.add_argument("-output", "-o", required=True, 38 | help="Output file") 39 | opt = parser.parse_args() 40 | 41 | final = average_models(opt.models) 42 | torch.save(final, opt.output) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/bpe_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Author : Thamme Gowda 3 | # Created : Nov 06, 2017 4 | 5 | ONMT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" 6 | 7 | #======= EXPERIMENT SETUP ====== 8 | # Activate python environment if needed 9 | source ~/.bashrc 10 | # source activate py3 11 | 12 | # update these variables 13 | NAME="run1" 14 | OUT="onmt-runs/$NAME" 15 | 16 | DATA="$ONMT/onmt-runs/data" 17 | TRAIN_SRC=$DATA/*train.src 18 | TRAIN_TGT=$DATA/*train.tgt 19 | VALID_SRC=$DATA/*dev.src 20 | VALID_TGT=$DATA/*dev.tgt 21 | TEST_SRC=$DATA/*test.src 22 | TEST_TGT=$DATA/*test.tgt 23 | 24 | BPE="" # default 25 | BPE="src" # src, tgt, src+tgt 26 | 27 | # applicable only when BPE="src" or "src+tgt" 28 | BPE_SRC_OPS=10000 29 | 30 | # applicable only when BPE="tgt" or "src+tgt" 31 | BPE_TGT_OPS=10000 32 | 33 | GPUARG="" # default 34 | GPUARG="0" 35 | 36 | 37 | #====== EXPERIMENT BEGIN ====== 38 | 39 | # Check if input exists 40 | for f in $TRAIN_SRC $TRAIN_TGT $VALID_SRC $VALID_TGT $TEST_SRC $TEST_TGT; do 41 | if [[ ! -f "$f" ]]; then 42 | echo "Input File $f doesnt exist. Please fix the paths" 43 | exit 1 44 | fi 45 | done 46 | 47 | function lines_check { 48 | l1=`wc -l $1` 49 | l2=`wc -l $2` 50 | if [[ $l1 != $l2 ]]; then 51 | echo "ERROR: Record counts doesnt match between: $1 and $2" 52 | exit 2 53 | fi 54 | } 55 | lines_check $TRAIN_SRC $TRAIN_TGT 56 | lines_check $VALID_SRC $VALID_TGT 57 | lines_check $TEST_SRC $TEST_TGT 58 | 59 | 60 | echo "Output dir = $OUT" 61 | [ -d $OUT ] || mkdir -p $OUT 62 | [ -d $OUT/data ] || mkdir -p $OUT/data 63 | [ -d $OUT/models ] || mkdir $OUT/models 64 | [ -d $OUT/test ] || mkdir -p $OUT/test 65 | 66 | 67 | echo "Step 1a: Preprocess inputs" 68 | if [[ "$BPE" == *"src"* ]]; then 69 | echo "BPE on source" 70 | # Here we could use more monolingual data 71 | $ONMT/tools/learn_bpe.py -s $BPE_SRC_OPS < $TRAIN_SRC > $OUT/data/bpe-codes.src 72 | 73 | $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.src < $TRAIN_SRC > $OUT/data/train.src 74 | $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.src < $VALID_SRC > $OUT/data/valid.src 75 | $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.src < $TEST_SRC > $OUT/data/test.src 76 | else 77 | ln -sf $TRAIN_SRC $OUT/data/train.src 78 | ln -sf $VALID_SRC $OUT/data/valid.src 79 | ln -sf $TEST_SRC $OUT/data/test.src 80 | fi 81 | 82 | 83 | if [[ "$BPE" == *"tgt"* ]]; then 84 | echo "BPE on target" 85 | # Here we could use more monolingual data 86 | $ONMT/tools/learn_bpe.py -s $BPE_SRC_OPS < $TRAIN_TGT > $OUT/data/bpe-codes.tgt 87 | 88 | $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.tgt < $TRAIN_TGT > $OUT/data/train.tgt 89 | $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.tgt < $VALID_TGT > $OUT/data/valid.tgt 90 | #$ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.tgt < $TEST_TGT > $OUT/data/test.tgt 91 | # We dont touch the test References, No BPE on them! 92 | ln -sf $TEST_TGT $OUT/data/test.tgt 93 | else 94 | ln -sf $TRAIN_TGT $OUT/data/train.tgt 95 | ln -sf $VALID_TGT $OUT/data/valid.tgt 96 | ln -sf $TEST_TGT $OUT/data/test.tgt 97 | fi 98 | 99 | 100 | #: < maxv) {maxv=score; max=$0}} END{ print max}'` 124 | echo "Chosen Model = $model" 125 | if [[ -z "$model" ]]; then 126 | echo "Model not found. Looked in $OUT/models/" 127 | exit 1 128 | fi 129 | 130 | GPU_OPTS="" 131 | if [ ! -z $GPUARG ]; then 132 | GPU_OPTS="-gpu $GPUARG" 133 | fi 134 | 135 | echo "Step 3a: Translate Test" 136 | python $ONMT/translate.py -model $model \ 137 | -src $OUT/data/test.src \ 138 | -output $OUT/test/test.out \ 139 | -replace_unk -verbose $GPU_OPTS > $OUT/test/test.log 140 | 141 | echo "Step 3b: Translate Dev" 142 | python $ONMT/translate.py -model $model \ 143 | -src $OUT/data/valid.src \ 144 | -output $OUT/test/valid.out \ 145 | -replace_unk -verbose $GPU_OPTS > $OUT/test/valid.log 146 | 147 | if [[ "$BPE" == *"tgt"* ]]; then 148 | echo "BPE decoding/detokenising target to match with references" 149 | mv $OUT/test/test.out{,.bpe} 150 | mv $OUT/test/valid.out{,.bpe} 151 | cat $OUT/test/valid.out.bpe | sed -E 's/(@@ )|(@@ ?$)//g' > $OUT/test/valid.out 152 | cat $OUT/test/test.out.bpe | sed -E 's/(@@ )|(@@ ?$)//g' > $OUT/test/test.out 153 | fi 154 | 155 | echo "Step 4a: Evaluate Test" 156 | $ONMT/tools/multi-bleu-detok.perl $OUT/data/test.tgt < $OUT/test/test.out > $OUT/test/test.tc.bleu 157 | $ONMT/tools/multi-bleu-detok.perl -lc $OUT/data/test.tgt < $OUT/test/test.out > $OUT/test/test.lc.bleu 158 | 159 | echo "Step 4b: Evaluate Dev" 160 | $ONMT/tools/multi-bleu-detok.perl $OUT/data/valid.tgt < $OUT/test/valid.out > $OUT/test/valid.tc.bleu 161 | $ONMT/tools/multi-bleu-detok.perl -lc $OUT/data/valid.tgt < $OUT/test/valid.out > $OUT/test/valid.lc.bleu 162 | 163 | #===== EXPERIMENT END ====== 164 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/embeddings_to_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | from __future__ import division 5 | import six 6 | import sys 7 | import numpy as np 8 | import argparse 9 | import torch 10 | from onmt.utils.logging import init_logger, logger 11 | 12 | 13 | def get_vocabs(dict_file): 14 | vocabs = torch.load(dict_file) 15 | 16 | enc_vocab, dec_vocab = None, None 17 | 18 | # the vocab object is a list of tuple (name, torchtext.Vocab) 19 | # we iterate over this list and associate vocabularies based on the name 20 | for vocab in vocabs: 21 | if vocab[0] == 'src': 22 | enc_vocab = vocab[1] 23 | if vocab[0] == 'tgt': 24 | dec_vocab = vocab[1] 25 | assert enc_vocab is not None and dec_vocab is not None 26 | 27 | logger.info("From: %s" % dict_file) 28 | logger.info("\t* source vocab: %d words" % len(enc_vocab)) 29 | logger.info("\t* target vocab: %d words" % len(dec_vocab)) 30 | 31 | return enc_vocab, dec_vocab 32 | 33 | 34 | def get_embeddings(file_enc, opt, flag): 35 | embs = dict() 36 | if flag == 'enc': 37 | for (i, l) in enumerate(open(file_enc, 'rb')): 38 | if i < opt.skip_lines: 39 | continue 40 | if not l: 41 | break 42 | if len(l) == 0: 43 | continue 44 | 45 | l_split = l.decode('utf8').strip().split(' ') 46 | if len(l_split) == 2: 47 | continue 48 | embs[l_split[0]] = [float(em) for em in l_split[1:]] 49 | logger.info("Got {} encryption embeddings from {}".format(len(embs), 50 | file_enc)) 51 | else: 52 | 53 | for (i, l) in enumerate(open(file_enc, 'rb')): 54 | if not l: 55 | break 56 | if len(l) == 0: 57 | continue 58 | 59 | l_split = l.decode('utf8').strip().split(' ') 60 | if len(l_split) == 2: 61 | continue 62 | embs[l_split[0]] = [float(em) for em in l_split[1:]] 63 | logger.info("Got {} decryption embeddings from {}".format(len(embs), 64 | file_enc)) 65 | return embs 66 | 67 | 68 | def match_embeddings(vocab, emb, opt): 69 | dim = len(six.next(six.itervalues(emb))) 70 | filtered_embeddings = np.zeros((len(vocab), dim)) 71 | count = {"match": 0, "miss": 0} 72 | for w, w_id in vocab.stoi.items(): 73 | if w in emb: 74 | filtered_embeddings[w_id] = emb[w] 75 | count['match'] += 1 76 | else: 77 | if opt.verbose: 78 | logger.info(u"not found:\t{}".format(w), file=sys.stderr) 79 | count['miss'] += 1 80 | 81 | return torch.Tensor(filtered_embeddings), count 82 | 83 | 84 | TYPES = ["GloVe", "word2vec"] 85 | 86 | 87 | def main(): 88 | 89 | parser = argparse.ArgumentParser(description='embeddings_to_torch.py') 90 | parser.add_argument('-emb_file_enc', required=True, 91 | help="source Embeddings from this file") 92 | parser.add_argument('-emb_file_dec', required=True, 93 | help="target Embeddings from this file") 94 | parser.add_argument('-output_file', required=True, 95 | help="Output file for the prepared data") 96 | parser.add_argument('-dict_file', required=True, 97 | help="Dictionary file") 98 | parser.add_argument('-verbose', action="store_true", default=False) 99 | parser.add_argument('-skip_lines', type=int, default=0, 100 | help="Skip first lines of the embedding file") 101 | parser.add_argument('-type', choices=TYPES, default="GloVe") 102 | opt = parser.parse_args() 103 | 104 | enc_vocab, dec_vocab = get_vocabs(opt.dict_file) 105 | if opt.type == "word2vec": 106 | opt.skip_lines = 1 107 | 108 | embeddings_enc = get_embeddings(opt.emb_file_enc, opt, flag='enc') 109 | embeddings_dec = get_embeddings(opt.emb_file_dec, opt, flag='dec') 110 | 111 | filtered_enc_embeddings, enc_count = match_embeddings(enc_vocab, 112 | embeddings_enc, 113 | opt) 114 | filtered_dec_embeddings, dec_count = match_embeddings(dec_vocab, 115 | embeddings_dec, 116 | opt) 117 | logger.info("\nMatching: ") 118 | match_percent = [_['match'] / (_['match'] + _['miss']) * 100 119 | for _ in [enc_count, dec_count]] 120 | logger.info("\t* enc: %d match, %d missing, (%.2f%%)" 121 | % (enc_count['match'], 122 | enc_count['miss'], 123 | match_percent[0])) 124 | logger.info("\t* dec: %d match, %d missing, (%.2f%%)" 125 | % (dec_count['match'], 126 | dec_count['miss'], 127 | match_percent[1])) 128 | 129 | logger.info("\nFiltered embeddings:") 130 | logger.info("\t* enc: ", filtered_enc_embeddings.size()) 131 | logger.info("\t* dec: ", filtered_dec_embeddings.size()) 132 | 133 | enc_output_file = opt.output_file + ".enc.pt" 134 | dec_output_file = opt.output_file + ".dec.pt" 135 | logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" 136 | % (enc_output_file, dec_output_file)) 137 | torch.save(filtered_enc_embeddings, enc_output_file) 138 | torch.save(filtered_dec_embeddings, dec_output_file) 139 | logger.info("\nDone.") 140 | 141 | 142 | if __name__ == "__main__": 143 | init_logger('embeddings_to_torch.log') 144 | main() 145 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/extract_embeddings.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import argparse 4 | import onmt 5 | import onmt.model_builder 6 | import onmt.inputters 7 | import onmt.opts 8 | 9 | from onmt.utils.misc import use_gpu 10 | from onmt.utils.logging import init_logger, logger 11 | 12 | parser = argparse.ArgumentParser(description='translate.py') 13 | 14 | parser.add_argument('-model', required=True, 15 | help='Path to model .pt file') 16 | parser.add_argument('-output_dir', default='.', 17 | help="""Path to output the embeddings""") 18 | parser.add_argument('-gpu', type=int, default=-1, 19 | help="Device to run on") 20 | 21 | 22 | def write_embeddings(filename, dict, embeddings): 23 | with open(filename, 'wb') as file: 24 | for i in range(min(len(embeddings), len(dict.itos))): 25 | str = dict.itos[i].encode("utf-8") 26 | for j in range(len(embeddings[0])): 27 | str = str + (" %5f" % (embeddings[i][j])).encode("utf-8") 28 | file.write(str + b"\n") 29 | 30 | 31 | def main(): 32 | dummy_parser = argparse.ArgumentParser(description='train.py') 33 | onmt.opts.model_opts(dummy_parser) 34 | dummy_opt = dummy_parser.parse_known_args([])[0] 35 | opt = parser.parse_args() 36 | opt.cuda = opt.gpu > -1 37 | if opt.cuda: 38 | torch.cuda.set_device(opt.gpu) 39 | 40 | # Add in default model arguments, possibly added since training. 41 | checkpoint = torch.load(opt.model, 42 | map_location=lambda storage, loc: storage) 43 | model_opt = checkpoint['opt'] 44 | 45 | src_dict, tgt_dict = None, None 46 | 47 | # the vocab object is a list of tuple (name, torchtext.Vocab) 48 | # we iterate over this list and associate vocabularies based on the name 49 | for vocab in checkpoint['vocab']: 50 | if vocab[0] == 'src': 51 | src_dict = vocab[1] 52 | if vocab[0] == 'tgt': 53 | tgt_dict = vocab[1] 54 | assert src_dict is not None and tgt_dict is not None 55 | 56 | fields = onmt.inputters.load_fields_from_vocab(checkpoint['vocab']) 57 | 58 | model_opt = checkpoint['opt'] 59 | for arg in dummy_opt.__dict__: 60 | if arg not in model_opt: 61 | model_opt.__dict__[arg] = dummy_opt.__dict__[arg] 62 | 63 | model = onmt.model_builder.build_base_model( 64 | model_opt, fields, use_gpu(opt), checkpoint) 65 | encoder = model.encoder 66 | decoder = model.decoder 67 | 68 | encoder_embeddings = encoder.embeddings.word_lut.weight.data.tolist() 69 | decoder_embeddings = decoder.embeddings.word_lut.weight.data.tolist() 70 | 71 | logger.info("Writing source embeddings") 72 | write_embeddings(opt.output_dir + "/src_embeddings.txt", src_dict, 73 | encoder_embeddings) 74 | 75 | logger.info("Writing target embeddings") 76 | write_embeddings(opt.output_dir + "/tgt_embeddings.txt", tgt_dict, 77 | decoder_embeddings) 78 | 79 | logger.info('... done.') 80 | logger.info('Converting model...') 81 | 82 | 83 | if __name__ == "__main__": 84 | init_logger('extract_embeddings.log') 85 | main() 86 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/multi-bleu.perl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env perl 2 | # 3 | # This file is part of moses. Its use is licensed under the GNU Lesser General 4 | # Public License version 2.1 or, at your option, any later version. 5 | 6 | # $Id$ 7 | use warnings; 8 | use strict; 9 | 10 | my $lowercase = 0; 11 | if ($ARGV[0] eq "-lc") { 12 | $lowercase = 1; 13 | shift; 14 | } 15 | 16 | my $stem = $ARGV[0]; 17 | if (!defined $stem) { 18 | print STDERR "usage: multi-bleu.pl [-lc] reference < hypothesis\n"; 19 | print STDERR "Reads the references from reference or reference0, reference1, ...\n"; 20 | exit(1); 21 | } 22 | 23 | $stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; 24 | 25 | my @REF; 26 | my $ref=0; 27 | while(-e "$stem$ref") { 28 | &add_to_ref("$stem$ref",\@REF); 29 | $ref++; 30 | } 31 | &add_to_ref($stem,\@REF) if -e $stem; 32 | die("ERROR: could not find reference file $stem") unless scalar @REF; 33 | 34 | # add additional references explicitly specified on the command line 35 | shift; 36 | foreach my $stem (@ARGV) { 37 | &add_to_ref($stem,\@REF) if -e $stem; 38 | } 39 | 40 | 41 | 42 | sub add_to_ref { 43 | my ($file,$REF) = @_; 44 | my $s=0; 45 | if ($file =~ /.gz$/) { 46 | open(REF,"gzip -dc $file|") or die "Can't read $file"; 47 | } else { 48 | open(REF,$file) or die "Can't read $file"; 49 | } 50 | while() { 51 | chop; 52 | push @{$$REF[$s++]}, $_; 53 | } 54 | close(REF); 55 | } 56 | 57 | my(@CORRECT,@TOTAL,$length_translation,$length_reference); 58 | my $s=0; 59 | while() { 60 | chop; 61 | $_ = lc if $lowercase; 62 | my @WORD = split; 63 | my %REF_NGRAM = (); 64 | my $length_translation_this_sentence = scalar(@WORD); 65 | my ($closest_diff,$closest_length) = (9999,9999); 66 | foreach my $reference (@{$REF[$s]}) { 67 | # print "$s $_ <=> $reference\n"; 68 | $reference = lc($reference) if $lowercase; 69 | my @WORD = split(' ',$reference); 70 | my $length = scalar(@WORD); 71 | my $diff = abs($length_translation_this_sentence-$length); 72 | if ($diff < $closest_diff) { 73 | $closest_diff = $diff; 74 | $closest_length = $length; 75 | # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; 76 | } elsif ($diff == $closest_diff) { 77 | $closest_length = $length if $length < $closest_length; 78 | # from two references with the same closeness to me 79 | # take the *shorter* into account, not the "first" one. 80 | } 81 | for(my $n=1;$n<=4;$n++) { 82 | my %REF_NGRAM_N = (); 83 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 84 | my $ngram = "$n"; 85 | for(my $w=0;$w<$n;$w++) { 86 | $ngram .= " ".$WORD[$start+$w]; 87 | } 88 | $REF_NGRAM_N{$ngram}++; 89 | } 90 | foreach my $ngram (keys %REF_NGRAM_N) { 91 | if (!defined($REF_NGRAM{$ngram}) || 92 | $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { 93 | $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; 94 | # print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; 95 | } 96 | } 97 | } 98 | } 99 | $length_translation += $length_translation_this_sentence; 100 | $length_reference += $closest_length; 101 | for(my $n=1;$n<=4;$n++) { 102 | my %T_NGRAM = (); 103 | for(my $start=0;$start<=$#WORD-($n-1);$start++) { 104 | my $ngram = "$n"; 105 | for(my $w=0;$w<$n;$w++) { 106 | $ngram .= " ".$WORD[$start+$w]; 107 | } 108 | $T_NGRAM{$ngram}++; 109 | } 110 | foreach my $ngram (keys %T_NGRAM) { 111 | $ngram =~ /^(\d+) /; 112 | my $n = $1; 113 | # my $corr = 0; 114 | # print "$i e $ngram $T_NGRAM{$ngram}
\n"; 115 | $TOTAL[$n] += $T_NGRAM{$ngram}; 116 | if (defined($REF_NGRAM{$ngram})) { 117 | if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { 118 | $CORRECT[$n] += $T_NGRAM{$ngram}; 119 | # $corr = $T_NGRAM{$ngram}; 120 | # print "$i e correct1 $T_NGRAM{$ngram}
\n"; 121 | } 122 | else { 123 | $CORRECT[$n] += $REF_NGRAM{$ngram}; 124 | # $corr = $REF_NGRAM{$ngram}; 125 | # print "$i e correct2 $REF_NGRAM{$ngram}
\n"; 126 | } 127 | } 128 | # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; 129 | # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" 130 | } 131 | } 132 | $s++; 133 | } 134 | my $brevity_penalty = 1; 135 | my $bleu = 0; 136 | 137 | my @bleu=(); 138 | 139 | for(my $n=1;$n<=4;$n++) { 140 | if (defined ($TOTAL[$n])){ 141 | $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; 142 | # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; 143 | }else{ 144 | $bleu[$n]=0; 145 | } 146 | } 147 | 148 | if ($length_reference==0){ 149 | printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; 150 | exit(1); 151 | } 152 | 153 | if ($length_translation<$length_reference) { 154 | $brevity_penalty = exp(1-$length_reference/$length_translation); 155 | } 156 | $bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + 157 | my_log( $bleu[2] ) + 158 | my_log( $bleu[3] ) + 159 | my_log( $bleu[4] ) ) / 4) ; 160 | printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", 161 | 100*$bleu, 162 | 100*$bleu[1], 163 | 100*$bleu[2], 164 | 100*$bleu[3], 165 | 100*$bleu[4], 166 | $brevity_penalty, 167 | $length_translation / $length_reference, 168 | $length_translation, 169 | $length_reference; 170 | 171 | sub my_log { 172 | return -9999999999 unless $_[0]; 173 | return log($_[0]); 174 | } 175 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/README.txt: -------------------------------------------------------------------------------- 1 | The language suffix can be found here: 2 | 3 | http://www.loc.gov/standards/iso639-2/php/code_list.php 4 | 5 | This code includes data from Daniel Naber's Language Tools (czech abbreviations). 6 | This code includes data from czech wiktionary (also czech abbreviations). 7 | 8 | 9 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.ca: -------------------------------------------------------------------------------- 1 | Dr 2 | Dra 3 | pàg 4 | p 5 | c 6 | av 7 | Sr 8 | Sra 9 | adm 10 | esq 11 | Prof 12 | S.A 13 | S.L 14 | p.e 15 | ptes 16 | Sta 17 | St 18 | pl 19 | màx 20 | cast 21 | dir 22 | nre 23 | fra 24 | admdora 25 | Emm 26 | Excma 27 | espf 28 | dc 29 | admdor 30 | tel 31 | angl 32 | aprox 33 | ca 34 | dept 35 | dj 36 | dl 37 | dt 38 | ds 39 | dg 40 | dv 41 | ed 42 | entl 43 | al 44 | i.e 45 | maj 46 | smin 47 | n 48 | núm 49 | pta 50 | A 51 | B 52 | C 53 | D 54 | E 55 | F 56 | G 57 | H 58 | I 59 | J 60 | K 61 | L 62 | M 63 | N 64 | O 65 | P 66 | Q 67 | R 68 | S 69 | T 70 | U 71 | V 72 | W 73 | X 74 | Y 75 | Z 76 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.cs: -------------------------------------------------------------------------------- 1 | Bc 2 | BcA 3 | Ing 4 | Ing.arch 5 | MUDr 6 | MVDr 7 | MgA 8 | Mgr 9 | JUDr 10 | PhDr 11 | RNDr 12 | PharmDr 13 | ThLic 14 | ThDr 15 | Ph.D 16 | Th.D 17 | prof 18 | doc 19 | CSc 20 | DrSc 21 | dr. h. c 22 | PaedDr 23 | Dr 24 | PhMr 25 | DiS 26 | abt 27 | ad 28 | a.i 29 | aj 30 | angl 31 | anon 32 | apod 33 | atd 34 | atp 35 | aut 36 | bd 37 | biogr 38 | b.m 39 | b.p 40 | b.r 41 | cca 42 | cit 43 | cizojaz 44 | c.k 45 | col 46 | čes 47 | čín 48 | čj 49 | ed 50 | facs 51 | fasc 52 | fol 53 | fot 54 | franc 55 | h.c 56 | hist 57 | hl 58 | hrsg 59 | ibid 60 | il 61 | ind 62 | inv.č 63 | jap 64 | jhdt 65 | jv 66 | koed 67 | kol 68 | korej 69 | kl 70 | krit 71 | lat 72 | lit 73 | m.a 74 | maď 75 | mj 76 | mp 77 | násl 78 | např 79 | nepubl 80 | něm 81 | no 82 | nr 83 | n.s 84 | okr 85 | odd 86 | odp 87 | obr 88 | opr 89 | orig 90 | phil 91 | pl 92 | pokrač 93 | pol 94 | port 95 | pozn 96 | př.kr 97 | př.n.l 98 | přel 99 | přeprac 100 | příl 101 | pseud 102 | pt 103 | red 104 | repr 105 | resp 106 | revid 107 | rkp 108 | roč 109 | roz 110 | rozš 111 | samost 112 | sect 113 | sest 114 | seš 115 | sign 116 | sl 117 | srv 118 | stol 119 | sv 120 | šk 121 | šk.ro 122 | špan 123 | tab 124 | t.č 125 | tis 126 | tj 127 | tř 128 | tzv 129 | univ 130 | uspoř 131 | vol 132 | vl.jm 133 | vs 134 | vyd 135 | vyobr 136 | zal 137 | zejm 138 | zkr 139 | zprac 140 | zvl 141 | n.p 142 | např 143 | než 144 | MUDr 145 | abl 146 | absol 147 | adj 148 | adv 149 | ak 150 | ak. sl 151 | akt 152 | alch 153 | amer 154 | anat 155 | angl 156 | anglosas 157 | arab 158 | arch 159 | archit 160 | arg 161 | astr 162 | astrol 163 | att 164 | bás 165 | belg 166 | bibl 167 | biol 168 | boh 169 | bot 170 | bulh 171 | círk 172 | csl 173 | č 174 | čas 175 | čes 176 | dat 177 | děj 178 | dep 179 | dět 180 | dial 181 | dór 182 | dopr 183 | dosl 184 | ekon 185 | epic 186 | etnonym 187 | eufem 188 | f 189 | fam 190 | fem 191 | fil 192 | film 193 | form 194 | fot 195 | fr 196 | fut 197 | fyz 198 | gen 199 | geogr 200 | geol 201 | geom 202 | germ 203 | gram 204 | hebr 205 | herald 206 | hist 207 | hl 208 | hovor 209 | hud 210 | hut 211 | chcsl 212 | chem 213 | ie 214 | imp 215 | impf 216 | ind 217 | indoevr 218 | inf 219 | instr 220 | interj 221 | ión 222 | iron 223 | it 224 | kanad 225 | katalán 226 | klas 227 | kniž 228 | komp 229 | konj 230 | 231 | konkr 232 | kř 233 | kuch 234 | lat 235 | lék 236 | les 237 | lid 238 | lit 239 | liturg 240 | lok 241 | log 242 | m 243 | mat 244 | meteor 245 | metr 246 | mod 247 | ms 248 | mysl 249 | n 250 | náb 251 | námoř 252 | neklas 253 | něm 254 | nesklon 255 | nom 256 | ob 257 | obch 258 | obyč 259 | ojed 260 | opt 261 | part 262 | pas 263 | pejor 264 | pers 265 | pf 266 | pl 267 | plpf 268 | 269 | práv 270 | prep 271 | předl 272 | přivl 273 | r 274 | rcsl 275 | refl 276 | reg 277 | rkp 278 | ř 279 | řec 280 | s 281 | samohl 282 | sg 283 | sl 284 | souhl 285 | spec 286 | srov 287 | stfr 288 | střv 289 | stsl 290 | subj 291 | subst 292 | superl 293 | sv 294 | sz 295 | táz 296 | tech 297 | telev 298 | teol 299 | trans 300 | typogr 301 | var 302 | vedl 303 | verb 304 | vl. jm 305 | voj 306 | vok 307 | vůb 308 | vulg 309 | výtv 310 | vztaž 311 | zahr 312 | zájm 313 | zast 314 | zejm 315 | 316 | zeměd 317 | zkr 318 | zř 319 | mj 320 | dl 321 | atp 322 | sport 323 | Mgr 324 | horn 325 | MVDr 326 | JUDr 327 | RSDr 328 | Bc 329 | PhDr 330 | ThDr 331 | Ing 332 | aj 333 | apod 334 | PharmDr 335 | pomn 336 | ev 337 | slang 338 | nprap 339 | odp 340 | dop 341 | pol 342 | st 343 | stol 344 | p. n. l 345 | před n. l 346 | n. l 347 | př. Kr 348 | po Kr 349 | př. n. l 350 | odd 351 | RNDr 352 | tzv 353 | atd 354 | tzn 355 | resp 356 | tj 357 | p 358 | br 359 | č. j 360 | čj 361 | č. p 362 | čp 363 | a. s 364 | s. r. o 365 | spol. s r. o 366 | p. o 367 | s. p 368 | v. o. s 369 | k. s 370 | o. p. s 371 | o. s 372 | v. r 373 | v z 374 | ml 375 | vč 376 | kr 377 | mld 378 | hod 379 | popř 380 | ap 381 | event 382 | rus 383 | slov 384 | rum 385 | švýc 386 | P. T 387 | zvl 388 | hor 389 | dol 390 | S.O.S -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.de: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | #no german words end in single lower-case letters, so we throw those in too. 7 | A 8 | B 9 | C 10 | D 11 | E 12 | F 13 | G 14 | H 15 | I 16 | J 17 | K 18 | L 19 | M 20 | N 21 | O 22 | P 23 | Q 24 | R 25 | S 26 | T 27 | U 28 | V 29 | W 30 | X 31 | Y 32 | Z 33 | a 34 | b 35 | c 36 | d 37 | e 38 | f 39 | g 40 | h 41 | i 42 | j 43 | k 44 | l 45 | m 46 | n 47 | o 48 | p 49 | q 50 | r 51 | s 52 | t 53 | u 54 | v 55 | w 56 | x 57 | y 58 | z 59 | 60 | 61 | #Roman Numerals. A dot after one of these is not a sentence break in German. 62 | I 63 | II 64 | III 65 | IV 66 | V 67 | VI 68 | VII 69 | VIII 70 | IX 71 | X 72 | XI 73 | XII 74 | XIII 75 | XIV 76 | XV 77 | XVI 78 | XVII 79 | XVIII 80 | XIX 81 | XX 82 | i 83 | ii 84 | iii 85 | iv 86 | v 87 | vi 88 | vii 89 | viii 90 | ix 91 | x 92 | xi 93 | xii 94 | xiii 95 | xiv 96 | xv 97 | xvi 98 | xvii 99 | xviii 100 | xix 101 | xx 102 | 103 | #Titles and Honorifics 104 | Adj 105 | Adm 106 | Adv 107 | Asst 108 | Bart 109 | Bldg 110 | Brig 111 | Bros 112 | Capt 113 | Cmdr 114 | Col 115 | Comdr 116 | Con 117 | Corp 118 | Cpl 119 | DR 120 | Dr 121 | Ens 122 | Gen 123 | Gov 124 | Hon 125 | Hosp 126 | Insp 127 | Lt 128 | MM 129 | MR 130 | MRS 131 | MS 132 | Maj 133 | Messrs 134 | Mlle 135 | Mme 136 | Mr 137 | Mrs 138 | Ms 139 | Msgr 140 | Op 141 | Ord 142 | Pfc 143 | Ph 144 | Prof 145 | Pvt 146 | Rep 147 | Reps 148 | Res 149 | Rev 150 | Rt 151 | Sen 152 | Sens 153 | Sfc 154 | Sgt 155 | Sr 156 | St 157 | Supt 158 | Surg 159 | 160 | #Misc symbols 161 | Mio 162 | Mrd 163 | bzw 164 | v 165 | vs 166 | usw 167 | d.h 168 | z.B 169 | u.a 170 | etc 171 | Mrd 172 | MwSt 173 | ggf 174 | d.J 175 | D.h 176 | m.E 177 | vgl 178 | I.F 179 | z.T 180 | sogen 181 | ff 182 | u.E 183 | g.U 184 | g.g.A 185 | c.-à-d 186 | Buchst 187 | u.s.w 188 | sog 189 | u.ä 190 | Std 191 | evtl 192 | Zt 193 | Chr 194 | u.U 195 | o.ä 196 | Ltd 197 | b.A 198 | z.Zt 199 | spp 200 | sen 201 | SA 202 | k.o 203 | jun 204 | i.H.v 205 | dgl 206 | dergl 207 | Co 208 | zzt 209 | usf 210 | s.p.a 211 | Dkr 212 | Corp 213 | bzgl 214 | BSE 215 | 216 | #Number indicators 217 | # add #NUMERIC_ONLY# after the word if it should ONLY be non-breaking when a 0-9 digit follows it 218 | No 219 | Nos 220 | Art 221 | Nr 222 | pp 223 | ca 224 | Ca 225 | 226 | #Ordinals are done with . in German - "1." = "1st" in English 227 | 1 228 | 2 229 | 3 230 | 4 231 | 5 232 | 6 233 | 7 234 | 8 235 | 9 236 | 10 237 | 11 238 | 12 239 | 13 240 | 14 241 | 15 242 | 16 243 | 17 244 | 18 245 | 19 246 | 20 247 | 21 248 | 22 249 | 23 250 | 24 251 | 25 252 | 26 253 | 27 254 | 28 255 | 29 256 | 30 257 | 31 258 | 32 259 | 33 260 | 34 261 | 35 262 | 36 263 | 37 264 | 38 265 | 39 266 | 40 267 | 41 268 | 42 269 | 43 270 | 44 271 | 45 272 | 46 273 | 47 274 | 48 275 | 49 276 | 50 277 | 51 278 | 52 279 | 53 280 | 54 281 | 55 282 | 56 283 | 57 284 | 58 285 | 59 286 | 60 287 | 61 288 | 62 289 | 63 290 | 64 291 | 65 292 | 66 293 | 67 294 | 68 295 | 69 296 | 70 297 | 71 298 | 72 299 | 73 300 | 74 301 | 75 302 | 76 303 | 77 304 | 78 305 | 79 306 | 80 307 | 81 308 | 82 309 | 83 310 | 84 311 | 85 312 | 86 313 | 87 314 | 88 315 | 89 316 | 90 317 | 91 318 | 92 319 | 93 320 | 94 321 | 95 322 | 96 323 | 97 324 | 98 325 | 99 326 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.en: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | 33 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 34 | Adj 35 | Adm 36 | Adv 37 | Asst 38 | Bart 39 | Bldg 40 | Brig 41 | Bros 42 | Capt 43 | Cmdr 44 | Col 45 | Comdr 46 | Con 47 | Corp 48 | Cpl 49 | DR 50 | Dr 51 | Drs 52 | Ens 53 | Gen 54 | Gov 55 | Hon 56 | Hr 57 | Hosp 58 | Insp 59 | Lt 60 | MM 61 | MR 62 | MRS 63 | MS 64 | Maj 65 | Messrs 66 | Mlle 67 | Mme 68 | Mr 69 | Mrs 70 | Ms 71 | Msgr 72 | Op 73 | Ord 74 | Pfc 75 | Ph 76 | Prof 77 | Pvt 78 | Rep 79 | Reps 80 | Res 81 | Rev 82 | Rt 83 | Sen 84 | Sens 85 | Sfc 86 | Sgt 87 | Sr 88 | St 89 | Supt 90 | Surg 91 | 92 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 93 | v 94 | vs 95 | i.e 96 | rev 97 | e.g 98 | 99 | #Numbers only. These should only induce breaks when followed by a numeric sequence 100 | # add NUMERIC_ONLY after the word for this function 101 | #This case is mostly for the english "No." which can either be a sentence of its own, or 102 | #if followed by a number, a non-breaking prefix 103 | No #NUMERIC_ONLY# 104 | Nos 105 | Art #NUMERIC_ONLY# 106 | Nr 107 | pp #NUMERIC_ONLY# 108 | 109 | #month abbreviations 110 | Jan 111 | Feb 112 | Mar 113 | Apr 114 | #May is a full word 115 | Jun 116 | Jul 117 | Aug 118 | Sep 119 | Oct 120 | Nov 121 | Dec 122 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.es: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | 33 | # Period-final abbreviation list from http://www.ctspanish.com/words/abbreviations.htm 34 | 35 | A.C 36 | Apdo 37 | Av 38 | Bco 39 | CC.AA 40 | Da 41 | Dep 42 | Dn 43 | Dr 44 | Dra 45 | EE.UU 46 | Excmo 47 | FF.CC 48 | Fil 49 | Gral 50 | J.C 51 | Let 52 | Lic 53 | N.B 54 | P.D 55 | P.V.P 56 | Prof 57 | Pts 58 | Rte 59 | S.A 60 | S.A.R 61 | S.E 62 | S.L 63 | S.R.C 64 | Sr 65 | Sra 66 | Srta 67 | Sta 68 | Sto 69 | T.V.E 70 | Tel 71 | Ud 72 | Uds 73 | V.B 74 | V.E 75 | Vd 76 | Vds 77 | a/c 78 | adj 79 | admón 80 | afmo 81 | apdo 82 | av 83 | c 84 | c.f 85 | c.g 86 | cap 87 | cm 88 | cta 89 | dcha 90 | doc 91 | ej 92 | entlo 93 | esq 94 | etc 95 | f.c 96 | gr 97 | grs 98 | izq 99 | kg 100 | km 101 | mg 102 | mm 103 | núm 104 | núm 105 | p 106 | p.a 107 | p.ej 108 | ptas 109 | pág 110 | págs 111 | pág 112 | págs 113 | q.e.g.e 114 | q.e.s.m 115 | s 116 | s.s.s 117 | vid 118 | vol 119 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.fi: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT 2 | #indicate an end-of-sentence marker. Special cases are included for prefixes 3 | #that ONLY appear before 0-9 numbers. 4 | 5 | #This list is compiled from omorfi database 6 | #by Tommi A Pirinen. 7 | 8 | 9 | #any single upper case letter followed by a period is not a sentence ender 10 | A 11 | B 12 | C 13 | D 14 | E 15 | F 16 | G 17 | H 18 | I 19 | J 20 | K 21 | L 22 | M 23 | N 24 | O 25 | P 26 | Q 27 | R 28 | S 29 | T 30 | U 31 | V 32 | W 33 | X 34 | Y 35 | Z 36 | Å 37 | Ä 38 | Ö 39 | 40 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 41 | alik 42 | alil 43 | amir 44 | apul 45 | apul.prof 46 | arkkit 47 | ass 48 | assist 49 | dipl 50 | dipl.arkkit 51 | dipl.ekon 52 | dipl.ins 53 | dipl.kielenk 54 | dipl.kirjeenv 55 | dipl.kosm 56 | dipl.urk 57 | dos 58 | erikoiseläinl 59 | erikoishammasl 60 | erikoisl 61 | erikoist 62 | ev.luutn 63 | evp 64 | fil 65 | ft 66 | hallinton 67 | hallintot 68 | hammaslääket 69 | jatk 70 | jääk 71 | kansaned 72 | kapt 73 | kapt.luutn 74 | kenr 75 | kenr.luutn 76 | kenr.maj 77 | kers 78 | kirjeenv 79 | kom 80 | kom.kapt 81 | komm 82 | konst 83 | korpr 84 | luutn 85 | maist 86 | maj 87 | Mr 88 | Mrs 89 | Ms 90 | M.Sc 91 | neuv 92 | nimim 93 | Ph.D 94 | prof 95 | puh.joht 96 | pääll 97 | res 98 | san 99 | siht 100 | suom 101 | sähköp 102 | säv 103 | toht 104 | toim 105 | toim.apul 106 | toim.joht 107 | toim.siht 108 | tuom 109 | ups 110 | vänr 111 | vääp 112 | ye.ups 113 | ylik 114 | ylil 115 | ylim 116 | ylimatr 117 | yliop 118 | yliopp 119 | ylip 120 | yliv 121 | 122 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall 123 | #into this category - it sometimes ends a sentence) 124 | e.g 125 | ent 126 | esim 127 | huom 128 | i.e 129 | ilm 130 | l 131 | mm 132 | myöh 133 | nk 134 | nyk 135 | par 136 | po 137 | t 138 | v 139 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.fr: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | # 4 | #any single upper case letter followed by a period is not a sentence ender 5 | #usually upper case letters are initials in a name 6 | #no French words end in single lower-case letters, so we throw those in too? 7 | A 8 | B 9 | C 10 | D 11 | E 12 | F 13 | G 14 | H 15 | I 16 | J 17 | K 18 | L 19 | M 20 | N 21 | O 22 | P 23 | Q 24 | R 25 | S 26 | T 27 | U 28 | V 29 | W 30 | X 31 | Y 32 | Z 33 | #a 34 | b 35 | c 36 | d 37 | e 38 | f 39 | g 40 | h 41 | i 42 | j 43 | k 44 | l 45 | m 46 | n 47 | o 48 | p 49 | q 50 | r 51 | s 52 | t 53 | u 54 | v 55 | w 56 | x 57 | y 58 | z 59 | 60 | # Period-final abbreviation list for French 61 | A.C.N 62 | A.M 63 | art 64 | ann 65 | apr 66 | av 67 | auj 68 | lib 69 | B.P 70 | boul 71 | ca 72 | c.-à-d 73 | cf 74 | ch.-l 75 | chap 76 | contr 77 | C.P.I 78 | C.Q.F.D 79 | C.N 80 | C.N.S 81 | C.S 82 | dir 83 | éd 84 | e.g 85 | env 86 | al 87 | etc 88 | E.V 89 | ex 90 | fasc 91 | fém 92 | fig 93 | fr 94 | hab 95 | ibid 96 | id 97 | i.e 98 | inf 99 | LL.AA 100 | LL.AA.II 101 | LL.AA.RR 102 | LL.AA.SS 103 | L.D 104 | LL.EE 105 | LL.MM 106 | LL.MM.II.RR 107 | loc.cit 108 | masc 109 | MM 110 | ms 111 | N.B 112 | N.D.A 113 | N.D.L.R 114 | N.D.T 115 | n/réf 116 | NN.SS 117 | N.S 118 | N.D 119 | N.P.A.I 120 | p.c.c 121 | pl 122 | pp 123 | p.ex 124 | p.j 125 | P.S 126 | R.A.S 127 | R.-V 128 | R.P 129 | R.I.P 130 | SS 131 | S.S 132 | S.A 133 | S.A.I 134 | S.A.R 135 | S.A.S 136 | S.E 137 | sec 138 | sect 139 | sing 140 | S.M 141 | S.M.I.R 142 | sq 143 | sqq 144 | suiv 145 | sup 146 | suppl 147 | tél 148 | T.S.V.P 149 | vb 150 | vol 151 | vs 152 | X.O 153 | Z.I 154 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.ga: -------------------------------------------------------------------------------- 1 | 2 | A 3 | B 4 | C 5 | D 6 | E 7 | F 8 | G 9 | H 10 | I 11 | J 12 | K 13 | L 14 | M 15 | N 16 | O 17 | P 18 | Q 19 | R 20 | S 21 | T 22 | U 23 | V 24 | W 25 | X 26 | Y 27 | Z 28 | Á 29 | É 30 | Í 31 | Ó 32 | Ú 33 | 34 | Uacht 35 | Dr 36 | B.Arch 37 | 38 | m.sh 39 | .i 40 | Co 41 | Cf 42 | cf 43 | i.e 44 | r 45 | Chr 46 | lch #NUMERIC_ONLY# 47 | lgh #NUMERIC_ONLY# 48 | uimh #NUMERIC_ONLY# 49 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.hu: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | Á 33 | É 34 | Í 35 | Ó 36 | Ö 37 | Ő 38 | Ú 39 | Ü 40 | Ű 41 | 42 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 43 | Dr 44 | dr 45 | kb 46 | Kb 47 | vö 48 | Vö 49 | pl 50 | Pl 51 | ca 52 | Ca 53 | min 54 | Min 55 | max 56 | Max 57 | ún 58 | Ún 59 | prof 60 | Prof 61 | de 62 | De 63 | du 64 | Du 65 | Szt 66 | St 67 | 68 | #Numbers only. These should only induce breaks when followed by a numeric sequence 69 | # add NUMERIC_ONLY after the word for this function 70 | #This case is mostly for the english "No." which can either be a sentence of its own, or 71 | #if followed by a number, a non-breaking prefix 72 | 73 | # Month name abbreviations 74 | jan #NUMERIC_ONLY# 75 | Jan #NUMERIC_ONLY# 76 | Feb #NUMERIC_ONLY# 77 | feb #NUMERIC_ONLY# 78 | márc #NUMERIC_ONLY# 79 | Márc #NUMERIC_ONLY# 80 | ápr #NUMERIC_ONLY# 81 | Ápr #NUMERIC_ONLY# 82 | máj #NUMERIC_ONLY# 83 | Máj #NUMERIC_ONLY# 84 | jún #NUMERIC_ONLY# 85 | Jún #NUMERIC_ONLY# 86 | Júl #NUMERIC_ONLY# 87 | júl #NUMERIC_ONLY# 88 | aug #NUMERIC_ONLY# 89 | Aug #NUMERIC_ONLY# 90 | Szept #NUMERIC_ONLY# 91 | szept #NUMERIC_ONLY# 92 | okt #NUMERIC_ONLY# 93 | Okt #NUMERIC_ONLY# 94 | nov #NUMERIC_ONLY# 95 | Nov #NUMERIC_ONLY# 96 | dec #NUMERIC_ONLY# 97 | Dec #NUMERIC_ONLY# 98 | 99 | # Other abbreviations 100 | tel #NUMERIC_ONLY# 101 | Tel #NUMERIC_ONLY# 102 | Fax #NUMERIC_ONLY# 103 | fax #NUMERIC_ONLY# 104 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.is: -------------------------------------------------------------------------------- 1 | no #NUMERIC_ONLY# 2 | No #NUMERIC_ONLY# 3 | nr #NUMERIC_ONLY# 4 | Nr #NUMERIC_ONLY# 5 | nR #NUMERIC_ONLY# 6 | NR #NUMERIC_ONLY# 7 | a 8 | b 9 | c 10 | d 11 | e 12 | f 13 | g 14 | h 15 | i 16 | j 17 | k 18 | l 19 | m 20 | n 21 | o 22 | p 23 | q 24 | r 25 | s 26 | t 27 | u 28 | v 29 | w 30 | x 31 | y 32 | z 33 | ^ 34 | í 35 | á 36 | ó 37 | æ 38 | A 39 | B 40 | C 41 | D 42 | E 43 | F 44 | G 45 | H 46 | I 47 | J 48 | K 49 | L 50 | M 51 | N 52 | O 53 | P 54 | Q 55 | R 56 | S 57 | T 58 | U 59 | V 60 | W 61 | X 62 | Y 63 | Z 64 | ab.fn 65 | a.fn 66 | afs 67 | al 68 | alm 69 | alg 70 | andh 71 | ath 72 | aths 73 | atr 74 | ao 75 | au 76 | aukaf 77 | áfn 78 | áhrl.s 79 | áhrs 80 | ákv.gr 81 | ákv 82 | bh 83 | bls 84 | dr 85 | e.Kr 86 | et 87 | ef 88 | efn 89 | ennfr 90 | eink 91 | end 92 | e.st 93 | erl 94 | fél 95 | fskj 96 | fh 97 | f.hl 98 | físl 99 | fl 100 | fn 101 | fo 102 | forl 103 | frb 104 | frl 105 | frh 106 | frt 107 | fsl 108 | fsh 109 | fs 110 | fsk 111 | fst 112 | f.Kr 113 | ft 114 | fv 115 | fyrrn 116 | fyrrv 117 | germ 118 | gm 119 | gr 120 | hdl 121 | hdr 122 | hf 123 | hl 124 | hlsk 125 | hljsk 126 | hljv 127 | hljóðv 128 | hr 129 | hv 130 | hvk 131 | holl 132 | Hos 133 | höf 134 | hk 135 | hrl 136 | ísl 137 | kaf 138 | kap 139 | Khöfn 140 | kk 141 | kg 142 | kk 143 | km 144 | kl 145 | klst 146 | kr 147 | kt 148 | kgúrsk 149 | kvk 150 | leturbr 151 | lh 152 | lh.nt 153 | lh.þt 154 | lo 155 | ltr 156 | mlja 157 | mljó 158 | millj 159 | mm 160 | mms 161 | m.fl 162 | miðm 163 | mgr 164 | mst 165 | mín 166 | nf 167 | nh 168 | nhm 169 | nl 170 | nk 171 | nmgr 172 | no 173 | núv 174 | nt 175 | o.áfr 176 | o.m.fl 177 | ohf 178 | o.fl 179 | o.s.frv 180 | ófn 181 | ób 182 | óákv.gr 183 | óákv 184 | pfn 185 | PR 186 | pr 187 | Ritstj 188 | Rvík 189 | Rvk 190 | samb 191 | samhlj 192 | samn 193 | samn 194 | sbr 195 | sek 196 | sérn 197 | sf 198 | sfn 199 | sh 200 | sfn 201 | sh 202 | s.hl 203 | sk 204 | skv 205 | sl 206 | sn 207 | so 208 | ss.us 209 | s.st 210 | samþ 211 | sbr 212 | shlj 213 | sign 214 | skál 215 | st 216 | st.s 217 | stk 218 | sþ 219 | teg 220 | tbl 221 | tfn 222 | tl 223 | tvíhlj 224 | tvt 225 | till 226 | to 227 | umr 228 | uh 229 | us 230 | uppl 231 | útg 232 | vb 233 | Vf 234 | vh 235 | vkf 236 | Vl 237 | vl 238 | vlf 239 | vmf 240 | 8vo 241 | vsk 242 | vth 243 | þt 244 | þf 245 | þjs 246 | þgf 247 | þlt 248 | þolm 249 | þm 250 | þml 251 | þýð 252 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.it: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | B 8 | C 9 | D 10 | E 11 | F 12 | G 13 | H 14 | I 15 | J 16 | K 17 | L 18 | M 19 | N 20 | O 21 | P 22 | Q 23 | R 24 | S 25 | T 26 | U 27 | V 28 | W 29 | X 30 | Y 31 | Z 32 | 33 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 34 | Adj 35 | Adm 36 | Adv 37 | Amn 38 | Arch 39 | Asst 40 | Avv 41 | Bart 42 | Bcc 43 | Bldg 44 | Brig 45 | Bros 46 | C.A.P 47 | C.P 48 | Capt 49 | Cc 50 | Cmdr 51 | Co 52 | Col 53 | Comdr 54 | Con 55 | Corp 56 | Cpl 57 | DR 58 | Dott 59 | Dr 60 | Drs 61 | Egr 62 | Ens 63 | Gen 64 | Geom 65 | Gov 66 | Hon 67 | Hosp 68 | Hr 69 | Id 70 | Ing 71 | Insp 72 | Lt 73 | MM 74 | MR 75 | MRS 76 | MS 77 | Maj 78 | Messrs 79 | Mlle 80 | Mme 81 | Mo 82 | Mons 83 | Mr 84 | Mrs 85 | Ms 86 | Msgr 87 | N.B 88 | Op 89 | Ord 90 | P.S 91 | P.T 92 | Pfc 93 | Ph 94 | Prof 95 | Pvt 96 | RP 97 | RSVP 98 | Rag 99 | Rep 100 | Reps 101 | Res 102 | Rev 103 | Rif 104 | Rt 105 | S.A 106 | S.B.F 107 | S.P.M 108 | S.p.A 109 | S.r.l 110 | Sen 111 | Sens 112 | Sfc 113 | Sgt 114 | Sig 115 | Sigg 116 | Soc 117 | Spett 118 | Sr 119 | St 120 | Supt 121 | Surg 122 | V.P 123 | 124 | # other 125 | a.c 126 | acc 127 | all 128 | banc 129 | c.a 130 | c.c.p 131 | c.m 132 | c.p 133 | c.s 134 | c.v 135 | corr 136 | dott 137 | e.p.c 138 | ecc 139 | es 140 | fatt 141 | gg 142 | int 143 | lett 144 | ogg 145 | on 146 | p.c 147 | p.c.c 148 | p.es 149 | p.f 150 | p.r 151 | p.v 152 | post 153 | pp 154 | racc 155 | ric 156 | s.n.c 157 | seg 158 | sgg 159 | ss 160 | tel 161 | u.s 162 | v.r 163 | v.s 164 | 165 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 166 | v 167 | vs 168 | i.e 169 | rev 170 | e.g 171 | 172 | #Numbers only. These should only induce breaks when followed by a numeric sequence 173 | # add NUMERIC_ONLY after the word for this function 174 | #This case is mostly for the english "No." which can either be a sentence of its own, or 175 | #if followed by a number, a non-breaking prefix 176 | No #NUMERIC_ONLY# 177 | Nos 178 | Art #NUMERIC_ONLY# 179 | Nr 180 | pp #NUMERIC_ONLY# 181 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.lv: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | A 7 | Ā 8 | B 9 | C 10 | Č 11 | D 12 | E 13 | Ē 14 | F 15 | G 16 | Ģ 17 | H 18 | I 19 | Ī 20 | J 21 | K 22 | Ķ 23 | L 24 | Ļ 25 | M 26 | N 27 | Ņ 28 | O 29 | P 30 | Q 31 | R 32 | S 33 | Š 34 | T 35 | U 36 | Ū 37 | V 38 | W 39 | X 40 | Y 41 | Z 42 | Ž 43 | 44 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 45 | dr 46 | Dr 47 | med 48 | prof 49 | Prof 50 | inž 51 | Inž 52 | ist.loc 53 | Ist.loc 54 | kor.loc 55 | Kor.loc 56 | v.i 57 | vietn 58 | Vietn 59 | 60 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 61 | a.l 62 | t.p 63 | pārb 64 | Pārb 65 | vec 66 | Vec 67 | inv 68 | Inv 69 | sk 70 | Sk 71 | spec 72 | Spec 73 | vienk 74 | Vienk 75 | virz 76 | Virz 77 | māksl 78 | Māksl 79 | mūz 80 | Mūz 81 | akad 82 | Akad 83 | soc 84 | Soc 85 | galv 86 | Galv 87 | vad 88 | Vad 89 | sertif 90 | Sertif 91 | folkl 92 | Folkl 93 | hum 94 | Hum 95 | 96 | #Numbers only. These should only induce breaks when followed by a numeric sequence 97 | # add NUMERIC_ONLY after the word for this function 98 | #This case is mostly for the english "No." which can either be a sentence of its own, or 99 | #if followed by a number, a non-breaking prefix 100 | Nr #NUMERIC_ONLY# 101 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.nl: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | #Sources: http://nl.wikipedia.org/wiki/Lijst_van_afkortingen 4 | # http://nl.wikipedia.org/wiki/Aanspreekvorm 5 | # http://nl.wikipedia.org/wiki/Titulatuur_in_het_Nederlands_hoger_onderwijs 6 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 7 | #usually upper case letters are initials in a name 8 | A 9 | B 10 | C 11 | D 12 | E 13 | F 14 | G 15 | H 16 | I 17 | J 18 | K 19 | L 20 | M 21 | N 22 | O 23 | P 24 | Q 25 | R 26 | S 27 | T 28 | U 29 | V 30 | W 31 | X 32 | Y 33 | Z 34 | 35 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 36 | bacc 37 | bc 38 | bgen 39 | c.i 40 | dhr 41 | dr 42 | dr.h.c 43 | drs 44 | drs 45 | ds 46 | eint 47 | fa 48 | Fa 49 | fam 50 | gen 51 | genm 52 | ing 53 | ir 54 | jhr 55 | jkvr 56 | jr 57 | kand 58 | kol 59 | lgen 60 | lkol 61 | Lt 62 | maj 63 | Mej 64 | mevr 65 | Mme 66 | mr 67 | mr 68 | Mw 69 | o.b.s 70 | plv 71 | prof 72 | ritm 73 | tint 74 | Vz 75 | Z.D 76 | Z.D.H 77 | Z.E 78 | Z.Em 79 | Z.H 80 | Z.K.H 81 | Z.K.M 82 | Z.M 83 | z.v 84 | 85 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 86 | #we seem to have a lot of these in dutch i.e.: i.p.v - in plaats van (in stead of) never ends a sentence 87 | a.g.v 88 | bijv 89 | bijz 90 | bv 91 | d.w.z 92 | e.c 93 | e.g 94 | e.k 95 | ev 96 | i.p.v 97 | i.s.m 98 | i.t.t 99 | i.v.m 100 | m.a.w 101 | m.b.t 102 | m.b.v 103 | m.h.o 104 | m.i 105 | m.i.v 106 | v.w.t 107 | 108 | #Numbers only. These should only induce breaks when followed by a numeric sequence 109 | # add NUMERIC_ONLY after the word for this function 110 | #This case is mostly for the english "No." which can either be a sentence of its own, or 111 | #if followed by a number, a non-breaking prefix 112 | Nr #NUMERIC_ONLY# 113 | Nrs 114 | nrs 115 | nr #NUMERIC_ONLY# 116 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.pl: -------------------------------------------------------------------------------- 1 | adw 2 | afr 3 | akad 4 | al 5 | Al 6 | am 7 | amer 8 | arch 9 | art 10 | Art 11 | artyst 12 | astr 13 | austr 14 | bałt 15 | bdb 16 | bł 17 | bm 18 | br 19 | bryg 20 | bryt 21 | centr 22 | ces 23 | chem 24 | chiń 25 | chir 26 | c.k 27 | c.o 28 | cyg 29 | cyw 30 | cyt 31 | czes 32 | czw 33 | cd 34 | Cd 35 | czyt 36 | ćw 37 | ćwicz 38 | daw 39 | dcn 40 | dekl 41 | demokr 42 | det 43 | diec 44 | dł 45 | dn 46 | dot 47 | dol 48 | dop 49 | dost 50 | dosł 51 | h.c 52 | ds 53 | dst 54 | duszp 55 | dypl 56 | egz 57 | ekol 58 | ekon 59 | elektr 60 | em 61 | ew 62 | fab 63 | farm 64 | fot 65 | fr 66 | gat 67 | gastr 68 | geogr 69 | geol 70 | gimn 71 | głęb 72 | gm 73 | godz 74 | górn 75 | gosp 76 | gr 77 | gram 78 | hist 79 | hiszp 80 | hr 81 | Hr 82 | hot 83 | id 84 | in 85 | im 86 | iron 87 | jn 88 | kard 89 | kat 90 | katol 91 | k.k 92 | kk 93 | kol 94 | kl 95 | k.p.a 96 | kpc 97 | k.p.c 98 | kpt 99 | kr 100 | k.r 101 | krak 102 | k.r.o 103 | kryt 104 | kult 105 | laic 106 | łac 107 | niem 108 | woj 109 | nb 110 | np 111 | Nb 112 | Np 113 | pol 114 | pow 115 | m.in 116 | pt 117 | ps 118 | Pt 119 | Ps 120 | cdn 121 | jw 122 | ryc 123 | rys 124 | Ryc 125 | Rys 126 | tj 127 | tzw 128 | Tzw 129 | tzn 130 | zob 131 | ang 132 | ub 133 | ul 134 | pw 135 | pn 136 | pl 137 | al 138 | k 139 | n 140 | nr #NUMERIC_ONLY# 141 | Nr #NUMERIC_ONLY# 142 | ww 143 | wł 144 | ur 145 | zm 146 | żyd 147 | żarg 148 | żyw 149 | wył 150 | bp 151 | bp 152 | wyst 153 | tow 154 | Tow 155 | o 156 | sp 157 | Sp 158 | st 159 | spółdz 160 | Spółdz 161 | społ 162 | spółgł 163 | stoł 164 | stow 165 | Stoł 166 | Stow 167 | zn 168 | zew 169 | zewn 170 | zdr 171 | zazw 172 | zast 173 | zaw 174 | zał 175 | zal 176 | zam 177 | zak 178 | zakł 179 | zagr 180 | zach 181 | adw 182 | Adw 183 | lek 184 | Lek 185 | med 186 | mec 187 | Mec 188 | doc 189 | Doc 190 | dyw 191 | dyr 192 | Dyw 193 | Dyr 194 | inż 195 | Inż 196 | mgr 197 | Mgr 198 | dh 199 | dr 200 | Dh 201 | Dr 202 | p 203 | P 204 | red 205 | Red 206 | prof 207 | prok 208 | Prof 209 | Prok 210 | hab 211 | płk 212 | Płk 213 | nadkom 214 | Nadkom 215 | podkom 216 | Podkom 217 | ks 218 | Ks 219 | gen 220 | Gen 221 | por 222 | Por 223 | reż 224 | Reż 225 | przyp 226 | Przyp 227 | śp 228 | św 229 | śW 230 | Śp 231 | Św 232 | ŚW 233 | szer 234 | Szer 235 | pkt #NUMERIC_ONLY# 236 | str #NUMERIC_ONLY# 237 | tab #NUMERIC_ONLY# 238 | Tab #NUMERIC_ONLY# 239 | tel 240 | ust #NUMERIC_ONLY# 241 | par #NUMERIC_ONLY# 242 | poz 243 | pok 244 | oo 245 | oO 246 | Oo 247 | OO 248 | r #NUMERIC_ONLY# 249 | l #NUMERIC_ONLY# 250 | s #NUMERIC_ONLY# 251 | najśw 252 | Najśw 253 | A 254 | B 255 | C 256 | D 257 | E 258 | F 259 | G 260 | H 261 | I 262 | J 263 | K 264 | L 265 | M 266 | N 267 | O 268 | P 269 | Q 270 | R 271 | S 272 | T 273 | U 274 | V 275 | W 276 | X 277 | Y 278 | Z 279 | Ś 280 | Ć 281 | Ż 282 | Ź 283 | Dz 284 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.ro: -------------------------------------------------------------------------------- 1 | A 2 | B 3 | C 4 | D 5 | E 6 | F 7 | G 8 | H 9 | I 10 | J 11 | K 12 | L 13 | M 14 | N 15 | O 16 | P 17 | Q 18 | R 19 | S 20 | T 21 | U 22 | V 23 | W 24 | X 25 | Y 26 | Z 27 | dpdv 28 | etc 29 | șamd 30 | M.Ap.N 31 | dl 32 | Dl 33 | d-na 34 | D-na 35 | dvs 36 | Dvs 37 | pt 38 | Pt 39 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.ru: -------------------------------------------------------------------------------- 1 | # added Cyrillic uppercase letters [А-Я] 2 | # removed 000D carriage return (this is not removed by chomp in tokenizer.perl, and prevents recognition of the prefixes) 3 | # edited by Kate Young (nspaceanalysis@earthlink.net) 21 May 2013 4 | А 5 | Б 6 | В 7 | Г 8 | Д 9 | Е 10 | Ж 11 | З 12 | И 13 | Й 14 | К 15 | Л 16 | М 17 | Н 18 | О 19 | П 20 | Р 21 | С 22 | Т 23 | У 24 | Ф 25 | Х 26 | Ц 27 | Ч 28 | Ш 29 | Щ 30 | Ъ 31 | Ы 32 | Ь 33 | Э 34 | Ю 35 | Я 36 | A 37 | B 38 | C 39 | D 40 | E 41 | F 42 | G 43 | H 44 | I 45 | J 46 | K 47 | L 48 | M 49 | N 50 | O 51 | P 52 | Q 53 | R 54 | S 55 | T 56 | U 57 | V 58 | W 59 | X 60 | Y 61 | Z 62 | 0гг 63 | 1гг 64 | 2гг 65 | 3гг 66 | 4гг 67 | 5гг 68 | 6гг 69 | 7гг 70 | 8гг 71 | 9гг 72 | 0г 73 | 1г 74 | 2г 75 | 3г 76 | 4г 77 | 5г 78 | 6г 79 | 7г 80 | 8г 81 | 9г 82 | Xвв 83 | Vвв 84 | Iвв 85 | Lвв 86 | Mвв 87 | Cвв 88 | Xв 89 | Vв 90 | Iв 91 | Lв 92 | Mв 93 | Cв 94 | 0м 95 | 1м 96 | 2м 97 | 3м 98 | 4м 99 | 5м 100 | 6м 101 | 7м 102 | 8м 103 | 9м 104 | 0мм 105 | 1мм 106 | 2мм 107 | 3мм 108 | 4мм 109 | 5мм 110 | 6мм 111 | 7мм 112 | 8мм 113 | 9мм 114 | 0см 115 | 1см 116 | 2см 117 | 3см 118 | 4см 119 | 5см 120 | 6см 121 | 7см 122 | 8см 123 | 9см 124 | 0дм 125 | 1дм 126 | 2дм 127 | 3дм 128 | 4дм 129 | 5дм 130 | 6дм 131 | 7дм 132 | 8дм 133 | 9дм 134 | 0л 135 | 1л 136 | 2л 137 | 3л 138 | 4л 139 | 5л 140 | 6л 141 | 7л 142 | 8л 143 | 9л 144 | 0км 145 | 1км 146 | 2км 147 | 3км 148 | 4км 149 | 5км 150 | 6км 151 | 7км 152 | 8км 153 | 9км 154 | 0га 155 | 1га 156 | 2га 157 | 3га 158 | 4га 159 | 5га 160 | 6га 161 | 7га 162 | 8га 163 | 9га 164 | 0кг 165 | 1кг 166 | 2кг 167 | 3кг 168 | 4кг 169 | 5кг 170 | 6кг 171 | 7кг 172 | 8кг 173 | 9кг 174 | 0т 175 | 1т 176 | 2т 177 | 3т 178 | 4т 179 | 5т 180 | 6т 181 | 7т 182 | 8т 183 | 9т 184 | 0г 185 | 1г 186 | 2г 187 | 3г 188 | 4г 189 | 5г 190 | 6г 191 | 7г 192 | 8г 193 | 9г 194 | 0мг 195 | 1мг 196 | 2мг 197 | 3мг 198 | 4мг 199 | 5мг 200 | 6мг 201 | 7мг 202 | 8мг 203 | 9мг 204 | бульв 205 | в 206 | вв 207 | г 208 | га 209 | гг 210 | гл 211 | гос 212 | д 213 | дм 214 | доп 215 | др 216 | е 217 | ед 218 | ед 219 | зам 220 | и 221 | инд 222 | исп 223 | Исп 224 | к 225 | кап 226 | кг 227 | кв 228 | кл 229 | км 230 | кол 231 | комн 232 | коп 233 | куб 234 | л 235 | лиц 236 | лл 237 | м 238 | макс 239 | мг 240 | мин 241 | мл 242 | млн 243 | млрд 244 | мм 245 | н 246 | наб 247 | нач 248 | неуд 249 | ном 250 | о 251 | обл 252 | обр 253 | общ 254 | ок 255 | ост 256 | отл 257 | п 258 | пер 259 | перераб 260 | пл 261 | пос 262 | пр 263 | просп 264 | проф 265 | р 266 | ред 267 | руб 268 | с 269 | сб 270 | св 271 | см 272 | соч 273 | ср 274 | ст 275 | стр 276 | т 277 | тел 278 | Тел 279 | тех 280 | тт 281 | туп 282 | тыс 283 | уд 284 | ул 285 | уч 286 | физ 287 | х 288 | хор 289 | ч 290 | чел 291 | шт 292 | экз 293 | э 294 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.sk: -------------------------------------------------------------------------------- 1 | Bc 2 | Mgr 3 | RNDr 4 | PharmDr 5 | PhDr 6 | JUDr 7 | PaedDr 8 | ThDr 9 | Ing 10 | MUDr 11 | MDDr 12 | MVDr 13 | Dr 14 | ThLic 15 | PhD 16 | ArtD 17 | ThDr 18 | Dr 19 | DrSc 20 | CSs 21 | prof 22 | obr 23 | Obr 24 | Č 25 | č 26 | absol 27 | adj 28 | admin 29 | adr 30 | Adr 31 | adv 32 | advok 33 | afr 34 | ak 35 | akad 36 | akc 37 | akuz 38 | et 39 | al 40 | alch 41 | amer 42 | anat 43 | angl 44 | Angl 45 | anglosas 46 | anorg 47 | ap 48 | apod 49 | arch 50 | archeol 51 | archit 52 | arg 53 | art 54 | astr 55 | astrol 56 | astron 57 | atp 58 | atď 59 | austr 60 | Austr 61 | aut 62 | belg 63 | Belg 64 | bibl 65 | Bibl 66 | biol 67 | bot 68 | bud 69 | bás 70 | býv 71 | cest 72 | chem 73 | cirk 74 | csl 75 | čs 76 | Čs 77 | dat 78 | dep 79 | det 80 | dial 81 | diaľ 82 | dipl 83 | distrib 84 | dokl 85 | dosl 86 | dopr 87 | dram 88 | duš 89 | dv 90 | dvojčl 91 | dór 92 | ekol 93 | ekon 94 | el 95 | elektr 96 | elektrotech 97 | energet 98 | epic 99 | est 100 | etc 101 | etonym 102 | eufem 103 | európ 104 | Európ 105 | ev 106 | evid 107 | expr 108 | fa 109 | fam 110 | farm 111 | fem 112 | feud 113 | fil 114 | filat 115 | filoz 116 | fi 117 | fon 118 | form 119 | fot 120 | fr 121 | Fr 122 | franc 123 | Franc 124 | fraz 125 | fut 126 | fyz 127 | fyziol 128 | garb 129 | gen 130 | genet 131 | genpor 132 | geod 133 | geogr 134 | geol 135 | geom 136 | germ 137 | gr 138 | Gr 139 | gréc 140 | Gréc 141 | gréckokat 142 | hebr 143 | herald 144 | hist 145 | hlav 146 | hosp 147 | hromad 148 | hud 149 | hypok 150 | ident 151 | i.e 152 | ident 153 | imp 154 | impf 155 | indoeur 156 | inf 157 | inform 158 | instr 159 | int 160 | interj 161 | inšt 162 | inštr 163 | iron 164 | jap 165 | Jap 166 | jaz 167 | jedn 168 | juhoamer 169 | juhových 170 | juhozáp 171 | juž 172 | kanad 173 | Kanad 174 | kanc 175 | kapit 176 | kpt 177 | kart 178 | katastr 179 | knih 180 | kniž 181 | komp 182 | konj 183 | konkr 184 | kozmet 185 | krajč 186 | kresť 187 | kt 188 | kuch 189 | lat 190 | latinskoamer 191 | lek 192 | lex 193 | lingv 194 | lit 195 | litur 196 | log 197 | lok 198 | max 199 | Max 200 | maď 201 | Maď 202 | medzinár 203 | mest 204 | metr 205 | mil 206 | Mil 207 | min 208 | Min 209 | miner 210 | ml 211 | mld 212 | mn 213 | mod 214 | mytol 215 | napr 216 | nar 217 | Nar 218 | nasl 219 | nedok 220 | neg 221 | negat 222 | neklas 223 | nem 224 | Nem 225 | neodb 226 | neos 227 | neskl 228 | nesklon 229 | nespis 230 | nespráv 231 | neved 232 | než 233 | niekt 234 | niž 235 | nom 236 | náb 237 | nákl 238 | námor 239 | nár 240 | obch 241 | obj 242 | obv 243 | obyč 244 | obč 245 | občian 246 | odb 247 | odd 248 | ods 249 | ojed 250 | okr 251 | Okr 252 | opt 253 | opyt 254 | org 255 | os 256 | osob 257 | ot 258 | ovoc 259 | par 260 | part 261 | pejor 262 | pers 263 | pf 264 | Pf 265 | P.f 266 | p.f 267 | pl 268 | Plk 269 | pod 270 | podst 271 | pokl 272 | polit 273 | politol 274 | polygr 275 | pomn 276 | popl 277 | por 278 | porad 279 | porov 280 | posch 281 | potrav 282 | použ 283 | poz 284 | pozit 285 | poľ 286 | poľno 287 | poľnohosp 288 | poľov 289 | pošt 290 | pož 291 | prac 292 | predl 293 | pren 294 | prep 295 | preuk 296 | priezv 297 | Priezv 298 | privl 299 | prof 300 | práv 301 | príd 302 | príj 303 | prík 304 | príp 305 | prír 306 | prísl 307 | príslov 308 | príč 309 | psych 310 | publ 311 | pís 312 | písm 313 | pôv 314 | refl 315 | reg 316 | rep 317 | resp 318 | rozk 319 | rozlič 320 | rozpráv 321 | roč 322 | Roč 323 | ryb 324 | rádiotech 325 | rím 326 | samohl 327 | semest 328 | sev 329 | severoamer 330 | severových 331 | severozáp 332 | sg 333 | skr 334 | skup 335 | sl 336 | Sloven 337 | soc 338 | soch 339 | sociol 340 | sp 341 | spol 342 | Spol 343 | spoloč 344 | spoluhl 345 | správ 346 | spôs 347 | st 348 | star 349 | starogréc 350 | starorím 351 | s.r.o 352 | stol 353 | stor 354 | str 355 | stredoamer 356 | stredoškol 357 | subj 358 | subst 359 | superl 360 | sv 361 | sz 362 | súkr 363 | súp 364 | súvzť 365 | tal 366 | Tal 367 | tech 368 | tel 369 | Tel 370 | telef 371 | teles 372 | telev 373 | teol 374 | trans 375 | turist 376 | tuzem 377 | typogr 378 | tzn 379 | tzv 380 | ukaz 381 | ul 382 | Ul 383 | umel 384 | univ 385 | ust 386 | ved 387 | vedľ 388 | verb 389 | veter 390 | vin 391 | viď 392 | vl 393 | vod 394 | vodohosp 395 | pnl 396 | vulg 397 | vyj 398 | vys 399 | vysokoškol 400 | vzťaž 401 | vôb 402 | vých 403 | výd 404 | výrob 405 | výsk 406 | výsl 407 | výtv 408 | výtvar 409 | význ 410 | včel 411 | vš 412 | všeob 413 | zahr 414 | zar 415 | zariad 416 | zast 417 | zastar 418 | zastaráv 419 | zb 420 | zdravot 421 | združ 422 | zjemn 423 | zlat 424 | zn 425 | Zn 426 | zool 427 | zr 428 | zried 429 | zv 430 | záhr 431 | zák 432 | zákl 433 | zám 434 | záp 435 | západoeur 436 | zázn 437 | územ 438 | účt 439 | čast 440 | čes 441 | Čes 442 | čl 443 | čísl 444 | živ 445 | pr 446 | fak 447 | Kr 448 | p.n.l 449 | A 450 | B 451 | C 452 | D 453 | E 454 | F 455 | G 456 | H 457 | I 458 | J 459 | K 460 | L 461 | M 462 | N 463 | O 464 | P 465 | Q 466 | R 467 | S 468 | T 469 | U 470 | V 471 | W 472 | X 473 | Y 474 | Z 475 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.sl: -------------------------------------------------------------------------------- 1 | dr 2 | Dr 3 | itd 4 | itn 5 | št #NUMERIC_ONLY# 6 | Št #NUMERIC_ONLY# 7 | d 8 | jan 9 | Jan 10 | feb 11 | Feb 12 | mar 13 | Mar 14 | apr 15 | Apr 16 | jun 17 | Jun 18 | jul 19 | Jul 20 | avg 21 | Avg 22 | sept 23 | Sept 24 | sep 25 | Sep 26 | okt 27 | Okt 28 | nov 29 | Nov 30 | dec 31 | Dec 32 | tj 33 | Tj 34 | npr 35 | Npr 36 | sl 37 | Sl 38 | op 39 | Op 40 | gl 41 | Gl 42 | oz 43 | Oz 44 | prev 45 | dipl 46 | ing 47 | prim 48 | Prim 49 | cf 50 | Cf 51 | gl 52 | Gl 53 | A 54 | B 55 | C 56 | D 57 | E 58 | F 59 | G 60 | H 61 | I 62 | J 63 | K 64 | L 65 | M 66 | N 67 | O 68 | P 69 | Q 70 | R 71 | S 72 | T 73 | U 74 | V 75 | W 76 | X 77 | Y 78 | Z 79 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.sv: -------------------------------------------------------------------------------- 1 | #single upper case letter are usually initials 2 | A 3 | B 4 | C 5 | D 6 | E 7 | F 8 | G 9 | H 10 | I 11 | J 12 | K 13 | L 14 | M 15 | N 16 | O 17 | P 18 | Q 19 | R 20 | S 21 | T 22 | U 23 | V 24 | W 25 | X 26 | Y 27 | Z 28 | #misc abbreviations 29 | AB 30 | G 31 | VG 32 | dvs 33 | etc 34 | from 35 | iaf 36 | jfr 37 | kl 38 | kr 39 | mao 40 | mfl 41 | mm 42 | osv 43 | pga 44 | tex 45 | tom 46 | vs 47 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.ta: -------------------------------------------------------------------------------- 1 | #Anything in this file, followed by a period (and an upper-case word), does NOT indicate an end-of-sentence marker. 2 | #Special cases are included for prefixes that ONLY appear before 0-9 numbers. 3 | 4 | #any single upper case letter followed by a period is not a sentence ender (excluding I occasionally, but we leave it in) 5 | #usually upper case letters are initials in a name 6 | அ 7 | ஆ 8 | இ 9 | ஈ 10 | உ 11 | ஊ 12 | எ 13 | ஏ 14 | ஐ 15 | ஒ 16 | ஓ 17 | ஔ 18 | ஃ 19 | க 20 | கா 21 | கி 22 | கீ 23 | கு 24 | கூ 25 | கெ 26 | கே 27 | கை 28 | கொ 29 | கோ 30 | கௌ 31 | க் 32 | ச 33 | சா 34 | சி 35 | சீ 36 | சு 37 | சூ 38 | செ 39 | சே 40 | சை 41 | சொ 42 | சோ 43 | சௌ 44 | ச் 45 | ட 46 | டா 47 | டி 48 | டீ 49 | டு 50 | டூ 51 | டெ 52 | டே 53 | டை 54 | டொ 55 | டோ 56 | டௌ 57 | ட் 58 | த 59 | தா 60 | தி 61 | தீ 62 | து 63 | தூ 64 | தெ 65 | தே 66 | தை 67 | தொ 68 | தோ 69 | தௌ 70 | த் 71 | ப 72 | பா 73 | பி 74 | பீ 75 | பு 76 | பூ 77 | பெ 78 | பே 79 | பை 80 | பொ 81 | போ 82 | பௌ 83 | ப் 84 | ற 85 | றா 86 | றி 87 | றீ 88 | று 89 | றூ 90 | றெ 91 | றே 92 | றை 93 | றொ 94 | றோ 95 | றௌ 96 | ற் 97 | ய 98 | யா 99 | யி 100 | யீ 101 | யு 102 | யூ 103 | யெ 104 | யே 105 | யை 106 | யொ 107 | யோ 108 | யௌ 109 | ய் 110 | ர 111 | ரா 112 | ரி 113 | ரீ 114 | ரு 115 | ரூ 116 | ரெ 117 | ரே 118 | ரை 119 | ரொ 120 | ரோ 121 | ரௌ 122 | ர் 123 | ல 124 | லா 125 | லி 126 | லீ 127 | லு 128 | லூ 129 | லெ 130 | லே 131 | லை 132 | லொ 133 | லோ 134 | லௌ 135 | ல் 136 | வ 137 | வா 138 | வி 139 | வீ 140 | வு 141 | வூ 142 | வெ 143 | வே 144 | வை 145 | வொ 146 | வோ 147 | வௌ 148 | வ் 149 | ள 150 | ளா 151 | ளி 152 | ளீ 153 | ளு 154 | ளூ 155 | ளெ 156 | ளே 157 | ளை 158 | ளொ 159 | ளோ 160 | ளௌ 161 | ள் 162 | ழ 163 | ழா 164 | ழி 165 | ழீ 166 | ழு 167 | ழூ 168 | ழெ 169 | ழே 170 | ழை 171 | ழொ 172 | ழோ 173 | ழௌ 174 | ழ் 175 | ங 176 | ஙா 177 | ஙி 178 | ஙீ 179 | ஙு 180 | ஙூ 181 | ஙெ 182 | ஙே 183 | ஙை 184 | ஙொ 185 | ஙோ 186 | ஙௌ 187 | ங் 188 | ஞ 189 | ஞா 190 | ஞி 191 | ஞீ 192 | ஞு 193 | ஞூ 194 | ஞெ 195 | ஞே 196 | ஞை 197 | ஞொ 198 | ஞோ 199 | ஞௌ 200 | ஞ் 201 | ண 202 | ணா 203 | ணி 204 | ணீ 205 | ணு 206 | ணூ 207 | ணெ 208 | ணே 209 | ணை 210 | ணொ 211 | ணோ 212 | ணௌ 213 | ண் 214 | ந 215 | நா 216 | நி 217 | நீ 218 | நு 219 | நூ 220 | நெ 221 | நே 222 | நை 223 | நொ 224 | நோ 225 | நௌ 226 | ந் 227 | ம 228 | மா 229 | மி 230 | மீ 231 | மு 232 | மூ 233 | மெ 234 | மே 235 | மை 236 | மொ 237 | மோ 238 | மௌ 239 | ம் 240 | ன 241 | னா 242 | னி 243 | னீ 244 | னு 245 | னூ 246 | னெ 247 | னே 248 | னை 249 | னொ 250 | னோ 251 | னௌ 252 | ன் 253 | 254 | 255 | #List of titles. These are often followed by upper-case names, but do not indicate sentence breaks 256 | திரு 257 | திருமதி 258 | வண 259 | கௌரவ 260 | 261 | 262 | #misc - odd period-ending items that NEVER indicate breaks (p.m. does NOT fall into this category - it sometimes ends a sentence) 263 | உ.ம் 264 | #கா.ம் 265 | #எ.ம் 266 | 267 | 268 | #Numbers only. These should only induce breaks when followed by a numeric sequence 269 | # add NUMERIC_ONLY after the word for this function 270 | #This case is mostly for the english "No." which can either be a sentence of its own, or 271 | #if followed by a number, a non-breaking prefix 272 | No #NUMERIC_ONLY# 273 | Nos 274 | Art #NUMERIC_ONLY# 275 | Nr 276 | pp #NUMERIC_ONLY# 277 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.yue: -------------------------------------------------------------------------------- 1 | # 2 | # Cantonese (Chinese) 3 | # 4 | # Anything in this file, followed by a period, 5 | # does NOT indicate an end-of-sentence marker. 6 | # 7 | # English/Euro-language given-name initials (appearing in 8 | # news, periodicals, etc.) 9 | A 10 | Ā 11 | B 12 | C 13 | Č 14 | D 15 | E 16 | Ē 17 | F 18 | G 19 | Ģ 20 | H 21 | I 22 | Ī 23 | J 24 | K 25 | Ķ 26 | L 27 | Ļ 28 | M 29 | N 30 | Ņ 31 | O 32 | P 33 | Q 34 | R 35 | S 36 | Š 37 | T 38 | U 39 | Ū 40 | V 41 | W 42 | X 43 | Y 44 | Z 45 | Ž 46 | 47 | # Numbers only. These should only induce breaks when followed by 48 | # a numeric sequence. 49 | # Add NUMERIC_ONLY after the word for this function. This case is 50 | # mostly for the english "No." which can either be a sentence of its 51 | # own, or if followed by a number, a non-breaking prefix. 52 | No #NUMERIC_ONLY# 53 | Nr #NUMERIC_ONLY# 54 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/nonbreaking_prefixes/nonbreaking_prefix.zh: -------------------------------------------------------------------------------- 1 | # 2 | # Mandarin (Chinese) 3 | # 4 | # Anything in this file, followed by a period, 5 | # does NOT indicate an end-of-sentence marker. 6 | # 7 | # English/Euro-language given-name initials (appearing in 8 | # news, periodicals, etc.) 9 | A 10 | Ā 11 | B 12 | C 13 | Č 14 | D 15 | E 16 | Ē 17 | F 18 | G 19 | Ģ 20 | H 21 | I 22 | Ī 23 | J 24 | K 25 | Ķ 26 | L 27 | Ļ 28 | M 29 | N 30 | Ņ 31 | O 32 | P 33 | Q 34 | R 35 | S 36 | Š 37 | T 38 | U 39 | Ū 40 | V 41 | W 42 | X 43 | Y 44 | Z 45 | Ž 46 | 47 | # Numbers only. These should only induce breaks when followed by 48 | # a numeric sequence. 49 | # Add NUMERIC_ONLY after the word for this function. This case is 50 | # mostly for the english "No." which can either be a sentence of its 51 | # own, or if followed by a number, a non-breaking prefix. 52 | No #NUMERIC_ONLY# 53 | Nr #NUMERIC_ONLY# 54 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/release_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import torch 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser( 7 | description="Removes the optim data of PyTorch models") 8 | parser.add_argument("--model", "-m", 9 | help="The model filename (*.pt)", required=True) 10 | parser.add_argument("--output", "-o", 11 | help="The output filename (*.pt)", required=True) 12 | opt = parser.parse_args() 13 | 14 | model = torch.load(opt.model) 15 | model['optim'] = None 16 | torch.save(model, opt.output) 17 | -------------------------------------------------------------------------------- /Bi-selective Encoding/tools/test_rouge.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | import argparse 3 | import os 4 | import time 5 | import pyrouge 6 | import shutil 7 | import sys 8 | import codecs 9 | 10 | from onmt.utils.logging import init_logger, logger 11 | 12 | 13 | def test_rouge(cand, ref): 14 | """Calculate ROUGE scores of sequences passed as an iterator 15 | e.g. a list of str, an open file, StringIO or even sys.stdin 16 | """ 17 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 18 | tmp_dir = ".rouge-tmp-{}".format(current_time) 19 | try: 20 | if not os.path.isdir(tmp_dir): 21 | os.mkdir(tmp_dir) 22 | os.mkdir(tmp_dir + "/candidate") 23 | os.mkdir(tmp_dir + "/reference") 24 | candidates = [line.strip() for line in cand] 25 | references = [line.strip() for line in ref] 26 | assert len(candidates) == len(references) 27 | cnt = len(candidates) 28 | for i in range(cnt): 29 | if len(references[i]) < 1: 30 | continue 31 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", 32 | encoding="utf-8") as f: 33 | f.write(candidates[i]) 34 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", 35 | encoding="utf-8") as f: 36 | f.write(references[i]) 37 | r = pyrouge.Rouge155() 38 | r.model_dir = tmp_dir + "/reference/" 39 | r.system_dir = tmp_dir + "/candidate/" 40 | r.model_filename_pattern = 'ref.#ID#.txt' 41 | r.system_filename_pattern = 'cand.(\d+).txt' 42 | rouge_results = r.convert_and_evaluate() 43 | results_dict = r.output_to_dict(rouge_results) 44 | return results_dict 45 | finally: 46 | pass 47 | if os.path.isdir(tmp_dir): 48 | shutil.rmtree(tmp_dir) 49 | 50 | 51 | def rouge_results_to_str(results_dict): 52 | return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format( 53 | results_dict["rouge_1_f_score"] * 100, 54 | results_dict["rouge_2_f_score"] * 100, 55 | results_dict["rouge_3_f_score"] * 100, 56 | results_dict["rouge_l_f_score"] * 100, 57 | results_dict["rouge_su*_f_score"] * 100) 58 | 59 | 60 | if __name__ == "__main__": 61 | init_logger('test_rouge.log') 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('-c', type=str, default="candidate.txt", 64 | help='candidate file') 65 | parser.add_argument('-r', type=str, default="reference.txt", 66 | help='reference file') 67 | args = parser.parse_args() 68 | if args.c.upper() == "STDIN": 69 | candidates = sys.stdin 70 | else: 71 | candidates = codecs.open(args.c, encoding="utf-8") 72 | references = codecs.open(args.r, encoding="utf-8") 73 | 74 | results_dict = test_rouge(candidates, references) 75 | logger.info(rouge_results_to_str(results_dict)) 76 | -------------------------------------------------------------------------------- /Bi-selective Encoding/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Main training workflow 4 | """ 5 | from __future__ import division 6 | 7 | import argparse 8 | 9 | import onmt.opts as opts 10 | from onmt.train_multi import main as multi_main 11 | from onmt.train_single import main as single_main 12 | 13 | 14 | def main(opt): 15 | if opt.rnn_type == "SRU" and not opt.gpuid: 16 | raise AssertionError("Using SRU requires -gpuid set.") 17 | 18 | if opt.epochs: 19 | raise AssertionError("-epochs is deprecated please use -train_steps.") 20 | 21 | if opt.truncated_decoder > 0 and opt.accum_count > 1: 22 | raise AssertionError("BPTT is not compatible with -accum > 1") 23 | 24 | if len(opt.gpuid) > 1: 25 | multi_main(opt) 26 | else: 27 | single_main(opt) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser( 32 | description='train.py', 33 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 34 | 35 | opts.add_md_help_argument(parser) 36 | opts.model_opts(parser) 37 | opts.train_opts(parser) 38 | 39 | opt = parser.parse_args() 40 | main(opt) 41 | -------------------------------------------------------------------------------- /Bi-selective Encoding/translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import division, unicode_literals 5 | import argparse 6 | 7 | from onmt.utils.logging import init_logger 8 | from onmt.translate.translator import build_translator 9 | 10 | import onmt.inputters 11 | import onmt.translate 12 | import onmt 13 | import onmt.model_builder 14 | import onmt.modules 15 | import onmt.opts 16 | 17 | 18 | def main(opt): 19 | translator = build_translator(opt, report_score=True) 20 | translator.translate(src_path=opt.src, 21 | tgt_path=opt.tgt, 22 | template_path=opt.template, 23 | src_dir=opt.src_dir, 24 | batch_size=opt.batch_size, 25 | attn_debug=opt.attn_debug) 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser( 29 | description='translate.py', 30 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 31 | onmt.opts.add_md_help_argument(parser) 32 | onmt.opts.translate_opts(parser) 33 | 34 | opt = parser.parse_args() 35 | logger = init_logger(opt.log_file) 36 | main(opt) 37 | -------------------------------------------------------------------------------- /FastRerank/PyRouge/README.md: -------------------------------------------------------------------------------- 1 | # PyRouge 2 | Rouge evaluation script implemented with Python. 3 | 4 | Currently, only Rouge-N is implemented. 5 | 6 | **WARNING** The result is slightly different from the official Rouge. 7 | 8 | ## Usage 9 | 10 | ```python 11 | python compute.py ref_file.txt predict_file.txt 12 | ``` 13 | 14 | `ref_file.txt` and `predict_file.txt` are line-by-line text files. 15 | 16 | Output format 17 | 18 | Now the script returns a dictionary, which looks like: 19 | ```python 20 | {'rouge-1': {'p': (0.34902417721047307, 0.0013577881868447896, (0.34636268087782229, 0.35168567354312386)), 'r': (0.29738279969648435, 0.0011050260347502225, (0.29521676027482341, 0.29954883911814528)), 'f': (0.31108022747945868, 0.0010620266366127937, (0.30899847420902349, 0.31316198074989388))}, 'rouge-2': {'p': (0.13283309482481312, 0.0010693069735949634, (0.13073707085268366, 0.13492911879694258)), 'r': (0.11229619796675784, 0.00089595545126339876, (0.11053997253273248, 0.1140524234007832)), 'f': (0.11772894246560359, 0.00090995790892731512, (0.11594526982730757, 0.1195126151038996))}} 21 | ``` 22 | For each measurement, the result is a tuple, which is (mead, std_error, (95% confidence interval)) -------------------------------------------------------------------------------- /FastRerank/PyRouge/Rouge/Rouge.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division, print_function, unicode_literals 3 | from collections import Counter 4 | import re 5 | 6 | from nltk.stem.porter import PorterStemmer 7 | import numpy as np, scipy.stats as st 8 | import itertools 9 | 10 | 11 | stemmer = PorterStemmer() 12 | 13 | 14 | class Rouge(object): 15 | def __init__(self, stem=True, use_ngram_buf=False): 16 | self.N = 2 17 | self.stem = stem 18 | self.use_ngram_buf = use_ngram_buf 19 | self.ngram_buf = {} 20 | 21 | @staticmethod 22 | def _format_sentence(sentence): 23 | s = sentence.lower() 24 | s = re.sub(r"[^0-9a-z]", " ", s) 25 | s = re.sub(r"\s+", " ", s) 26 | s = s.strip() 27 | return s 28 | 29 | def _create_n_gram(self, raw_sentence, n, stem): 30 | if self.use_ngram_buf: 31 | if raw_sentence in self.ngram_buf: 32 | return self.ngram_buf[raw_sentence] 33 | res = {} 34 | sentence = Rouge._format_sentence(raw_sentence) 35 | tokens = sentence.split(' ') 36 | if stem: 37 | # try: # TODO older NLTK has a bug in Porter Stemmer 38 | tokens = [stemmer.stem(t) for t in tokens] 39 | # except: 40 | # pass 41 | sent_len = len(tokens) 42 | for _n in range(n): 43 | buf = Counter() 44 | for idx, token in enumerate(tokens): 45 | if idx + _n >= sent_len: 46 | break 47 | ngram = ' '.join(tokens[idx: idx + _n + 1]) 48 | buf[ngram] += 1 49 | res[_n] = buf 50 | if self.use_ngram_buf: 51 | self.ngram_buf[raw_sentence] = res 52 | return res 53 | 54 | def get_ngram(self, sents, N, stem=False): 55 | if isinstance(sents, list): 56 | res = {} 57 | for _n in range(N): 58 | res[_n] = Counter() 59 | for sent in sents: 60 | ngrams = self._create_n_gram(sent, N, stem) 61 | for this_n, counter in ngrams.items(): 62 | # res[this_n] = res[this_n] + counter 63 | self_counter = res[this_n] 64 | for elem, count in counter.items(): 65 | if elem not in self_counter: 66 | self_counter[elem] = count 67 | else: 68 | self_counter[elem] += count 69 | return res 70 | elif isinstance(sents, str): 71 | return self._create_n_gram(sents, N, stem) 72 | else: 73 | raise ValueError 74 | 75 | def get_mean_sd_internal(self, x): 76 | mean = np.mean(x) 77 | sd = st.sem(x) 78 | res = st.t.interval(0.95, len(x) - 1, loc=mean, scale=sd) 79 | return (mean, sd, res) 80 | 81 | def compute_rouge(self, references, systems): 82 | assert (len(references) == len(systems)) 83 | 84 | peer_count = len(references) 85 | 86 | 87 | result_buf = {} 88 | for n in range(self.N): 89 | result_buf[n] = {'p': [], 'r': [], 'f': []} 90 | 91 | for ref_sent, sys_sent in zip(references, systems): 92 | ref_ngrams = self.get_ngram(ref_sent, self.N, self.stem) 93 | sys_ngrams = self.get_ngram(sys_sent, self.N, self.stem) 94 | for n in range(self.N): 95 | ref_ngram = ref_ngrams[n] 96 | sys_ngram = sys_ngrams[n] 97 | ref_count = sum(ref_ngram.values()) 98 | sys_count = sum(sys_ngram.values()) 99 | match_count = 0 100 | for k, v in sys_ngram.items(): 101 | if k in ref_ngram: 102 | match_count += min(v, ref_ngram[k]) 103 | p = match_count / sys_count if sys_count != 0 else 0 104 | r = match_count / ref_count if ref_count != 0 else 0 105 | f = 0 if (p == 0 or r == 0) else 2 * p * r / (p + r) 106 | result_buf[n]['p'].append(p) 107 | result_buf[n]['r'].append(r) 108 | result_buf[n]['f'].append(f) 109 | 110 | 111 | 112 | res = {} 113 | for n in range(self.N): 114 | n_key = 'rouge-{0}'.format(n + 1) 115 | res[n_key] = {} 116 | if len(result_buf[n]['p']) >= 50: 117 | res[n_key]['p'] = self.get_mean_sd_internal(result_buf[n]['p']) 118 | res[n_key]['r'] = self.get_mean_sd_internal(result_buf[n]['r']) 119 | res[n_key]['f'] = self.get_mean_sd_internal(result_buf[n]['f']) 120 | else: 121 | # not enough samples to calculate confidence interval 122 | res[n_key]['p'] = (np.mean(np.array(result_buf[n]['p'])), 0, (0, 0)) 123 | res[n_key]['r'] = (np.mean(np.array(result_buf[n]['r'])), 0, (0, 0)) 124 | res[n_key]['f'] = (np.mean(np.array(result_buf[n]['f'])), 0, (0, 0)) 125 | return res 126 | 127 | 128 | -------------------------------------------------------------------------------- /FastRerank/PyRouge/Rouge/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /FastRerank/PyRouge/compute.py: -------------------------------------------------------------------------------- 1 | 2 | from Rouge import Rouge 3 | rouge = Rouge.Rouge() 4 | 5 | ref_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/test/ref-word.txt" 6 | 7 | system_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/models/0815-171759 512 linear 18.41/dev.out.107" 8 | 9 | inputs = [] 10 | systems = [] 11 | refs = [] 12 | 13 | f0 = open('input-word.txt', encoding='utf-8') 14 | f1 = open(system_file, encoding='utf-8') 15 | f2 = open(ref_file, encoding='utf-8') 16 | 17 | for l0, l1, l2 in zip(f0, f1, f2): 18 | if not l0: 19 | break 20 | l0 = l0.strip() 21 | if l0 != 'unknown_word': 22 | systems.append(l1.strip()) 23 | refs.append([l2.strip()]) 24 | 25 | print(len(systems)) 26 | scores = rouge.compute_rouge(refs, systems) 27 | print(scores) 28 | 29 | 30 | ref_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref0.txt" 31 | system_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/models/0821-213049/dev.out.1" 32 | 33 | inputs = [] 34 | systems = [] 35 | refs = [] 36 | 37 | f0 = open('input-word.txt', encoding='utf-8') 38 | f1 = open(system_file, encoding='utf-8') 39 | f2 = open(ref_file, encoding='utf-8') 40 | 41 | for l0, l1, l2 in zip(f0, f1, f2): 42 | if not l0: 43 | break 44 | l0 = l0.strip() 45 | # if l0 != 'unknown_word': 46 | systems.append(l1.strip()[:75]) 47 | refs.append([l2.strip()]) 48 | 49 | print(len(systems)) 50 | scores = rouge.compute_rouge(refs, systems) 51 | print(scores) 52 | 53 | ref_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref1.txt" 54 | system_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/models/0821-213049/dev.out.1" 55 | 56 | inputs = [] 57 | systems = [] 58 | refs = [] 59 | 60 | f0 = open('input-word.txt', encoding='utf-8') 61 | f1 = open(system_file, encoding='utf-8') 62 | f2 = open(ref_file, encoding='utf-8') 63 | 64 | for l0, l1, l2 in zip(f0, f1, f2): 65 | if not l0: 66 | break 67 | l0 = l0.strip() 68 | # if l0 != 'unknown_word': 69 | systems.append(l1.strip()[:75]) 70 | refs.append([l2.strip()]) 71 | 72 | print(len(systems)) 73 | scores = rouge.compute_rouge(refs, systems) 74 | print(scores) 75 | 76 | ref_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref2.txt" 77 | system_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/models/0821-213049/dev.out.1" 78 | 79 | inputs = [] 80 | systems = [] 81 | refs = [] 82 | 83 | f0 = open('input-word.txt', encoding='utf-8') 84 | f1 = open(system_file, encoding='utf-8') 85 | f2 = open(ref_file, encoding='utf-8') 86 | 87 | for l0, l1, l2 in zip(f0, f1, f2): 88 | if not l0: 89 | break 90 | l0 = l0.strip() 91 | # if l0 != 'unknown_word': 92 | systems.append(l1.strip()[:75]) 93 | refs.append([l2.strip()]) 94 | 95 | print(len(systems)) 96 | scores = rouge.compute_rouge(refs, systems) 97 | print(scores) 98 | 99 | ref_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref3.txt" 100 | system_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/models/0821-213049/dev.out.1" 101 | 102 | inputs = [] 103 | systems = [] 104 | refs = [] 105 | 106 | f0 = open('input-word.txt', encoding='utf-8') 107 | f1 = open(system_file, encoding='utf-8') 108 | f2 = open(ref_file, encoding='utf-8') 109 | 110 | for l0, l1, l2 in zip(f0, f1, f2): 111 | if not l0: 112 | break 113 | l0 = l0.strip() 114 | # if l0 != 'unknown_word': 115 | systems.append(l1.strip()[:75]) 116 | refs.append([l2.strip()]) 117 | 118 | print(len(systems)) 119 | scores = rouge.compute_rouge(refs, systems) 120 | print(scores) 121 | # ref_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref.txt" 122 | # system_file = "/data/QJXMS/SEASS-FASTCNN/data/giga/models/0821-213049/dev.out.1" 123 | # ref_file0 = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref0.txt" 124 | # ref_file1 = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref1.txt" 125 | # ref_file2 = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref2.txt" 126 | # ref_file3 = "/data/QJXMS/SEASS-FASTCNN/data/giga/duc/task1_ref3.txt" 127 | # inputs = [] 128 | # systems = [] 129 | # refs = [] 130 | # 131 | # f0 = open('input-word.txt', encoding='utf-8') 132 | # f1 = open(system_file, encoding='utf-8') 133 | # f20 = open(ref_file0, encoding='utf-8') 134 | # f21 = open(ref_file1, encoding='utf-8') 135 | # f22 = open(ref_file2, encoding='utf-8') 136 | # f23 = open(ref_file3, encoding='utf-8') 137 | # 138 | # for l0, l1, l20,l21,l22,l23 in zip(f0, f1, f20,f21,f22,f23): 139 | # if not l0: 140 | # break 141 | # l0 = l0.strip() 142 | # # if l0 != 'unknown_word': 143 | # systems.append(l1.strip()[:75]) 144 | # refs.append([l20.strip(),l21.strip(),l22.strip(),l23.strip()]) 145 | # 146 | # print(len(systems)) 147 | # scores = rouge.compute_rouge(refs, systems) 148 | # print(scores) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 InitialBug 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BiSET: Bi-directional Selective Encoding with Template for Abstractive Summarization (ACL 2019) 2 | 3 | This [paper](https://www.aclweb.org/anthology/P19-1207) contains three basic module: **Retrieve**, **FastRerank**, **Bi-selective Encoding**. The following is the usage. 4 | 5 | ## Retrieve 6 | The Retrieve module is based on [Apache Lucene](http://lucene.apache.org/), an open source search library. You should first download the core library from the website, and then build the java project. After that, you can index and search on the dataset by following steps: 7 | 1. Change the path in the ```Constants.java``` to your directory. 8 | 2. Run ```Indexer.java``` to build the index of the trainning set. (This process may cost several days, but only need once.) 9 | 3. Run ```Searcher.java``` to search for the candidates and generate the template index files. 10 | 11 | ## FastRerank 12 | The FastRerank module is implemented with pytorch, before run it, you should first prepare all the data (template index retrieved by Retrieve module and the raw dataset). 13 | 1. Run ```python config.py --mode preprocess``` to preprocess the data. 14 | 2. Run ```python config.py --mode train``` to train the model or ```python config.py --mode train --model modelname``` to finetune a model. (eg. ```python config.py --mode train --model model_final.pkl```) 15 | 3. Run ```python config.py --mode dev --model modelname``` to evaluate or test the model, and the template with highest score will be stored. 16 | 17 | ## Bi-selective Encoding 18 | The Bi-selective Encoding module is integrated with [OpenNMT](https://github.com/OpenNMT/OpenNMT-py). Now it only has the bi-selective encoding layer, I will add other three interaction methods (concate, multi-head attention, DCN attention) later. You can directly train it end to end with the [data](https://drive.google.com/file/d/1WtaDnpufPyqf8afFyfC13U_h56ars6CY/view?usp=sharing) by following steps: 19 | 1. Run ```python preprocess.py``` to prepare the data. 20 | 2. Run ```python train.py``` to train the model. 21 | 3. Run ```python translate.py``` to generate the summaries. 22 | 23 | ## Notice 24 | 1. If you are not familiar with Java or think the first two steps are time-consuming, you can directly train the **Bi-selective Encoding** module with the retrieved&reranked templates and data in [Google Disk](https://drive.google.com/file/d/1WtaDnpufPyqf8afFyfC13U_h56ars6CY/view?usp=sharing). 25 | 2. I refactor my code for clearity and conciseness (rename the variables and class), but I don't have enough time to do a thorough test. If the code has some problems or you have any questions, please raise an issue, I will figure it out whenever I'm 26 | available. 27 | 3. For personal communication related to BiSET, please contact me (```wangk73@mail2.sysu.edu.cn```). 28 | -------------------------------------------------------------------------------- /Retrieve/src/com/wk/lucene/Constants.java: -------------------------------------------------------------------------------- 1 | package com.wk.lucene; 2 | 3 | public class Constants { 4 | //directory to save the index 5 | public static final String indexDir = "/home/k/Data/index_article"; 6 | 7 | //the target directory to be indexed, put the data in separate files, each file contains one data 8 | public static final String dataDir = "/home/k/Data/train_article"; 9 | 10 | //put the queries in one file, each line contains one query 11 | public static final String queryPath = "/home/k/Data/train.article.txt"; 12 | 13 | //the file to save the results, each line is the index of indexed file 14 | public static final String results="/home/k/Data/train.template.index"; 15 | 16 | // query number 17 | public static final int query_num=30; 18 | 19 | } 20 | -------------------------------------------------------------------------------- /Retrieve/src/com/wk/lucene/Indexer.java: -------------------------------------------------------------------------------- 1 | package com.wk.lucene; 2 | 3 | import java.io.File; 4 | import java.io.FileReader; 5 | import java.io.IOException; 6 | import java.nio.file.Paths; 7 | import java.util.concurrent.TimeUnit; 8 | 9 | import org.apache.lucene.analysis.Analyzer; 10 | import org.apache.lucene.analysis.en.EnglishAnalyzer; 11 | import org.apache.lucene.document.Document; 12 | import org.apache.lucene.document.Field; 13 | import org.apache.lucene.document.TextField; 14 | import org.apache.lucene.index.IndexWriter; 15 | import org.apache.lucene.index.IndexWriterConfig; 16 | import org.apache.lucene.store.Directory; 17 | import org.apache.lucene.store.FSDirectory; 18 | 19 | import com.wk.lucene.Constants; 20 | 21 | /*build index */ 22 | public class Indexer { 23 | private IndexWriter writer; 24 | 25 | /** 26 | * @param indexDir 27 | * @throws IOException 28 | */ 29 | public Indexer(String indexDir) throws IOException { 30 | //get the directory of index 31 | Directory directory = FSDirectory.open(Paths.get(indexDir)); 32 | // use the EnglishAnalyzer to tokenize 33 | Analyzer analyzer = new EnglishAnalyzer(); 34 | //save the config 35 | IndexWriterConfig iwConfig = new IndexWriterConfig(analyzer); 36 | 37 | writer = new IndexWriter(directory, iwConfig); 38 | } 39 | 40 | /** 41 | * close index 42 | * 43 | * @throws Exception 44 | * @return the number of indexed documents 45 | */ 46 | public void close() throws IOException { 47 | writer.close(); 48 | } 49 | 50 | public int index(String dataDir) throws Exception { 51 | File[] files = new File(dataDir).listFiles(); 52 | for (File file : files) { 53 | 54 | indexFile(file); 55 | } 56 | return writer.numDocs(); 57 | 58 | } 59 | 60 | /** 61 | * index target file 62 | * 63 | * @param file 64 | */ 65 | private void indexFile(File f) throws Exception { 66 | System.out.println("index file:" + f.getCanonicalPath()); 67 | Document doc = getDocument(f); 68 | writer.addDocument(doc); 69 | } 70 | 71 | 72 | private Document getDocument(File f) throws Exception { 73 | Document doc = new Document(); 74 | doc.add(new TextField("contents", new FileReader(f))); 75 | doc.add(new TextField("fileName", f.getName(), Field.Store.YES)); 76 | doc.add(new TextField("fullPath", f.getCanonicalPath(), Field.Store.YES)); 77 | return doc; 78 | } 79 | 80 | public static void main(String[] args) { 81 | try { 82 | TimeUnit.MINUTES.sleep(10); 83 | } catch (InterruptedException e1) { 84 | // TODO Auto-generated catch block 85 | e1.printStackTrace(); 86 | } 87 | //directory to save the index 88 | String indexDir = Constants.indexDir; 89 | 90 | //the target directory to be indexed, put the data in separate files, each file contains one data 91 | String dataDir = Constants.dataDir; 92 | Indexer indexer = null; 93 | int numIndexed = 0; 94 | long start = System.currentTimeMillis(); 95 | try { 96 | indexer = new Indexer(indexDir); 97 | numIndexed = indexer.index(dataDir); 98 | } catch (Exception e) { 99 | // TODO Auto-generated catch block 100 | e.printStackTrace(); 101 | } finally { 102 | try { 103 | indexer.close(); 104 | } catch (Exception e) { 105 | // TODO Auto-generated catch block 106 | e.printStackTrace(); 107 | } 108 | } 109 | long end = System.currentTimeMillis(); 110 | System.out.println("index:" + numIndexed + " files cost" + (end - start) + " ms"); 111 | } 112 | 113 | } -------------------------------------------------------------------------------- /Retrieve/src/com/wk/lucene/Searcher.java: -------------------------------------------------------------------------------- 1 | package com.wk.lucene; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileReader; 5 | import java.io.FileWriter; 6 | import java.io.IOException; 7 | import java.nio.file.Paths; 8 | 9 | import org.apache.lucene.analysis.Analyzer; 10 | import org.apache.lucene.analysis.en.EnglishAnalyzer; 11 | import org.apache.lucene.document.Document; 12 | import org.apache.lucene.index.DirectoryReader; 13 | import org.apache.lucene.index.IndexReader; 14 | import org.apache.lucene.queryparser.classic.QueryParser; 15 | import org.apache.lucene.search.IndexSearcher; 16 | import org.apache.lucene.search.Query; 17 | import org.apache.lucene.search.ScoreDoc; 18 | import org.apache.lucene.search.TopDocs; 19 | import org.apache.lucene.store.Directory; 20 | import org.apache.lucene.store.FSDirectory; 21 | 22 | import com.wk.lucene.Constants; 23 | 24 | public class Searcher { 25 | 26 | public static void search(String indexDir, String queryPath,String fileName, int query_num) throws Exception { 27 | 28 | FileWriter writer=new FileWriter(fileName,true); 29 | 30 | Directory dir = FSDirectory.open(Paths.get(indexDir)); 31 | IndexReader reader = DirectoryReader.open(dir); 32 | IndexSearcher is = new IndexSearcher(reader); 33 | Analyzer analyzer = new EnglishAnalyzer(); 34 | 35 | QueryParser parser = new QueryParser("contents", analyzer); 36 | int num=0; 37 | String line=""; 38 | BufferedReader in=new BufferedReader(new FileReader(queryPath)); 39 | while(true){ 40 | System.out.println(num); 41 | num++; 42 | line=in.readLine(); 43 | if(line==null) 44 | break; 45 | 46 | line=line.replaceAll("[^a-z<>#\\s]" , ""); 47 | 48 | System.out.println(line); 49 | Query query = parser.parse(line); 50 | 51 | 52 | TopDocs hits = is.search(query, query_num); 53 | 54 | 55 | for (ScoreDoc scoreDoc : hits.scoreDocs) { 56 | Document doc = is.doc(scoreDoc.doc); 57 | writer.write(doc.get("fileName")+' '); 58 | } 59 | writer.write('\n'); 60 | 61 | } 62 | reader.close(); 63 | writer.close(); 64 | 65 | } 66 | 67 | public static void main(String[] args) throws IOException { 68 | 69 | //the directory of index 70 | String indexDir = Constants.indexDir; 71 | //put the queries in one file, each line contains one query 72 | String queryPath = Constants.queryPath; 73 | //the file to save the results, each line is the index of indexed file 74 | String results=Constants.results; 75 | // query number 76 | int query_num=Constants.query_num; 77 | 78 | try { 79 | search(indexDir, queryPath,results,query_num); 80 | } catch (Exception e) { 81 | // TODO Auto-generated catch block 82 | e.printStackTrace(); 83 | } 84 | } 85 | } --------------------------------------------------------------------------------