├── translator ├── config │ ├── config.preprocess.yml │ ├── config.translate.yml │ └── config.train.yml ├── torchtext │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── example.py │ │ ├── pipeline.py │ │ ├── batch.py │ │ └── utils.py │ ├── datasets │ │ ├── __init__.py │ │ ├── imdb.py │ │ ├── trec.py │ │ ├── sequence_tagging.py │ │ ├── sst.py │ │ └── babi.py │ └── utils.py ├── onmt │ ├── models │ │ ├── __init__.py │ │ ├── model.py │ │ ├── stacked_rnn.py │ │ └── model_saver.py │ ├── utils │ │ ├── rnn_factory.py │ │ ├── __init__.py │ │ ├── logging.py │ │ ├── cnn_factory.py │ │ ├── misc.py │ │ ├── distributed.py │ │ ├── parse.py │ │ ├── statistics.py │ │ └── report_manager.py │ ├── decoders │ │ ├── __init__.py │ │ ├── cnn_decoder.py │ │ └── ensemble.py │ ├── __init__.py │ ├── encoders │ │ ├── __init__.py │ │ ├── mean_encoder.py │ │ ├── encoder.py │ │ ├── cnn_encoder.py │ │ ├── rnn_encoder.py │ │ ├── transformer.py │ │ ├── image_encoder.py │ │ └── audio_encoder.py │ ├── translate │ │ ├── __init__.py │ │ ├── penalties.py │ │ ├── decode_strategy.py │ │ └── translation.py │ ├── modules │ │ ├── __init__.py │ │ ├── position_ffn.py │ │ ├── structured_attention.py │ │ ├── util_class.py │ │ ├── sparse_activations.py │ │ ├── sparse_losses.py │ │ ├── conv_multi_step_attention.py │ │ ├── gate.py │ │ └── average_attn.py │ ├── inputters │ │ ├── __init__.py │ │ ├── datareader_base.py │ │ └── image_dataset.py │ └── train_single.py ├── README.md ├── translate.py ├── recover-dummy.py ├── train.py ├── preprocess.py └── evaluate.py ├── stitcher ├── stitching.sh ├── README.md └── err_utils.py ├── tokenizer ├── clang │ ├── __init__.py │ └── enumerations.py └── format-pairs.py ├── scripts ├── clang-format-spoc └── split-on-field.py └── README.md /translator/config/config.preprocess.yml: -------------------------------------------------------------------------------- 1 | train_src: data/input-tok-train-shuf.tsv 2 | train_tgt: "-" 3 | valid_src: data/input-tok-eval.tsv 4 | valid_tgt: "-" 5 | save_data: out/preprocessed 6 | dynamic_dict: true 7 | -------------------------------------------------------------------------------- /translator/torchtext/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import datasets 3 | from . import utils 4 | from . import vocab 5 | 6 | __version__ = '0.4.0' 7 | 8 | __all__ = ['data', 9 | 'datasets', 10 | 'utils', 11 | 'vocab'] 12 | -------------------------------------------------------------------------------- /translator/onmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining models.""" 2 | from onmt.models.model_saver import build_model_saver, ModelSaver 3 | from onmt.models.model import NMTModel 4 | 5 | __all__ = ["build_model_saver", "ModelSaver", 6 | "NMTModel", "check_sru_requirement"] 7 | -------------------------------------------------------------------------------- /translator/config/config.translate.yml: -------------------------------------------------------------------------------- 1 | src: data/input-tok-test-src.tsv 2 | tgt: data/input-tok-test-tgt.tsv 3 | output: /dev/null 4 | 5 | replace_unk: true 6 | verbose: true 7 | beam_size: 100 8 | n_best: 100 9 | gpu: 0 10 | batch_size: 1000 # Use 1 for large beam_size (~1000), reduces memory load 11 | -------------------------------------------------------------------------------- /translator/onmt/utils/rnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN tools 3 | """ 4 | import torch.nn as nn 5 | import onmt.models 6 | 7 | 8 | def rnn_factory(rnn_type, **kwargs): 9 | """ rnn factory, Use pytorch version when available. """ 10 | no_pack_padded_seq = False 11 | if rnn_type == "SRU": 12 | # SRU doesn't support PackedSequence. 13 | no_pack_padded_seq = True 14 | rnn = onmt.models.sru.SRU(**kwargs) 15 | else: 16 | rnn = getattr(nn, rnn_type)(**kwargs) 17 | return rnn, no_pack_padded_seq 18 | -------------------------------------------------------------------------------- /stitcher/stitching.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]]; then 3 | echo "Usage: $0 NAME" 4 | echo " where NAME is the name of .tsv and .summary files" 5 | exit 1 6 | fi 7 | 8 | P=10 9 | 10 | # Count the number of programs 11 | N=$(tail -n+2 ${1}.tsv | cut -f 3-6 | uniq | wc -l) 12 | 13 | # Change the stitcher (-o) to the appropriate one! 14 | i=1 15 | while [[ $i -le $N ]]; do 16 | echo Submitting $'python stitcher/stitch.py -o -p '"$P"' '"$1"' '"$i"'' 17 | python stitcher/stitch.py -o -p $P $1 $i --out-dir out/ 18 | i=$(($i + 1)) 19 | done 20 | -------------------------------------------------------------------------------- /translator/config/config.train.yml: -------------------------------------------------------------------------------- 1 | data: data/preprocessed/preprocessed 2 | save_model: out/model 3 | log_file: out/log.txt 4 | 5 | world_size: 1 6 | gpu_ranks: 0 7 | 8 | save_checkpoint_steps: 1000 9 | valid_steps: 1000 10 | train_steps: 20000 11 | 12 | src_word_vec_size: 200 13 | tgt_word_vec_size: 200 14 | enc_rnn_size: 200 15 | dec_rnn_size: 200 16 | 17 | copy_attn: true 18 | copy_attn_force: true 19 | 20 | batch_size: 32 21 | valid_batch_size: 32 22 | 23 | optim: adam 24 | learning_rate: 0.001 25 | 26 | encoder_type: brnn 27 | coverage_attn: true 28 | -------------------------------------------------------------------------------- /translator/onmt/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining decoders.""" 2 | from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ 3 | StdRNNDecoder 4 | from onmt.decoders.transformer import TransformerDecoder 5 | from onmt.decoders.cnn_decoder import CNNDecoder 6 | 7 | 8 | str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, 9 | "cnn": CNNDecoder, "transformer": TransformerDecoder} 10 | 11 | __all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", 12 | "InputFeedRNNDecoder", "str2dec"] 13 | -------------------------------------------------------------------------------- /translator/onmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining various utilities.""" 2 | from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed 3 | from onmt.utils.report_manager import ReportMgr, build_report_manager 4 | from onmt.utils.statistics import Statistics 5 | from onmt.utils.optimizers import MultipleOptimizer, \ 6 | Optimizer, AdaFactor 7 | 8 | __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", 9 | "build_report_manager", "Statistics", 10 | "MultipleOptimizer", "Optimizer", "AdaFactor"] 11 | -------------------------------------------------------------------------------- /translator/onmt/__init__.py: -------------------------------------------------------------------------------- 1 | """ Main entry point of the ONMT library """ 2 | from __future__ import division, print_function 3 | 4 | import onmt.inputters 5 | import onmt.encoders 6 | import onmt.decoders 7 | import onmt.models 8 | import onmt.utils 9 | import onmt.modules 10 | from onmt.trainer import Trainer 11 | import sys 12 | import onmt.utils.optimizers 13 | onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer 14 | sys.modules["onmt.Optim"] = onmt.utils.optimizers 15 | 16 | # For Flake 17 | __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, 18 | onmt.utils, onmt.modules, "Trainer"] 19 | 20 | __version__ = "0.8.2" 21 | -------------------------------------------------------------------------------- /stitcher/README.md: -------------------------------------------------------------------------------- 1 | To run stitching: 2 | 1. mkdir stitch_{exp_name} 3 | 2. cd stitch_{exp_name} 4 | 3. Copy the `postprocess.sh`, `cleanup.sh` and `stitch.py` 5 | 4. Copy your summary file as `{exp_name}.summary` and `input_all_test.tsv` as `{exp_name}.tsv` 6 | 5. Edit the `postprocess.sh` script to have the right number of programs to be stitched in the for loop (more is fine, process just die as out of bounds) 7 | 6. Set one of `-o`, `-x`, `-g` or `-b` to run in oracle mode, top1, gibbs & best-first mode respectively 8 | 7. Set `-p` param to the appropriate number of predictions value (usually 100) 9 | 8. If something goes wrong in the run, do `bash cleanup.sh` to remove all the generated dirs and out files. 10 | -------------------------------------------------------------------------------- /tokenizer/clang/__init__.py: -------------------------------------------------------------------------------- 1 | #===- __init__.py - Clang Python Bindings --------------------*- python -*--===# 2 | # 3 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | #===------------------------------------------------------------------------===# 8 | 9 | r""" 10 | Clang Library Bindings 11 | ====================== 12 | 13 | This package provides access to the Clang compiler and libraries. 14 | 15 | The available modules are: 16 | 17 | cindex 18 | 19 | Bindings for the Clang indexing library. 20 | """ 21 | 22 | __all__ = ['cindex'] 23 | 24 | -------------------------------------------------------------------------------- /translator/onmt/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining encoders.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.encoders.transformer import TransformerEncoder 4 | from onmt.encoders.rnn_encoder import RNNEncoder 5 | from onmt.encoders.cnn_encoder import CNNEncoder 6 | from onmt.encoders.mean_encoder import MeanEncoder 7 | from onmt.encoders.audio_encoder import AudioEncoder 8 | from onmt.encoders.image_encoder import ImageEncoder 9 | 10 | 11 | str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, 12 | "transformer": TransformerEncoder, "img": ImageEncoder, 13 | "audio": AudioEncoder, "mean": MeanEncoder} 14 | 15 | __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", 16 | "MeanEncoder", "str2enc"] 17 | -------------------------------------------------------------------------------- /translator/torchtext/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch import Batch 2 | from .dataset import Dataset, TabularDataset 3 | from .example import Example 4 | from .field import RawField, Field, ReversibleField, SubwordField, NestedField, LabelField 5 | from .iterator import (batch, BucketIterator, Iterator, BPTTIterator, 6 | pool) 7 | from .pipeline import Pipeline 8 | from .utils import get_tokenizer, interleave_keys 9 | 10 | __all__ = ["Batch", 11 | "Dataset", "TabularDataset", 12 | "Example", 13 | "RawField", "Field", "ReversibleField", "SubwordField", "NestedField", 14 | "LabelField", 15 | "batch", "BucketIterator", "Iterator", "BPTTIterator", 16 | "pool", 17 | "Pipeline", 18 | "get_tokenizer", "interleave_keys"] 19 | -------------------------------------------------------------------------------- /translator/onmt/translate/__init__.py: -------------------------------------------------------------------------------- 1 | """ Modules for translation """ 2 | from onmt.translate.translator import Translator 3 | from onmt.translate.translation import Translation, TranslationBuilder 4 | from onmt.translate.beam import Beam, GNMTGlobalScorer 5 | from onmt.translate.beam_search import BeamSearch 6 | from onmt.translate.decode_strategy import DecodeStrategy 7 | from onmt.translate.random_sampling import RandomSampling 8 | from onmt.translate.penalties import PenaltyBuilder 9 | from onmt.translate.translation_server import TranslationServer, \ 10 | ServerModelError 11 | 12 | __all__ = ['Translator', 'Translation', 'Beam', 'BeamSearch', 13 | 'GNMTGlobalScorer', 'TranslationBuilder', 14 | 'PenaltyBuilder', 'TranslationServer', 'ServerModelError', 15 | "DecodeStrategy", "RandomSampling"] 16 | -------------------------------------------------------------------------------- /translator/onmt/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET, log_mode='a'): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = logging.FileHandler(log_file, mode=log_mode) 20 | file_handler.setLevel(log_file_level) 21 | file_handler.setFormatter(log_format) 22 | logger.addHandler(file_handler) 23 | 24 | return logger 25 | -------------------------------------------------------------------------------- /translator/torchtext/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_modeling import LanguageModelingDataset, WikiText2, WikiText103, PennTreebank # NOQA 2 | from .nli import SNLI, MultiNLI 3 | from .sst import SST 4 | from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA 5 | from .sequence_tagging import SequenceTaggingDataset, UDPOS, CoNLL2000Chunking # NOQA 6 | from .trec import TREC 7 | from .imdb import IMDB 8 | from .babi import BABI20 9 | 10 | 11 | __all__ = ['LanguageModelingDataset', 12 | 'SNLI', 13 | 'MultiNLI', 14 | 'SST', 15 | 'TranslationDataset', 16 | 'Multi30k', 17 | 'IWSLT', 18 | 'WMT14', 19 | 'WikiText2', 20 | 'WikiText103', 21 | 'PennTreebank', 22 | 'TREC', 23 | 'IMDB', 24 | 'SequenceTaggingDataset', 25 | 'UDPOS', 26 | 'CoNLL2000Chunking', 27 | 'BABI20'] 28 | -------------------------------------------------------------------------------- /translator/README.md: -------------------------------------------------------------------------------- 1 | # Commands 2 | 3 | ## Preprocessing 4 | 5 | ``` 6 | ./opennmt/preprocess.py -train_src data/debug/text/trainA_000.text -train_tgt data/debug/code/trainA_000.code -valid_src data/debug/text/trainA_001.text -valid_tgt data/debug/code/trainA_001.code -save_data out/debug -dynamic_dict 7 | ``` 8 | TODO: Make this handle directories of text and code 9 | 10 | ## Training 11 | 12 | ``` 13 | ./opennmt/train.py -config opennmt/config/small.train.yml -data out/debug -save_model out/model 14 | ``` 15 | 16 | To run on GPU, add the following: 17 | ``` 18 | -world_size 1 -gpu_ranks 0 19 | ``` 20 | and request for 1 GPU on the cluster. 21 | 22 | ## Prediction 23 | ``` 24 | ./opennmt/translate.py -config opennmt/config/small.translate.yml -model out/model_step_10000.pt -src data/debug/text/trainA_004.text -output out/pred.txt 25 | ``` 26 | 27 | To run on GPU, add the following: 28 | ``` 29 | -gpu 1 30 | ``` 31 | and request for 1 GPU on the cluster. 32 | -------------------------------------------------------------------------------- /translator/onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ Attention and normalization modules """ 2 | from onmt.modules.util_class import Elementwise 3 | from onmt.modules.gate import context_gate_factory, ContextGate 4 | from onmt.modules.global_attention import GlobalAttention 5 | from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention 6 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ 7 | CopyGeneratorLossCompute 8 | from onmt.modules.multi_headed_attn import MultiHeadedAttention 9 | from onmt.modules.embeddings import Embeddings, PositionalEncoding 10 | from onmt.modules.weight_norm import WeightNormConv2d 11 | from onmt.modules.average_attn import AverageAttention 12 | 13 | __all__ = ["Elementwise", "context_gate_factory", "ContextGate", 14 | "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", 15 | "CopyGeneratorLoss", "CopyGeneratorLossCompute", 16 | "MultiHeadedAttention", "Embeddings", "PositionalEncoding", 17 | "WeightNormConv2d", "AverageAttention"] 18 | -------------------------------------------------------------------------------- /translator/onmt/encoders/mean_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | 4 | 5 | class MeanEncoder(EncoderBase): 6 | """A trivial non-recurrent encoder. Simply applies mean pooling. 7 | 8 | Args: 9 | num_layers (int): number of replicated layers 10 | embeddings (onmt.modules.Embeddings): embedding module to use 11 | """ 12 | 13 | def __init__(self, num_layers, embeddings): 14 | super(MeanEncoder, self).__init__() 15 | self.num_layers = num_layers 16 | self.embeddings = embeddings 17 | 18 | @classmethod 19 | def from_opt(cls, opt, embeddings): 20 | """Alternate constructor.""" 21 | return cls( 22 | opt.enc_layers, 23 | embeddings) 24 | 25 | def forward(self, src, lengths=None): 26 | """See :func:`EncoderBase.forward()`""" 27 | self._check_args(src, lengths) 28 | 29 | emb = self.embeddings(src) 30 | _, batch, emb_dim = emb.size() 31 | mean = emb.mean(0).expand(self.num_layers, batch, emb_dim) 32 | memory_bank = emb 33 | encoder_final = (mean, mean) 34 | return encoder_final, memory_bank, lengths 35 | -------------------------------------------------------------------------------- /translator/onmt/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining inputters. 2 | 3 | Inputters implement the logic of transforming raw data to vectorized inputs, 4 | e.g., from a line of text to a sequence of embeddings. 5 | """ 6 | from onmt.inputters.inputter import \ 7 | load_old_vocab, get_fields, OrderedIterator, \ 8 | build_vocab, old_style_vocab, filter_example 9 | from onmt.inputters.dataset_base import Dataset 10 | from onmt.inputters.text_dataset import text_sort_key, TextDataReader 11 | from onmt.inputters.image_dataset import img_sort_key, ImageDataReader 12 | from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader 13 | from onmt.inputters.datareader_base import DataReaderBase 14 | 15 | 16 | str2reader = { 17 | "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader} 18 | str2sortkey = { 19 | 'text': text_sort_key, 'img': img_sort_key, 'audio': audio_sort_key} 20 | 21 | 22 | __all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'DataReaderBase', 23 | 'filter_example', 'old_style_vocab', 24 | 'build_vocab', 'OrderedIterator', 25 | 'text_sort_key', 'img_sort_key', 'audio_sort_key', 26 | 'TextDataReader', 'ImageDataReader', 'AudioDataReader'] 27 | -------------------------------------------------------------------------------- /tokenizer/clang/enumerations.py: -------------------------------------------------------------------------------- 1 | #===- enumerations.py - Python Enumerations ------------------*- python -*--===# 2 | # 3 | # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 | # See https://llvm.org/LICENSE.txt for license information. 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 | # 7 | #===------------------------------------------------------------------------===# 8 | 9 | """ 10 | Clang Enumerations 11 | ================== 12 | 13 | This module provides static definitions of enumerations that exist in libclang. 14 | 15 | Enumerations are typically defined as a list of tuples. The exported values are 16 | typically munged into other types or classes at module load time. 17 | 18 | All enumerations are centrally defined in this file so they are all grouped 19 | together and easier to audit. And, maybe even one day this file will be 20 | automatically generated by scanning the libclang headers! 21 | """ 22 | 23 | # Maps to CXTokenKind. Note that libclang maintains a separate set of token 24 | # enumerations from the C++ API. 25 | TokenKinds = [ 26 | ('PUNCTUATION', 0), 27 | ('KEYWORD', 1), 28 | ('IDENTIFIER', 2), 29 | ('LITERAL', 3), 30 | ('COMMENT', 4), 31 | ] 32 | 33 | __all__ = ['TokenKinds'] 34 | -------------------------------------------------------------------------------- /translator/onmt/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """Position feed-forward network from "Attention is All You Need".""" 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class PositionwiseFeedForward(nn.Module): 7 | """ A two-layer Feed-Forward-Network with residual layer norm. 8 | 9 | Args: 10 | d_model (int): the size of input for the first-layer of the FFN. 11 | d_ff (int): the hidden layer size of the second-layer 12 | of the FNN. 13 | dropout (float): dropout probability in :math:`[0, 1)`. 14 | """ 15 | 16 | def __init__(self, d_model, d_ff, dropout=0.1): 17 | super(PositionwiseFeedForward, self).__init__() 18 | self.w_1 = nn.Linear(d_model, d_ff) 19 | self.w_2 = nn.Linear(d_ff, d_model) 20 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 21 | self.dropout_1 = nn.Dropout(dropout) 22 | self.relu = nn.ReLU() 23 | self.dropout_2 = nn.Dropout(dropout) 24 | 25 | def forward(self, x): 26 | """Layer definition. 27 | 28 | Args: 29 | x: ``(batch_size, input_len, model_dim)`` 30 | 31 | Returns: 32 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``. 33 | """ 34 | 35 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 36 | output = self.dropout_2(self.w_2(inter)) 37 | return output + x 38 | -------------------------------------------------------------------------------- /translator/onmt/inputters/datareader_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | # several data readers need optional dependencies. There's no 5 | # appropriate builtin exception 6 | class MissingDependencyException(Exception): 7 | pass 8 | 9 | 10 | class DataReaderBase(object): 11 | """Read data from file system and yield as dicts. 12 | 13 | Raises: 14 | onmt.inputters.datareader_base.MissingDependencyException: A number 15 | of DataReaders need specific additional packages. 16 | If any are missing, this will be raised. 17 | """ 18 | 19 | @classmethod 20 | def from_opt(cls, opt): 21 | """Alternative constructor. 22 | 23 | Args: 24 | opt (argparse.Namespace): The parsed arguments. 25 | """ 26 | 27 | return cls() 28 | 29 | @classmethod 30 | def _read_file(cls, path): 31 | """Line-by-line read a file as bytes.""" 32 | with open(path, "rb") as f: 33 | for line in f: 34 | yield line 35 | 36 | @staticmethod 37 | def _raise_missing_dep(*missing_deps): 38 | """Raise missing dep exception with standard error message.""" 39 | raise MissingDependencyException( 40 | "Could not create reader. Be sure to install " 41 | "the following dependencies: " + ", ".join(missing_deps)) 42 | 43 | def read(self, data, side, src_dir): 44 | """Read data from file system and yield as dicts.""" 45 | raise NotImplementedError() 46 | -------------------------------------------------------------------------------- /translator/translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | from itertools import repeat 6 | 7 | from onmt.utils.logging import init_logger 8 | from onmt.utils.misc import split_corpus 9 | from onmt.translate.translator import build_translator 10 | 11 | import onmt.opts as opts 12 | from onmt.utils.parse import ArgumentParser 13 | 14 | 15 | def main(opt): 16 | ArgumentParser.validate_translate_opts(opt) 17 | logger = init_logger(opt.log_file, log_mode='w') 18 | 19 | translator = build_translator(opt, report_score=True, logger=logger) 20 | src_shards = split_corpus(opt.src, opt.shard_size) 21 | tgt_shards = split_corpus(opt.tgt, opt.shard_size) \ 22 | if opt.tgt is not None else repeat(None) 23 | shard_pairs = zip(src_shards, tgt_shards) 24 | 25 | for i, (src_shard, tgt_shard) in enumerate(shard_pairs): 26 | logger.info("Translating shard %d." % i) 27 | translator.translate( 28 | src=src_shard, 29 | tgt=tgt_shard, 30 | src_dir=opt.src_dir, 31 | batch_size=opt.batch_size, 32 | attn_debug=opt.attn_debug 33 | ) 34 | 35 | 36 | def _get_parser(): 37 | parser = ArgumentParser(description='translate.py') 38 | 39 | opts.config_opts(parser) 40 | opts.translate_opts(parser) 41 | return parser 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = _get_parser() 46 | 47 | opt = parser.parse_args() 48 | main(opt) 49 | -------------------------------------------------------------------------------- /translator/onmt/modules/structured_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.cuda 4 | 5 | 6 | class MatrixTree(nn.Module): 7 | """Implementation of the matrix-tree theorem for computing marginals 8 | of non-projective dependency parsing. This attention layer is used 9 | in the paper "Learning Structured Text Representations" 10 | :cite:`DBLP:journals/corr/LiuL17d`. 11 | """ 12 | 13 | def __init__(self, eps=1e-5): 14 | self.eps = eps 15 | super(MatrixTree, self).__init__() 16 | 17 | def forward(self, input): 18 | laplacian = input.exp() + self.eps 19 | output = input.clone() 20 | for b in range(input.size(0)): 21 | lap = laplacian[b].masked_fill( 22 | torch.eye(input.size(1), device=input.device).ne(0), 0) 23 | lap = -lap + torch.diag(lap.sum(0)) 24 | # store roots on diagonal 25 | lap[0] = input[b].diag().exp() 26 | inv_laplacian = lap.inverse() 27 | 28 | factor = inv_laplacian.diag().unsqueeze(1)\ 29 | .expand_as(input[b]).transpose(0, 1) 30 | term1 = input[b].exp().mul(factor).clone() 31 | term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() 32 | term1[:, 0] = 0 33 | term2[0] = 0 34 | output[b] = term1 - term2 35 | roots_output = input[b].diag().exp().mul( 36 | inv_laplacian.transpose(0, 1)[0]) 37 | output[b] = output[b] + torch.diag(roots_output) 38 | return output 39 | -------------------------------------------------------------------------------- /translator/recover-dummy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Add DUMMY lines to the summary file 5 | """ 6 | 7 | import sys, os, shutil, re, argparse, json, random 8 | from collections import defaultdict, Counter 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-v', '--verbose', action='store_true') 14 | parser.add_argument('orig_test_tsv') 15 | parser.add_argument('summary_nodummy_tsv') 16 | args = parser.parse_args() 17 | 18 | index = 1 19 | with open(args.orig_test_tsv) as orig_f, open(args.summary_nodummy_tsv) as summ_f: 20 | orig_head = orig_f.readline().rstrip('\n') 21 | summ_head = summ_f.readline().rstrip('\n') 22 | print(summ_head) 23 | for orig_line in orig_f: 24 | orig_line = orig_line.rstrip('\n').split('\t') 25 | if not orig_line[0]: 26 | # Print a dummy line 27 | print('\t'.join([ 28 | str(index), 29 | 'DUMMY', 30 | '0.0', 31 | '0.0', 32 | orig_line[1], 33 | 'DUMMY', 34 | ])) 35 | else: 36 | # Replace index 37 | summ_line = summ_f.readline().rstrip('\n').split('\t', 1) 38 | print('{}\t{}'.format(index, summ_line[1])) 39 | index += 1 40 | # Sanity check: the summary file should have been fully consumed 41 | blank = summ_f.readline() 42 | assert not blank 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | 48 | -------------------------------------------------------------------------------- /translator/onmt/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for encoders and generic multi encoders.""" 2 | 3 | import torch.nn as nn 4 | 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class EncoderBase(nn.Module): 9 | """ 10 | Base encoder class. Specifies the interface used by different encoder types 11 | and required by :class:`onmt.Models.NMTModel`. 12 | 13 | .. mermaid:: 14 | 15 | graph BT 16 | A[Input] 17 | subgraph RNN 18 | C[Pos 1] 19 | D[Pos 2] 20 | E[Pos N] 21 | end 22 | F[Memory_Bank] 23 | G[Final] 24 | A-->C 25 | A-->D 26 | A-->E 27 | C-->F 28 | D-->F 29 | E-->F 30 | E-->G 31 | """ 32 | 33 | @classmethod 34 | def from_opt(cls, opt, embeddings=None): 35 | raise NotImplementedError 36 | 37 | def _check_args(self, src, lengths=None, hidden=None): 38 | _, n_batch, _ = src.size() 39 | if lengths is not None: 40 | n_batch_, = lengths.size() 41 | aeq(n_batch, n_batch_) 42 | 43 | def forward(self, src, lengths=None): 44 | """ 45 | Args: 46 | src (LongTensor): 47 | padded sequences of sparse indices ``(src_len, batch, nfeat)`` 48 | lengths (LongTensor): length of each sequence ``(batch,)`` 49 | 50 | 51 | Returns: 52 | (FloatTensor, FloatTensor): 53 | 54 | * final encoder state, used to initialize decoder 55 | * memory bank for attention, ``(src_len, batch, hidden)`` 56 | """ 57 | 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /translator/onmt/modules/util_class.py: -------------------------------------------------------------------------------- 1 | """ Misc classes """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # At the moment this class is only used by embeddings.Embeddings look-up tables 7 | class Elementwise(nn.ModuleList): 8 | """ 9 | A simple network container. 10 | Parameters are a list of modules. 11 | Inputs are a 3d Tensor whose last dimension is the same length 12 | as the list. 13 | Outputs are the result of applying modules to inputs elementwise. 14 | An optional merge parameter allows the outputs to be reduced to a 15 | single Tensor. 16 | """ 17 | 18 | def __init__(self, merge=None, *args): 19 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 20 | self.merge = merge 21 | super(Elementwise, self).__init__(*args) 22 | 23 | def forward(self, inputs): 24 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 25 | assert len(self) == len(inputs_) 26 | outputs = [f(x) for f, x in zip(self, inputs_)] 27 | if self.merge == 'first': 28 | return outputs[0] 29 | elif self.merge == 'concat' or self.merge == 'mlp': 30 | return torch.cat(outputs, 2) 31 | elif self.merge == 'sum': 32 | return sum(outputs) 33 | else: 34 | return outputs 35 | 36 | 37 | class Cast(nn.Module): 38 | """ 39 | Basic layer that casts its input to a specific data type. The same tensor 40 | is returned if the data type is already correct. 41 | """ 42 | 43 | def __init__(self, dtype): 44 | super(Cast, self).__init__() 45 | self._dtype = dtype 46 | 47 | def forward(self, x): 48 | return x.to(self._dtype) 49 | -------------------------------------------------------------------------------- /translator/onmt/utils/cnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | 8 | import onmt.modules 9 | 10 | SCALE_WEIGHT = 0.5 ** 0.5 11 | 12 | 13 | def shape_transform(x): 14 | """ Tranform the size of the tensors to fit for conv input. """ 15 | return torch.unsqueeze(torch.transpose(x, 1, 2), 3) 16 | 17 | 18 | class GatedConv(nn.Module): 19 | """ Gated convolution for CNN class """ 20 | 21 | def __init__(self, input_size, width=3, dropout=0.2, nopad=False): 22 | super(GatedConv, self).__init__() 23 | self.conv = onmt.modules.WeightNormConv2d( 24 | input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), 25 | padding=(width // 2 * (1 - nopad), 0)) 26 | init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout))**0.5) 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, x_var): 30 | x_var = self.dropout(x_var) 31 | x_var = self.conv(x_var) 32 | out, gate = x_var.split(int(x_var.size(1) / 2), 1) 33 | out = out * torch.sigmoid(gate) 34 | return out 35 | 36 | 37 | class StackedCNN(nn.Module): 38 | """ Stacked CNN class """ 39 | 40 | def __init__(self, num_layers, input_size, cnn_kernel_width=3, 41 | dropout=0.2): 42 | super(StackedCNN, self).__init__() 43 | self.dropout = dropout 44 | self.num_layers = num_layers 45 | self.layers = nn.ModuleList() 46 | for _ in range(num_layers): 47 | self.layers.append( 48 | GatedConv(input_size, cnn_kernel_width, dropout)) 49 | 50 | def forward(self, x): 51 | for conv in self.layers: 52 | x = x + conv(x) 53 | x *= SCALE_WEIGHT 54 | return x 55 | -------------------------------------------------------------------------------- /translator/onmt/encoders/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch.nn as nn 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | from onmt.utils.cnn_factory import shape_transform, StackedCNN 8 | 9 | SCALE_WEIGHT = 0.5 ** 0.5 10 | 11 | 12 | class CNNEncoder(EncoderBase): 13 | """Encoder based on "Convolutional Sequence to Sequence Learning" 14 | :cite:`DBLP:journals/corr/GehringAGYD17`. 15 | """ 16 | 17 | def __init__(self, num_layers, hidden_size, 18 | cnn_kernel_width, dropout, embeddings): 19 | super(CNNEncoder, self).__init__() 20 | 21 | self.embeddings = embeddings 22 | input_size = embeddings.embedding_size 23 | self.linear = nn.Linear(input_size, hidden_size) 24 | self.cnn = StackedCNN(num_layers, hidden_size, 25 | cnn_kernel_width, dropout) 26 | 27 | @classmethod 28 | def from_opt(cls, opt, embeddings): 29 | """Alternate constructor.""" 30 | return cls( 31 | opt.enc_layers, 32 | opt.enc_rnn_size, 33 | opt.cnn_kernel_width, 34 | opt.dropout, 35 | embeddings) 36 | 37 | def forward(self, input, lengths=None, hidden=None): 38 | """See :class:`onmt.modules.EncoderBase.forward()`""" 39 | self._check_args(input, lengths, hidden) 40 | 41 | emb = self.embeddings(input) 42 | # s_len, batch, emb_dim = emb.size() 43 | 44 | emb = emb.transpose(0, 1).contiguous() 45 | emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) 46 | emb_remap = self.linear(emb_reshape) 47 | emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) 48 | emb_remap = shape_transform(emb_remap) 49 | out = self.cnn(emb_remap) 50 | 51 | return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ 52 | out.squeeze(3).transpose(0, 1).contiguous(), lengths 53 | -------------------------------------------------------------------------------- /translator/onmt/models/model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch.nn as nn 3 | 4 | 5 | class NMTModel(nn.Module): 6 | """ 7 | Core trainable object in OpenNMT. Implements a trainable interface 8 | for a simple, generic encoder + decoder model. 9 | 10 | Args: 11 | encoder (onmt.encoders.EncoderBase): an encoder object 12 | decoder (onmt.decoders.DecoderBase): a decoder object 13 | """ 14 | 15 | def __init__(self, encoder, decoder): 16 | super(NMTModel, self).__init__() 17 | self.encoder = encoder 18 | self.decoder = decoder 19 | 20 | def forward(self, src, tgt, lengths, bptt=False): 21 | """Forward propagate a `src` and `tgt` pair for training. 22 | Possible initialized with a beginning decoder state. 23 | 24 | Args: 25 | src (Tensor): A source sequence passed to encoder. 26 | typically for inputs this will be a padded `LongTensor` 27 | of size ``(len, batch, features)``. However, may be an 28 | image or other generic input depending on encoder. 29 | tgt (LongTensor): A target sequence of size ``(tgt_len, batch)``. 30 | lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. 31 | bptt (Boolean): A flag indicating if truncated bptt is set. 32 | If reset then init_state 33 | 34 | Returns: 35 | (FloatTensor, dict[str, FloatTensor]): 36 | 37 | * decoder output ``(tgt_len, batch, hidden)`` 38 | * dictionary attention dists of ``(tgt_len, batch, src_len)`` 39 | """ 40 | tgt = tgt[:-1] # exclude last target from inputs 41 | 42 | enc_state, memory_bank, lengths = self.encoder(src, lengths) 43 | if bptt is False: 44 | self.decoder.init_state(src, memory_bank, enc_state) 45 | dec_out, attns = self.decoder(tgt, memory_bank, 46 | memory_lengths=lengths) 47 | return dec_out, attns 48 | -------------------------------------------------------------------------------- /translator/onmt/models/stacked_rnn.py: -------------------------------------------------------------------------------- 1 | """ Implementation of ONMT RNN for Input Feeding Decoding """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class StackedLSTM(nn.Module): 7 | """ 8 | Our own implementation of stacked LSTM. 9 | Needed for the decoder, because we do input feeding. 10 | """ 11 | 12 | def __init__(self, num_layers, input_size, rnn_size, dropout): 13 | super(StackedLSTM, self).__init__() 14 | self.dropout = nn.Dropout(dropout) 15 | self.num_layers = num_layers 16 | self.layers = nn.ModuleList() 17 | 18 | for _ in range(num_layers): 19 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 20 | input_size = rnn_size 21 | 22 | def forward(self, input_feed, hidden): 23 | h_0, c_0 = hidden 24 | h_1, c_1 = [], [] 25 | for i, layer in enumerate(self.layers): 26 | h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) 27 | input_feed = h_1_i 28 | if i + 1 != self.num_layers: 29 | input_feed = self.dropout(input_feed) 30 | h_1 += [h_1_i] 31 | c_1 += [c_1_i] 32 | 33 | h_1 = torch.stack(h_1) 34 | c_1 = torch.stack(c_1) 35 | 36 | return input_feed, (h_1, c_1) 37 | 38 | 39 | class StackedGRU(nn.Module): 40 | """ 41 | Our own implementation of stacked GRU. 42 | Needed for the decoder, because we do input feeding. 43 | """ 44 | 45 | def __init__(self, num_layers, input_size, rnn_size, dropout): 46 | super(StackedGRU, self).__init__() 47 | self.dropout = nn.Dropout(dropout) 48 | self.num_layers = num_layers 49 | self.layers = nn.ModuleList() 50 | 51 | for _ in range(num_layers): 52 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 53 | input_size = rnn_size 54 | 55 | def forward(self, input_feed, hidden): 56 | h_1 = [] 57 | for i, layer in enumerate(self.layers): 58 | h_1_i = layer(input_feed, hidden[0][i]) 59 | input_feed = h_1_i 60 | if i + 1 != self.num_layers: 61 | input_feed = self.dropout(input_feed) 62 | h_1 += [h_1_i] 63 | 64 | h_1 = torch.stack(h_1) 65 | return input_feed, (h_1,) 66 | -------------------------------------------------------------------------------- /scripts/clang-format-spoc: -------------------------------------------------------------------------------- 1 | # version: 3.8 2 | # exclude: cmake/* 3 | # include: *.h 4 | # include: *.hpp 5 | # include: *.cpp 6 | 7 | Language: Cpp 8 | AccessModifierOffset: -4 9 | AlignAfterOpenBracket: true 10 | AlignConsecutiveAssignments: false 11 | AlignEscapedNewlinesLeft: false 12 | AlignOperands: true 13 | AlignTrailingComments: true 14 | AllowAllParametersOfDeclarationOnNextLine: true 15 | AllowShortBlocksOnASingleLine: true 16 | AllowShortCaseLabelsOnASingleLine: false 17 | AllowShortFunctionsOnASingleLine: Empty 18 | AllowShortIfStatementsOnASingleLine: true 19 | AllowShortLoopsOnASingleLine: true 20 | AlwaysBreakAfterDefinitionReturnType: None 21 | AlwaysBreakBeforeMultilineStrings: false 22 | AlwaysBreakTemplateDeclarations: true 23 | BinPackArguments: true 24 | BinPackParameters: true 25 | BreakBeforeBinaryOperators: All 26 | BreakBeforeBraces: Custom 27 | BraceWrapping: 28 | AfterFunction: false 29 | AfterControlStatement: false 30 | BreakBeforeTernaryOperators: true 31 | BreakConstructorInitializersBeforeComma: false 32 | ColumnLimit: 800 33 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 34 | ConstructorInitializerIndentWidth: 4 35 | ContinuationIndentWidth: 4 36 | Cpp11BracedListStyle: true 37 | DerivePointerAlignment: false 38 | DisableFormat: false 39 | ExperimentalAutoDetectBinPacking: false 40 | IndentCaseLabels: true 41 | IndentWidth: 4 42 | IndentWrappedFunctionNames: false 43 | KeepEmptyLinesAtTheStartOfBlocks: true 44 | MacroBlockBegin: '' 45 | MacroBlockEnd: '' 46 | MaxEmptyLinesToKeep: 1 47 | NamespaceIndentation: None 48 | ObjCBlockIndentWidth: 2 49 | ObjCSpaceAfterProperty: false 50 | ObjCSpaceBeforeProtocolList: true 51 | PenaltyBreakBeforeFirstCallParameter: 19 52 | PenaltyBreakComment: 22312 53 | PenaltyBreakFirstLessLess: 120 54 | PenaltyBreakString: 2123 55 | PenaltyExcessCharacter: 1000000 56 | PenaltyReturnTypeOnItsOwnLine: 60 57 | PointerAlignment: Right 58 | SortIncludes: false 59 | SpaceAfterCStyleCast: false 60 | SpaceBeforeAssignmentOperators: true 61 | SpaceBeforeParens: ControlStatements 62 | SpaceInEmptyParentheses: false 63 | SpacesBeforeTrailingComments: 1 64 | SpacesInAngles: false 65 | SpacesInContainerLiterals: false 66 | SpacesInCStyleCastParentheses: false 67 | SpacesInParentheses: false 68 | SpacesInSquareBrackets: false 69 | Standard: Cpp11 70 | TabWidth: 4 71 | UseTab: Never 72 | 73 | -------------------------------------------------------------------------------- /translator/torchtext/data/example.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import six 4 | 5 | 6 | class Example(object): 7 | """Defines a single training or test example. 8 | 9 | Stores each column of the example as an attribute. 10 | """ 11 | 12 | @classmethod 13 | def fromJSON(cls, data, fields): 14 | return cls.fromdict(json.loads(data), fields) 15 | 16 | @classmethod 17 | def fromdict(cls, data, fields): 18 | ex = cls() 19 | for key, vals in fields.items(): 20 | if key not in data: 21 | raise ValueError("Specified key {} was not found in " 22 | "the input data".format(key)) 23 | if vals is not None: 24 | if not isinstance(vals, list): 25 | vals = [vals] 26 | for val in vals: 27 | name, field = val 28 | setattr(ex, name, field.preprocess(data[key])) 29 | return ex 30 | 31 | @classmethod 32 | def fromCSV(cls, data, fields, field_to_index=None): 33 | if field_to_index is None: 34 | return cls.fromlist(data, fields) 35 | else: 36 | assert(isinstance(fields, dict)) 37 | data_dict = {f: data[idx] for f, idx in field_to_index.items()} 38 | return cls.fromdict(data_dict, fields) 39 | 40 | @classmethod 41 | def fromlist(cls, data, fields): 42 | ex = cls() 43 | for (name, field), val in zip(fields, data): 44 | if field is not None: 45 | if isinstance(val, six.string_types): 46 | val = val.rstrip('\n') 47 | # Handle field tuples 48 | if isinstance(name, tuple): 49 | for n, f in zip(name, field): 50 | setattr(ex, n, f.preprocess(val)) 51 | else: 52 | setattr(ex, name, field.preprocess(val)) 53 | return ex 54 | 55 | @classmethod 56 | def fromtree(cls, data, fields, subtrees=False): 57 | try: 58 | from nltk.tree import Tree 59 | except ImportError: 60 | print("Please install NLTK. " 61 | "See the docs at http://nltk.org for more information.") 62 | raise 63 | tree = Tree.fromstring(data) 64 | if subtrees: 65 | return [cls.fromlist( 66 | [' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()] 67 | return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields) 68 | -------------------------------------------------------------------------------- /translator/torchtext/utils.py: -------------------------------------------------------------------------------- 1 | import six 2 | import requests 3 | import csv 4 | from tqdm import tqdm 5 | 6 | 7 | def reporthook(t): 8 | """https://github.com/tqdm/tqdm""" 9 | last_b = [0] 10 | 11 | def inner(b=1, bsize=1, tsize=None): 12 | """ 13 | b: int, optional 14 | Number of blocks just transferred [default: 1]. 15 | bsize: int, optional 16 | Size of each block (in tqdm units) [default: 1]. 17 | tsize: int, optional 18 | Total size (in tqdm units). If [default: None] remains unchanged. 19 | """ 20 | if tsize is not None: 21 | t.total = tsize 22 | t.update((b - last_b[0]) * bsize) 23 | last_b[0] = b 24 | return inner 25 | 26 | 27 | def download_from_url(url, path): 28 | """Download file, with logic (from tensor2tensor) for Google Drive""" 29 | def process_response(r): 30 | chunk_size = 16 * 1024 31 | total_size = int(r.headers.get('Content-length', 0)) 32 | with open(path, "wb") as file: 33 | with tqdm(total=total_size, unit='B', 34 | unit_scale=1, desc=path.split('/')[-1]) as t: 35 | for chunk in r.iter_content(chunk_size): 36 | if chunk: 37 | file.write(chunk) 38 | t.update(len(chunk)) 39 | 40 | if 'drive.google.com' not in url: 41 | response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) 42 | process_response(response) 43 | return 44 | 45 | print('downloading from Google Drive; may take a few minutes') 46 | confirm_token = None 47 | session = requests.Session() 48 | response = session.get(url, stream=True) 49 | for k, v in response.cookies.items(): 50 | if k.startswith("download_warning"): 51 | confirm_token = v 52 | 53 | if confirm_token: 54 | url = url + "&confirm=" + confirm_token 55 | response = session.get(url, stream=True) 56 | 57 | process_response(response) 58 | 59 | 60 | def unicode_csv_reader(unicode_csv_data, **kwargs): 61 | """Since the standard csv library does not handle unicode in Python 2, we need a wrapper. 62 | Borrowed and slightly modified from the Python docs: 63 | https://docs.python.org/2/library/csv.html#csv-examples""" 64 | if six.PY2: 65 | # csv.py doesn't do Unicode; encode temporarily as UTF-8: 66 | csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs) 67 | for row in csv_reader: 68 | # decode UTF-8 back to Unicode, cell by cell: 69 | yield [cell.decode('utf-8') for cell in row] 70 | else: 71 | for line in csv.reader(unicode_csv_data, **kwargs): 72 | yield line 73 | 74 | 75 | def utf_8_encoder(unicode_csv_data): 76 | for line in unicode_csv_data: 77 | yield line.encode('utf-8') 78 | -------------------------------------------------------------------------------- /translator/onmt/modules/sparse_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of sparsemax (Martins & Astudillo, 2016). See 3 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 4 | 5 | By Ben Peters and Vlad Niculae 6 | """ 7 | 8 | import torch 9 | from torch.autograd import Function 10 | import torch.nn as nn 11 | 12 | 13 | def _make_ix_like(input, dim=0): 14 | d = input.size(dim) 15 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 16 | view = [1] * input.dim() 17 | view[0] = -1 18 | return rho.view(view).transpose(0, dim) 19 | 20 | 21 | def _threshold_and_support(input, dim=0): 22 | """Sparsemax building block: compute the threshold 23 | 24 | Args: 25 | input: any dimension 26 | dim: dimension along which to apply the sparsemax 27 | 28 | Returns: 29 | the threshold value 30 | """ 31 | 32 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 33 | input_cumsum = input_srt.cumsum(dim) - 1 34 | rhos = _make_ix_like(input, dim) 35 | support = rhos * input_srt > input_cumsum 36 | 37 | support_size = support.sum(dim=dim).unsqueeze(dim) 38 | tau = input_cumsum.gather(dim, support_size - 1) 39 | tau /= support_size.to(input.dtype) 40 | return tau, support_size 41 | 42 | 43 | class SparsemaxFunction(Function): 44 | 45 | @staticmethod 46 | def forward(ctx, input, dim=0): 47 | """sparsemax: normalizing sparse transform (a la softmax) 48 | 49 | Parameters: 50 | input (Tensor): any shape 51 | dim: dimension along which to apply sparsemax 52 | 53 | Returns: 54 | output (Tensor): same shape as input 55 | """ 56 | ctx.dim = dim 57 | max_val, _ = input.max(dim=dim, keepdim=True) 58 | input -= max_val # same numerical stability trick as for softmax 59 | tau, supp_size = _threshold_and_support(input, dim=dim) 60 | output = torch.clamp(input - tau, min=0) 61 | ctx.save_for_backward(supp_size, output) 62 | return output 63 | 64 | @staticmethod 65 | def backward(ctx, grad_output): 66 | supp_size, output = ctx.saved_tensors 67 | dim = ctx.dim 68 | grad_input = grad_output.clone() 69 | grad_input[output == 0] = 0 70 | 71 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 72 | v_hat = v_hat.unsqueeze(dim) 73 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 74 | return grad_input, None 75 | 76 | 77 | sparsemax = SparsemaxFunction.apply 78 | 79 | 80 | class Sparsemax(nn.Module): 81 | 82 | def __init__(self, dim=0): 83 | self.dim = dim 84 | super(Sparsemax, self).__init__() 85 | 86 | def forward(self, input): 87 | return sparsemax(input, self.dim) 88 | 89 | 90 | class LogSparsemax(nn.Module): 91 | 92 | def __init__(self, dim=0): 93 | self.dim = dim 94 | super(LogSparsemax, self).__init__() 95 | 96 | def forward(self, input): 97 | return torch.log(sparsemax(input, self.dim)) 98 | -------------------------------------------------------------------------------- /translator/onmt/modules/sparse_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from onmt.modules.sparse_activations import _threshold_and_support 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class SparsemaxLossFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, input, target): 12 | """ 13 | input (FloatTensor): ``(n, num_classes)``. 14 | target (LongTensor): ``(n,)``, the indices of the target classes 15 | """ 16 | input_batch, classes = input.size() 17 | target_batch = target.size(0) 18 | aeq(input_batch, target_batch) 19 | 20 | z_k = input.gather(1, target.unsqueeze(1)).squeeze() 21 | tau_z, support_size = _threshold_and_support(input, dim=1) 22 | support = input > tau_z 23 | x = torch.where( 24 | support, input**2 - tau_z**2, 25 | torch.tensor(0.0, device=input.device) 26 | ).sum(dim=1) 27 | ctx.save_for_backward(input, target, tau_z) 28 | # clamping necessary because of numerical errors: loss should be lower 29 | # bounded by zero, but negative values near zero are possible without 30 | # the clamp 31 | return torch.clamp(x / 2 - z_k + 0.5, min=0.0) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | input, target, tau_z = ctx.saved_tensors 36 | sparsemax_out = torch.clamp(input - tau_z, min=0) 37 | delta = torch.zeros_like(sparsemax_out) 38 | delta.scatter_(1, target.unsqueeze(1), 1) 39 | return sparsemax_out - delta, None 40 | 41 | 42 | sparsemax_loss = SparsemaxLossFunction.apply 43 | 44 | 45 | class SparsemaxLoss(nn.Module): 46 | """ 47 | An implementation of sparsemax loss, first proposed in 48 | :cite:`DBLP:journals/corr/MartinsA16`. If using 49 | a sparse output layer, it is not possible to use negative log likelihood 50 | because the loss is infinite in the case the target is assigned zero 51 | probability. Inputs to SparsemaxLoss are arbitrary dense real-valued 52 | vectors (like in nn.CrossEntropyLoss), not probability vectors (like in 53 | nn.NLLLoss). 54 | """ 55 | 56 | def __init__(self, weight=None, ignore_index=-100, 57 | reduction='elementwise_mean'): 58 | assert reduction in ['elementwise_mean', 'sum', 'none'] 59 | self.reduction = reduction 60 | self.weight = weight 61 | self.ignore_index = ignore_index 62 | super(SparsemaxLoss, self).__init__() 63 | 64 | def forward(self, input, target): 65 | loss = sparsemax_loss(input, target) 66 | if self.ignore_index >= 0: 67 | ignored_positions = target == self.ignore_index 68 | size = float((target.size(0) - ignored_positions.sum()).item()) 69 | loss.masked_fill_(ignored_positions, 0.0) 70 | else: 71 | size = float(target.size(0)) 72 | if self.reduction == 'sum': 73 | loss = loss.sum() 74 | elif self.reduction == 'elementwise_mean': 75 | loss = loss.sum() / size 76 | return loss 77 | -------------------------------------------------------------------------------- /translator/onmt/modules/conv_multi_step_attention.py: -------------------------------------------------------------------------------- 1 | """ Multi Step Attention for CNN """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | SCALE_WEIGHT = 0.5 ** 0.5 9 | 10 | 11 | def seq_linear(linear, x): 12 | """ linear transform for 3-d tensor """ 13 | batch, hidden_size, length, _ = x.size() 14 | h = linear(torch.transpose(x, 1, 2).contiguous().view( 15 | batch * length, hidden_size)) 16 | return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) 17 | 18 | 19 | class ConvMultiStepAttention(nn.Module): 20 | """ 21 | Conv attention takes a key matrix, a value matrix and a query vector. 22 | Attention weight is calculated by key matrix with the query vector 23 | and sum on the value matrix. And the same operation is applied 24 | in each decode conv layer. 25 | """ 26 | 27 | def __init__(self, input_size): 28 | super(ConvMultiStepAttention, self).__init__() 29 | self.linear_in = nn.Linear(input_size, input_size) 30 | self.mask = None 31 | 32 | def apply_mask(self, mask): 33 | """ Apply mask """ 34 | self.mask = mask 35 | 36 | def forward(self, base_target_emb, input_from_dec, encoder_out_top, 37 | encoder_out_combine): 38 | """ 39 | Args: 40 | base_target_emb: target emb tensor 41 | input_from_dec: output of decode conv 42 | encoder_out_top: the key matrix for calculation of attetion weight, 43 | which is the top output of encode conv 44 | encoder_out_combine: 45 | the value matrix for the attention-weighted sum, 46 | which is the combination of base emb and top output of encode 47 | """ 48 | 49 | # checks 50 | # batch, channel, height, width = base_target_emb.size() 51 | batch, _, height, _ = base_target_emb.size() 52 | # batch_, channel_, height_, width_ = input_from_dec.size() 53 | batch_, _, height_, _ = input_from_dec.size() 54 | aeq(batch, batch_) 55 | aeq(height, height_) 56 | 57 | # enc_batch, enc_channel, enc_height = encoder_out_top.size() 58 | enc_batch, _, enc_height = encoder_out_top.size() 59 | # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() 60 | enc_batch_, _, enc_height_ = encoder_out_combine.size() 61 | 62 | aeq(enc_batch, enc_batch_) 63 | aeq(enc_height, enc_height_) 64 | 65 | preatt = seq_linear(self.linear_in, input_from_dec) 66 | target = (base_target_emb + preatt) * SCALE_WEIGHT 67 | target = torch.squeeze(target, 3) 68 | target = torch.transpose(target, 1, 2) 69 | pre_attn = torch.bmm(target, encoder_out_top) 70 | 71 | if self.mask is not None: 72 | pre_attn.data.masked_fill_(self.mask, -float('inf')) 73 | 74 | attn = F.softmax(pre_attn, dim=2) 75 | 76 | context_output = torch.bmm( 77 | attn, torch.transpose(encoder_out_combine, 1, 2)) 78 | context_output = torch.transpose( 79 | torch.unsqueeze(context_output, 3), 1, 2) 80 | return context_output, attn 81 | -------------------------------------------------------------------------------- /translator/torchtext/data/pipeline.py: -------------------------------------------------------------------------------- 1 | class Pipeline(object): 2 | """Defines a pipeline for transforming sequence data. 3 | 4 | The input is assumed to be utf-8 encoded `str` (Python 3) or 5 | `unicode` (Python 2). 6 | 7 | Attributes: 8 | convert_token: The function to apply to input sequence data. 9 | pipes: The Pipelines that will be applied to input sequence 10 | data in order. 11 | """ 12 | def __init__(self, convert_token=None): 13 | """Create a pipeline. 14 | 15 | Arguments: 16 | convert_token: The function to apply to input sequence data. 17 | If None, the identity function is used. Default: None 18 | """ 19 | if convert_token is None: 20 | self.convert_token = Pipeline.identity 21 | elif callable(convert_token): 22 | self.convert_token = convert_token 23 | else: 24 | raise ValueError("Pipeline input convert_token {} is not None " 25 | "or callable".format(convert_token)) 26 | self.pipes = [self] 27 | 28 | def __call__(self, x, *args): 29 | """Apply the the current Pipeline(s) to an input. 30 | 31 | Arguments: 32 | x: The input to process with the Pipeline(s). 33 | Positional arguments: Forwarded to the `call` function 34 | of the Pipeline(s). 35 | """ 36 | for pipe in self.pipes: 37 | x = pipe.call(x, *args) 38 | return x 39 | 40 | def call(self, x, *args): 41 | """Apply _only_ the convert_token function of the current pipeline 42 | to the input. If the input is a list, a list with the results of 43 | applying the `convert_token` function to all input elements is 44 | returned. 45 | 46 | Arguments: 47 | x: The input to apply the convert_token function to. 48 | Positional arguments: Forwarded to the `convert_token` function 49 | of the current Pipeline. 50 | """ 51 | if isinstance(x, list): 52 | return [self.convert_token(tok, *args) for tok in x] 53 | return self.convert_token(x, *args) 54 | 55 | def add_before(self, pipeline): 56 | """Add a Pipeline to be applied before this processing pipeline. 57 | 58 | Arguments: 59 | pipeline: The Pipeline or callable to apply before this 60 | Pipeline. 61 | """ 62 | if not isinstance(pipeline, Pipeline): 63 | pipeline = Pipeline(pipeline) 64 | self.pipes = pipeline.pipes[:] + self.pipes[:] 65 | return self 66 | 67 | def add_after(self, pipeline): 68 | """Add a Pipeline to be applied after this processing pipeline. 69 | 70 | Arguments: 71 | pipeline: The Pipeline or callable to apply after this 72 | Pipeline. 73 | """ 74 | if not isinstance(pipeline, Pipeline): 75 | pipeline = Pipeline(pipeline) 76 | self.pipes = self.pipes[:] + pipeline.pipes[:] 77 | return self 78 | 79 | @staticmethod 80 | def identity(x): 81 | """Return a copy of the input. 82 | 83 | This is here for serialization compatibility with pickle. 84 | """ 85 | return x 86 | -------------------------------------------------------------------------------- /translator/torchtext/datasets/imdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import io 4 | 5 | from .. import data 6 | 7 | 8 | class IMDB(data.Dataset): 9 | 10 | urls = ['http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'] 11 | name = 'imdb' 12 | dirname = 'aclImdb' 13 | 14 | @staticmethod 15 | def sort_key(ex): 16 | return len(ex.text) 17 | 18 | def __init__(self, path, text_field, label_field, **kwargs): 19 | """Create an IMDB dataset instance given a path and fields. 20 | 21 | Arguments: 22 | path: Path to the dataset's highest level directory 23 | text_field: The field that will be used for text data. 24 | label_field: The field that will be used for label data. 25 | Remaining keyword arguments: Passed to the constructor of 26 | data.Dataset. 27 | """ 28 | fields = [('text', text_field), ('label', label_field)] 29 | examples = [] 30 | 31 | for label in ['pos', 'neg']: 32 | for fname in glob.iglob(os.path.join(path, label, '*.txt')): 33 | with io.open(fname, 'r', encoding="utf-8") as f: 34 | text = f.readline() 35 | examples.append(data.Example.fromlist([text, label], fields)) 36 | 37 | super(IMDB, self).__init__(examples, fields, **kwargs) 38 | 39 | @classmethod 40 | def splits(cls, text_field, label_field, root='.data', 41 | train='train', test='test', **kwargs): 42 | """Create dataset objects for splits of the IMDB dataset. 43 | 44 | Arguments: 45 | text_field: The field that will be used for the sentence. 46 | label_field: The field that will be used for label data. 47 | root: Root dataset storage directory. Default is '.data'. 48 | train: The directory that contains the training examples 49 | test: The directory that contains the test examples 50 | Remaining keyword arguments: Passed to the splits method of 51 | Dataset. 52 | """ 53 | return super(IMDB, cls).splits( 54 | root=root, text_field=text_field, label_field=label_field, 55 | train=train, validation=None, test=test, **kwargs) 56 | 57 | @classmethod 58 | def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs): 59 | """Create iterator objects for splits of the IMDB dataset. 60 | 61 | Arguments: 62 | batch_size: Batch_size 63 | device: Device to create batches on. Use - 1 for CPU and None for 64 | the currently active GPU device. 65 | root: The root directory that contains the imdb dataset subdirectory 66 | vectors: one of the available pretrained vectors or a list with each 67 | element one of the available pretrained vectors (see Vocab.load_vectors) 68 | 69 | Remaining keyword arguments: Passed to the splits method. 70 | """ 71 | TEXT = data.Field() 72 | LABEL = data.Field(sequential=False) 73 | 74 | train, test = cls.splits(TEXT, LABEL, root=root, **kwargs) 75 | 76 | TEXT.build_vocab(train, vectors=vectors) 77 | LABEL.build_vocab(train) 78 | 79 | return data.BucketIterator.splits( 80 | (train, test), batch_size=batch_size, device=device) 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPoC: Search-based Pseudocode to Code 2 | 3 | [[Paper](https://arxiv.org/abs/1906.04908)] [[Webpage](https://cs.stanford.edu/~sumith/spoc)] [[Codalab](https://worksheets.codalab.org/worksheets/0xd445b1bd087d46d3b84f2dcf9a8094fa)] 4 | 5 | For your convenience, we provide a docker image as well as preprocessed data along with original data. 6 | 7 | ## Dependencies 8 | 9 | * GCC: TODO 10 | * Python: 3.6.5 11 | * PyTorch: 0.4.1 12 | 13 | ``` 14 | pip install cython 15 | pip install tqdm 16 | ``` 17 | 18 | **Docker** (WIP, currently not updated) 19 | 20 | ``` 21 | docker build -t sumith1896/spoc . 22 | docker run -it sumith1896/spoc bash 23 | ``` 24 | 25 | **Code** 26 | We use a modified version of [OpenNMT-py](https://github.com/OpenNMT/OpenNMT-py) as our `translator`, bundled with this repository. 27 | 28 | ``` 29 | git clone https://github.com/Sumith1896/synlp.git 30 | cd synlp 31 | ``` 32 | 33 | ## Data 34 | 35 | Get the dataset: 36 | 37 | ``` 38 | wget https://sumith1896.github.io/spoc/data/spoc.zip 39 | unzip spoc.zip && mv spoc/ data/ && rm spoc.zip 40 | ``` 41 | 42 | For tokenization of code, we rely on Clang (7.0.1). Fetch the right version of Clang: 43 | 44 | ``` 45 | wget http://releases.llvm.org/7.0.1/clang+llvm-7.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz 46 | tar -xf clang+llvm-7.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz 47 | mkdir tokenizer/lib 48 | mv clang+llvm-7.0.1-x86_64-linux-gnu-ubuntu-16.04/lib/libclang.so.7 tokenizer/lib/libclang.so 49 | rm -rf clang+llvm-7.0.1-x86_64-linux-gnu-ubuntu-16.04 50 | rm clang+llvm-7.0.1-x86_64-linux-gnu-ubuntu-16.04.tar.xz 51 | ``` 52 | 53 | Finally, tokenize the data: 54 | 55 | ``` 56 | ./tokenizer/format-pairs.py -c ./tokenizer/lib -H -t data/train/split/spoc-train-train.tsv > data/input-tok-train.tsv 57 | ./tokenizer/format-pairs.py -c ./tokenizer/lib -H -t data/train/split/spoc-train-eval.tsv > data/input-tok-eval.tsv 58 | ./tokenizer/format-pairs.py -c ./tokenizer/lib -H -t data/train/split/spoc-train-test.tsv > data/input-tok-test.tsv 59 | ``` 60 | 61 | ## Translation 62 | 63 | Preprocessing the data (or download [this preprocessed data](https://worksheets.codalab.org/worksheets/0xd445b1bd087d46d3b84f2dcf9a8094fa) from the CodaLab platform): 64 | 65 | ``` 66 | shuf data/input-tok-train.tsv -o data/input-tok-train-shuf.tsv 67 | mkdir out 68 | ./translator/preprocess.py -config translator/config/config.preprocess.yml > preprocess.log 69 | ``` 70 | 71 | Training the translation model: 72 | 73 | ``` 74 | ./translator/train.py -config translator/config/config.train.yml > train.log 75 | ``` 76 | 77 | ``` 78 | grep 'Validation.*accuracy\|Saving' train.log 79 | ``` 80 | 81 | Evaluating translation model on test split: 82 | 83 | ``` 84 | # Separate test source (pseudocode) and target data (code) 85 | cut -f1 data/input-tok-test.tsv > data/input-tok-test-src.tsv 86 | cut -f2 data/input-tok-test.tsv > data/input-tok-test-tgt.tsv 87 | 88 | # Run inference (w/ beam search) for test 89 | ./translator/translate.py -config translator/config/config.translate.yml -model out/model_best.pt -log_file out/translate.log 90 | 91 | # Get summary statistics of result 92 | ./translator/evaluate.py out/translate.log data/input-tok-test-tgt.tsv -o out/translate.summary-nodummy -a | tee out/translate.stats 93 | 94 | # Insert back dummy lines that were not translated 95 | ./translator/recover-dummy.py data/train/split/spoc-train-test.tsv out/translate.summary-nodummy > out/translate.summary 96 | ``` 97 | 98 | ## Stitching 99 | 100 | ``` 101 | cp out/translate.summary out/spoc-train-test.summary 102 | cp ../data/train/split/spoc-train-test.tsv out/ 103 | bash stitcher/stitching.sh out/spoc-train-test 104 | 105 | ## TODO use different options 106 | 107 | ``` 108 | -------------------------------------------------------------------------------- /translator/torchtext/datasets/trec.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .. import data 4 | 5 | 6 | class TREC(data.Dataset): 7 | 8 | urls = ['http://cogcomp.org/Data/QA/QC/train_5500.label', 9 | 'http://cogcomp.org/Data/QA/QC/TREC_10.label'] 10 | name = 'trec' 11 | dirname = '' 12 | 13 | @staticmethod 14 | def sort_key(ex): 15 | return len(ex.text) 16 | 17 | def __init__(self, path, text_field, label_field, 18 | fine_grained=False, **kwargs): 19 | """Create an TREC dataset instance given a path and fields. 20 | 21 | Arguments: 22 | path: Path to the data file. 23 | text_field: The field that will be used for text data. 24 | label_field: The field that will be used for label data. 25 | fine_grained: Whether to use the fine-grained (50-class) version of TREC 26 | or the coarse grained (6-class) version. 27 | Remaining keyword arguments: Passed to the constructor of 28 | data.Dataset. 29 | """ 30 | fields = [('text', text_field), ('label', label_field)] 31 | examples = [] 32 | 33 | def get_label_str(label): 34 | return label.split(':')[0] if not fine_grained else label 35 | label_field.preprocessing = data.Pipeline(get_label_str) 36 | 37 | for line in open(os.path.expanduser(path), 'rb'): 38 | # there is one non-ASCII byte: sisterBADBYTEcity; replaced with space 39 | label, _, text = line.replace(b'\xf0', b' ').decode().partition(' ') 40 | examples.append(data.Example.fromlist([text, label], fields)) 41 | 42 | super(TREC, self).__init__(examples, fields, **kwargs) 43 | 44 | @classmethod 45 | def splits(cls, text_field, label_field, root='.data', 46 | train='train_5500.label', test='TREC_10.label', **kwargs): 47 | """Create dataset objects for splits of the TREC dataset. 48 | 49 | Arguments: 50 | text_field: The field that will be used for the sentence. 51 | label_field: The field that will be used for label data. 52 | root: Root dataset storage directory. Default is '.data'. 53 | train: The filename of the train data. Default: 'train_5500.label'. 54 | test: The filename of the test data, or None to not load the test 55 | set. Default: 'TREC_10.label'. 56 | Remaining keyword arguments: Passed to the splits method of 57 | Dataset. 58 | """ 59 | return super(TREC, cls).splits( 60 | root=root, text_field=text_field, label_field=label_field, 61 | train=train, validation=None, test=test, **kwargs) 62 | 63 | @classmethod 64 | def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs): 65 | """Create iterator objects for splits of the TREC dataset. 66 | 67 | Arguments: 68 | batch_size: Batch_size 69 | device: Device to create batches on. Use - 1 for CPU and None for 70 | the currently active GPU device. 71 | root: The root directory that contains the trec dataset subdirectory 72 | vectors: one of the available pretrained vectors or a list with each 73 | element one of the available pretrained vectors (see Vocab.load_vectors) 74 | Remaining keyword arguments: Passed to the splits method. 75 | """ 76 | TEXT = data.Field() 77 | LABEL = data.Field(sequential=False) 78 | 79 | train, test = cls.splits(TEXT, LABEL, root=root, **kwargs) 80 | 81 | TEXT.build_vocab(train, vectors=vectors) 82 | LABEL.build_vocab(train) 83 | 84 | return data.BucketIterator.splits( 85 | (train, test), batch_size=batch_size, device=device) 86 | -------------------------------------------------------------------------------- /translator/onmt/inputters/image_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import torch 6 | from torchtext.data import Field 7 | 8 | from onmt.inputters.datareader_base import DataReaderBase 9 | 10 | # domain specific dependencies 11 | try: 12 | from PIL import Image 13 | from torchvision import transforms 14 | import cv2 15 | except ImportError: 16 | Image, transforms, cv2 = None, None, None 17 | 18 | 19 | class ImageDataReader(DataReaderBase): 20 | """Read image data from disk. 21 | 22 | Args: 23 | truncate (tuple[int] or NoneType): maximum img size. Use 24 | ``(0,0)`` or ``None`` for unlimited. 25 | channel_size (int): Number of channels per image. 26 | 27 | Raises: 28 | onmt.inputters.datareader_base.MissingDependencyException: If 29 | importing any of ``PIL``, ``torchvision``, or ``cv2`` fail. 30 | """ 31 | 32 | def __init__(self, truncate=None, channel_size=3): 33 | self._check_deps() 34 | self.truncate = truncate 35 | self.channel_size = channel_size 36 | 37 | @classmethod 38 | def from_opt(cls, opt): 39 | return cls(channel_size=opt.image_channel_size) 40 | 41 | @classmethod 42 | def _check_deps(cls): 43 | if any([Image is None, transforms is None, cv2 is None]): 44 | cls._raise_missing_dep( 45 | "PIL", "torchvision", "cv2") 46 | 47 | def read(self, images, side, img_dir=None): 48 | """Read data into dicts. 49 | 50 | Args: 51 | images (str or Iterable[str]): Sequence of image paths or 52 | path to file containing audio paths. 53 | In either case, the filenames may be relative to ``src_dir`` 54 | (default behavior) or absolute. 55 | side (str): Prefix used in return dict. Usually 56 | ``"src"`` or ``"tgt"``. 57 | img_dir (str): Location of source image files. See ``images``. 58 | 59 | Yields: 60 | a dictionary containing image data, path and index for each line. 61 | """ 62 | if isinstance(images, str): 63 | images = DataReaderBase._read_file(images) 64 | 65 | for i, filename in enumerate(images): 66 | filename = filename.decode("utf-8").strip() 67 | img_path = os.path.join(img_dir, filename) 68 | if not os.path.exists(img_path): 69 | img_path = filename 70 | 71 | assert os.path.exists(img_path), \ 72 | 'img path %s not found' % filename 73 | 74 | if self.channel_size == 1: 75 | img = transforms.ToTensor()( 76 | Image.fromarray(cv2.imread(img_path, 0))) 77 | else: 78 | img = transforms.ToTensor()(Image.open(img_path)) 79 | if self.truncate and self.truncate != (0, 0): 80 | if not (img.size(1) <= self.truncate[0] 81 | and img.size(2) <= self.truncate[1]): 82 | continue 83 | yield {side: img, side + '_path': filename, 'indices': i} 84 | 85 | 86 | def img_sort_key(ex): 87 | """Sort using the size of the image: (width, height).""" 88 | return ex.src.size(2), ex.src.size(1) 89 | 90 | 91 | def batch_img(data, vocab): 92 | """Pad and batch a sequence of images.""" 93 | c = data[0].size(0) 94 | h = max([t.size(1) for t in data]) 95 | w = max([t.size(2) for t in data]) 96 | imgs = torch.zeros(len(data), c, h, w).fill_(1) 97 | for i, img in enumerate(data): 98 | imgs[i, :, 0:img.size(1), 0:img.size(2)] = img 99 | return imgs 100 | 101 | 102 | def image_fields(**kwargs): 103 | img = Field( 104 | use_vocab=False, dtype=torch.float, 105 | postprocessing=batch_img, sequential=False) 106 | return img 107 | -------------------------------------------------------------------------------- /translator/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Train models.""" 3 | import os 4 | import signal 5 | import torch 6 | 7 | import onmt.opts as opts 8 | import onmt.utils.distributed 9 | 10 | from onmt.utils.logging import logger 11 | from onmt.train_single import main as single_main 12 | from onmt.utils.parse import ArgumentParser 13 | 14 | 15 | def main(opt): 16 | ArgumentParser.validate_train_opts(opt) 17 | ArgumentParser.update_model_opts(opt) 18 | ArgumentParser.validate_model_opts(opt) 19 | 20 | nb_gpu = len(opt.gpu_ranks) 21 | 22 | if opt.world_size > 1: 23 | mp = torch.multiprocessing.get_context('spawn') 24 | # Create a thread to listen for errors in the child processes. 25 | error_queue = mp.SimpleQueue() 26 | error_handler = ErrorHandler(error_queue) 27 | # Train with multiprocessing. 28 | procs = [] 29 | for device_id in range(nb_gpu): 30 | procs.append(mp.Process(target=run, args=( 31 | opt, device_id, error_queue, ), daemon=True)) 32 | procs[device_id].start() 33 | logger.info(" Starting process pid: %d " % procs[device_id].pid) 34 | error_handler.add_child(procs[device_id].pid) 35 | for p in procs: 36 | p.join() 37 | 38 | elif nb_gpu == 1: # case 1 GPU only 39 | single_main(opt, 0) 40 | else: # case only CPU 41 | single_main(opt, -1) 42 | 43 | 44 | def run(opt, device_id, error_queue): 45 | """ run process """ 46 | try: 47 | gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) 48 | if gpu_rank != opt.gpu_ranks[device_id]: 49 | raise AssertionError("An error occurred in \ 50 | Distributed initialization") 51 | single_main(opt, device_id) 52 | except KeyboardInterrupt: 53 | pass # killed by parent, do nothing 54 | except Exception: 55 | # propagate exception to parent process, keeping original traceback 56 | import traceback 57 | error_queue.put((opt.gpu_ranks[device_id], traceback.format_exc())) 58 | 59 | 60 | class ErrorHandler(object): 61 | """A class that listens for exceptions in children processes and propagates 62 | the tracebacks to the parent process.""" 63 | 64 | def __init__(self, error_queue): 65 | """ init error handler """ 66 | import signal 67 | import threading 68 | self.error_queue = error_queue 69 | self.children_pids = [] 70 | self.error_thread = threading.Thread( 71 | target=self.error_listener, daemon=True) 72 | self.error_thread.start() 73 | signal.signal(signal.SIGUSR1, self.signal_handler) 74 | 75 | def add_child(self, pid): 76 | """ error handler """ 77 | self.children_pids.append(pid) 78 | 79 | def error_listener(self): 80 | """ error listener """ 81 | (rank, original_trace) = self.error_queue.get() 82 | self.error_queue.put((rank, original_trace)) 83 | os.kill(os.getpid(), signal.SIGUSR1) 84 | 85 | def signal_handler(self, signalnum, stackframe): 86 | """ signal handler """ 87 | for pid in self.children_pids: 88 | os.kill(pid, signal.SIGINT) # kill children processes 89 | (rank, original_trace) = self.error_queue.get() 90 | msg = """\n\n-- Tracebacks above this line can probably 91 | be ignored --\n\n""" 92 | msg += original_trace 93 | raise Exception(msg) 94 | 95 | 96 | def _get_parser(): 97 | parser = ArgumentParser(description='train.py') 98 | 99 | opts.config_opts(parser) 100 | opts.model_opts(parser) 101 | opts.train_opts(parser) 102 | return parser 103 | 104 | 105 | if __name__ == "__main__": 106 | parser = _get_parser() 107 | 108 | opt = parser.parse_args() 109 | main(opt) 110 | -------------------------------------------------------------------------------- /translator/onmt/modules/gate.py: -------------------------------------------------------------------------------- 1 | """ ContextGate module """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def context_gate_factory(gate_type, embeddings_size, decoder_size, 7 | attention_size, output_size): 8 | """Returns the correct ContextGate class""" 9 | 10 | gate_types = {'source': SourceContextGate, 11 | 'target': TargetContextGate, 12 | 'both': BothContextGate} 13 | 14 | assert gate_type in gate_types, "Not valid ContextGate type: {0}".format( 15 | gate_type) 16 | return gate_types[gate_type](embeddings_size, decoder_size, attention_size, 17 | output_size) 18 | 19 | 20 | class ContextGate(nn.Module): 21 | """ 22 | Context gate is a decoder module that takes as input the previous word 23 | embedding, the current decoder state and the attention state, and 24 | produces a gate. 25 | The gate can be used to select the input from the target side context 26 | (decoder state), from the source context (attention state) or both. 27 | """ 28 | 29 | def __init__(self, embeddings_size, decoder_size, 30 | attention_size, output_size): 31 | super(ContextGate, self).__init__() 32 | input_size = embeddings_size + decoder_size + attention_size 33 | self.gate = nn.Linear(input_size, output_size, bias=True) 34 | self.sig = nn.Sigmoid() 35 | self.source_proj = nn.Linear(attention_size, output_size) 36 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 37 | output_size) 38 | 39 | def forward(self, prev_emb, dec_state, attn_state): 40 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 41 | z = self.sig(self.gate(input_tensor)) 42 | proj_source = self.source_proj(attn_state) 43 | proj_target = self.target_proj( 44 | torch.cat((prev_emb, dec_state), dim=1)) 45 | return z, proj_source, proj_target 46 | 47 | 48 | class SourceContextGate(nn.Module): 49 | """Apply the context gate only to the source context""" 50 | 51 | def __init__(self, embeddings_size, decoder_size, 52 | attention_size, output_size): 53 | super(SourceContextGate, self).__init__() 54 | self.context_gate = ContextGate(embeddings_size, decoder_size, 55 | attention_size, output_size) 56 | self.tanh = nn.Tanh() 57 | 58 | def forward(self, prev_emb, dec_state, attn_state): 59 | z, source, target = self.context_gate( 60 | prev_emb, dec_state, attn_state) 61 | return self.tanh(target + z * source) 62 | 63 | 64 | class TargetContextGate(nn.Module): 65 | """Apply the context gate only to the target context""" 66 | 67 | def __init__(self, embeddings_size, decoder_size, 68 | attention_size, output_size): 69 | super(TargetContextGate, self).__init__() 70 | self.context_gate = ContextGate(embeddings_size, decoder_size, 71 | attention_size, output_size) 72 | self.tanh = nn.Tanh() 73 | 74 | def forward(self, prev_emb, dec_state, attn_state): 75 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 76 | return self.tanh(z * target + source) 77 | 78 | 79 | class BothContextGate(nn.Module): 80 | """Apply the context gate to both contexts""" 81 | 82 | def __init__(self, embeddings_size, decoder_size, 83 | attention_size, output_size): 84 | super(BothContextGate, self).__init__() 85 | self.context_gate = ContextGate(embeddings_size, decoder_size, 86 | attention_size, output_size) 87 | self.tanh = nn.Tanh() 88 | 89 | def forward(self, prev_emb, dec_state, attn_state): 90 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 91 | return self.tanh((1. - z) * target + z * source) 92 | -------------------------------------------------------------------------------- /translator/onmt/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """Returns the Length and Coverage Penalty function for Beam Search. 7 | 8 | Args: 9 | length_pen (str): option name of length pen 10 | cov_pen (str): option name of cov pen 11 | 12 | Attributes: 13 | has_cov_pen (bool): Whether coverage penalty is None (applying it 14 | is a no-op). Note that the converse isn't true. Setting beta 15 | to 0 should force coverage length to be a no-op. 16 | has_len_pen (bool): Whether length penalty is None (applying it 17 | is a no-op). Note that the converse isn't true. Setting alpha 18 | to 1 should force length penalty to be a no-op. 19 | coverage_penalty (callable[[FloatTensor, float], FloatTensor]): 20 | Calculates the coverage penalty. 21 | length_penalty (callable[[int, float], float]): Calculates 22 | the length penalty. 23 | """ 24 | 25 | def __init__(self, cov_pen, length_pen): 26 | self.has_cov_pen = not self._pen_is_none(cov_pen) 27 | self.coverage_penalty = self._coverage_penalty(cov_pen) 28 | self.has_len_pen = not self._pen_is_none(length_pen) 29 | self.length_penalty = self._length_penalty(length_pen) 30 | 31 | @staticmethod 32 | def _pen_is_none(pen): 33 | return pen == "none" or pen is None 34 | 35 | def _coverage_penalty(self, cov_pen): 36 | if cov_pen == "wu": 37 | return self.coverage_wu 38 | elif cov_pen == "summary": 39 | return self.coverage_summary 40 | elif self._pen_is_none(cov_pen): 41 | return self.coverage_none 42 | else: 43 | raise NotImplementedError("No '{:s}' coverage penalty.".format( 44 | cov_pen)) 45 | 46 | def _length_penalty(self, length_pen): 47 | if length_pen == "wu": 48 | return self.length_wu 49 | elif length_pen == "avg": 50 | return self.length_average 51 | elif self._pen_is_none(length_pen): 52 | return self.length_none 53 | else: 54 | raise NotImplementedError("No '{:s}' length penalty.".format( 55 | length_pen)) 56 | 57 | # Below are all the different penalty terms implemented so far. 58 | # Subtract coverage penalty from topk log probs. 59 | # Divide topk log probs by length penalty. 60 | 61 | def coverage_wu(self, cov, beta=0.): 62 | """GNMT coverage re-ranking score. 63 | 64 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 65 | ``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is 66 | probably ``batch_size x beam_size`` but could be several 67 | dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention, 68 | then the ``seq_len`` axis probably sums to (almost) 1. 69 | """ 70 | 71 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) 72 | return beta * penalty 73 | 74 | def coverage_summary(self, cov, beta=0.): 75 | """Our summary penalty.""" 76 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1) 77 | penalty -= cov.size(-1) 78 | return beta * penalty 79 | 80 | def coverage_none(self, cov, beta=0.): 81 | """Returns zero as penalty""" 82 | none = torch.zeros((1,), device=cov.device, 83 | dtype=torch.float) 84 | if cov.dim() == 3: 85 | none = none.unsqueeze(0) 86 | return none 87 | 88 | def length_wu(self, cur_len, alpha=0.): 89 | """GNMT length re-ranking score. 90 | 91 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 92 | """ 93 | 94 | return ((5 + cur_len) / 6.0) ** alpha 95 | 96 | def length_average(self, cur_len, alpha=0.): 97 | """Returns the current sequence length.""" 98 | return cur_len 99 | 100 | def length_none(self, cur_len, alpha=0.): 101 | """Returns unmodified scores.""" 102 | return 1.0 103 | -------------------------------------------------------------------------------- /translator/torchtext/datasets/sequence_tagging.py: -------------------------------------------------------------------------------- 1 | from .. import data 2 | import random 3 | 4 | 5 | class SequenceTaggingDataset(data.Dataset): 6 | """Defines a dataset for sequence tagging. Examples in this dataset 7 | contain paired lists -- paired list of words and tags. 8 | 9 | For example, in the case of part-of-speech tagging, an example is of the 10 | form 11 | [I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT] 12 | 13 | See torchtext/test/sequence_tagging.py on how to use this class. 14 | """ 15 | 16 | @staticmethod 17 | def sort_key(example): 18 | for attr in dir(example): 19 | if not callable(getattr(example, attr)) and \ 20 | not attr.startswith("__"): 21 | return len(getattr(example, attr)) 22 | return 0 23 | 24 | def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs): 25 | examples = [] 26 | columns = [] 27 | 28 | with open(path, encoding=encoding) as input_file: 29 | for line in input_file: 30 | line = line.strip() 31 | if line == "": 32 | if columns: 33 | examples.append(data.Example.fromlist(columns, fields)) 34 | columns = [] 35 | else: 36 | for i, column in enumerate(line.split(separator)): 37 | if len(columns) < i + 1: 38 | columns.append([]) 39 | columns[i].append(column) 40 | 41 | if columns: 42 | examples.append(data.Example.fromlist(columns, fields)) 43 | super(SequenceTaggingDataset, self).__init__(examples, fields, 44 | **kwargs) 45 | 46 | 47 | class UDPOS(SequenceTaggingDataset): 48 | 49 | # Universal Dependencies English Web Treebank. 50 | # Download original at http://universaldependencies.org/ 51 | # License: http://creativecommons.org/licenses/by-sa/4.0/ 52 | urls = ['https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip'] 53 | dirname = 'en-ud-v2' 54 | name = 'udpos' 55 | 56 | @classmethod 57 | def splits(cls, fields, root=".data", train="en-ud-tag.v2.train.txt", 58 | validation="en-ud-tag.v2.dev.txt", 59 | test="en-ud-tag.v2.test.txt", **kwargs): 60 | """Downloads and loads the Universal Dependencies Version 2 POS Tagged 61 | data. 62 | """ 63 | 64 | return super(UDPOS, cls).splits( 65 | fields=fields, root=root, train=train, validation=validation, 66 | test=test, **kwargs) 67 | 68 | 69 | class CoNLL2000Chunking(SequenceTaggingDataset): 70 | # CoNLL 2000 Chunking Dataset 71 | # https://www.clips.uantwerpen.be/conll2000/chunking/ 72 | urls = ['https://www.clips.uantwerpen.be/conll2000/chunking/train.txt.gz', 73 | 'https://www.clips.uantwerpen.be/conll2000/chunking/test.txt.gz'] 74 | dirname = '' 75 | name = 'conll2000' 76 | 77 | @classmethod 78 | def splits(cls, fields, root=".data", train="train.txt", 79 | test="test.txt", validation_frac=0.1, **kwargs): 80 | """Downloads and loads the CoNLL 2000 Chunking dataset. 81 | NOTE: There is only a train and test dataset so we use 82 | 10% of the train set as validation 83 | """ 84 | 85 | train, test = super(CoNLL2000Chunking, cls).splits( 86 | fields=fields, root=root, train=train, 87 | test=test, separator=' ', **kwargs) 88 | 89 | # HACK: Saving the sort key function as the split() call removes it 90 | sort_key = train.sort_key 91 | 92 | # Now split the train set 93 | # Force a random seed to make the split deterministic 94 | random.seed(0) 95 | train, val = train.split(1 - validation_frac, random_state=random.getstate()) 96 | # Reset the seed 97 | random.seed() 98 | 99 | # HACK: Set the sort key 100 | train.sort_key = sort_key 101 | val.sort_key = sort_key 102 | 103 | return train, val, test 104 | -------------------------------------------------------------------------------- /translator/torchtext/data/batch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Batch(object): 5 | """Defines a batch of examples along with its Fields. 6 | 7 | Attributes: 8 | batch_size: Number of examples in the batch. 9 | dataset: A reference to the dataset object the examples come from 10 | (which itself contains the dataset's Field objects). 11 | train: Deprecated: this attribute is left for backwards compatibility, 12 | however it is UNUSED as of the merger with pytorch 0.4. 13 | input_fields: The names of the fields that are used as input for the model 14 | target_fields: The names of the fields that are used as targets during 15 | model training 16 | 17 | Also stores the Variable for each column in the batch as an attribute. 18 | """ 19 | 20 | def __init__(self, data=None, dataset=None, device=None): 21 | """Create a Batch from a list of examples.""" 22 | if data is not None: 23 | self.batch_size = len(data) 24 | self.dataset = dataset 25 | self.fields = dataset.fields.keys() # copy field names 26 | self.input_fields = [k for k, v in dataset.fields.items() if 27 | v is not None and not v.is_target] 28 | self.target_fields = [k for k, v in dataset.fields.items() if 29 | v is not None and v.is_target] 30 | 31 | for (name, field) in dataset.fields.items(): 32 | if field is not None: 33 | batch = [getattr(x, name) for x in data] 34 | setattr(self, name, field.process(batch, device=device)) 35 | 36 | @classmethod 37 | def fromvars(cls, dataset, batch_size, train=None, **kwargs): 38 | """Create a Batch directly from a number of Variables.""" 39 | batch = cls() 40 | batch.batch_size = batch_size 41 | batch.dataset = dataset 42 | batch.fields = dataset.fields.keys() 43 | for k, v in kwargs.items(): 44 | setattr(batch, k, v) 45 | return batch 46 | 47 | def __repr__(self): 48 | return str(self) 49 | 50 | def __str__(self): 51 | if not self.__dict__: 52 | return 'Empty {} instance'.format(torch.typename(self)) 53 | 54 | fields_to_index = filter(lambda field: field is not None, self.fields) 55 | var_strs = '\n'.join(['\t[.' + name + ']' + ":" + _short_str(getattr(self, name)) 56 | for name in fields_to_index if hasattr(self, name)]) 57 | 58 | data_str = (' from {}'.format(self.dataset.name.upper()) 59 | if hasattr(self.dataset, 'name') 60 | and isinstance(self.dataset.name, str) else '') 61 | 62 | strt = '[{} of size {}{}]\n{}'.format(torch.typename(self), 63 | self.batch_size, data_str, var_strs) 64 | return '\n' + strt 65 | 66 | def __len__(self): 67 | return self.batch_size 68 | 69 | def _get_field_values(self, fields): 70 | if len(fields) == 0: 71 | return None 72 | elif len(fields) == 1: 73 | return getattr(self, fields[0]) 74 | else: 75 | return tuple(getattr(self, f) for f in fields) 76 | 77 | def __iter__(self): 78 | yield self._get_field_values(self.input_fields) 79 | yield self._get_field_values(self.target_fields) 80 | 81 | 82 | def _short_str(tensor): 83 | # unwrap variable to tensor 84 | if not torch.is_tensor(tensor): 85 | # (1) unpack variable 86 | if hasattr(tensor, 'data'): 87 | tensor = getattr(tensor, 'data') 88 | # (2) handle include_lengths 89 | elif isinstance(tensor, tuple): 90 | return str(tuple(_short_str(t) for t in tensor)) 91 | # (3) fallback to default str 92 | else: 93 | return str(tensor) 94 | 95 | # copied from torch _tensor_str 96 | size_str = 'x'.join(str(size) for size in tensor.size()) 97 | device_str = '' if not tensor.is_cuda else \ 98 | ' (GPU {})'.format(tensor.get_device()) 99 | strt = '[{} of size {}{}]'.format(torch.typename(tensor), 100 | size_str, device_str) 101 | return strt 102 | -------------------------------------------------------------------------------- /translator/onmt/models/model_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from collections import deque 6 | from onmt.utils.logging import logger 7 | 8 | from copy import deepcopy 9 | 10 | 11 | def build_model_saver(model_opt, opt, model, fields, optim): 12 | model_saver = ModelSaver(opt.save_model, 13 | model, 14 | model_opt, 15 | fields, 16 | optim, 17 | opt.keep_checkpoint) 18 | return model_saver 19 | 20 | 21 | class ModelSaverBase(object): 22 | """Base class for model saving operations 23 | 24 | Inherited classes must implement private methods: 25 | * `_save` 26 | * `_rm_checkpoint 27 | """ 28 | 29 | def __init__(self, base_path, model, model_opt, fields, optim, 30 | keep_checkpoint=-1): 31 | self.base_path = base_path 32 | self.model = model 33 | self.model_opt = model_opt 34 | self.fields = fields 35 | self.optim = optim 36 | self.last_saved_step = None 37 | self.keep_checkpoint = keep_checkpoint 38 | if keep_checkpoint > 0: 39 | self.checkpoint_queue = deque([], maxlen=keep_checkpoint) 40 | 41 | def save(self, step, moving_average=None): 42 | """Main entry point for model saver 43 | 44 | It wraps the `_save` method with checks and apply `keep_checkpoint` 45 | related logic 46 | """ 47 | 48 | if self.keep_checkpoint == 0 or step == self.last_saved_step: 49 | return 50 | 51 | if moving_average: 52 | save_model = deepcopy(self.model) 53 | for avg, param in zip(moving_average, save_model.parameters()): 54 | param.data.copy_(avg.data) 55 | else: 56 | save_model = self.model 57 | 58 | chkpt, chkpt_name = self._save(step, save_model) 59 | self.last_saved_step = step 60 | 61 | if moving_average: 62 | del save_model 63 | 64 | if self.keep_checkpoint > 0: 65 | if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: 66 | todel = self.checkpoint_queue.popleft() 67 | self._rm_checkpoint(todel) 68 | self.checkpoint_queue.append(chkpt_name) 69 | 70 | def _save(self, step): 71 | """Save a resumable checkpoint. 72 | 73 | Args: 74 | step (int): step number 75 | 76 | Returns: 77 | (object, str): 78 | 79 | * checkpoint: the saved object 80 | * checkpoint_name: name (or path) of the saved checkpoint 81 | """ 82 | 83 | raise NotImplementedError() 84 | 85 | def _rm_checkpoint(self, name): 86 | """Remove a checkpoint 87 | 88 | Args: 89 | name(str): name that indentifies the checkpoint 90 | (it may be a filepath) 91 | """ 92 | 93 | raise NotImplementedError() 94 | 95 | 96 | class ModelSaver(ModelSaverBase): 97 | """Simple model saver to filesystem""" 98 | 99 | def _save(self, step, model): 100 | real_model = (model.module 101 | if isinstance(model, nn.DataParallel) 102 | else model) 103 | real_generator = (real_model.generator.module 104 | if isinstance(real_model.generator, nn.DataParallel) 105 | else real_model.generator) 106 | 107 | model_state_dict = real_model.state_dict() 108 | model_state_dict = {k: v for k, v in model_state_dict.items() 109 | if 'generator' not in k} 110 | generator_state_dict = real_generator.state_dict() 111 | checkpoint = { 112 | 'model': model_state_dict, 113 | 'generator': generator_state_dict, 114 | 'vocab': self.fields, 115 | 'opt': self.model_opt, 116 | 'optim': self.optim.state_dict(), 117 | } 118 | 119 | logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) 120 | checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) 121 | torch.save(checkpoint, checkpoint_path) 122 | return checkpoint, checkpoint_path 123 | 124 | def _rm_checkpoint(self, name): 125 | os.remove(name) 126 | -------------------------------------------------------------------------------- /translator/onmt/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import random 5 | import inspect 6 | from itertools import islice 7 | 8 | 9 | def split_corpus(path, shard_size): 10 | with open(path, "rb") as f: 11 | if shard_size <= 0: 12 | yield f.readlines() 13 | else: 14 | while True: 15 | shard = list(islice(f, shard_size)) 16 | if not shard: 17 | break 18 | yield shard 19 | 20 | 21 | def aeq(*args): 22 | """ 23 | Assert all arguments have the same value 24 | """ 25 | arguments = (arg for arg in args) 26 | first = next(arguments) 27 | assert all(arg == first for arg in arguments), \ 28 | "Not all arguments have the same value: " + str(args) 29 | 30 | 31 | def sequence_mask(lengths, max_len=None): 32 | """ 33 | Creates a boolean mask from sequence lengths. 34 | """ 35 | batch_size = lengths.numel() 36 | max_len = max_len or lengths.max() 37 | return (torch.arange(0, max_len) 38 | .type_as(lengths) 39 | .repeat(batch_size, 1) 40 | .lt(lengths.unsqueeze(1))) 41 | 42 | 43 | def tile(x, count, dim=0): 44 | """ 45 | Tiles x on dimension dim count times. 46 | """ 47 | perm = list(range(len(x.size()))) 48 | if dim != 0: 49 | perm[0], perm[dim] = perm[dim], perm[0] 50 | x = x.permute(perm).contiguous() 51 | out_size = list(x.size()) 52 | out_size[0] *= count 53 | batch = x.size(0) 54 | x = x.view(batch, -1) \ 55 | .transpose(0, 1) \ 56 | .repeat(count, 1) \ 57 | .transpose(0, 1) \ 58 | .contiguous() \ 59 | .view(*out_size) 60 | if dim != 0: 61 | x = x.permute(perm).contiguous() 62 | return x 63 | 64 | 65 | def use_gpu(opt): 66 | """ 67 | Creates a boolean if gpu used 68 | """ 69 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 70 | (hasattr(opt, 'gpu') and opt.gpu > -1) 71 | 72 | 73 | def set_random_seed(seed, is_cuda): 74 | """Sets the random seed.""" 75 | if seed > 0: 76 | torch.manual_seed(seed) 77 | # this one is needed for torchtext random call (shuffled iterator) 78 | # in multi gpu it ensures datasets are read in the same order 79 | random.seed(seed) 80 | # some cudnn methods can be random even after fixing the seed 81 | # unless you tell it to be deterministic 82 | torch.backends.cudnn.deterministic = True 83 | 84 | if is_cuda and seed > 0: 85 | # These ensure same initialization in multi gpu mode 86 | torch.cuda.manual_seed(seed) 87 | 88 | 89 | def generate_relative_positions_matrix(length, max_relative_positions, 90 | cache=False): 91 | """Generate the clipped relative positions matrix 92 | for a given length and maximum relative positions""" 93 | if cache: 94 | distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0) 95 | else: 96 | range_vec = torch.arange(length) 97 | range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) 98 | distance_mat = range_mat - range_mat.transpose(0, 1) 99 | distance_mat_clipped = torch.clamp(distance_mat, 100 | min=-max_relative_positions, 101 | max=max_relative_positions) 102 | # Shift values to be >= 0 103 | final_mat = distance_mat_clipped + max_relative_positions 104 | return final_mat 105 | 106 | 107 | def relative_matmul(x, z, transpose): 108 | """Helper function for relative positions attention.""" 109 | batch_size = x.shape[0] 110 | heads = x.shape[1] 111 | length = x.shape[2] 112 | x_t = x.permute(2, 0, 1, 3) 113 | x_t_r = x_t.reshape(length, heads * batch_size, -1) 114 | if transpose: 115 | z_t = z.transpose(1, 2) 116 | x_tz_matmul = torch.matmul(x_t_r, z_t) 117 | else: 118 | x_tz_matmul = torch.matmul(x_t_r, z) 119 | x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1) 120 | x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) 121 | return x_tz_matmul_r_t 122 | 123 | 124 | def fn_args(fun): 125 | """Returns the list of function arguments name.""" 126 | return inspect.getfullargspec(fun).args 127 | -------------------------------------------------------------------------------- /translator/onmt/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | import torch.distributed 12 | 13 | from onmt.utils.logging import logger 14 | 15 | 16 | def is_master(opt, device_id): 17 | return opt.gpu_ranks[device_id] == 0 18 | 19 | 20 | def multi_init(opt, device_id): 21 | dist_init_method = 'tcp://{master_ip}:{master_port}'.format( 22 | master_ip=opt.master_ip, 23 | master_port=opt.master_port) 24 | dist_world_size = opt.world_size 25 | torch.distributed.init_process_group( 26 | backend=opt.gpu_backend, init_method=dist_init_method, 27 | world_size=dist_world_size, rank=opt.gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(opt, device_id): 30 | logger.disabled = True 31 | 32 | return gpu_rank 33 | 34 | 35 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 36 | buffer_size=10485760): 37 | """All-reduce and rescale tensors in chunks of the specified size. 38 | 39 | Args: 40 | tensors: list of Tensors to all-reduce 41 | rescale_denom: denominator for rescaling summed Tensors 42 | buffer_size: all-reduce chunk size in bytes 43 | """ 44 | # buffer size in bytes, determine equiv. # of elements based on data type 45 | buffer_t = tensors[0].new( 46 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 47 | buffer = [] 48 | 49 | def all_reduce_buffer(): 50 | # copy tensors into buffer_t 51 | offset = 0 52 | for t in buffer: 53 | numel = t.numel() 54 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 55 | offset += numel 56 | 57 | # all-reduce and rescale 58 | torch.distributed.all_reduce(buffer_t[:offset]) 59 | buffer_t.div_(rescale_denom) 60 | 61 | # copy all-reduced buffer back into tensors 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 66 | offset += numel 67 | 68 | filled = 0 69 | for t in tensors: 70 | sz = t.numel() * t.element_size() 71 | if sz > buffer_size: 72 | # tensor is bigger than buffer, all-reduce and rescale directly 73 | torch.distributed.all_reduce(t) 74 | t.div_(rescale_denom) 75 | elif filled + sz > buffer_size: 76 | # buffer is full, all-reduce and replace buffer with grad 77 | all_reduce_buffer() 78 | buffer = [t] 79 | filled = sz 80 | else: 81 | # add tensor to buffer 82 | buffer.append(t) 83 | filled += sz 84 | 85 | if len(buffer) > 0: 86 | all_reduce_buffer() 87 | 88 | 89 | def all_gather_list(data, max_size=4096): 90 | """Gathers arbitrary data from all nodes into a list.""" 91 | world_size = torch.distributed.get_world_size() 92 | if not hasattr(all_gather_list, '_in_buffer') or \ 93 | max_size != all_gather_list._in_buffer.size(): 94 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 95 | all_gather_list._out_buffers = [ 96 | torch.cuda.ByteTensor(max_size) 97 | for i in range(world_size) 98 | ] 99 | in_buffer = all_gather_list._in_buffer 100 | out_buffers = all_gather_list._out_buffers 101 | 102 | enc = pickle.dumps(data) 103 | enc_size = len(enc) 104 | if enc_size + 2 > max_size: 105 | raise ValueError( 106 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 107 | assert max_size < 255*256 108 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 109 | in_buffer[1] = enc_size % 255 110 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 111 | 112 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 113 | 114 | results = [] 115 | for i in range(world_size): 116 | out_buffer = out_buffers[i] 117 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 118 | 119 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 120 | result = pickle.loads(bytes_list) 121 | results.append(result) 122 | return results 123 | -------------------------------------------------------------------------------- /translator/onmt/modules/average_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Average Attention module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules.position_ffn import PositionwiseFeedForward 8 | 9 | 10 | class AverageAttention(nn.Module): 11 | """ 12 | Average Attention module from 13 | "Accelerating Neural Transformer via an Average Attention Network" 14 | :cite:`DBLP:journals/corr/abs-1805-00631`. 15 | 16 | Args: 17 | model_dim (int): the dimension of keys/values/queries, 18 | must be divisible by head_count 19 | dropout (float): dropout parameter 20 | """ 21 | 22 | def __init__(self, model_dim, dropout=0.1): 23 | self.model_dim = model_dim 24 | 25 | super(AverageAttention, self).__init__() 26 | 27 | self.average_layer = PositionwiseFeedForward(model_dim, model_dim, 28 | dropout) 29 | self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) 30 | 31 | def cumulative_average_mask(self, batch_size, inputs_len): 32 | """ 33 | Builds the mask to compute the cumulative average as described in 34 | :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3 35 | 36 | Args: 37 | batch_size (int): batch size 38 | inputs_len (int): length of the inputs 39 | 40 | Returns: 41 | (FloatTensor): 42 | 43 | * A Tensor of shape ``(batch_size, input_len, input_len)`` 44 | """ 45 | 46 | triangle = torch.tril(torch.ones(inputs_len, inputs_len)) 47 | weights = torch.ones(1, inputs_len) / torch.arange( 48 | 1, inputs_len + 1, dtype=torch.float) 49 | mask = triangle * weights.transpose(0, 1) 50 | 51 | return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) 52 | 53 | def cumulative_average(self, inputs, mask_or_step, 54 | layer_cache=None, step=None): 55 | """ 56 | Computes the cumulative average as described in 57 | :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6) 58 | 59 | Args: 60 | inputs (FloatTensor): sequence to average 61 | ``(batch_size, input_len, dimension)`` 62 | mask_or_step: if cache is set, this is assumed 63 | to be the current step of the 64 | dynamic decoding. Otherwise, it is the mask matrix 65 | used to compute the cumulative average. 66 | layer_cache: a dictionary containing the cumulative average 67 | of the previous step. 68 | 69 | Returns: 70 | a tensor of the same shape and type as ``inputs``. 71 | """ 72 | 73 | if layer_cache is not None: 74 | step = mask_or_step 75 | device = inputs.device 76 | average_attention = (inputs + step * 77 | layer_cache["prev_g"].to(device)) / (step + 1) 78 | layer_cache["prev_g"] = average_attention 79 | return average_attention 80 | else: 81 | mask = mask_or_step 82 | return torch.matmul(mask, inputs) 83 | 84 | def forward(self, inputs, mask=None, layer_cache=None, step=None): 85 | """ 86 | Args: 87 | inputs (FloatTensor): ``(batch_size, input_len, model_dim)`` 88 | 89 | Returns: 90 | (FloatTensor, FloatTensor): 91 | 92 | * gating_outputs ``(batch_size, input_len, model_dim)`` 93 | * average_outputs average attention 94 | ``(batch_size, input_len, model_dim)`` 95 | """ 96 | 97 | batch_size = inputs.size(0) 98 | inputs_len = inputs.size(1) 99 | 100 | device = inputs.device 101 | average_outputs = self.cumulative_average( 102 | inputs, self.cumulative_average_mask(batch_size, 103 | inputs_len).to(device).float() 104 | if layer_cache is None else step, layer_cache=layer_cache) 105 | average_outputs = self.average_layer(average_outputs) 106 | gating_outputs = self.gating_layer(torch.cat((inputs, 107 | average_outputs), -1)) 108 | input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) 109 | gating_outputs = torch.sigmoid(input_gate) * inputs + \ 110 | torch.sigmoid(forget_gate) * average_outputs 111 | 112 | return gating_outputs, average_outputs 113 | -------------------------------------------------------------------------------- /translator/onmt/train_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Training on a single process.""" 3 | import os 4 | 5 | import torch 6 | 7 | import onmt.inputters as inputters 8 | from onmt.inputters.inputter import build_dataset_iter, \ 9 | load_old_vocab, old_style_vocab 10 | from onmt.model_builder import build_model 11 | from onmt.utils.optimizers import Optimizer 12 | from onmt.utils.misc import set_random_seed 13 | from onmt.trainer import build_trainer 14 | from onmt.models import build_model_saver 15 | from onmt.utils.logging import init_logger, logger 16 | from onmt.utils.parse import ArgumentParser 17 | 18 | 19 | def _check_save_model_path(opt): 20 | save_model_path = os.path.abspath(opt.save_model) 21 | model_dirname = os.path.dirname(save_model_path) 22 | if not os.path.exists(model_dirname): 23 | os.makedirs(model_dirname) 24 | 25 | 26 | def _tally_parameters(model): 27 | enc = 0 28 | dec = 0 29 | for name, param in model.named_parameters(): 30 | if 'encoder' in name: 31 | enc += param.nelement() 32 | else: 33 | dec += param.nelement() 34 | return enc + dec, enc, dec 35 | 36 | 37 | def configure_process(opt, device_id): 38 | if device_id >= 0: 39 | torch.cuda.set_device(device_id) 40 | set_random_seed(opt.seed, device_id >= 0) 41 | 42 | 43 | def main(opt, device_id): 44 | # NOTE: It's important that ``opt`` has been validated and updated 45 | # at this point. 46 | configure_process(opt, device_id) 47 | init_logger(opt.log_file) 48 | # Load checkpoint if we resume from a previous training. 49 | if opt.train_from: 50 | logger.info('Loading checkpoint from %s' % opt.train_from) 51 | checkpoint = torch.load(opt.train_from, 52 | map_location=lambda storage, loc: storage) 53 | 54 | model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) 55 | ArgumentParser.update_model_opts(model_opt) 56 | ArgumentParser.validate_model_opts(model_opt) 57 | logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) 58 | vocab = checkpoint['vocab'] 59 | else: 60 | checkpoint = None 61 | model_opt = opt 62 | vocab = torch.load(opt.data + '.vocab.pt') 63 | 64 | # check for code where vocab is saved instead of fields 65 | # (in the future this will be done in a smarter way) 66 | if old_style_vocab(vocab): 67 | fields = load_old_vocab( 68 | vocab, opt.model_type, dynamic_dict=opt.copy_attn) 69 | else: 70 | fields = vocab 71 | 72 | sort_key = inputters.str2sortkey[opt.data_type] 73 | 74 | # Report src and tgt vocab sizes, including for features 75 | for side in ['src', 'tgt']: 76 | f = fields[side] 77 | try: 78 | f_iter = iter(f) 79 | except TypeError: 80 | f_iter = [(side, f)] 81 | for sn, sf in f_iter: 82 | if sf.use_vocab: 83 | logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) 84 | 85 | # Build model. 86 | model = build_model(model_opt, opt, fields, checkpoint) 87 | n_params, enc, dec = _tally_parameters(model) 88 | logger.info('encoder: %d' % enc) 89 | logger.info('decoder: %d' % dec) 90 | logger.info('* number of parameters: %d' % n_params) 91 | _check_save_model_path(opt) 92 | 93 | # Build optimizer. 94 | optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) 95 | 96 | # Build model saver 97 | model_saver = build_model_saver(model_opt, opt, model, fields, optim) 98 | 99 | trainer = build_trainer( 100 | opt, device_id, model, fields, optim, model_saver=model_saver) 101 | 102 | train_iter = build_dataset_iter("train", fields, sort_key, opt) 103 | valid_iter = build_dataset_iter( 104 | "valid", fields, sort_key, opt, is_train=False) 105 | 106 | if len(opt.gpu_ranks): 107 | logger.info('Starting training on GPU: %s' % opt.gpu_ranks) 108 | else: 109 | logger.info('Starting training on CPU, could be very slow') 110 | train_steps = opt.train_steps 111 | if opt.single_pass and train_steps > 0: 112 | logger.warning("Option single_pass is enabled, ignoring train_steps.") 113 | train_steps = 0 114 | trainer.train( 115 | train_iter, 116 | train_steps, 117 | save_checkpoint_steps=opt.save_checkpoint_steps, 118 | valid_iter=valid_iter, 119 | valid_steps=opt.valid_steps) 120 | 121 | if opt.tensorboard: 122 | trainer.report_manager.tensorboard_writer.close() 123 | -------------------------------------------------------------------------------- /translator/onmt/encoders/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | """Define RNN-based encoders.""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | 8 | from onmt.encoders.encoder import EncoderBase 9 | from onmt.utils.rnn_factory import rnn_factory 10 | 11 | 12 | class RNNEncoder(EncoderBase): 13 | """ A generic recurrent neural network encoder. 14 | 15 | Args: 16 | rnn_type (str): 17 | style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] 18 | bidirectional (bool) : use a bidirectional RNN 19 | num_layers (int) : number of stacked layers 20 | hidden_size (int) : hidden size of each layer 21 | dropout (float) : dropout value for :class:`torch.nn.Dropout` 22 | embeddings (onmt.modules.Embeddings): embedding module to use 23 | """ 24 | 25 | def __init__(self, rnn_type, bidirectional, num_layers, 26 | hidden_size, dropout=0.0, embeddings=None, 27 | use_bridge=False): 28 | super(RNNEncoder, self).__init__() 29 | assert embeddings is not None 30 | 31 | num_directions = 2 if bidirectional else 1 32 | assert hidden_size % num_directions == 0 33 | hidden_size = hidden_size // num_directions 34 | self.embeddings = embeddings 35 | 36 | self.rnn, self.no_pack_padded_seq = \ 37 | rnn_factory(rnn_type, 38 | input_size=embeddings.embedding_size, 39 | hidden_size=hidden_size, 40 | num_layers=num_layers, 41 | dropout=dropout, 42 | bidirectional=bidirectional) 43 | 44 | # Initialize the bridge layer 45 | self.use_bridge = use_bridge 46 | if self.use_bridge: 47 | self._initialize_bridge(rnn_type, 48 | hidden_size, 49 | num_layers) 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings): 53 | """Alternate constructor.""" 54 | return cls( 55 | opt.rnn_type, 56 | opt.brnn, 57 | opt.enc_layers, 58 | opt.enc_rnn_size, 59 | opt.dropout, 60 | embeddings, 61 | opt.bridge) 62 | 63 | def forward(self, src, lengths=None): 64 | """See :func:`EncoderBase.forward()`""" 65 | self._check_args(src, lengths) 66 | 67 | emb = self.embeddings(src) 68 | # s_len, batch, emb_dim = emb.size() 69 | 70 | packed_emb = emb 71 | if lengths is not None and not self.no_pack_padded_seq: 72 | # Lengths data is wrapped inside a Tensor. 73 | lengths_list = lengths.view(-1).tolist() 74 | packed_emb = pack(emb, lengths_list) 75 | 76 | memory_bank, encoder_final = self.rnn(packed_emb) 77 | 78 | if lengths is not None and not self.no_pack_padded_seq: 79 | memory_bank = unpack(memory_bank)[0] 80 | 81 | if self.use_bridge: 82 | encoder_final = self._bridge(encoder_final) 83 | return encoder_final, memory_bank, lengths 84 | 85 | def _initialize_bridge(self, rnn_type, 86 | hidden_size, 87 | num_layers): 88 | 89 | # LSTM has hidden and cell state, other only one 90 | number_of_states = 2 if rnn_type == "LSTM" else 1 91 | # Total number of states 92 | self.total_hidden_dim = hidden_size * num_layers 93 | 94 | # Build a linear layer for each 95 | self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim, 96 | self.total_hidden_dim, 97 | bias=True) 98 | for _ in range(number_of_states)]) 99 | 100 | def _bridge(self, hidden): 101 | """Forward hidden state through bridge.""" 102 | def bottle_hidden(linear, states): 103 | """ 104 | Transform from 3D to 2D, apply linear and return initial size 105 | """ 106 | size = states.size() 107 | result = linear(states.view(-1, self.total_hidden_dim)) 108 | return F.relu(result).view(size) 109 | 110 | if isinstance(hidden, tuple): # LSTM 111 | outs = tuple([bottle_hidden(layer, hidden[ix]) 112 | for ix, layer in enumerate(self.bridge)]) 113 | else: 114 | outs = bottle_hidden(self.bridge[0], hidden) 115 | return outs 116 | -------------------------------------------------------------------------------- /translator/onmt/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | from onmt.encoders.encoder import EncoderBase 8 | from onmt.modules import MultiHeadedAttention 9 | from onmt.modules.position_ffn import PositionwiseFeedForward 10 | 11 | 12 | class TransformerEncoderLayer(nn.Module): 13 | """ 14 | A single layer of the transformer encoder. 15 | 16 | Args: 17 | d_model (int): the dimension of keys/values/queries in 18 | MultiHeadedAttention, also the input size of 19 | the first-layer of the PositionwiseFeedForward. 20 | heads (int): the number of head for MultiHeadedAttention. 21 | d_ff (int): the second-layer of the PositionwiseFeedForward. 22 | dropout (float): dropout probability(0-1.0). 23 | """ 24 | 25 | def __init__(self, d_model, heads, d_ff, dropout, 26 | max_relative_positions=0): 27 | super(TransformerEncoderLayer, self).__init__() 28 | 29 | self.self_attn = MultiHeadedAttention( 30 | heads, d_model, dropout=dropout, 31 | max_relative_positions=max_relative_positions) 32 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 33 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 34 | self.dropout = nn.Dropout(dropout) 35 | 36 | def forward(self, inputs, mask): 37 | """ 38 | Args: 39 | inputs (FloatTensor): ``(batch_size, src_len, model_dim)`` 40 | mask (LongTensor): ``(batch_size, src_len, src_len)`` 41 | 42 | Returns: 43 | (FloatTensor): 44 | 45 | * outputs ``(batch_size, src_len, model_dim)`` 46 | """ 47 | input_norm = self.layer_norm(inputs) 48 | context, _ = self.self_attn(input_norm, input_norm, input_norm, 49 | mask=mask, type="self") 50 | out = self.dropout(context) + inputs 51 | return self.feed_forward(out) 52 | 53 | 54 | class TransformerEncoder(EncoderBase): 55 | """The Transformer encoder from "Attention is All You Need" 56 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` 57 | 58 | .. mermaid:: 59 | 60 | graph BT 61 | A[input] 62 | B[multi-head self-attn] 63 | C[feed forward] 64 | O[output] 65 | A --> B 66 | B --> C 67 | C --> O 68 | 69 | Args: 70 | num_layers (int): number of encoder layers 71 | d_model (int): size of the model 72 | heads (int): number of heads 73 | d_ff (int): size of the inner FF layer 74 | dropout (float): dropout parameters 75 | embeddings (onmt.modules.Embeddings): 76 | embeddings to use, should have positional encodings 77 | 78 | Returns: 79 | (torch.FloatTensor, torch.FloatTensor): 80 | 81 | * embeddings ``(src_len, batch_size, model_dim)`` 82 | * memory_bank ``(src_len, batch_size, model_dim)`` 83 | """ 84 | 85 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings, 86 | max_relative_positions): 87 | super(TransformerEncoder, self).__init__() 88 | 89 | self.embeddings = embeddings 90 | self.transformer = nn.ModuleList( 91 | [TransformerEncoderLayer( 92 | d_model, heads, d_ff, dropout, 93 | max_relative_positions=max_relative_positions) 94 | for i in range(num_layers)]) 95 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 96 | 97 | @classmethod 98 | def from_opt(cls, opt, embeddings): 99 | """Alternate constructor.""" 100 | return cls( 101 | opt.enc_layers, 102 | opt.enc_rnn_size, 103 | opt.heads, 104 | opt.transformer_ff, 105 | opt.dropout, 106 | embeddings, 107 | opt.max_relative_positions) 108 | 109 | def forward(self, src, lengths=None): 110 | """See :func:`EncoderBase.forward()`""" 111 | self._check_args(src, lengths) 112 | 113 | emb = self.embeddings(src) 114 | 115 | out = emb.transpose(0, 1).contiguous() 116 | words = src[:, :, 0].transpose(0, 1) 117 | w_batch, w_len = words.size() 118 | padding_idx = self.embeddings.word_padding_idx 119 | mask = words.data.eq(padding_idx).unsqueeze(1) # [B, 1, T] 120 | # Run the forward pass of every layer of the tranformer. 121 | for layer in self.transformer: 122 | out = layer(out, mask) 123 | out = self.layer_norm(out) 124 | 125 | return emb, out.transpose(0, 1).contiguous(), lengths 126 | -------------------------------------------------------------------------------- /scripts/split-on-field.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Split the TSV file into train/eval/test such that: 5 | - The amount of data follows the specified ratio 6 | - Each group of examples with the same specified field goes together. 7 | - If `test_field_values_file` is specified, read field values from the file, 8 | and put examples with those field values in the test set. 9 | """ 10 | 11 | import sys, os, shutil, re, argparse, json, random 12 | from collections import defaultdict 13 | 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-v', '--verbose', action='store_true') 19 | parser.add_argument('-s', '--seed', type=int, default=42) 20 | parser.add_argument('-c', '--sort-by-count', action='store_true', 21 | help='Instead of shuffling, sort the groups by example counts.' 22 | ' Groups with few examples will go to the test set') 23 | parser.add_argument('-t', '--train-ratio', type=float, default=.9, 24 | help='Fraction of training data within train + dev data' 25 | ' (i.e., test data already excluded)') 26 | group = parser.add_mutually_exclusive_group() 27 | group.add_argument('-r', '--test-ratio', type=float, default=.1, 28 | help='Fraction of test data within all data') 29 | group.add_argument('-f', '--test-field-values-file', 30 | help='Read field values from this file,' 31 | ' and put examples with those field values in the test set.') 32 | parser.add_argument('infile', help='TSV file') 33 | parser.add_argument('field', help='Field for grouping examples') 34 | args = parser.parse_args() 35 | 36 | # field value -> (raw line string, key-value dict) 37 | data = defaultdict(list) 38 | n = 0 39 | 40 | with open(args.infile) as fin: 41 | header_line = fin.readline() 42 | header = header_line.rstrip('\n').split('\t') 43 | for i, line in enumerate(fin): 44 | n += 1 45 | kv = dict(zip(header, line.rstrip('\n').split('\t'))) 46 | data[kv[args.field]].append((line, kv)) 47 | 48 | print('Read {} lines in {} groups'.format(n, len(data))) 49 | 50 | # Split data 51 | keys = list(data.keys()) 52 | if args.sort_by_count: 53 | keys.sort(key=lambda x: -len(data[x])) 54 | else: 55 | random.seed(args.seed) 56 | random.shuffle(keys) 57 | print('Num examples in each group:', [len(data[x]) for x in keys]) 58 | 59 | train_data = [] 60 | eval_data = [] 61 | test_data = [] 62 | 63 | if args.test_field_values_file: 64 | # Put test examples in test_data first 65 | with open(args.test_field_values_file) as fin: 66 | test_keys = [x.strip() for x in fin] 67 | for key in test_keys: 68 | if key not in data: 69 | print('WARNING: test key {} not in raw data'.format(key)) 70 | else: 71 | test_data.extend(data[key]) 72 | else: 73 | # Put a certain number of examples in test_data 74 | # Prioritize small groups 75 | key_iter = reversed(keys) 76 | test_keys = [] 77 | while len(test_data) < n * args.test_raio: 78 | key = next(key_iter) 79 | test_data.extend(data[key]) 80 | test_keys.append(key) 81 | 82 | # Divide the rest 83 | keys = [x for x in keys if x not in set(test_keys)] 84 | print('Remaining:', [len(data[x]) for x in keys]) 85 | key_iter = iter(keys) 86 | while len(train_data) < (n - len(test_data)) * args.train_ratio: 87 | train_data.extend(data[next(key_iter)]) 88 | for key in key_iter: 89 | eval_data.extend(data[key]) 90 | 91 | print('Examples: {} train / {} eval / {} test'.format( 92 | len(train_data), len(eval_data), len(test_data))) 93 | if not train_data or not eval_data or not test_data: 94 | print('WARNING: some split has 0 examples!!!') 95 | 96 | # Print statistics 97 | if args.verbose: 98 | for other_field in header: 99 | print('Unique {}: {} train / {} eval / {} test'.format( 100 | other_field, 101 | len(set(x[other_field] for _, x in train_data)), 102 | len(set(x[other_field] for _, x in eval_data)), 103 | len(set(x[other_field] for _, x in test_data)), 104 | )) 105 | 106 | # Dump to files 107 | prefix = re.sub(r'\.tsv$', '', args.infile) 108 | for suffix, data in ( 109 | ('-train.tsv', train_data), 110 | ('-eval.tsv', eval_data), 111 | ('-test.tsv', test_data), 112 | ): 113 | with open(prefix + suffix, 'w') as fout: 114 | fout.write(header_line) 115 | for line, _ in data: 116 | fout.write(line) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | 122 | -------------------------------------------------------------------------------- /translator/torchtext/datasets/sst.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .. import data 4 | 5 | 6 | class SST(data.Dataset): 7 | 8 | urls = ['http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'] 9 | dirname = 'trees' 10 | name = 'sst' 11 | 12 | @staticmethod 13 | def sort_key(ex): 14 | return len(ex.text) 15 | 16 | def __init__(self, path, text_field, label_field, subtrees=False, 17 | fine_grained=False, **kwargs): 18 | """Create an SST dataset instance given a path and fields. 19 | 20 | Arguments: 21 | path: Path to the data file 22 | text_field: The field that will be used for text data. 23 | label_field: The field that will be used for label data. 24 | subtrees: Whether to include sentiment-tagged subphrases 25 | in addition to complete examples. Default: False. 26 | fine_grained: Whether to use 5-class instead of 3-class 27 | labeling. Default: False. 28 | Remaining keyword arguments: Passed to the constructor of 29 | data.Dataset. 30 | """ 31 | fields = [('text', text_field), ('label', label_field)] 32 | 33 | def get_label_str(label): 34 | pre = 'very ' if fine_grained else '' 35 | return {'0': pre + 'negative', '1': 'negative', '2': 'neutral', 36 | '3': 'positive', '4': pre + 'positive', None: None}[label] 37 | label_field.preprocessing = data.Pipeline(get_label_str) 38 | with open(os.path.expanduser(path)) as f: 39 | if subtrees: 40 | examples = [ex for line in f for ex in 41 | data.Example.fromtree(line, fields, True)] 42 | else: 43 | examples = [data.Example.fromtree(line, fields) for line in f] 44 | super(SST, self).__init__(examples, fields, **kwargs) 45 | 46 | @classmethod 47 | def splits(cls, text_field, label_field, root='.data', 48 | train='train.txt', validation='dev.txt', test='test.txt', 49 | train_subtrees=False, **kwargs): 50 | """Create dataset objects for splits of the SST dataset. 51 | 52 | Arguments: 53 | text_field: The field that will be used for the sentence. 54 | label_field: The field that will be used for label data. 55 | root: The root directory that the dataset's zip archive will be 56 | expanded into; therefore the directory in whose trees 57 | subdirectory the data files will be stored. 58 | train: The filename of the train data. Default: 'train.txt'. 59 | validation: The filename of the validation data, or None to not 60 | load the validation set. Default: 'dev.txt'. 61 | test: The filename of the test data, or None to not load the test 62 | set. Default: 'test.txt'. 63 | train_subtrees: Whether to use all subtrees in the training set. 64 | Default: False. 65 | Remaining keyword arguments: Passed to the splits method of 66 | Dataset. 67 | """ 68 | path = cls.download(root) 69 | 70 | train_data = None if train is None else cls( 71 | os.path.join(path, train), text_field, label_field, subtrees=train_subtrees, 72 | **kwargs) 73 | val_data = None if validation is None else cls( 74 | os.path.join(path, validation), text_field, label_field, **kwargs) 75 | test_data = None if test is None else cls( 76 | os.path.join(path, test), text_field, label_field, **kwargs) 77 | return tuple(d for d in (train_data, val_data, test_data) 78 | if d is not None) 79 | 80 | @classmethod 81 | def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs): 82 | """Create iterator objects for splits of the SST dataset. 83 | 84 | Arguments: 85 | batch_size: Batch_size 86 | device: Device to create batches on. Use - 1 for CPU and None for 87 | the currently active GPU device. 88 | root: The root directory that the dataset's zip archive will be 89 | expanded into; therefore the directory in whose trees 90 | subdirectory the data files will be stored. 91 | vectors: one of the available pretrained vectors or a list with each 92 | element one of the available pretrained vectors (see Vocab.load_vectors) 93 | Remaining keyword arguments: Passed to the splits method. 94 | """ 95 | TEXT = data.Field() 96 | LABEL = data.Field(sequential=False) 97 | 98 | train, val, test = cls.splits(TEXT, LABEL, root=root, **kwargs) 99 | 100 | TEXT.build_vocab(train, vectors=vectors) 101 | LABEL.build_vocab(train) 102 | 103 | return data.BucketIterator.splits( 104 | (train, val, test), batch_size=batch_size, device=device) 105 | -------------------------------------------------------------------------------- /translator/onmt/encoders/image_encoder.py: -------------------------------------------------------------------------------- 1 | """Image Encoder.""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | 8 | 9 | class ImageEncoder(EncoderBase): 10 | """A simple encoder CNN -> RNN for image src. 11 | 12 | Args: 13 | num_layers (int): number of encoder layers. 14 | bidirectional (bool): bidirectional encoder. 15 | rnn_size (int): size of hidden states of the rnn. 16 | dropout (float): dropout probablity. 17 | """ 18 | 19 | def __init__(self, num_layers, bidirectional, rnn_size, dropout, 20 | image_chanel_size=3): 21 | super(ImageEncoder, self).__init__() 22 | self.num_layers = num_layers 23 | self.num_directions = 2 if bidirectional else 1 24 | self.hidden_size = rnn_size 25 | 26 | self.layer1 = nn.Conv2d(image_chanel_size, 64, kernel_size=(3, 3), 27 | padding=(1, 1), stride=(1, 1)) 28 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3), 29 | padding=(1, 1), stride=(1, 1)) 30 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3), 31 | padding=(1, 1), stride=(1, 1)) 32 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3), 33 | padding=(1, 1), stride=(1, 1)) 34 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3), 35 | padding=(1, 1), stride=(1, 1)) 36 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3), 37 | padding=(1, 1), stride=(1, 1)) 38 | 39 | self.batch_norm1 = nn.BatchNorm2d(256) 40 | self.batch_norm2 = nn.BatchNorm2d(512) 41 | self.batch_norm3 = nn.BatchNorm2d(512) 42 | 43 | src_size = 512 44 | self.rnn = nn.LSTM(src_size, int(rnn_size / self.num_directions), 45 | num_layers=num_layers, 46 | dropout=dropout, 47 | bidirectional=bidirectional) 48 | self.pos_lut = nn.Embedding(1000, src_size) 49 | 50 | @classmethod 51 | def from_opt(cls, opt, embeddings=None): 52 | """Alternate constructor.""" 53 | if embeddings is not None: 54 | raise ValueError("Cannot use embeddings with ImageEncoder.") 55 | # why is the model_opt.__dict__ check necessary? 56 | if "image_channel_size" not in opt.__dict__: 57 | image_channel_size = 3 58 | else: 59 | image_channel_size = opt.image_channel_size 60 | return cls( 61 | opt.enc_layers, 62 | opt.brnn, 63 | opt.enc_rnn_size, 64 | opt.dropout, 65 | image_channel_size 66 | ) 67 | 68 | def load_pretrained_vectors(self, opt): 69 | """Pass in needed options only when modify function definition.""" 70 | pass 71 | 72 | def forward(self, src, lengths=None): 73 | """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" 74 | 75 | batch_size = src.size(0) 76 | # (batch_size, 64, imgH, imgW) 77 | # layer 1 78 | src = F.relu(self.layer1(src[:, :, :, :] - 0.5), True) 79 | 80 | # (batch_size, 64, imgH/2, imgW/2) 81 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 82 | 83 | # (batch_size, 128, imgH/2, imgW/2) 84 | # layer 2 85 | src = F.relu(self.layer2(src), True) 86 | 87 | # (batch_size, 128, imgH/2/2, imgW/2/2) 88 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 89 | 90 | # (batch_size, 256, imgH/2/2, imgW/2/2) 91 | # layer 3 92 | # batch norm 1 93 | src = F.relu(self.batch_norm1(self.layer3(src)), True) 94 | 95 | # (batch_size, 256, imgH/2/2, imgW/2/2) 96 | # layer4 97 | src = F.relu(self.layer4(src), True) 98 | 99 | # (batch_size, 256, imgH/2/2/2, imgW/2/2) 100 | src = F.max_pool2d(src, kernel_size=(1, 2), stride=(1, 2)) 101 | 102 | # (batch_size, 512, imgH/2/2/2, imgW/2/2) 103 | # layer 5 104 | # batch norm 2 105 | src = F.relu(self.batch_norm2(self.layer5(src)), True) 106 | 107 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 108 | src = F.max_pool2d(src, kernel_size=(2, 1), stride=(2, 1)) 109 | 110 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 111 | src = F.relu(self.batch_norm3(self.layer6(src)), True) 112 | 113 | # # (batch_size, 512, H, W) 114 | all_outputs = [] 115 | for row in range(src.size(2)): 116 | inp = src[:, :, row, :].transpose(0, 2) \ 117 | .transpose(1, 2) 118 | row_vec = torch.Tensor(batch_size).type_as(inp.data) \ 119 | .long().fill_(row) 120 | pos_emb = self.pos_lut(row_vec) 121 | with_pos = torch.cat( 122 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0) 123 | outputs, hidden_t = self.rnn(with_pos) 124 | all_outputs.append(outputs) 125 | out = torch.cat(all_outputs, 0) 126 | 127 | return hidden_t, out, lengths 128 | -------------------------------------------------------------------------------- /translator/onmt/utils/parse.py: -------------------------------------------------------------------------------- 1 | import configargparse as cfargparse 2 | import os 3 | 4 | import torch 5 | 6 | import onmt.opts as opts 7 | from onmt.utils.logging import logger 8 | 9 | 10 | class ArgumentParser(cfargparse.ArgumentParser): 11 | def __init__( 12 | self, 13 | config_file_parser_class=cfargparse.YAMLConfigFileParser, 14 | formatter_class=cfargparse.ArgumentDefaultsHelpFormatter, 15 | **kwargs): 16 | super(ArgumentParser, self).__init__( 17 | config_file_parser_class=config_file_parser_class, 18 | formatter_class=formatter_class, 19 | **kwargs) 20 | 21 | @classmethod 22 | def defaults(cls, *args): 23 | """Get default arguments added to a parser by all ``*args``.""" 24 | dummy_parser = cls() 25 | for callback in args: 26 | callback(dummy_parser) 27 | defaults = dummy_parser.parse_known_args([])[0] 28 | return defaults 29 | 30 | @classmethod 31 | def update_model_opts(cls, model_opt): 32 | if model_opt.word_vec_size > 0: 33 | model_opt.src_word_vec_size = model_opt.word_vec_size 34 | model_opt.tgt_word_vec_size = model_opt.word_vec_size 35 | 36 | if model_opt.layers > 0: 37 | model_opt.enc_layers = model_opt.layers 38 | model_opt.dec_layers = model_opt.layers 39 | 40 | if model_opt.rnn_size > 0: 41 | model_opt.enc_rnn_size = model_opt.rnn_size 42 | model_opt.dec_rnn_size = model_opt.rnn_size 43 | 44 | model_opt.brnn = model_opt.encoder_type == "brnn" 45 | 46 | if model_opt.copy_attn_type is None: 47 | model_opt.copy_attn_type = model_opt.global_attention 48 | 49 | @classmethod 50 | def validate_model_opts(cls, model_opt): 51 | assert model_opt.model_type in ["text", "img", "audio"], \ 52 | "Unsupported model type %s" % model_opt.model_type 53 | 54 | # this check is here because audio allows the encoder and decoder to 55 | # be different sizes, but other model types do not yet 56 | same_size = model_opt.enc_rnn_size == model_opt.dec_rnn_size 57 | assert model_opt.model_type == 'audio' or same_size, \ 58 | "The encoder and decoder rnns must be the same size for now" 59 | 60 | assert model_opt.rnn_type != "SRU" or model_opt.gpu_ranks, \ 61 | "Using SRU requires -gpu_ranks set." 62 | if model_opt.share_embeddings: 63 | if model_opt.model_type != "text": 64 | raise AssertionError( 65 | "--share_embeddings requires --model_type text.") 66 | if model_opt.model_dtype == "fp16": 67 | logger.warning( 68 | "FP16 is experimental, the generated checkpoints may " 69 | "be incompatible with a future version") 70 | 71 | @classmethod 72 | def ckpt_model_opts(cls, ckpt_opt): 73 | # Load default opt values, then overwrite with the opts in 74 | # the checkpoint. That way, if there are new options added, 75 | # the defaults are used. 76 | opt = cls.defaults(opts.model_opts) 77 | opt.__dict__.update(ckpt_opt.__dict__) 78 | return opt 79 | 80 | @classmethod 81 | def validate_train_opts(cls, opt): 82 | if opt.epochs: 83 | raise AssertionError( 84 | "-epochs is deprecated please use -train_steps.") 85 | if opt.truncated_decoder > 0 and opt.accum_count > 1: 86 | raise AssertionError("BPTT is not compatible with -accum > 1") 87 | if opt.gpuid: 88 | raise AssertionError("gpuid is deprecated \ 89 | see world_size and gpu_ranks") 90 | if torch.cuda.is_available() and not opt.gpu_ranks: 91 | logger.info("WARNING: You have a CUDA device, \ 92 | should run with -gpu_ranks") 93 | 94 | @classmethod 95 | def validate_translate_opts(cls, opt): 96 | if opt.beam_size != 1 and opt.random_sampling_topk != 1: 97 | raise ValueError('Can either do beam search OR random sampling.') 98 | 99 | @classmethod 100 | def validate_preprocess_args(cls, opt): 101 | assert opt.max_shard_size == 0, \ 102 | "-max_shard_size is deprecated. Please use \ 103 | -shard_size (number of examples) instead." 104 | assert opt.shuffle == 0, \ 105 | "-shuffle is not implemented. Please shuffle \ 106 | your data before pre-processing." 107 | 108 | ################ MODIFIED ################ 109 | assert os.path.isfile(opt.train_src), \ 110 | "Please check path of your train src file!" 111 | assert not opt.train_tgt or opt.train_tgt == '-' or os.path.isfile(opt.train_tgt), \ 112 | "Please check path of your train tgt file!" 113 | ########################################## 114 | assert not opt.valid_src or os.path.isfile(opt.valid_src), \ 115 | "Please check path of your valid src file!" 116 | assert not opt.valid_tgt or opt.valid_tgt == '-' or os.path.isfile(opt.valid_tgt), \ 117 | "Please check path of your valid tgt file!" 118 | -------------------------------------------------------------------------------- /translator/onmt/decoders/cnn_decoder.py: -------------------------------------------------------------------------------- 1 | """Implementation of the CNN Decoder part of 2 | "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules import ConvMultiStepAttention, GlobalAttention 8 | from onmt.utils.cnn_factory import shape_transform, GatedConv 9 | from onmt.decoders.decoder import DecoderBase 10 | 11 | SCALE_WEIGHT = 0.5 ** 0.5 12 | 13 | 14 | class CNNDecoder(DecoderBase): 15 | """Decoder based on "Convolutional Sequence to Sequence Learning" 16 | :cite:`DBLP:journals/corr/GehringAGYD17`. 17 | 18 | Consists of residual convolutional layers, with ConvMultiStepAttention. 19 | """ 20 | 21 | def __init__(self, num_layers, hidden_size, attn_type, 22 | copy_attn, cnn_kernel_width, dropout, embeddings, 23 | copy_attn_type): 24 | super(CNNDecoder, self).__init__() 25 | 26 | self.cnn_kernel_width = cnn_kernel_width 27 | self.embeddings = embeddings 28 | 29 | # Decoder State 30 | self.state = {} 31 | 32 | input_size = self.embeddings.embedding_size 33 | self.linear = nn.Linear(input_size, hidden_size) 34 | self.conv_layers = nn.ModuleList( 35 | [GatedConv(hidden_size, cnn_kernel_width, dropout, True) 36 | for i in range(num_layers)] 37 | ) 38 | self.attn_layers = nn.ModuleList( 39 | [ConvMultiStepAttention(hidden_size) for i in range(num_layers)] 40 | ) 41 | 42 | # CNNDecoder has its own attention mechanism. 43 | # Set up a separate copy attention layer if needed. 44 | assert not copy_attn, "Copy mechanism not yet tested in conv2conv" 45 | if copy_attn: 46 | self.copy_attn = GlobalAttention( 47 | hidden_size, attn_type=copy_attn_type) 48 | else: 49 | self.copy_attn = None 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings): 53 | """Alternate constructor.""" 54 | return cls( 55 | opt.dec_layers, 56 | opt.dec_rnn_size, 57 | opt.global_attention, 58 | opt.copy_attn, 59 | opt.cnn_kernel_width, 60 | opt.dropout, 61 | embeddings, 62 | opt.copy_attn_type) 63 | 64 | def init_state(self, _, memory_bank, enc_hidden): 65 | """Init decoder state.""" 66 | self.state["src"] = (memory_bank + enc_hidden) * SCALE_WEIGHT 67 | self.state["previous_input"] = None 68 | 69 | def map_state(self, fn): 70 | self.state["src"] = fn(self.state["src"], 1) 71 | if self.state["previous_input"] is not None: 72 | self.state["previous_input"] = fn(self.state["previous_input"], 1) 73 | 74 | def detach_state(self): 75 | self.state["previous_input"] = self.state["previous_input"].detach() 76 | 77 | def forward(self, tgt, memory_bank, step=None, **kwargs): 78 | """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" 79 | 80 | if self.state["previous_input"] is not None: 81 | tgt = torch.cat([self.state["previous_input"], tgt], 0) 82 | 83 | dec_outs = [] 84 | attns = {"std": []} 85 | if self.copy_attn is not None: 86 | attns["copy"] = [] 87 | 88 | emb = self.embeddings(tgt) 89 | assert emb.dim() == 3 # len x batch x embedding_dim 90 | 91 | tgt_emb = emb.transpose(0, 1).contiguous() 92 | # The output of CNNEncoder. 93 | src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() 94 | # The combination of output of CNNEncoder and source embeddings. 95 | src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() 96 | 97 | emb_reshape = tgt_emb.contiguous().view( 98 | tgt_emb.size(0) * tgt_emb.size(1), -1) 99 | linear_out = self.linear(emb_reshape) 100 | x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) 101 | x = shape_transform(x) 102 | 103 | pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) 104 | 105 | pad = pad.type_as(x) 106 | base_target_emb = x 107 | 108 | for conv, attention in zip(self.conv_layers, self.attn_layers): 109 | new_target_input = torch.cat([pad, x], 2) 110 | out = conv(new_target_input) 111 | c, attn = attention(base_target_emb, out, 112 | src_memory_bank_t, src_memory_bank_c) 113 | x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT 114 | output = x.squeeze(3).transpose(1, 2) 115 | 116 | # Process the result and update the attentions. 117 | dec_outs = output.transpose(0, 1).contiguous() 118 | if self.state["previous_input"] is not None: 119 | dec_outs = dec_outs[self.state["previous_input"].size(0):] 120 | attn = attn[:, self.state["previous_input"].size(0):].squeeze() 121 | attn = torch.stack([attn]) 122 | attns["std"] = attn 123 | if self.copy_attn is not None: 124 | attns["copy"] = attn 125 | 126 | # Update the state. 127 | self.state["previous_input"] = tgt 128 | # TODO change the way attns is returned dict => list or tuple (onnx) 129 | return dec_outs, attns 130 | -------------------------------------------------------------------------------- /translator/torchtext/data/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | from contextlib import contextmanager 3 | from copy import deepcopy 4 | 5 | from functools import partial 6 | 7 | 8 | def _split_tokenizer(x): 9 | return x.split() 10 | 11 | 12 | def _spacy_tokenize(x, spacy): 13 | return [tok.text for tok in spacy.tokenizer(x)] 14 | 15 | 16 | def get_tokenizer(tokenizer, language='en'): 17 | # default tokenizer is string.split(), added as a module function for serialization 18 | if tokenizer is None: 19 | return _split_tokenizer 20 | 21 | # simply return if a function is passed 22 | if callable(tokenizer): 23 | return tokenizer 24 | 25 | if tokenizer == "spacy": 26 | try: 27 | import spacy 28 | spacy = spacy.load(language) 29 | return partial(_spacy_tokenize, spacy=spacy) 30 | except ImportError: 31 | print("Please install SpaCy. " 32 | "See the docs at https://spacy.io for more information.") 33 | raise 34 | except AttributeError: 35 | print("Please install SpaCy and the SpaCy {} tokenizer. " 36 | "See the docs at https://spacy.io for more " 37 | "information.".format(language)) 38 | raise 39 | elif tokenizer == "moses": 40 | try: 41 | from sacremoses import MosesTokenizer 42 | moses_tokenizer = MosesTokenizer() 43 | return moses_tokenizer.tokenize 44 | except ImportError: 45 | print("Please install SacreMoses. " 46 | "See the docs at https://github.com/alvations/sacremoses " 47 | "for more information.") 48 | raise 49 | elif tokenizer == "toktok": 50 | try: 51 | from nltk.tokenize.toktok import ToktokTokenizer 52 | toktok = ToktokTokenizer() 53 | return toktok.tokenize 54 | except ImportError: 55 | print("Please install NLTK. " 56 | "See the docs at https://nltk.org for more information.") 57 | raise 58 | elif tokenizer == 'revtok': 59 | try: 60 | import revtok 61 | return revtok.tokenize 62 | except ImportError: 63 | print("Please install revtok.") 64 | raise 65 | elif tokenizer == 'subword': 66 | try: 67 | import revtok 68 | return partial(revtok.tokenize, decap=True) 69 | except ImportError: 70 | print("Please install revtok.") 71 | raise 72 | raise ValueError("Requested tokenizer {}, valid choices are a " 73 | "callable that takes a single string as input, " 74 | "\"revtok\" for the revtok reversible tokenizer, " 75 | "\"subword\" for the revtok caps-aware tokenizer, " 76 | "\"spacy\" for the SpaCy English tokenizer, or " 77 | "\"moses\" for the NLTK port of the Moses tokenization " 78 | "script.".format(tokenizer)) 79 | 80 | 81 | def is_tokenizer_serializable(tokenizer, language): 82 | """Extend with other tokenizers which are found to not be serializable 83 | """ 84 | if tokenizer == 'spacy': 85 | return False 86 | return True 87 | 88 | 89 | def interleave_keys(a, b): 90 | """Interleave bits from two sort keys to form a joint sort key. 91 | 92 | Examples that are similar in both of the provided keys will have similar 93 | values for the key defined by this function. Useful for tasks with two 94 | text fields like machine translation or natural language inference. 95 | """ 96 | def interleave(args): 97 | return ''.join([x for t in zip(*args) for x in t]) 98 | return int(''.join(interleave(format(x, '016b') for x in (a, b))), base=2) 99 | 100 | 101 | def get_torch_version(): 102 | import torch 103 | v = torch.__version__ 104 | version_substrings = v.split('.') 105 | major, minor = version_substrings[0], version_substrings[1] 106 | return int(major), int(minor) 107 | 108 | 109 | def dtype_to_attr(dtype): 110 | # convert torch.dtype to dtype string id 111 | # e.g. torch.int32 -> "int32" 112 | # used for serialization 113 | _, dtype = str(dtype).split('.') 114 | return dtype 115 | 116 | 117 | class RandomShuffler(object): 118 | """Use random functions while keeping track of the random state to make it 119 | reproducible and deterministic.""" 120 | 121 | def __init__(self, random_state=None): 122 | self._random_state = random_state 123 | if self._random_state is None: 124 | self._random_state = random.getstate() 125 | 126 | @contextmanager 127 | def use_internal_state(self): 128 | """Use a specific RNG state.""" 129 | old_state = random.getstate() 130 | random.setstate(self._random_state) 131 | yield 132 | self._random_state = random.getstate() 133 | random.setstate(old_state) 134 | 135 | @property 136 | def random_state(self): 137 | return deepcopy(self._random_state) 138 | 139 | @random_state.setter 140 | def random_state(self, s): 141 | self._random_state = s 142 | 143 | def __call__(self, data): 144 | """Shuffle and return a new list.""" 145 | with self.use_internal_state(): 146 | return random.sample(data, len(data)) 147 | -------------------------------------------------------------------------------- /translator/onmt/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | import time 4 | import math 5 | import sys 6 | 7 | from onmt.utils.logging import logger 8 | 9 | ################ MODIFIED ################ 10 | 11 | class Statistics(object): 12 | """ 13 | Accumulator for loss statistics. 14 | Currently calculates: 15 | 16 | * accuracy 17 | * perplexity 18 | * elapsed time 19 | """ 20 | 21 | def __init__(self, loss=0, n_words=0, n_correct=0, n_seqs=0, n_seqs_correct=0): 22 | self.loss = loss 23 | self.n_words = n_words 24 | self.n_correct = n_correct 25 | self.n_seqs = n_seqs 26 | self.n_seqs_correct = n_seqs_correct 27 | self.n_src_words = 0 28 | self.start_time = time.time() 29 | 30 | @staticmethod 31 | def all_gather_stats(stat, max_size=4096): 32 | """ 33 | Gather a `Statistics` object accross multiple process/nodes 34 | 35 | Args: 36 | stat(:obj:Statistics): the statistics object to gather 37 | accross all processes/nodes 38 | max_size(int): max buffer size to use 39 | 40 | Returns: 41 | `Statistics`, the update stats object 42 | """ 43 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 44 | return stats[0] 45 | 46 | @staticmethod 47 | def all_gather_stats_list(stat_list, max_size=4096): 48 | """ 49 | Gather a `Statistics` list accross all processes/nodes 50 | 51 | Args: 52 | stat_list(list([`Statistics`])): list of statistics objects to 53 | gather accross all processes/nodes 54 | max_size(int): max buffer size to use 55 | 56 | Returns: 57 | our_stats(list([`Statistics`])): list of updated stats 58 | """ 59 | from torch.distributed import get_rank 60 | from onmt.utils.distributed import all_gather_list 61 | 62 | # Get a list of world_size lists with len(stat_list) Statistics objects 63 | all_stats = all_gather_list(stat_list, max_size=max_size) 64 | 65 | our_rank = get_rank() 66 | our_stats = all_stats[our_rank] 67 | for other_rank, stats in enumerate(all_stats): 68 | if other_rank == our_rank: 69 | continue 70 | for i, stat in enumerate(stats): 71 | our_stats[i].update(stat, update_n_src_words=True) 72 | return our_stats 73 | 74 | def update(self, stat, update_n_src_words=False): 75 | """ 76 | Update statistics by suming values with another `Statistics` object 77 | 78 | Args: 79 | stat: another statistic object 80 | update_n_src_words(bool): whether to update (sum) `n_src_words` 81 | or not 82 | 83 | """ 84 | self.loss += stat.loss 85 | self.n_words += stat.n_words 86 | self.n_correct += stat.n_correct 87 | self.n_seqs += stat.n_seqs 88 | self.n_seqs_correct += stat.n_seqs_correct 89 | 90 | if update_n_src_words: 91 | self.n_src_words += stat.n_src_words 92 | 93 | def accuracy(self): 94 | """ compute accuracy """ 95 | return 100 * (self.n_correct / self.n_words) 96 | 97 | def seq_accuracy(self): 98 | """ sequence-level accuracy """ 99 | return 100 * (self.n_seqs_correct / self.n_seqs) 100 | 101 | def xent(self): 102 | """ compute cross entropy """ 103 | return self.loss / self.n_words 104 | 105 | def ppl(self): 106 | """ compute perplexity """ 107 | return math.exp(min(self.loss / self.n_words, 100)) 108 | 109 | def elapsed_time(self): 110 | """ compute elapsed time """ 111 | return time.time() - self.start_time 112 | 113 | def output(self, step, num_steps, learning_rate, start): 114 | """Write out statistics to stdout. 115 | 116 | Args: 117 | step (int): current step 118 | n_batch (int): total batches 119 | start (int): start time of step. 120 | """ 121 | t = self.elapsed_time() 122 | step_fmt = "%2d" % step 123 | if num_steps > 0: 124 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 125 | logger.info( 126 | ("Step %s; acc: %6.2f; seqacc: %6.2f; " + 127 | "ppl: %5.2f; xent: %4.2f; " + 128 | "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") 129 | % (step_fmt, 130 | self.accuracy(), 131 | self.seq_accuracy(), 132 | self.ppl(), 133 | self.xent(), 134 | learning_rate, 135 | self.n_src_words / (t + 1e-5), 136 | self.n_words / (t + 1e-5), 137 | time.time() - start)) 138 | sys.stdout.flush() 139 | 140 | def log_tensorboard(self, prefix, writer, learning_rate, step): 141 | """ display statistics to tensorboard """ 142 | t = self.elapsed_time() 143 | writer.add_scalar(prefix + "/xent", self.xent(), step) 144 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 145 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 146 | writer.add_scalar(prefix + "/seq_acc", self.seq_accuracy(), step) 147 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 148 | writer.add_scalar(prefix + "/lr", learning_rate, step) 149 | 150 | ######################################### 151 | -------------------------------------------------------------------------------- /tokenizer/format-pairs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Convert the (text, code) pairs from mturk-to-pairs.py into various formats. 5 | 6 | The default is to convert into templates. For instance: 7 | set x to 7 x = 7; 8 | will become 9 | set $1 to $2 x = $2 ; 10 | """ 11 | 12 | import sys, os, shutil, re, argparse, json 13 | 14 | 15 | ################################################ 16 | # Basic tokenization 17 | 18 | TEXT_TOKENIZER = re.compile(r'\w+|[^\w\s]', re.UNICODE) 19 | 20 | 21 | def tokenize_text(text): 22 | return TEXT_TOKENIZER.findall(text) 23 | 24 | 25 | ################################################ 26 | # Clang interface 27 | 28 | def setup_clang(clang_path): 29 | from clang.cindex import Config 30 | Config.set_library_path(clang_path) 31 | 32 | def fix_char_string_tok(tokens): 33 | res_tokens = [] 34 | if tokens and tokens[0] == "}": 35 | tokens = tokens[1:] 36 | if tokens and tokens[-1] == "{": 37 | tokens = tokens[:-1] 38 | for token in tokens: 39 | if token[0] == "\"" and token[-1] == "\"": 40 | res_tokens.append("\"") 41 | res_tokens.append(token[1:-1]) 42 | res_tokens.append("\"") 43 | elif token[0] == "\'" and token[-1] == "\'": 44 | res_tokens.append("\'") 45 | res_tokens.append(token[1:-1]) 46 | res_tokens.append("\'") 47 | else: 48 | res_tokens.append(token) 49 | return res_tokens 50 | 51 | def tokenize_code(code): 52 | from clang.cindex import Index 53 | index = Index.create() 54 | tu = index.parse('tmp.cpp', args=['-std=c++11'], unsaved_files=[('tmp.cpp', code)]) 55 | tokens = [token.spelling for token in tu.get_tokens(extent=tu.cursor.extent)] 56 | tokens = fix_char_string_tok(tokens) 57 | return tokens 58 | 59 | ################################################ 60 | # Extract templates 61 | 62 | VARNAMES = re.compile(r'[A-Za-z]\w*', re.UNICODE) 63 | NUMBERS = re.compile(r'\d+', re.UNICODE) 64 | RESERVED = { 65 | 'alignas', 'alignof', 'and', 'and_eq', 'asm', 'atomic_cancel', 66 | 'atomic_commit', 'atomic_noexcept', 'auto', 'bitand', 'bitor', 'bool', 67 | 'break', 'case', 'catch', 'char', 'char16_t', 'char32_t', 'char8_t', 68 | 'class', 'co_await', 'co_return', 'co_yield', 'compl', 'concept', 'const', 69 | 'const_cast', 'consteval', 'constexpr', 'continue', 'decltype', 'default', 70 | 'delete', 'do', 'double', 'dynamic_cast', 'else', 'enum', 'explicit', 71 | 'export', 'extern', 'false', 'float', 'for', 'friend', 'goto', 'if', 72 | 'import', 'inline', 'int', 'long', 'module', 'mutable', 'namespace', 'new', 73 | 'noexcept', 'not', 'not_eq', 'nullptr', 'operator', 'or', 'or_eq', 74 | 'private', 'protected', 'public', 'reflexpr', 'register', 'reinterpret_cast', 75 | 'requires', 'return', 'short', 'signed', 'sizeof', 'static', 'static_assert', 76 | 'static_cast', 'struct', 'switch', 'synchronized', 'template', 'this', 77 | 'thread_local', 'throw', 'true', 'try', 'typedef', 'typeid', 'typename', 78 | 'union', 'unsigned', 'using', 'virtual', 'void', 'volatile', 'wchar_t', 79 | 'while', 'xor', 'xor_eq', 80 | } 81 | 82 | 83 | def can_placehold(token): 84 | return ( 85 | (VARNAMES.match(token) or NUMBERS.match(token)) 86 | and token not in RESERVED 87 | ) 88 | 89 | 90 | def match(text_tokens, code_tokens): 91 | text_tokens = text_tokens[:] 92 | code_tokens = code_tokens[:] 93 | placeholder = 1 94 | for i, x in enumerate(text_tokens): 95 | if x in code_tokens and can_placehold(x): 96 | # Replace all occurrences of x in code_tokens and text_tokens 97 | for j, y in enumerate(code_tokens): 98 | if y == x: 99 | code_tokens[j] = '${}'.format(placeholder) 100 | for j, y in enumerate(text_tokens): 101 | if y == x: 102 | text_tokens[j] = '${}'.format(placeholder) 103 | placeholder += 1 104 | return text_tokens, code_tokens 105 | 106 | 107 | ################################################ 108 | # Main 109 | 110 | def main(): 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('-B', '--no-blank', action='store_true', 113 | help='Replace blank utterance with something else') 114 | parser.add_argument('-c', '--clang', 115 | help='Use clang from this location to tokenize code') 116 | parser.add_argument('-H', '--has-header', action='store_true', 117 | help='Input is a TSV file with a header') 118 | parser.add_argument('-t', '--tokenize-only', action='store_true', 119 | help='Do not replace matching tokens with placeholders.') 120 | parser.add_argument('infile', help='TSV files with text and code columns') 121 | args = parser.parse_args() 122 | 123 | if args.clang: 124 | setup_clang(args.clang) 125 | 126 | with open(args.infile) as fin: 127 | if args.has_header: 128 | header = fin.readline().rstrip('\n').split('\t') 129 | 130 | for line in fin: 131 | if header: 132 | data = dict(zip(header, line.rstrip('\n').split('\t'))) 133 | text, code = data['text'], data['code'] 134 | if not text: 135 | if args.no_blank: 136 | text = 'DUMMY' 137 | else: 138 | continue 139 | else: 140 | text, code = line.rstrip('\n').split('\t') 141 | 142 | text_tokens = tokenize_text(text) 143 | if args.clang: 144 | code_tokens = tokenize_code(code) 145 | else: 146 | code_tokens = tokenize_text(code) 147 | if not args.tokenize_only: 148 | text_tokens, code_tokens = match(text_tokens, code_tokens) 149 | 150 | print('{}\t{}'.format(' '.join(text_tokens), ' '.join(code_tokens))) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /translator/onmt/translate/decode_strategy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DecodeStrategy(object): 5 | """Base class for generation strategies. 6 | 7 | Args: 8 | pad (int): Magic integer in output vocab. 9 | bos (int): Magic integer in output vocab. 10 | eos (int): Magic integer in output vocab. 11 | batch_size (int): Current batch size. 12 | device (torch.device or str): Device for memory bank (encoder). 13 | parallel_paths (int): Decoding strategies like beam search 14 | use parallel paths. Each batch is repeated ``parallel_paths`` 15 | times in relevant state tensors. 16 | min_length (int): Shortest acceptable generation, not counting 17 | begin-of-sentence or end-of-sentence. 18 | max_length (int): Longest acceptable sequence, not counting 19 | begin-of-sentence (presumably there has been no EOS 20 | yet if max_length is used as a cutoff). 21 | block_ngram_repeat (int): Block beams where 22 | ``block_ngram_repeat``-grams repeat. 23 | exclusion_tokens (set[int]): If a gram contains any of these 24 | tokens, it may repeat. 25 | return_attention (bool): Whether to work with attention too. If this 26 | is true, it is assumed that the decoder is attentional. 27 | 28 | Attributes: 29 | pad (int): See above. 30 | bos (int): See above. 31 | eos (int): See above. 32 | predictions (list[list[LongTensor]]): For each batch, holds a 33 | list of beam prediction sequences. 34 | scores (list[list[FloatTensor]]): For each batch, holds a 35 | list of scores. 36 | attention (list[list[FloatTensor or list[]]]): For each 37 | batch, holds a list of attention sequence tensors 38 | (or empty lists) having shape ``(step, inp_seq_len)`` where 39 | ``inp_seq_len`` is the length of the sample (not the max 40 | length of all inp seqs). 41 | alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``. 42 | This sequence grows in the ``step`` axis on each call to 43 | :func:`advance()`. 44 | is_finished (ByteTensor or NoneType): Shape 45 | ``(B, parallel_paths)``. Initialized to ``None``. 46 | alive_attn (FloatTensor or NoneType): If tensor, shape is 47 | ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len`` 48 | is the (max) length of the input sequence. 49 | min_length (int): See above. 50 | max_length (int): See above. 51 | block_ngram_repeat (int): See above. 52 | exclusion_tokens (set[int]): See above. 53 | return_attention (bool): See above. 54 | done (bool): See above. 55 | """ 56 | 57 | def __init__(self, pad, bos, eos, batch_size, device, parallel_paths, 58 | min_length, block_ngram_repeat, exclusion_tokens, 59 | return_attention, max_length): 60 | 61 | # magic indices 62 | self.pad = pad 63 | self.bos = bos 64 | self.eos = eos 65 | 66 | # result caching 67 | self.predictions = [[] for _ in range(batch_size)] 68 | self.scores = [[] for _ in range(batch_size)] 69 | self.attention = [[] for _ in range(batch_size)] 70 | 71 | self.alive_seq = torch.full( 72 | [batch_size * parallel_paths, 1], self.bos, 73 | dtype=torch.long, device=device) 74 | self.is_finished = torch.zeros( 75 | [batch_size, parallel_paths], 76 | dtype=torch.uint8, device=device) 77 | self.alive_attn = None 78 | 79 | self.min_length = min_length 80 | self.max_length = max_length 81 | self.block_ngram_repeat = block_ngram_repeat 82 | self.exclusion_tokens = exclusion_tokens 83 | self.return_attention = return_attention 84 | 85 | self.done = False 86 | 87 | def __len__(self): 88 | return self.alive_seq.shape[1] 89 | 90 | def ensure_min_length(self, log_probs): 91 | if len(self) <= self.min_length: 92 | log_probs[:, self.eos] = -1e20 93 | 94 | def ensure_max_length(self): 95 | # add one to account for BOS. Don't account for EOS because hitting 96 | # this implies it hasn't been found. 97 | if len(self) == self.max_length + 1: 98 | self.is_finished.fill_(1) 99 | 100 | def block_ngram_repeats(self, log_probs): 101 | cur_len = len(self) 102 | if self.block_ngram_repeat > 0 and cur_len > 1: 103 | for path_idx in range(self.alive_seq.shape[0]): 104 | # skip BOS 105 | hyp = self.alive_seq[path_idx, 1:] 106 | ngrams = set() 107 | fail = False 108 | gram = [] 109 | for i in range(cur_len - 1): 110 | # Last n tokens, n = block_ngram_repeat 111 | gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:] 112 | # skip the blocking if any token in gram is excluded 113 | if set(gram) & self.exclusion_tokens: 114 | continue 115 | if tuple(gram) in ngrams: 116 | fail = True 117 | ngrams.add(tuple(gram)) 118 | if fail: 119 | log_probs[path_idx] = -10e20 120 | 121 | def advance(self, log_probs, attn): 122 | """DecodeStrategy subclasses should override :func:`advance()`. 123 | 124 | Advance is used to update ``self.alive_seq``, ``self.is_finished``, 125 | and, when appropriate, ``self.alive_attn``. 126 | """ 127 | 128 | raise NotImplementedError() 129 | 130 | def update_finished(self): 131 | """DecodeStrategy subclasses should override :func:`update_finished()`. 132 | 133 | ``update_finished`` is used to update ``self.predictions``, 134 | ``self.scores``, and other "output" attributes. 135 | """ 136 | 137 | raise NotImplementedError() 138 | -------------------------------------------------------------------------------- /translator/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Pre-process Data / features files and build vocabulary 5 | """ 6 | import codecs 7 | import glob 8 | import sys 9 | import gc 10 | import torch 11 | from functools import partial 12 | 13 | from onmt.utils.logging import init_logger, logger 14 | from onmt.utils.misc import split_corpus 15 | import onmt.inputters as inputters 16 | import onmt.opts as opts 17 | from onmt.utils.parse import ArgumentParser 18 | 19 | 20 | def check_existing_pt_files(opt): 21 | """ Check if there are existing .pt files to avoid overwriting them """ 22 | for ext in ('pt', 'gz'): 23 | pattern = opt.save_data + '.{}*.' + ext 24 | for t in ['train', 'valid', 'vocab']: 25 | path = pattern.format(t) 26 | if glob.glob(path): 27 | sys.stderr.write("Please backup existing file: %s, " 28 | "to avoid overwriting them!\n" % path) 29 | sys.exit(1) 30 | 31 | 32 | ################ MODIFIED ################ 33 | def get_shard_pairs(src, tgt, shard_size): 34 | if not tgt or tgt == '-': 35 | for pair_shard in split_corpus(src, shard_size): 36 | pairs = [pair.split(b'\t') for pair in pair_shard] 37 | yield zip(*pairs) 38 | else: 39 | src_shards = split_corpus(src, shard_size) 40 | tgt_shards = split_corpus(tgt, shard_size) 41 | return zip(src_shards, tgt_shards) 42 | ########################################## 43 | 44 | 45 | def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): 46 | assert corpus_type in ['train', 'valid'] 47 | 48 | if corpus_type == 'train': 49 | src = opt.train_src 50 | tgt = opt.train_tgt 51 | else: 52 | src = opt.valid_src 53 | tgt = opt.valid_tgt 54 | 55 | logger.info("Reading source and target files: %s %s." % (src, tgt)) 56 | 57 | ################ MODIFIED ################ 58 | shard_pairs = get_shard_pairs(src, tgt, opt.shard_size) 59 | dataset_paths = [] 60 | filter_pred = None # Avoid dropping examples 61 | ########################################## 62 | for i, (src_shard, tgt_shard) in enumerate(shard_pairs): 63 | assert len(src_shard) == len(tgt_shard) 64 | logger.info("Building shard %d." % i) 65 | dataset = inputters.Dataset.from_raw( 66 | fields, 67 | readers=[src_reader, tgt_reader] if tgt_reader else [src_reader], 68 | data=([("src", src_shard), ("tgt", tgt_shard)] 69 | if tgt_reader else [("src", src_shard)]), 70 | dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir], 71 | sort_key=inputters.str2sortkey[opt.data_type], 72 | filter_pred=filter_pred 73 | ) 74 | 75 | data_path = "{:s}.{:s}.{:d}.gz".format(opt.save_data, corpus_type, i) 76 | dataset_paths.append(data_path) 77 | 78 | logger.info(" * saving %sth %s data shard to %s." 79 | % (i, corpus_type, data_path)) 80 | 81 | dataset.save_jsonl(data_path) 82 | 83 | del dataset.examples 84 | gc.collect() 85 | del dataset 86 | gc.collect() 87 | 88 | return dataset_paths 89 | 90 | 91 | def build_save_vocab(train_dataset, fields, opt): 92 | fields = inputters.build_vocab( 93 | train_dataset, fields, opt.data_type, opt.share_vocab, 94 | opt.src_vocab, opt.src_vocab_size, opt.src_words_min_frequency, 95 | opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency, 96 | vocab_size_multiple=opt.vocab_size_multiple 97 | ) 98 | 99 | vocab_path = opt.save_data + '.vocab.pt' 100 | torch.save(fields, vocab_path) 101 | 102 | 103 | def count_features(path): 104 | """ 105 | path: location of a corpus file with whitespace-delimited tokens and 106 | │-delimited features within the token 107 | returns: the number of features in the dataset 108 | """ 109 | with codecs.open(path, "r", "utf-8") as f: 110 | first_tok = f.readline().split(None, 1)[0] 111 | return len(first_tok.split(u"│")) - 1 112 | 113 | 114 | def main(opt): 115 | ArgumentParser.validate_preprocess_args(opt) 116 | torch.manual_seed(opt.seed) 117 | check_existing_pt_files(opt) 118 | 119 | init_logger(opt.log_file) 120 | logger.info("Extracting features...") 121 | 122 | ################ MODIFIED ################ 123 | #src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \ 124 | # else 0 125 | #tgt_nfeats = count_features(opt.train_tgt) # tgt always text so far 126 | #logger.info(" * number of source features: %d." % src_nfeats) 127 | #logger.info(" * number of target features: %d." % tgt_nfeats) 128 | src_nfeats = tgt_nfeats = 0 129 | ########################################## 130 | 131 | logger.info("Building `Fields` object...") 132 | fields = inputters.get_fields( 133 | opt.data_type, 134 | src_nfeats, 135 | tgt_nfeats, 136 | dynamic_dict=opt.dynamic_dict, 137 | src_truncate=opt.src_seq_length_trunc, 138 | tgt_truncate=opt.tgt_seq_length_trunc) 139 | 140 | src_reader = inputters.str2reader[opt.data_type].from_opt(opt) 141 | tgt_reader = inputters.str2reader["text"].from_opt(opt) 142 | 143 | logger.info("Building & saving training data...") 144 | train_dataset_files = build_save_dataset( 145 | 'train', fields, src_reader, tgt_reader, opt) 146 | 147 | if opt.valid_src and opt.valid_tgt: 148 | logger.info("Building & saving validation data...") 149 | build_save_dataset('valid', fields, src_reader, tgt_reader, opt) 150 | 151 | logger.info("Building & saving vocabulary...") 152 | build_save_vocab(train_dataset_files, fields, opt) 153 | 154 | 155 | def _get_parser(): 156 | parser = ArgumentParser(description='preprocess.py') 157 | 158 | opts.config_opts(parser) 159 | opts.preprocess_opts(parser) 160 | return parser 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = _get_parser() 165 | 166 | opt = parser.parse_args() 167 | main(opt) 168 | -------------------------------------------------------------------------------- /translator/onmt/decoders/ensemble.py: -------------------------------------------------------------------------------- 1 | """Ensemble decoding. 2 | 3 | Decodes using multiple models simultaneously, 4 | combining their prediction distributions by averaging. 5 | All models in the ensemble must share a target vocabulary. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from onmt.encoders.encoder import EncoderBase 12 | from onmt.models import NMTModel 13 | import onmt.model_builder 14 | 15 | 16 | class EnsembleDecoderOutput(object): 17 | """Wrapper around multiple decoder final hidden states.""" 18 | def __init__(self, model_dec_outs): 19 | self.model_dec_outs = tuple(model_dec_outs) 20 | 21 | def squeeze(self, dim=None): 22 | """Delegate squeeze to avoid modifying 23 | :func:`onmt.translate.translator.Translator.translate_batch()` 24 | """ 25 | return EnsembleDecoderOutput([ 26 | x.squeeze(dim) for x in self.model_dec_outs]) 27 | 28 | def __getitem__(self, index): 29 | return self.model_dec_outs[index] 30 | 31 | 32 | class EnsembleEncoder(EncoderBase): 33 | """Dummy Encoder that delegates to individual real Encoders.""" 34 | def __init__(self, model_encoders): 35 | super(EnsembleEncoder, self).__init__() 36 | self.model_encoders = nn.ModuleList(model_encoders) 37 | 38 | def forward(self, src, lengths=None): 39 | enc_hidden, memory_bank, _ = zip(*[ 40 | model_encoder(src, lengths) 41 | for model_encoder in self.model_encoders]) 42 | return enc_hidden, memory_bank, lengths 43 | 44 | 45 | class EnsembleDecoder(nn.Module): 46 | """Dummy Decoder that delegates to individual real Decoders.""" 47 | def __init__(self, model_decoders): 48 | super(EnsembleDecoder, self).__init__() 49 | self.model_decoders = nn.ModuleList(model_decoders) 50 | 51 | def forward(self, tgt, memory_bank, memory_lengths=None, step=None): 52 | """See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" 53 | # Memory_lengths is a single tensor shared between all models. 54 | # This assumption will not hold if Translator is modified 55 | # to calculate memory_lengths as something other than the length 56 | # of the input. 57 | dec_outs, attns = zip(*[ 58 | model_decoder( 59 | tgt, memory_bank[i], 60 | memory_lengths=memory_lengths, step=step) 61 | for i, model_decoder in enumerate(self.model_decoders)]) 62 | mean_attns = self.combine_attns(attns) 63 | return EnsembleDecoderOutput(dec_outs), mean_attns 64 | 65 | def combine_attns(self, attns): 66 | result = {} 67 | for key in attns[0].keys(): 68 | result[key] = torch.stack([attn[key] for attn in attns]).mean(0) 69 | return result 70 | 71 | def init_state(self, src, memory_bank, enc_hidden): 72 | """ See :obj:`RNNDecoderBase.init_state()` """ 73 | for i, model_decoder in enumerate(self.model_decoders): 74 | model_decoder.init_state(src, memory_bank[i], enc_hidden[i]) 75 | 76 | def map_state(self, fn): 77 | for model_decoder in self.model_decoders: 78 | model_decoder.map_state(fn) 79 | 80 | 81 | class EnsembleGenerator(nn.Module): 82 | """ 83 | Dummy Generator that delegates to individual real Generators, 84 | and then averages the resulting target distributions. 85 | """ 86 | def __init__(self, model_generators, raw_probs=False): 87 | super(EnsembleGenerator, self).__init__() 88 | self.model_generators = nn.ModuleList(model_generators) 89 | self._raw_probs = raw_probs 90 | 91 | def forward(self, hidden, attn=None, src_map=None): 92 | """ 93 | Compute a distribution over the target dictionary 94 | by averaging distributions from models in the ensemble. 95 | All models in the ensemble must share a target vocabulary. 96 | """ 97 | distributions = torch.stack( 98 | [mg(h) if attn is None else mg(h, attn, src_map) 99 | for h, mg in zip(hidden, self.model_generators)] 100 | ) 101 | if self._raw_probs: 102 | return torch.log(torch.exp(distributions).mean(0)) 103 | else: 104 | return distributions.mean(0) 105 | 106 | 107 | class EnsembleModel(NMTModel): 108 | """Dummy NMTModel wrapping individual real NMTModels.""" 109 | def __init__(self, models, raw_probs=False): 110 | encoder = EnsembleEncoder(model.encoder for model in models) 111 | decoder = EnsembleDecoder(model.decoder for model in models) 112 | super(EnsembleModel, self).__init__(encoder, decoder) 113 | self.generator = EnsembleGenerator( 114 | [model.generator for model in models], raw_probs) 115 | self.models = nn.ModuleList(models) 116 | 117 | 118 | def load_test_model(opt): 119 | """Read in multiple models for ensemble.""" 120 | shared_fields = None 121 | shared_model_opt = None 122 | models = [] 123 | for model_path in opt.models: 124 | fields, model, model_opt = \ 125 | onmt.model_builder.load_test_model(opt, model_path=model_path) 126 | if shared_fields is None: 127 | shared_fields = fields 128 | else: 129 | for key, field in fields.items(): 130 | try: 131 | f_iter = iter(field) 132 | except TypeError: 133 | f_iter = [(key, field)] 134 | for sn, sf in f_iter: 135 | if sf is not None and 'vocab' in sf.__dict__: 136 | sh_field = shared_fields[key] 137 | try: 138 | sh_f_iter = iter(sh_field) 139 | except TypeError: 140 | sh_f_iter = [(key, sh_field)] 141 | sh_f_dict = dict(sh_f_iter) 142 | assert sf.vocab.stoi == sh_f_dict[sn].vocab.stoi, \ 143 | "Ensemble models must use the same " \ 144 | "preprocessed data" 145 | models.append(model) 146 | if shared_model_opt is None: 147 | shared_model_opt = model_opt 148 | ensemble_model = EnsembleModel(models, opt.avg_raw_probs) 149 | return shared_fields, ensemble_model, shared_model_opt 150 | -------------------------------------------------------------------------------- /translator/onmt/encoders/audio_encoder.py: -------------------------------------------------------------------------------- 1 | """Audio encoder""" 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | 9 | from onmt.utils.rnn_factory import rnn_factory 10 | from onmt.encoders.encoder import EncoderBase 11 | 12 | 13 | class AudioEncoder(EncoderBase): 14 | """A simple encoder CNN -> RNN for audio input. 15 | 16 | Args: 17 | rnn_type (str): Type of RNN (e.g. GRU, LSTM, etc). 18 | enc_layers (int): Number of encoder layers. 19 | dec_layers (int): Number of decoder layers. 20 | brnn (bool): Bidirectional encoder. 21 | enc_rnn_size (int): Size of hidden states of the rnn. 22 | dec_rnn_size (int): Size of the decoder hidden states. 23 | enc_pooling (str): A comma separated list either of length 1 24 | or of length ``enc_layers`` specifying the pooling amount. 25 | dropout (float): dropout probablity. 26 | sample_rate (float): input spec 27 | window_size (int): input spec 28 | """ 29 | 30 | def __init__(self, rnn_type, enc_layers, dec_layers, brnn, 31 | enc_rnn_size, dec_rnn_size, enc_pooling, dropout, 32 | sample_rate, window_size): 33 | super(AudioEncoder, self).__init__() 34 | self.enc_layers = enc_layers 35 | self.rnn_type = rnn_type 36 | self.dec_layers = dec_layers 37 | num_directions = 2 if brnn else 1 38 | self.num_directions = num_directions 39 | assert enc_rnn_size % num_directions == 0 40 | enc_rnn_size_real = enc_rnn_size // num_directions 41 | assert dec_rnn_size % num_directions == 0 42 | self.dec_rnn_size = dec_rnn_size 43 | dec_rnn_size_real = dec_rnn_size // num_directions 44 | self.dec_rnn_size_real = dec_rnn_size_real 45 | self.dec_rnn_size = dec_rnn_size 46 | input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 47 | enc_pooling = enc_pooling.split(',') 48 | assert len(enc_pooling) == enc_layers or len(enc_pooling) == 1 49 | if len(enc_pooling) == 1: 50 | enc_pooling = enc_pooling * enc_layers 51 | enc_pooling = [int(p) for p in enc_pooling] 52 | self.enc_pooling = enc_pooling 53 | 54 | if dropout > 0: 55 | self.dropout = nn.Dropout(dropout) 56 | else: 57 | self.dropout = None 58 | self.W = nn.Linear(enc_rnn_size, dec_rnn_size, bias=False) 59 | self.batchnorm_0 = nn.BatchNorm1d(enc_rnn_size, affine=True) 60 | self.rnn_0, self.no_pack_padded_seq = \ 61 | rnn_factory(rnn_type, 62 | input_size=input_size, 63 | hidden_size=enc_rnn_size_real, 64 | num_layers=1, 65 | dropout=dropout, 66 | bidirectional=brnn) 67 | self.pool_0 = nn.MaxPool1d(enc_pooling[0]) 68 | for l in range(enc_layers - 1): 69 | batchnorm = nn.BatchNorm1d(enc_rnn_size, affine=True) 70 | rnn, _ = \ 71 | rnn_factory(rnn_type, 72 | input_size=enc_rnn_size, 73 | hidden_size=enc_rnn_size_real, 74 | num_layers=1, 75 | dropout=dropout, 76 | bidirectional=brnn) 77 | setattr(self, 'rnn_%d' % (l + 1), rnn) 78 | setattr(self, 'pool_%d' % (l + 1), 79 | nn.MaxPool1d(enc_pooling[l + 1])) 80 | setattr(self, 'batchnorm_%d' % (l + 1), batchnorm) 81 | 82 | @classmethod 83 | def from_opt(cls, opt, embeddings=None): 84 | """Alternate constructor.""" 85 | if embeddings is not None: 86 | raise ValueError("Cannot use embeddings with AudioEncoder.") 87 | return cls( 88 | opt.rnn_type, 89 | opt.enc_layers, 90 | opt.dec_layers, 91 | opt.brnn, 92 | opt.enc_rnn_size, 93 | opt.dec_rnn_size, 94 | opt.audio_enc_pooling, 95 | opt.dropout, 96 | opt.sample_rate, 97 | opt.window_size) 98 | 99 | def forward(self, src, lengths=None): 100 | """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" 101 | batch_size, _, nfft, t = src.size() 102 | src = src.transpose(0, 1).transpose(0, 3).contiguous() \ 103 | .view(t, batch_size, nfft) 104 | orig_lengths = lengths 105 | lengths = lengths.view(-1).tolist() 106 | 107 | for l in range(self.enc_layers): 108 | rnn = getattr(self, 'rnn_%d' % l) 109 | pool = getattr(self, 'pool_%d' % l) 110 | batchnorm = getattr(self, 'batchnorm_%d' % l) 111 | stride = self.enc_pooling[l] 112 | packed_emb = pack(src, lengths) 113 | memory_bank, tmp = rnn(packed_emb) 114 | memory_bank = unpack(memory_bank)[0] 115 | t, _, _ = memory_bank.size() 116 | memory_bank = memory_bank.transpose(0, 2) 117 | memory_bank = pool(memory_bank) 118 | lengths = [int(math.floor((length - stride) / stride + 1)) 119 | for length in lengths] 120 | memory_bank = memory_bank.transpose(0, 2) 121 | src = memory_bank 122 | t, _, num_feat = src.size() 123 | src = batchnorm(src.contiguous().view(-1, num_feat)) 124 | src = src.view(t, -1, num_feat) 125 | if self.dropout and l + 1 != self.enc_layers: 126 | src = self.dropout(src) 127 | 128 | memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2)) 129 | memory_bank = self.W(memory_bank).view(-1, batch_size, 130 | self.dec_rnn_size) 131 | 132 | state = memory_bank.new_full((self.dec_layers * self.num_directions, 133 | batch_size, self.dec_rnn_size_real), 0) 134 | if self.rnn_type == 'LSTM': 135 | # The encoder hidden is (layers*directions) x batch x dim. 136 | encoder_final = (state, state) 137 | else: 138 | encoder_final = state 139 | return encoder_final, memory_bank, orig_lengths.new_tensor(lengths) 140 | -------------------------------------------------------------------------------- /translator/torchtext/datasets/babi.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import open 3 | 4 | import torch 5 | 6 | from ..data import Dataset, Field, Example, Iterator 7 | 8 | 9 | class BABI20Field(Field): 10 | 11 | def __init__(self, memory_size, **kwargs): 12 | super(BABI20Field, self).__init__(**kwargs) 13 | self.memory_size = memory_size 14 | self.unk_token = None 15 | self.batch_first = True 16 | 17 | def preprocess(self, x): 18 | if isinstance(x, list): 19 | return [super(BABI20Field, self).preprocess(s) for s in x] 20 | else: 21 | return super(BABI20Field, self).preprocess(x) 22 | 23 | def pad(self, minibatch): 24 | if isinstance(minibatch[0][0], list): 25 | self.fix_length = max(max(len(x) for x in ex) for ex in minibatch) 26 | padded = [] 27 | for ex in minibatch: 28 | # sentences are indexed in reverse order and truncated to memory_size 29 | nex = ex[::-1][:self.memory_size] 30 | padded.append( 31 | super(BABI20Field, self).pad(nex) 32 | + [[self.pad_token] * self.fix_length] 33 | * (self.memory_size - len(nex))) 34 | self.fix_length = None 35 | return padded 36 | else: 37 | return super(BABI20Field, self).pad(minibatch) 38 | 39 | def numericalize(self, arr, device=None): 40 | if isinstance(arr[0][0], list): 41 | tmp = [ 42 | super(BABI20Field, self).numericalize(x, device=device).data 43 | for x in arr 44 | ] 45 | arr = torch.stack(tmp) 46 | if self.sequential: 47 | arr = arr.contiguous() 48 | return arr 49 | else: 50 | return super(BABI20Field, self).numericalize(arr, device=device) 51 | 52 | 53 | class BABI20(Dataset): 54 | urls = ['http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'] 55 | name = '' 56 | dirname = '' 57 | 58 | def __init__(self, path, text_field, only_supporting=False, **kwargs): 59 | fields = [('story', text_field), ('query', text_field), ('answer', text_field)] 60 | self.sort_key = lambda x: len(x.query) 61 | 62 | with open(path, 'r', encoding="utf-8") as f: 63 | triplets = self._parse(f, only_supporting) 64 | examples = [Example.fromlist(triplet, fields) for triplet in triplets] 65 | 66 | super(BABI20, self).__init__(examples, fields, **kwargs) 67 | 68 | @staticmethod 69 | def _parse(file, only_supporting): 70 | data, story = [], [] 71 | for line in file: 72 | tid, text = line.rstrip('\n').split(' ', 1) 73 | if tid == '1': 74 | story = [] 75 | # sentence 76 | if text.endswith('.'): 77 | story.append(text[:-1]) 78 | # question 79 | else: 80 | # remove any leading or trailing whitespace after splitting 81 | query, answer, supporting = (x.strip() for x in text.split('\t')) 82 | if only_supporting: 83 | substory = [story[int(i) - 1] for i in supporting.split()] 84 | else: 85 | substory = [x for x in story if x] 86 | data.append((substory, query[:-1], answer)) # remove '?' 87 | story.append("") 88 | return data 89 | 90 | @classmethod 91 | def splits(cls, text_field, path=None, root='.data', task=1, joint=False, tenK=False, 92 | only_supporting=False, train=None, validation=None, test=None, **kwargs): 93 | assert isinstance(task, int) and 1 <= task <= 20 94 | if tenK: 95 | cls.dirname = os.path.join('tasks_1-20_v1-2', 'en-valid-10k') 96 | else: 97 | cls.dirname = os.path.join('tasks_1-20_v1-2', 'en-valid') 98 | if path is None: 99 | path = cls.download(root) 100 | if train is None: 101 | if joint: # put all tasks together for joint learning 102 | train = 'all_train.txt' 103 | if not os.path.isfile(os.path.join(path, train)): 104 | with open(os.path.join(path, train), 'w') as tf: 105 | for task in range(1, 21): 106 | with open( 107 | os.path.join(path, 108 | 'qa' + str(task) + '_train.txt')) as f: 109 | tf.write(f.read()) 110 | else: 111 | train = 'qa' + str(task) + '_train.txt' 112 | if validation is None: 113 | if joint: # put all tasks together for joint learning 114 | validation = 'all_valid.txt' 115 | if not os.path.isfile(os.path.join(path, validation)): 116 | with open(os.path.join(path, validation), 'w') as tf: 117 | for task in range(1, 21): 118 | with open( 119 | os.path.join(path, 120 | 'qa' + str(task) + '_valid.txt')) as f: 121 | tf.write(f.read()) 122 | else: 123 | validation = 'qa' + str(task) + '_valid.txt' 124 | if test is None: 125 | test = 'qa' + str(task) + '_test.txt' 126 | return super(BABI20, 127 | cls).splits(path=path, root=root, text_field=text_field, train=train, 128 | validation=validation, test=test, **kwargs) 129 | 130 | @classmethod 131 | def iters(cls, batch_size=32, root='.data', memory_size=50, task=1, joint=False, 132 | tenK=False, only_supporting=False, sort=False, shuffle=False, device=None, 133 | **kwargs): 134 | text = BABI20Field(memory_size) 135 | train, val, test = BABI20.splits(text, root=root, task=task, joint=joint, 136 | tenK=tenK, only_supporting=only_supporting, 137 | **kwargs) 138 | text.build_vocab(train) 139 | return Iterator.splits((train, val, test), batch_size=batch_size, sort=sort, 140 | shuffle=shuffle, device=device) 141 | -------------------------------------------------------------------------------- /translator/onmt/utils/report_manager.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | import time 4 | from datetime import datetime 5 | 6 | import onmt 7 | 8 | from onmt.utils.logging import logger 9 | 10 | 11 | def build_report_manager(opt): 12 | if opt.tensorboard: 13 | from tensorboardX import SummaryWriter 14 | tensorboard_log_dir = opt.tensorboard_log_dir 15 | 16 | if not opt.train_from: 17 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 18 | 19 | writer = SummaryWriter(tensorboard_log_dir, 20 | comment="Unmt") 21 | else: 22 | writer = None 23 | 24 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 25 | tensorboard_writer=writer) 26 | return report_mgr 27 | 28 | 29 | class ReportMgrBase(object): 30 | """ 31 | Report Manager Base class 32 | Inherited classes should override: 33 | * `_report_training` 34 | * `_report_step` 35 | """ 36 | 37 | def __init__(self, report_every, start_time=-1.): 38 | """ 39 | Args: 40 | report_every(int): Report status every this many sentences 41 | start_time(float): manually set report start time. Negative values 42 | means that you will need to set it later or use `start()` 43 | """ 44 | self.report_every = report_every 45 | self.progress_step = 0 46 | self.start_time = start_time 47 | 48 | def start(self): 49 | self.start_time = time.time() 50 | 51 | def log(self, *args, **kwargs): 52 | logger.info(*args, **kwargs) 53 | 54 | def report_training(self, step, num_steps, learning_rate, 55 | report_stats, multigpu=False): 56 | """ 57 | This is the user-defined batch-level traing progress 58 | report function. 59 | 60 | Args: 61 | step(int): current step count. 62 | num_steps(int): total number of batches. 63 | learning_rate(float): current learning rate. 64 | report_stats(Statistics): old Statistics instance. 65 | Returns: 66 | report_stats(Statistics): updated Statistics instance. 67 | """ 68 | if self.start_time < 0: 69 | raise ValueError("""ReportMgr needs to be started 70 | (set 'start_time' or use 'start()'""") 71 | 72 | if step % self.report_every == 0: 73 | if multigpu: 74 | report_stats = \ 75 | onmt.utils.Statistics.all_gather_stats(report_stats) 76 | self._report_training( 77 | step, num_steps, learning_rate, report_stats) 78 | self.progress_step += 1 79 | return onmt.utils.Statistics() 80 | else: 81 | return report_stats 82 | 83 | def _report_training(self, *args, **kwargs): 84 | """ To be overridden """ 85 | raise NotImplementedError() 86 | 87 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 88 | """ 89 | Report stats of a step 90 | 91 | Args: 92 | train_stats(Statistics): training stats 93 | valid_stats(Statistics): validation stats 94 | lr(float): current learning rate 95 | """ 96 | self._report_step( 97 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 98 | 99 | def _report_step(self, *args, **kwargs): 100 | raise NotImplementedError() 101 | 102 | 103 | class ReportMgr(ReportMgrBase): 104 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 105 | """ 106 | A report manager that writes statistics on standard output as well as 107 | (optionally) TensorBoard 108 | 109 | Args: 110 | report_every(int): Report status every this many sentences 111 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 112 | The TensorBoard Summary writer to use or None 113 | """ 114 | super(ReportMgr, self).__init__(report_every, start_time) 115 | self.tensorboard_writer = tensorboard_writer 116 | 117 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 118 | if self.tensorboard_writer is not None: 119 | stats.log_tensorboard( 120 | prefix, self.tensorboard_writer, learning_rate, step) 121 | 122 | def _report_training(self, step, num_steps, learning_rate, 123 | report_stats): 124 | """ 125 | See base class method `ReportMgrBase.report_training`. 126 | """ 127 | report_stats.output(step, num_steps, 128 | learning_rate, self.start_time) 129 | 130 | # Log the progress using the number of batches on the x-axis. 131 | self.maybe_log_tensorboard(report_stats, 132 | "progress", 133 | learning_rate, 134 | self.progress_step) 135 | report_stats = onmt.utils.Statistics() 136 | 137 | return report_stats 138 | 139 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 140 | """ 141 | See base class method `ReportMgrBase.report_step`. 142 | """ 143 | ################ MODIFIED ################ 144 | if train_stats is not None: 145 | self.log('Train perplexity: %g' % train_stats.ppl()) 146 | self.log('Train accuracy: %g' % train_stats.accuracy()) 147 | self.log('Train sequence accuracy: %g' % train_stats.seq_accuracy()) 148 | 149 | self.maybe_log_tensorboard(train_stats, 150 | "train", 151 | lr, 152 | step) 153 | 154 | if valid_stats is not None: 155 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 156 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 157 | self.log('Validation sequence accuracy: %g' % valid_stats.seq_accuracy()) 158 | 159 | self.maybe_log_tensorboard(valid_stats, 160 | "valid", 161 | lr, 162 | step) 163 | ########################################## 164 | -------------------------------------------------------------------------------- /translator/onmt/translate/translation.py: -------------------------------------------------------------------------------- 1 | """ Translation main class """ 2 | from __future__ import unicode_literals, print_function 3 | 4 | import torch 5 | from onmt.inputters.text_dataset import TextMultiField 6 | 7 | 8 | class TranslationBuilder(object): 9 | """ 10 | Build a word-based translation from the batch output 11 | of translator and the underlying dictionaries. 12 | 13 | Replacement based on "Addressing the Rare Word 14 | Problem in Neural Machine Translation" :cite:`Luong2015b` 15 | 16 | Args: 17 | data (onmt.inputters.Dataset): Data. 18 | fields (List[Tuple[str, torchtext.data.Field]]): data fields 19 | n_best (int): number of translations produced 20 | replace_unk (bool): replace unknown words using attention 21 | has_tgt (bool): will the batch have gold targets 22 | """ 23 | 24 | def __init__(self, data, fields, n_best=1, replace_unk=False, 25 | has_tgt=False): 26 | self.data = data 27 | self.fields = fields 28 | self._has_text_src = isinstance( 29 | dict(self.fields)["src"], TextMultiField) 30 | self.n_best = n_best 31 | self.replace_unk = replace_unk 32 | self.has_tgt = has_tgt 33 | 34 | def _build_target_tokens(self, src, src_vocab, src_raw, pred, attn): 35 | tgt_field = dict(self.fields)["tgt"].base_field 36 | vocab = tgt_field.vocab 37 | tokens = [] 38 | for tok in pred: 39 | if tok < len(vocab): 40 | tokens.append(vocab.itos[tok]) 41 | else: 42 | tokens.append(src_vocab.itos[tok - len(vocab)]) 43 | if tokens[-1] == tgt_field.eos_token: 44 | tokens = tokens[:-1] 45 | break 46 | if self.replace_unk and attn is not None and src is not None: 47 | for i in range(len(tokens)): 48 | if tokens[i] == tgt_field.unk_token: 49 | _, max_index = attn[i].max(0) 50 | tokens[i] = src_raw[max_index.item()] 51 | return tokens 52 | 53 | def from_batch(self, translation_batch): 54 | batch = translation_batch["batch"] 55 | assert(len(translation_batch["gold_score"]) == 56 | len(translation_batch["predictions"])) 57 | batch_size = batch.batch_size 58 | 59 | preds, pred_score, attn, gold_score, indices = list(zip( 60 | *sorted(zip(translation_batch["predictions"], 61 | translation_batch["scores"], 62 | translation_batch["attention"], 63 | translation_batch["gold_score"], 64 | batch.indices.data), 65 | key=lambda x: x[-1]))) 66 | 67 | # Sorting 68 | inds, perm = torch.sort(batch.indices) 69 | if self._has_text_src: 70 | src = batch.src[0][:, :, 0].index_select(1, perm) 71 | else: 72 | src = None 73 | tgt = batch.tgt[:, :, 0].index_select(1, perm) \ 74 | if self.has_tgt else None 75 | 76 | translations = [] 77 | for b in range(batch_size): 78 | if self._has_text_src: 79 | src_vocab = self.data.src_vocabs[inds[b]] \ 80 | if self.data.src_vocabs else None 81 | src_raw = self.data.examples[inds[b]].src[0] 82 | else: 83 | src_vocab = None 84 | src_raw = None 85 | pred_sents = [self._build_target_tokens( 86 | src[:, b] if src is not None else None, 87 | src_vocab, src_raw, 88 | preds[b][n], attn[b][n]) 89 | for n in range(self.n_best)] 90 | gold_sent = None 91 | if tgt is not None: 92 | gold_sent = self._build_target_tokens( 93 | src[:, b] if src is not None else None, 94 | src_vocab, src_raw, 95 | tgt[1:, b] if tgt is not None else None, None) 96 | 97 | translation = Translation( 98 | src[:, b] if src is not None else None, 99 | src_raw, pred_sents, attn[b], pred_score[b], 100 | gold_sent, gold_score[b] 101 | ) 102 | translations.append(translation) 103 | 104 | return translations 105 | 106 | 107 | class Translation(object): 108 | """Container for a translated sentence. 109 | 110 | Attributes: 111 | src (LongTensor): Source word IDs. 112 | src_raw (List[str]): Raw source words. 113 | pred_sents (List[List[str]]): Words from the n-best translations. 114 | pred_scores (List[List[float]]): Log-probs of n-best translations. 115 | attns (List[FloatTensor]) : Attention distribution for each 116 | translation. 117 | gold_sent (List[str]): Words from gold translation. 118 | gold_score (List[float]): Log-prob of gold translation. 119 | """ 120 | 121 | __slots__ = ["src", "src_raw", "pred_sents", "attns", "pred_scores", 122 | "gold_sent", "gold_score"] 123 | 124 | def __init__(self, src, src_raw, pred_sents, 125 | attn, pred_scores, tgt_sent, gold_score): 126 | self.src = src 127 | self.src_raw = src_raw 128 | self.pred_sents = pred_sents 129 | self.attns = attn 130 | self.pred_scores = pred_scores 131 | self.gold_sent = tgt_sent 132 | self.gold_score = gold_score 133 | 134 | def log(self, sent_number): 135 | """ 136 | Log translation. 137 | """ 138 | 139 | msg = ['\nSENT {}: {}\n'.format(sent_number, self.src_raw)] 140 | 141 | best_pred = self.pred_sents[0] 142 | best_score = self.pred_scores[0] 143 | pred_sent = ' '.join(best_pred) 144 | msg.append('PRED {}: {}\n'.format(sent_number, pred_sent)) 145 | msg.append("PRED SCORE: {:.4f}\n".format(best_score)) 146 | 147 | if self.gold_sent is not None: 148 | tgt_sent = ' '.join(self.gold_sent) 149 | msg.append('GOLD {}: {}\n'.format(sent_number, tgt_sent)) 150 | msg.append(("GOLD SCORE: {:.4f}\n".format(self.gold_score))) 151 | if len(self.pred_sents) > 1: 152 | msg.append('\nBEST HYP:\n') 153 | for score, sent in zip(self.pred_scores, self.pred_sents): 154 | msg.append("[{:.4f}] {}\n".format(score, sent)) 155 | 156 | return "".join(msg) 157 | -------------------------------------------------------------------------------- /translator/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Parse the log file from translate.py and output evaluation scores. 5 | """ 6 | from __future__ import print_function 7 | 8 | import re 9 | import argparse 10 | from collections import Counter 11 | 12 | 13 | SKIPTHISLINE = "SKIPTHISLINE" 14 | DUMMY = "DUMMY" 15 | GOLD_NOT_FOUND = 999999 16 | 17 | 18 | class Output(object): 19 | IGNORE_BRACES = False 20 | 21 | def __init__(self): 22 | self.index = None 23 | self.sentence = None 24 | self.preds = [] 25 | self.pred_scores = [] 26 | self.gold = None 27 | self.gold_score = None 28 | self.gold_unked = None 29 | 30 | @property 31 | def rank(self): 32 | """ 33 | Compute the rank (1-indexed) of the gold among the predictions. 34 | If the gold is not found, return GOLD_NOT_FOUND. 35 | """ 36 | for i, pred in enumerate(self.preds): 37 | if pred == self.gold: 38 | return i + 1 39 | return GOLD_NOT_FOUND 40 | 41 | @classmethod 42 | def dump_header(self, fout): 43 | stuff = [ 44 | "index", 45 | "text", 46 | "gold_score", 47 | "pred_score", 48 | "gold", 49 | "pred", 50 | ] 51 | print("\t".join(str(x) for x in stuff), file=fout) 52 | 53 | 54 | def dump(self, fout, args): 55 | """Print a TSV line summarizing the example.""" 56 | stuff = [ 57 | self.index, 58 | self.sentence, 59 | self.gold_score, 60 | self.pred_scores[0], 61 | self.gold, 62 | self.preds[0], 63 | ] 64 | if args.gold_rank: 65 | stuff.append(self.rank) 66 | if args.dump_all_preds: 67 | stuff += self.preds[1:] 68 | stuff += self.pred_scores 69 | print("\t".join(str(x) for x in stuff), file=fout) 70 | if args.dump_all_preds_verbose: 71 | for i, pred in enumerate(self.preds): 72 | print("\t".join(["#", str(i), pred]), file=fout) 73 | print(file=fout) 74 | 75 | @classmethod 76 | def format_code(cls, code): 77 | # Resplit and rejoin 78 | code = ' '.join(code.split()) 79 | # Remove braces if specified 80 | if cls.IGNORE_BRACES: 81 | code = re.sub('^}|{$', '', code).strip() 82 | return code 83 | 84 | @classmethod 85 | def parse_file(cls, pred_file, tgt_file): 86 | """"Yield Output objects based on the given files.""" 87 | output = None 88 | with open(pred_file) as fin, open(tgt_file) as ftgt: 89 | while True: 90 | line = fin.readline() 91 | if not line: 92 | break 93 | line = line.rstrip('\n') 94 | # SENT 1: [...] (begins a new output) 95 | m = re.match(r'^SENT (\d+): (.*)$', line) 96 | if m: 97 | if output is not None: 98 | yield output 99 | output = Output() 100 | i, sentence = m.groups() 101 | output.index = int(i) 102 | output.sentence = " ".join(eval(sentence)) 103 | continue 104 | # GOLD 1: ... 105 | m = re.match(r'^GOLD (\d+): (.*)$', line) 106 | if m: 107 | i, gold = m.groups() 108 | assert int(i) == output.index 109 | output.gold_unk = gold 110 | output.gold = cls.format_code(ftgt.readline().strip()) 111 | continue 112 | # GOLD SCORE: ... 113 | m = re.match(r'^GOLD SCORE: (.*)$', line) 114 | if m: 115 | output.gold_score = float(m.groups()[0]) 116 | continue 117 | # BEST HYP: 118 | if line == "BEST HYP:": 119 | while True: 120 | pred_line = fin.readline().strip() 121 | if not pred_line: 122 | break 123 | m = re.match(r'\[([^]]+)\] (.*)$', pred_line) 124 | score, pred = m.groups() 125 | output.preds.append(cls.format_code(" ".join(eval(pred)))) 126 | output.pred_scores.append(float(score)) 127 | # The last output 128 | if output is not None: 129 | yield output 130 | 131 | 132 | def main(): 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('-o', '--out_file', default='/dev/null', 135 | help='dump prediction TSV to this file') 136 | parser.add_argument('-b', '--ignore-braces', action='store_true', 137 | help='ignore the initial and ending braces in code') 138 | parser.add_argument('-a', '--dump-all-preds', action='store_true', 139 | help='When dumping predictions, dump all ranks') 140 | parser.add_argument('-A', '--dump-all-preds-verbose', action='store_true', 141 | help='When dumping predictions, dump all ranks in their own lines') 142 | parser.add_argument('-r', '--gold-rank', action='store_true', 143 | help='Print gold rank') 144 | parser.add_argument("pred_file", 145 | help="the prediction file from translate.py") 146 | parser.add_argument("tgt_file", 147 | help="gold target file") 148 | args = parser.parse_args() 149 | 150 | if args.ignore_braces: 151 | Output.IGNORE_BRACES = True 152 | 153 | ranks = Counter() 154 | 155 | with open(args.out_file, 'w') as fout: 156 | Output.dump_header(fout) 157 | for output in Output.parse_file(args.pred_file, args.tgt_file): 158 | if args.out_file != '/dev/null': 159 | output.dump(fout, args) 160 | if output.sentence != SKIPTHISLINE and output.sentence != DUMMY: 161 | ranks[output.rank] += 1 162 | 163 | n = sum(ranks.values()) 164 | print("Number of examples: {}".format(n)) 165 | accum = 0 166 | mrr = 0. 167 | for i, count in sorted(ranks.items()): 168 | accum += count 169 | if i != GOLD_NOT_FOUND: 170 | mrr += count * 1. / (i + 1) 171 | print("RANK {:>2} = {:5} = {:6.2f} % | accum: {:5} = {:6.2f} %".format( 172 | i if i != GOLD_NOT_FOUND else "no", 173 | count, count * 100. / n, accum, accum * 100. / n)) 174 | print("MRR: {:.6f}".format(mrr / n)) 175 | 176 | if __name__ == '__main__': 177 | main() 178 | 179 | -------------------------------------------------------------------------------- /stitcher/err_utils.py: -------------------------------------------------------------------------------- 1 | # Error detection methods 2 | import json 3 | import math 4 | import re 5 | from urllib.parse import urlencode 6 | from urllib.request import Request, urlopen 7 | 8 | 9 | LINE_OFFSET = 5 10 | TEXT_TOKENIZER = re.compile(r'\w+|[^\w\s]', re.UNICODE) 11 | 12 | 13 | def tokenize_err_msg(text): 14 | return TEXT_TOKENIZER.findall(text) 15 | 16 | 17 | def parse_error(raw_err_msg, tokenize=True): 18 | """ 19 | Return the first error line number and error message. 20 | 21 | Args: 22 | raw_err_msg (str): Raw string from g++ 23 | Returns: 24 | stmt_index: real line number - LINE_OFFSET. 25 | That is, the line number where the first non preamble line is line 0, 26 | and where DUMMY lines are still included. 27 | """ 28 | lines = raw_err_msg.split('\n') 29 | for line in lines: 30 | m = re.match('[^:]*:(\d+):[:0-9 ]+error: (.*)', line) 31 | if not m: 32 | continue 33 | lineno, message = m.groups() 34 | if tokenize: 35 | message = ' '.join(tokenize_err_msg(message)) 36 | return int(lineno) - LINE_OFFSET, message.strip() 37 | return None, None 38 | 39 | 40 | def post_request(server, post_fields): 41 | request = Request(server, urlencode(post_fields).encode()) 42 | response = urlopen(request).read().decode() 43 | response = json.loads(response) 44 | return response 45 | 46 | 47 | ################################################ 48 | 49 | 50 | class ErrDetector(object): 51 | 52 | def __init__(self, args): 53 | pass 54 | 55 | def detect(self, code_lines, raw_err_msg): 56 | """ 57 | Detect where the error actually happens based on the code and g++ message. 58 | 59 | Args: 60 | code_lines: tuple of (pseudocode str, code str, indent int) 61 | raw_err_msg: (str) error message from g++ 62 | 63 | Returns: 64 | tuple (err_line, err_msg) 65 | - err_line: (int) The stmt_idx of the predicted error line 66 | (i.e., real line number minus LINE_OFFSET) 67 | To abstain prediction, let err_line = None 68 | - err_msg: (str) The tokenized message. 69 | """ 70 | raise NotImplementedError 71 | 72 | 73 | class NaiveErrDetector(ErrDetector): 74 | """ 75 | Just return the first error line number from the g++ message. 76 | """ 77 | 78 | def detect(self, code_lines, raw_err_msg): 79 | return parse_error(raw_err_msg) 80 | 81 | 82 | class TemplateErrDetector(ErrDetector): 83 | """ 84 | Match the error message against a template corpus. 85 | If a high-confident match is found, return the error line. 86 | Otherwise, abstain. 87 | """ 88 | 89 | MODES = ['all', 'vars', 'none'] 90 | 91 | def __init__(self, args): 92 | self.templates = {k: {} for k in self.MODES} 93 | with open(args.err_template_file) as fin: 94 | for line in fin: 95 | # mode total_count template line_offset count_percent 96 | line = line.rstrip('\n').split('\t') 97 | if float(line[4]) >= args.err_template_threshold: 98 | self.templates[line[0]][line[2]] = int(line[3]) 99 | for mode, templates in self.templates.items(): 100 | print('Read {} {} templates'.format(len(templates), mode)) 101 | 102 | def anonymize(self, msg, mode): 103 | if mode == 'none': 104 | return msg 105 | if mode == 'vars': 106 | return re.sub('‘[A-Za-z0-9_ ]*’', '@@@', msg) 107 | return re.sub('‘[^’]*’', '@@@', msg) 108 | 109 | def detect(self, code_lines, raw_err_msg): 110 | lineno, msg = parse_error(raw_err_msg, tokenize=False) 111 | if msg is None: 112 | return None, None 113 | tokenized_msg = ' '.join(tokenize_err_msg(msg)).strip() 114 | for mode in self.MODES: 115 | anon_msg = self.anonymize(msg, mode) 116 | if anon_msg in self.templates[mode]: 117 | return lineno - self.templates[mode][anon_msg], tokenized_msg 118 | return None, tokenized_msg 119 | 120 | 121 | class BinaryErrDetector(ErrDetector): 122 | """ 123 | Ask the PyTorch server if the error line from g++ is correct. 124 | If not, abstain. 125 | """ 126 | 127 | def __init__(self, args): 128 | self.server = args.err_server 129 | self.info = {'probno': args.probno} 130 | 131 | def detect(self, code_lines, raw_err_msg): 132 | lineno, msg = parse_error(raw_err_msg, tokenize=True) 133 | if msg is None: 134 | return None, None 135 | # Call the server 136 | q = { 137 | 'info': self.info, 138 | 'code_lines': code_lines, 139 | 'err_line': { 140 | 'lineno': lineno, 141 | 'msg': msg, 142 | } 143 | } 144 | response = post_request(self.server, {'q': json.dumps(q)}) 145 | if response['pred'][0]: 146 | return lineno, msg 147 | return None, msg 148 | 149 | 150 | class AdvancedErrDetector(ErrDetector): 151 | """ 152 | Ask the PyTorch server what the actual error line is. 153 | """ 154 | 155 | def __init__(self, args): 156 | self.server = args.err_server 157 | self.info = {'probno': args.probno} 158 | self.threshold = args.err_advanced_threshold 159 | 160 | def detect(self, code_lines, raw_err_msg): 161 | lineno, msg = parse_error(raw_err_msg, tokenize=True) 162 | if msg is None: 163 | return None, None 164 | # Call the server 165 | q = { 166 | 'info': self.info, 167 | 'code_lines': code_lines, 168 | 'err_line': { 169 | 'lineno': lineno, 170 | 'msg': msg, 171 | } 172 | } 173 | response = post_request(self.server, {'q': json.dumps(q)}) 174 | argmax = response['pred'][0] 175 | probs = self.softmax(response['logit'][0]) 176 | if probs[argmax] >= self.threshold: 177 | return argmax, msg 178 | return None, msg 179 | 180 | def softmax(self, numbers): 181 | numbers = [math.exp(x - max(numbers)) for x in numbers] 182 | return [x / sum(numbers) for x in numbers] 183 | 184 | 185 | ################################################ 186 | 187 | 188 | def get_err_detector(args): 189 | if args.err_detector == 'naive': 190 | return NaiveErrDetector(args) 191 | if args.err_detector == 'template': 192 | return TemplateErrDetector(args) 193 | if args.err_detector == 'binary': 194 | return BinaryErrDetector(args) 195 | if args.err_detector == 'advanced': 196 | return AdvancedErrDetector(args) 197 | raise ValueError('Unknown detector: {}'.format(args.err_detector)) 198 | --------------------------------------------------------------------------------