├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── average-checkpoints.py ├── batch_translate.py ├── create-experiment.py ├── data └── make-dataset.py ├── onmt ├── __init__.py ├── bin │ ├── __init__.py │ ├── average_models.py │ ├── preprocess.py │ ├── server.py │ ├── train.py │ └── translate.py ├── decoders │ ├── __init__.py │ ├── cnn_decoder.py │ ├── decoder.py │ ├── ensemble.py │ ├── hierarchical_decoder.py │ └── transformer.py ├── encoders │ ├── __init__.py │ ├── audio_encoder.py │ ├── cnn_encoder.py │ ├── encoder.py │ ├── hierarchical_transformer.py │ ├── image_encoder.py │ ├── mean_encoder.py │ ├── rnn_encoder.py │ └── transformer.py ├── inputters │ ├── __init__.py │ ├── audio_dataset.py │ ├── datareader_base.py │ ├── dataset_base.py │ ├── image_dataset.py │ ├── inputter.py │ ├── text_dataset.py │ └── vec_dataset.py ├── model_builder.py ├── models │ ├── __init__.py │ ├── model.py │ ├── model_saver.py │ ├── sru.py │ └── stacked_rnn.py ├── modules │ ├── __init__.py │ ├── average_attn.py │ ├── conv_multi_step_attention.py │ ├── copy_generator.py │ ├── embeddings.py │ ├── gate.py │ ├── global_attention.py │ ├── glu.py │ ├── hierarchical_attention.py │ ├── multi_headed_attn.py │ ├── position_ffn.py │ ├── self_attention.py │ ├── sparse_activations.py │ ├── sparse_losses.py │ ├── structured_attention.py │ ├── table_embeddings.py │ ├── util_class.py │ └── weight_norm.py ├── opts.py ├── tests │ ├── __init__.py │ ├── pull_request_chk.sh │ ├── rebuild_test_models.sh │ ├── test_attention.py │ ├── test_audio_dataset.py │ ├── test_beam_search.py │ ├── test_copy_generator.py │ ├── test_embeddings.py │ ├── test_greedy_search.py │ ├── test_image_dataset.py │ ├── test_model.pt │ ├── test_model2.pt │ ├── test_models.py │ ├── test_models.sh │ ├── test_preprocess.py │ ├── test_simple.py │ ├── test_structured_attention.py │ ├── test_text_dataset.py │ ├── test_translation_server.py │ └── utils_for_tests.py ├── train_single.py ├── trainer.py ├── translate │ ├── __init__.py │ ├── beam_search.py │ ├── decode_strategy.py │ ├── greedy_search.py │ ├── penalties.py │ ├── process_zh.py │ ├── translation.py │ ├── translation_server.py │ └── translator.py └── utils │ ├── __init__.py │ ├── alignment.py │ ├── cnn_factory.py │ ├── distributed.py │ ├── earlystopping.py │ ├── logging.py │ ├── loss.py │ ├── misc.py │ ├── optimizers.py │ ├── parse.py │ ├── report_manager.py │ ├── rnn_factory.py │ └── statistics.py ├── outputs.zip ├── preprocess.cfg ├── preprocess.py ├── requirements.txt ├── train.cfg ├── train.py ├── translate.cfg └── translate.py /.gitattributes: -------------------------------------------------------------------------------- 1 | htransformer.tar.bz2 filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/ 2 | data/rotowire/ 3 | *.txt 4 | __pycache__/ 5 | .ipynb_checkpoints/ 6 | *.pyc 7 | *.bz2 8 | *.tar 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Data-to-Text-Hierarchical [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/a-hierarchical-model-for-data-to-text/data-to-text-generation-on-rotowire)](https://paperswithcode.com/sota/data-to-text-generation-on-rotowire?p=a-hierarchical-model-for-data-to-text) 2 | 3 | Code for [A Hierarchical Model for Data-to-Text Generation](https://arxiv.org/abs/1912.10011) (Rebuffel, Soulier, Scoutheeten, Gallinari; ECIR 2020); most of this code is based on [OpenNMT](https://github.com/OpenNMT/OpenNMT-py). 4 | 5 | UPDATE 11/03/2021: The original checkpoints used to produce results from the 6 | paper are officialy lost. However, I still have the actual model outputs, which 7 | are now included in this repo. Simply `unzip outputs.zip`. 8 | 9 | Furthermore, [Radmil Raychev][1] and [Craig Thomson][2] from the University of Aberdeen 10 | are currently working with this repo, and have agreed to share their checkpoints, 11 | namely [htransformer.tar.gz2][5]. 12 | (Note that this file is not downloadable by command line, still looking for a better 13 | alternative) 14 | 15 | Once it's downloaded, simply `tar -xvf htransformer.tar.gz2`. 16 | You'll find the `data` used to train the model, as well as `*.cfg` files and 17 | `*.pt` checkpoints. Note that the data is from [SportSett][3], which contains some 18 | additional info (such as day of the week for instance). 19 | (Also see [Thomson et al.][4] for more info regarding the impact of additional data 20 | on system performances.) 21 | 22 | [1]: https://github.com/radmilr 23 | [2]: https://github.com/nlgcat 24 | [3]: https://github.com/nlgcat/sport_sett_basketball 25 | [4]: https://www.aclweb.org/anthology/2020.inlg-1.6/ 26 | [5]: https://dl.orangedox.com/hierarchical-transformer-checkpoint 27 | 28 | ## Requirements 29 | 30 | You will need a recent python to use it as is, especially OpenNMT. However I guess files could be tweaked to work with older pythons. Please note that at the time of writting, torch==1.1.0 can be problematic with very recent version of python. I suggest running the code with python=3.6 31 | 32 | Full requirements can be found in `requirements.txt`. Note that they are not really all required, it's the full pip freeze of a clean conda virtual env, you can probably make it work with less. 33 | 34 | Beyond standard packages included in miniconda, usefull packages are torch==1.1.0 torchtext==0.4 and some others required to make onmt work (PyYAML and configargparse for example). I also use more_itertools to format the dataset. 35 | 36 | # Dataset 37 | 38 | The dataset used in the paper can be downloaded [here](https://github.com/harvardnlp/boxscore-data). More specifically, you just need to download the [RotoWire dataset](https://github.com/harvardnlp/boxscore-data/blob/master/rotowire.tar.bz2): 39 | 40 | ```bash 41 | cd data 42 | wget https://github.com/harvardnlp/boxscore-data/raw/master/rotowire.tar.bz2 43 | tar -xvjf rotowire.tar.bz2 44 | cd .. 45 | ``` 46 | 47 | You'll need to format the dataset so that it can be preprocessed by OpenNMT. 48 | 49 | `python data/make-dataset.py --folder data/` 50 | 51 | At this stage, your repository should look like this: 52 | 53 | ``` 54 | . 55 | ├── onmt/ # Most of the heavy-lifting is done by onmt 56 | ├── data/ # Dataset is here 57 | │ ├── rotowire/ # Raw data stored here 58 | ├ ├── make-dataset.py # formating script 59 | ├ ├── train_input.txt 60 | ├ ├── train_output.txt 61 | │ └── ... 62 | └── ... 63 | ``` 64 | 65 | # Experiments 66 | 67 | Before any code run, we build an experiment folder to keep things contained 68 | 69 | `python create-experiment.py --name exp-1` 70 | 71 | At this stage, your repository should look like this: 72 | 73 | ``` 74 | . 75 | ├── onmt # Most of the heavy-lifting is done by onmt 76 | ├── experiments # Experiments are stored here 77 | │ └── exp-1 78 | │ │ ├── data 79 | ├ │ ├── gens 80 | │ │ └── models 81 | ├── data # Dataset is here 82 | └── ... 83 | ``` 84 | 85 | # Preprocessing 86 | 87 | Before training models via OpenNMT, you must preprocess the data. I've handled all useful parameters with a config file. Please check it out if you want to tweak things, I have tried to include comments on each command. For futher info you can always check out the OpenNMT [preprocessing doc](http://opennmt.net/OpenNMT-py/options/preprocess.html) 88 | 89 | ``` 90 | python preprocess.py --config preprocess.cfg 91 | ``` 92 | 93 | At this stage, your repository should look like this: 94 | 95 | ``` 96 | ├── onmt # Most of the heavy-lifting is done by onmt 97 | ├── experiments # Experiments are stored here 98 | │ └── exp-1 99 | │ │ ├── data 100 | │ │ │ ├── data.train.0.pt 101 | │ │ │ ├── data.valid.0.pt 102 | │ │ │ ├── data.vocab.pt 103 | │ │ │ ├── preprocess-log.txt 104 | ├ │ ├── gens 105 | │ │ └── models 106 | ├── data # Dataset is here 107 | └── ... 108 | ``` 109 | 110 | # Training 111 | 112 | To train a hierarchical model on Rotowire you can run: 113 | 114 | `python train.py --config train.cfg` 115 | 116 | To train with different parameters than used in the paper, please refer to my comments in the config file, or check OpenNMT [train doc](http://opennmt.net/OpenNMT-py/options/train.html). 117 | 118 | This config file runs the training for 100 000 steps, however we manually stopped the training at 30 000. 119 | 120 | # Translating 121 | 122 | Before translating, we average the last checkpoints. If you did anything different from previous commands, please change the first few line of `average-checkpoints.py`. 123 | 124 | You can average checkpoints by running: 125 | 126 | `python average-checkpoints.py --folder exp-1 --steps 31000 32000 33000 34000 35000` 127 | 128 | Now you can simply translate the test input by running: 129 | 130 | `python translate.py --config translate.cfg` 131 | 132 | # Evaluation 133 | 134 | RG, CS and CO metrics were originaly ([see here](https://github.com/harvardnlp/data2text)) coded in Lua and python2. 135 | Because of compatibility issues with modern hardware, I have re-implemented RG 136 | in PyTorch and python3: 137 | 138 | Follow instructions at: [KaijuML/rotowire-rg-metric](https://github.com/KaijuML/rotowire-rg-metric). 139 | 140 | You can evaluate the BLEU score using [SacreBLEU](https://github.com/mjpost/sacreBLEU) 141 | from [Post, 2018](aclweb.org/anthology/W18-6319). 142 | See the repo for installation, it should be a breeze with pip. 143 | 144 | You can get the BLEU score by running: 145 | 146 | `cat experiments/exp-1/gens/test/predictions.txt | sacrebleu --force data/test_output.txt` 147 | 148 | (Note that --force is not required as it doesn't change the score computation, 149 | it just suppresses a warning because this dataset is already tokenized) 150 | 151 | Alternatively you can use any prefered method for BLEU computation. 152 | I have also checked scoring models with [NLTK](aclweb.org/anthology/W18-6319) and scores were virtually the same. 153 | -------------------------------------------------------------------------------- /average-checkpoints.py: -------------------------------------------------------------------------------- 1 | """This file is nearly word-for-word taken from the folder tools in OpenNMT""" 2 | import pkg_resources 3 | import argparse 4 | import torch 5 | import os 6 | 7 | 8 | def average_checkpoints(checkpoint_files): 9 | vocab = None 10 | opt = None 11 | avg_model = None 12 | avg_generator = None 13 | 14 | for i, checkpoint_file in enumerate(checkpoint_files): 15 | m = torch.load(checkpoint_file, map_location='cpu') 16 | model_weights = m['model'] 17 | generator_weights = m['generator'] 18 | 19 | if i == 0: 20 | vocab, opt = m['vocab'], m['opt'] 21 | avg_model = model_weights 22 | avg_generator = generator_weights 23 | else: 24 | for (k, v) in avg_model.items(): 25 | avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1) 26 | 27 | for (k, v) in avg_generator.items(): 28 | avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) 29 | 30 | return {"vocab": vocab, 'opt': opt, 'optim': None, 31 | "generator": avg_generator, "model": avg_model} 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='This script merges checkpoints of the same model') 35 | parser.add_argument('--folder', dest="folder", help="experiment name") 36 | parser.add_argument('--steps', dest="steps", nargs="+", help="checkpoints step numbers") 37 | 38 | args = parser.parse_args() 39 | 40 | expfolder = pkg_resources.resource_filename(__name__, 'experiments') 41 | model_folder = os.path.join(expfolder, args.folder, 'models') 42 | 43 | assert os.path.exists(model_folder), f'{model_folder} is not a valid folder' 44 | 45 | checkpoint_files = [os.path.join(model_folder, f'model_step_{step}.pt') for step in args.steps] 46 | 47 | avg_cp = average_checkpoints(checkpoint_files) 48 | torch.save(avg_cp, os.path.join(model_folder, 'avg_model.pt')) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /batch_translate.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import functools 3 | import argparse 4 | import torch 5 | import os 6 | import re 7 | 8 | 9 | partial_shell= = functools.partial(subprocess.run, shell=True, 10 | stdout=subprocess.PIPE) 11 | def shell(cmd): 12 | """Execute cmd as if from the command line""" 13 | completed_process = partial_shell(cmd) 14 | return completed_process.stdout.decde('utf8') 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() -------------------------------------------------------------------------------- /create-experiment.py: -------------------------------------------------------------------------------- 1 | import pkg_resources 2 | import argparse 3 | import shutil 4 | import os 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser(description="Create an experiment folder") 9 | parser.add_argument('--name', dest='name', required=True) 10 | 11 | args = parser.parse_args() 12 | 13 | folder = pkg_resources.resource_filename(__name__, 'experiments') 14 | 15 | if not os.path.exists(folder): 16 | print("Creating a folder 'experiments/' where all experiments will be stored.") 17 | os.mkdir(folder) 18 | 19 | folder = os.path.join(folder, args.name) 20 | 21 | if os.path.exists(folder): 22 | raise ValueError('An experiment with this name already exists') 23 | 24 | os.mkdir(folder) 25 | os.mkdir(os.path.join(folder, 'data')) 26 | os.mkdir(os.path.join(folder, 'models')) 27 | os.mkdir(os.path.join(folder, 'gens')) 28 | os.mkdir(os.path.join(folder, 'gens', 'test')) 29 | os.mkdir(os.path.join(folder, 'gens', 'valid')) 30 | 31 | print(f'Experiment {args.name} created.') 32 | 33 | -------------------------------------------------------------------------------- /data/make-dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this file we build the RotoWire dataset so that it can be used in OpenNMT 3 | and it can be used by our proposed hierarchical model. 4 | 5 | All tables are represented as a sequence, where every ENT_SIZE tokens are one 6 | entity, so that seq.view(ENT_SIZE, -1) separates all entities. 7 | Each entity starts with token, for learning entity repr 8 | 9 | A lot of this file comes from previous work on this dataset: 10 | https://github.com/ratishsp/data2text-plan-py/blob/master/scripts/create_dataset.py 11 | """ 12 | 13 | from more_itertools import collapse 14 | 15 | import pkg_resources 16 | import json, os, re 17 | import argparse 18 | 19 | 20 | # OpenNMT has a fancy pipe 21 | DELIM = "│" 22 | 23 | # I manually checked and there are at most 24 elements in an entity 24 | ENT_SIZE = 24 25 | 26 | 27 | bs_keys = ['START_POSITION', 'MIN', 'PTS', 'FGM', 'FGA', 'FG_PCT', 'FG3M', 28 | 'FG3A', 'FG3_PCT', 'FTM', 'FTA', 'FT_PCT', 'OREB', 'DREB', 'REB', 29 | 'AST', 'TO', 'STL', 'BLK', 'PF', 'FIRST_NAME', 'SECOND_NAME'] 30 | 31 | ls_keys = ['PTS_QTR1', 'PTS_QTR2', 'PTS_QTR3', 'PTS_QTR4', 'PTS', 'FG_PCT', 32 | 'FG3_PCT', 'FT_PCT', 'REB', 'AST', 'TOV', 'WINS', 'LOSSES', 'CITY', 33 | 'NAME'] 34 | ls_keys = [f'TEAM-{key}' for key in ls_keys] 35 | 36 | 37 | def _build_home(entry): 38 | """The team who hosted the game""" 39 | records = [DELIM.join(['', ''])] 40 | for key in ls_keys: 41 | records.append(DELIM.join([ 42 | entry['home_line'][key].replace(' ', '_'), 43 | key 44 | ])) 45 | 46 | # Contrary to previous work, IS_HOME is now a unique token at the end 47 | records.append(DELIM.join(['yes', 'IS_HOME'])) 48 | 49 | # We pad the entity to size ENT_SIZE with OpenNMT token 50 | records.extend([DELIM.join(['', ''])] * (ENT_SIZE - len(records))) 51 | return records 52 | 53 | 54 | def _build_vis(entry): 55 | """The visiting team""" 56 | records = [DELIM.join(['', ''])] 57 | for key in ls_keys: 58 | records.append(DELIM.join([ 59 | entry['vis_line'][key].replace(' ', '_'), 60 | key 61 | ])) 62 | 63 | # Contrary to previous work, IS_HOME is now a unique token at the end 64 | records.append(DELIM.join(['no', 'IS_HOME'])) 65 | 66 | # We pad the entity to size ENT_SIZE with OpenNMT token 67 | records.extend([DELIM.join(['', ''])] * (ENT_SIZE - len(records))) 68 | return records 69 | 70 | 71 | def get_player_idxs(entry): 72 | # In 4 instances the Clippers play against the Lakers 73 | # Both are from LA... We simply devide in half the players 74 | # In all 4, there are 26 players so we return 13-25 & 0-12 75 | # as it is always visiting first and home second. 76 | if entry['home_city'] == entry['vis_city']: 77 | assert entry['home_city'] == 'Los Angeles' 78 | return ([str(idx) for idx in range(13, 26)], 79 | [str(idx) for idx in range(13)]) 80 | 81 | nplayers = len(entry['box_score']['PTS']) 82 | home_players, vis_players = list(), list() 83 | for i in range(nplayers): 84 | player_city = entry['box_score']['TEAM_CITY'][str(i)] 85 | if player_city == entry['home_city']: 86 | home_players.append(str(i)) 87 | else: 88 | vis_players.append(str(i)) 89 | return home_players, vis_players 90 | 91 | 92 | def box_preprocess(entry, remove_na=True): 93 | home_players, vis_players = get_player_idxs(entry) 94 | 95 | all_entities = list() # will contain all records of the input table 96 | 97 | 98 | for is_home, player_idxs in enumerate([vis_players, home_players]): 99 | for player_idx in player_idxs: 100 | player = [DELIM.join(['', ''])] 101 | for key in bs_keys: 102 | val = entry['box_score'][key][player_idx] 103 | if remove_na and val == 'N/A': continue 104 | player.append(DELIM.join([ 105 | val.replace(' ', '_'), 106 | key 107 | ])) 108 | is_home_str = 'yes' if is_home else 'no' 109 | player.append(DELIM.join([is_home_str, 'IS_HOME'])) 110 | 111 | # We pad the entity to size ENT_SIZE with OpenNMT token 112 | player.extend([DELIM.join(['', ''])] * (ENT_SIZE - len(player))) 113 | all_entities.append(player) 114 | 115 | all_entities.append(_build_home(entry)) 116 | all_entities.append(_build_vis(entry)) 117 | return list(collapse(all_entities)) 118 | 119 | 120 | def _clean_summary(summary, tokens): 121 | """ 122 | In here, we slightly help the copy mechanism 123 | When we built the source sequence, we took all multi-words value 124 | and repalaced spaces by underscores. We replace those as well in 125 | the summaries, so that the copy mechanism knows it was a copy. 126 | It only happens with city names like "Los Angeles". 127 | """ 128 | summary = ' '.join(summary) 129 | for token in tokens: 130 | val = token.split(DELIM)[0] 131 | if '_' in val: 132 | val_no_underscore = val.replace('_', ' ') 133 | summary = summary.replace(val_no_underscore, val) 134 | return summary 135 | 136 | 137 | if __name__ == '__main__': 138 | parser = argparse.ArgumentParser() 139 | parser.add_argument('--folder', dest='folder', required=True, 140 | help='Save the preprocessed dataset to this folder') 141 | parser.add_argument('--keep-na', dest='keep_na', action='store_true', 142 | help='Activate to keep NA in the dataset') 143 | 144 | args = parser.parse_args() 145 | 146 | if not os.path.exists(args.folder): 147 | print('Creating folder to store preprocessed dataset at:') 148 | print(args.folder) 149 | os.mkdir(args.folder) 150 | 151 | for setname in ['train', 'valid', 'test']: 152 | filename = f'rotowire/{setname}.json' 153 | filename = pkg_resources.resource_filename(__name__, filename) 154 | with open(filename, encoding='utf8', mode='r') as f: 155 | data = json.load(f) 156 | 157 | input_filename = os.path.join(args.folder, f'{setname}_input.txt') 158 | output_filename = os.path.join(args.folder, f'{setname}_output.txt') 159 | with open(input_filename, mode='w', encoding='utf8') as inputf: 160 | with open(output_filename, mode='w', encoding='utf8') as outputf: 161 | for entry in data: 162 | input = box_preprocess(entry) 163 | inputf.write(' '.join(input) + '\n') 164 | outputf.write(_clean_summary(entry['summary'], input) + '\n') 165 | -------------------------------------------------------------------------------- /onmt/__init__.py: -------------------------------------------------------------------------------- 1 | """ Main entry point of the ONMT library """ 2 | from __future__ import division, print_function 3 | 4 | import onmt.inputters 5 | import onmt.encoders 6 | import onmt.decoders 7 | import onmt.models 8 | import onmt.utils 9 | import onmt.modules 10 | from onmt.trainer import Trainer 11 | import sys 12 | import onmt.utils.optimizers 13 | onmt.utils.optimizers.Optim = onmt.utils.optimizers.Optimizer 14 | sys.modules["onmt.Optim"] = onmt.utils.optimizers 15 | 16 | ENT_SIZE = 24 # Used for hierarchical training on RotoWire 17 | 18 | # For Flake 19 | __all__ = [onmt.inputters, onmt.encoders, onmt.decoders, onmt.models, 20 | onmt.utils, onmt.modules, "Trainer"] 21 | 22 | __version__ = "1.0.0" 23 | -------------------------------------------------------------------------------- /onmt/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaijuML/data-to-text-hierarchical/da88d2d4491266fccc39ac1cc1fbb56bd7bbc30c/onmt/bin/__init__.py -------------------------------------------------------------------------------- /onmt/bin/average_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import torch 4 | 5 | 6 | def average_models(model_files, fp32=False): 7 | vocab = None 8 | opt = None 9 | avg_model = None 10 | avg_generator = None 11 | 12 | for i, model_file in enumerate(model_files): 13 | m = torch.load(model_file, map_location='cpu') 14 | model_weights = m['model'] 15 | generator_weights = m['generator'] 16 | 17 | if fp32: 18 | for k, v in model_weights.items(): 19 | model_weights[k] = v.float() 20 | for k, v in generator_weights.items(): 21 | generator_weights[k] = v.float() 22 | 23 | if i == 0: 24 | vocab, opt = m['vocab'], m['opt'] 25 | avg_model = model_weights 26 | avg_generator = generator_weights 27 | else: 28 | for (k, v) in avg_model.items(): 29 | avg_model[k].mul_(i).add_(model_weights[k]).div_(i + 1) 30 | 31 | for (k, v) in avg_generator.items(): 32 | avg_generator[k].mul_(i).add_(generator_weights[k]).div_(i + 1) 33 | 34 | final = {"vocab": vocab, "opt": opt, "optim": None, 35 | "generator": avg_generator, "model": avg_model} 36 | return final 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser(description="") 41 | parser.add_argument("-models", "-m", nargs="+", required=True, 42 | help="List of models") 43 | parser.add_argument("-output", "-o", required=True, 44 | help="Output file") 45 | parser.add_argument("-fp32", "-f", action="store_true", 46 | help="Cast params to float32") 47 | opt = parser.parse_args() 48 | 49 | final = average_models(opt.models, opt.fp32) 50 | torch.save(final, opt.output) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /onmt/bin/server.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import configargparse 3 | 4 | from flask import Flask, jsonify, request 5 | from onmt.translate import TranslationServer, ServerModelError 6 | 7 | STATUS_OK = "ok" 8 | STATUS_ERROR = "error" 9 | 10 | 11 | def start(config_file, 12 | url_root="./translator", 13 | host="0.0.0.0", 14 | port=5000, 15 | debug=True): 16 | def prefix_route(route_function, prefix='', mask='{0}{1}'): 17 | def newroute(route, *args, **kwargs): 18 | return route_function(mask.format(prefix, route), *args, **kwargs) 19 | return newroute 20 | 21 | app = Flask(__name__) 22 | app.route = prefix_route(app.route, url_root) 23 | translation_server = TranslationServer() 24 | translation_server.start(config_file) 25 | 26 | @app.route('/models', methods=['GET']) 27 | def get_models(): 28 | out = translation_server.list_models() 29 | return jsonify(out) 30 | 31 | @app.route('/health', methods=['GET']) 32 | def health(): 33 | out = {} 34 | out['status'] = STATUS_OK 35 | return jsonify(out) 36 | 37 | @app.route('/clone_model/', methods=['POST']) 38 | def clone_model(model_id): 39 | out = {} 40 | data = request.get_json(force=True) 41 | timeout = -1 42 | if 'timeout' in data: 43 | timeout = data['timeout'] 44 | del data['timeout'] 45 | 46 | opt = data.get('opt', None) 47 | try: 48 | model_id, load_time = translation_server.clone_model( 49 | model_id, opt, timeout) 50 | except ServerModelError as e: 51 | out['status'] = STATUS_ERROR 52 | out['error'] = str(e) 53 | else: 54 | out['status'] = STATUS_OK 55 | out['model_id'] = model_id 56 | out['load_time'] = load_time 57 | 58 | return jsonify(out) 59 | 60 | @app.route('/unload_model/', methods=['GET']) 61 | def unload_model(model_id): 62 | out = {"model_id": model_id} 63 | 64 | try: 65 | translation_server.unload_model(model_id) 66 | out['status'] = STATUS_OK 67 | except Exception as e: 68 | out['status'] = STATUS_ERROR 69 | out['error'] = str(e) 70 | 71 | return jsonify(out) 72 | 73 | @app.route('/translate', methods=['POST']) 74 | def translate(): 75 | inputs = request.get_json(force=True) 76 | out = {} 77 | try: 78 | trans, scores, n_best, _, aligns = translation_server.run(inputs) 79 | assert len(trans) == len(inputs) * n_best 80 | assert len(scores) == len(inputs) * n_best 81 | assert len(aligns) == len(inputs) * n_best 82 | 83 | out = [[] for _ in range(n_best)] 84 | for i in range(len(trans)): 85 | response = {"src": inputs[i // n_best]['src'], "tgt": trans[i], 86 | "n_best": n_best, "pred_score": scores[i]} 87 | if aligns[i] is not None: 88 | response["align"] = aligns[i] 89 | out[i % n_best].append(response) 90 | except ServerModelError as e: 91 | out['error'] = str(e) 92 | out['status'] = STATUS_ERROR 93 | 94 | return jsonify(out) 95 | 96 | @app.route('/to_cpu/', methods=['GET']) 97 | def to_cpu(model_id): 98 | out = {'model_id': model_id} 99 | translation_server.models[model_id].to_cpu() 100 | 101 | out['status'] = STATUS_OK 102 | return jsonify(out) 103 | 104 | @app.route('/to_gpu/', methods=['GET']) 105 | def to_gpu(model_id): 106 | out = {'model_id': model_id} 107 | translation_server.models[model_id].to_gpu() 108 | 109 | out['status'] = STATUS_OK 110 | return jsonify(out) 111 | 112 | app.run(debug=debug, host=host, port=port, use_reloader=False, 113 | threaded=True) 114 | 115 | 116 | def _get_parser(): 117 | parser = configargparse.ArgumentParser( 118 | config_file_parser_class=configargparse.YAMLConfigFileParser, 119 | description="OpenNMT-py REST Server") 120 | parser.add_argument("--ip", type=str, default="0.0.0.0") 121 | parser.add_argument("--port", type=int, default="5000") 122 | parser.add_argument("--url_root", type=str, default="/translator") 123 | parser.add_argument("--debug", "-d", action="store_true") 124 | parser.add_argument("--config", "-c", type=str, 125 | default="./available_models/conf.json") 126 | return parser 127 | 128 | 129 | def main(): 130 | parser = _get_parser() 131 | args = parser.parse_args() 132 | start(args.config, url_root=args.url_root, host=args.ip, port=args.port, 133 | debug=args.debug) 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | -------------------------------------------------------------------------------- /onmt/bin/translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from __future__ import unicode_literals 5 | 6 | from onmt.utils.logging import init_logger 7 | from onmt.utils.misc import split_corpus 8 | from onmt.translate.translator import build_translator 9 | 10 | import onmt.opts as opts 11 | from onmt.utils.parse import ArgumentParser 12 | 13 | 14 | def translate(opt): 15 | ArgumentParser.validate_translate_opts(opt) 16 | logger = init_logger(opt.log_file) 17 | 18 | translator = build_translator(opt, report_score=True) 19 | src_shards = split_corpus(opt.src, opt.shard_size) 20 | tgt_shards = split_corpus(opt.tgt, opt.shard_size) 21 | shard_pairs = zip(src_shards, tgt_shards) 22 | 23 | for i, (src_shard, tgt_shard) in enumerate(shard_pairs): 24 | logger.info("Translating shard %d." % i) 25 | translator.translate( 26 | src=src_shard, 27 | tgt=tgt_shard, 28 | src_dir=opt.src_dir, 29 | batch_size=opt.batch_size, 30 | batch_type=opt.batch_type, 31 | attn_debug=opt.attn_debug, 32 | align_debug=opt.align_debug 33 | ) 34 | 35 | 36 | def _get_parser(): 37 | parser = ArgumentParser(description='translate.py') 38 | 39 | opts.config_opts(parser) 40 | opts.translate_opts(parser) 41 | return parser 42 | 43 | 44 | def main(): 45 | parser = _get_parser() 46 | 47 | opt = parser.parse_args() 48 | translate(opt) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /onmt/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining decoders.""" 2 | from onmt.decoders.decoder import DecoderBase, InputFeedRNNDecoder, \ 3 | StdRNNDecoder 4 | from onmt.decoders.transformer import TransformerDecoder 5 | from onmt.decoders.cnn_decoder import CNNDecoder 6 | from onmt.decoders.hierarchical_decoder import HierarchicalRNNDecoder 7 | 8 | 9 | str2dec = {"rnn": StdRNNDecoder, "ifrnn": InputFeedRNNDecoder, 10 | "cnn": CNNDecoder, "transformer": TransformerDecoder, 11 | "hrnn": HierarchicalRNNDecoder} 12 | 13 | __all__ = ["DecoderBase", "TransformerDecoder", "StdRNNDecoder", "CNNDecoder", 14 | "InputFeedRNNDecoder", "str2dec", "HierarchicalRNNDecoder"] 15 | -------------------------------------------------------------------------------- /onmt/decoders/cnn_decoder.py: -------------------------------------------------------------------------------- 1 | """Implementation of the CNN Decoder part of 2 | "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules import ConvMultiStepAttention, GlobalAttention 8 | from onmt.utils.cnn_factory import shape_transform, GatedConv 9 | from onmt.decoders.decoder import DecoderBase 10 | 11 | SCALE_WEIGHT = 0.5 ** 0.5 12 | 13 | 14 | class CNNDecoder(DecoderBase): 15 | """Decoder based on "Convolutional Sequence to Sequence Learning" 16 | :cite:`DBLP:journals/corr/GehringAGYD17`. 17 | 18 | Consists of residual convolutional layers, with ConvMultiStepAttention. 19 | """ 20 | 21 | def __init__(self, num_layers, hidden_size, attn_type, 22 | copy_attn, cnn_kernel_width, dropout, embeddings, 23 | copy_attn_type): 24 | super(CNNDecoder, self).__init__() 25 | 26 | self.cnn_kernel_width = cnn_kernel_width 27 | self.embeddings = embeddings 28 | 29 | # Decoder State 30 | self.state = {} 31 | 32 | input_size = self.embeddings.embedding_size 33 | self.linear = nn.Linear(input_size, hidden_size) 34 | self.conv_layers = nn.ModuleList( 35 | [GatedConv(hidden_size, cnn_kernel_width, dropout, True) 36 | for i in range(num_layers)] 37 | ) 38 | self.attn_layers = nn.ModuleList( 39 | [ConvMultiStepAttention(hidden_size) for i in range(num_layers)] 40 | ) 41 | 42 | # CNNDecoder has its own attention mechanism. 43 | # Set up a separate copy attention layer if needed. 44 | assert not copy_attn, "Copy mechanism not yet tested in conv2conv" 45 | if copy_attn: 46 | self.copy_attn = GlobalAttention( 47 | hidden_size, attn_type=copy_attn_type) 48 | else: 49 | self.copy_attn = None 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings): 53 | """Alternate constructor.""" 54 | return cls( 55 | opt.dec_layers, 56 | opt.dec_rnn_size, 57 | opt.global_attention, 58 | opt.copy_attn, 59 | opt.cnn_kernel_width, 60 | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, 61 | embeddings, 62 | opt.copy_attn_type) 63 | 64 | def init_state(self, _, memory_bank, enc_hidden): 65 | """Init decoder state.""" 66 | self.state["src"] = (memory_bank + enc_hidden) * SCALE_WEIGHT 67 | self.state["previous_input"] = None 68 | 69 | def map_state(self, fn): 70 | self.state["src"] = fn(self.state["src"], 1) 71 | if self.state["previous_input"] is not None: 72 | self.state["previous_input"] = fn(self.state["previous_input"], 1) 73 | 74 | def detach_state(self): 75 | self.state["previous_input"] = self.state["previous_input"].detach() 76 | 77 | def forward(self, tgt, memory_bank, step=None, **kwargs): 78 | """ See :obj:`onmt.modules.RNNDecoderBase.forward()`""" 79 | 80 | if self.state["previous_input"] is not None: 81 | tgt = torch.cat([self.state["previous_input"], tgt], 0) 82 | 83 | dec_outs = [] 84 | attns = {"std": []} 85 | if self.copy_attn is not None: 86 | attns["copy"] = [] 87 | 88 | emb = self.embeddings(tgt) 89 | assert emb.dim() == 3 # len x batch x embedding_dim 90 | 91 | tgt_emb = emb.transpose(0, 1).contiguous() 92 | # The output of CNNEncoder. 93 | src_memory_bank_t = memory_bank.transpose(0, 1).contiguous() 94 | # The combination of output of CNNEncoder and source embeddings. 95 | src_memory_bank_c = self.state["src"].transpose(0, 1).contiguous() 96 | 97 | emb_reshape = tgt_emb.contiguous().view( 98 | tgt_emb.size(0) * tgt_emb.size(1), -1) 99 | linear_out = self.linear(emb_reshape) 100 | x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1) 101 | x = shape_transform(x) 102 | 103 | pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1) 104 | 105 | pad = pad.type_as(x) 106 | base_target_emb = x 107 | 108 | for conv, attention in zip(self.conv_layers, self.attn_layers): 109 | new_target_input = torch.cat([pad, x], 2) 110 | out = conv(new_target_input) 111 | c, attn = attention(base_target_emb, out, 112 | src_memory_bank_t, src_memory_bank_c) 113 | x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT 114 | output = x.squeeze(3).transpose(1, 2) 115 | 116 | # Process the result and update the attentions. 117 | dec_outs = output.transpose(0, 1).contiguous() 118 | if self.state["previous_input"] is not None: 119 | dec_outs = dec_outs[self.state["previous_input"].size(0):] 120 | attn = attn[:, self.state["previous_input"].size(0):].squeeze() 121 | attn = torch.stack([attn]) 122 | attns["std"] = attn 123 | if self.copy_attn is not None: 124 | attns["copy"] = attn 125 | 126 | # Update the state. 127 | self.state["previous_input"] = tgt 128 | # TODO change the way attns is returned dict => list or tuple (onnx) 129 | return dec_outs, attns 130 | 131 | def update_dropout(self, dropout): 132 | for layer in self.conv_layers: 133 | layer.dropout.p = dropout 134 | -------------------------------------------------------------------------------- /onmt/decoders/ensemble.py: -------------------------------------------------------------------------------- 1 | """Ensemble decoding. 2 | 3 | Decodes using multiple models simultaneously, 4 | combining their prediction distributions by averaging. 5 | All models in the ensemble must share a target vocabulary. 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from onmt.encoders.encoder import EncoderBase 12 | from onmt.decoders.decoder import DecoderBase 13 | from onmt.models import NMTModel 14 | import onmt.model_builder 15 | 16 | 17 | class EnsembleDecoderOutput(object): 18 | """Wrapper around multiple decoder final hidden states.""" 19 | def __init__(self, model_dec_outs): 20 | self.model_dec_outs = tuple(model_dec_outs) 21 | 22 | def squeeze(self, dim=None): 23 | """Delegate squeeze to avoid modifying 24 | :func:`onmt.translate.translator.Translator.translate_batch()` 25 | """ 26 | return EnsembleDecoderOutput([ 27 | x.squeeze(dim) for x in self.model_dec_outs]) 28 | 29 | def __getitem__(self, index): 30 | return self.model_dec_outs[index] 31 | 32 | 33 | class EnsembleEncoder(EncoderBase): 34 | """Dummy Encoder that delegates to individual real Encoders.""" 35 | def __init__(self, model_encoders): 36 | super(EnsembleEncoder, self).__init__() 37 | self.model_encoders = nn.ModuleList(model_encoders) 38 | 39 | def forward(self, src, lengths=None): 40 | enc_hidden, memory_bank, _ = zip(*[ 41 | model_encoder(src, lengths) 42 | for model_encoder in self.model_encoders]) 43 | return enc_hidden, memory_bank, lengths 44 | 45 | 46 | class EnsembleDecoder(DecoderBase): 47 | """Dummy Decoder that delegates to individual real Decoders.""" 48 | def __init__(self, model_decoders): 49 | model_decoders = nn.ModuleList(model_decoders) 50 | attentional = any([dec.attentional for dec in model_decoders]) 51 | super(EnsembleDecoder, self).__init__(attentional) 52 | self.model_decoders = model_decoders 53 | 54 | def forward(self, tgt, memory_bank, memory_lengths=None, step=None, 55 | **kwargs): 56 | """See :func:`onmt.decoders.decoder.DecoderBase.forward()`.""" 57 | # Memory_lengths is a single tensor shared between all models. 58 | # This assumption will not hold if Translator is modified 59 | # to calculate memory_lengths as something other than the length 60 | # of the input. 61 | dec_outs, attns = zip(*[ 62 | model_decoder( 63 | tgt, memory_bank[i], 64 | memory_lengths=memory_lengths, step=step) 65 | for i, model_decoder in enumerate(self.model_decoders)]) 66 | mean_attns = self.combine_attns(attns) 67 | return EnsembleDecoderOutput(dec_outs), mean_attns 68 | 69 | def combine_attns(self, attns): 70 | result = {} 71 | for key in attns[0].keys(): 72 | result[key] = torch.stack( 73 | [attn[key] for attn in attns if attn[key] is not None]).mean(0) 74 | return result 75 | 76 | def init_state(self, src, memory_bank, enc_hidden): 77 | """ See :obj:`RNNDecoderBase.init_state()` """ 78 | for i, model_decoder in enumerate(self.model_decoders): 79 | model_decoder.init_state(src, memory_bank[i], enc_hidden[i]) 80 | 81 | def map_state(self, fn): 82 | for model_decoder in self.model_decoders: 83 | model_decoder.map_state(fn) 84 | 85 | 86 | class EnsembleGenerator(nn.Module): 87 | """ 88 | Dummy Generator that delegates to individual real Generators, 89 | and then averages the resulting target distributions. 90 | """ 91 | def __init__(self, model_generators, raw_probs=False): 92 | super(EnsembleGenerator, self).__init__() 93 | self.model_generators = nn.ModuleList(model_generators) 94 | self._raw_probs = raw_probs 95 | 96 | def forward(self, hidden, attn=None, src_map=None): 97 | """ 98 | Compute a distribution over the target dictionary 99 | by averaging distributions from models in the ensemble. 100 | All models in the ensemble must share a target vocabulary. 101 | """ 102 | distributions = torch.stack( 103 | [mg(h) if attn is None else mg(h, attn, src_map) 104 | for h, mg in zip(hidden, self.model_generators)] 105 | ) 106 | if self._raw_probs: 107 | return torch.log(torch.exp(distributions).mean(0)) 108 | else: 109 | return distributions.mean(0) 110 | 111 | 112 | class EnsembleModel(NMTModel): 113 | """Dummy NMTModel wrapping individual real NMTModels.""" 114 | def __init__(self, models, raw_probs=False): 115 | encoder = EnsembleEncoder(model.encoder for model in models) 116 | decoder = EnsembleDecoder(model.decoder for model in models) 117 | super(EnsembleModel, self).__init__(encoder, decoder) 118 | self.generator = EnsembleGenerator( 119 | [model.generator for model in models], raw_probs) 120 | self.models = nn.ModuleList(models) 121 | 122 | 123 | def load_test_model(opt): 124 | """Read in multiple models for ensemble.""" 125 | shared_fields = None 126 | shared_model_opt = None 127 | models = [] 128 | for model_path in opt.models: 129 | fields, model, model_opt = \ 130 | onmt.model_builder.load_test_model(opt, model_path=model_path) 131 | if shared_fields is None: 132 | shared_fields = fields 133 | else: 134 | for key, field in fields.items(): 135 | try: 136 | f_iter = iter(field) 137 | except TypeError: 138 | f_iter = [(key, field)] 139 | for sn, sf in f_iter: 140 | if sf is not None and 'vocab' in sf.__dict__: 141 | sh_field = shared_fields[key] 142 | try: 143 | sh_f_iter = iter(sh_field) 144 | except TypeError: 145 | sh_f_iter = [(key, sh_field)] 146 | sh_f_dict = dict(sh_f_iter) 147 | assert sf.vocab.stoi == sh_f_dict[sn].vocab.stoi, \ 148 | "Ensemble models must use the same " \ 149 | "preprocessed data" 150 | models.append(model) 151 | if shared_model_opt is None: 152 | shared_model_opt = model_opt 153 | ensemble_model = EnsembleModel(models, opt.avg_raw_probs) 154 | return shared_fields, ensemble_model, shared_model_opt 155 | -------------------------------------------------------------------------------- /onmt/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining encoders.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.encoders.transformer import TransformerEncoder 4 | from onmt.encoders.rnn_encoder import RNNEncoder 5 | from onmt.encoders.cnn_encoder import CNNEncoder 6 | from onmt.encoders.mean_encoder import MeanEncoder 7 | from onmt.encoders.audio_encoder import AudioEncoder 8 | from onmt.encoders.image_encoder import ImageEncoder 9 | from onmt.encoders.hierarchical_transformer import HierarchicalTransformerEncoder 10 | 11 | 12 | str2enc = {"rnn": RNNEncoder, "brnn": RNNEncoder, "cnn": CNNEncoder, 13 | "transformer": TransformerEncoder, "img": ImageEncoder, 14 | "audio": AudioEncoder, "mean": MeanEncoder, "htransformer": HierarchicalTransformerEncoder} 15 | 16 | __all__ = ["EncoderBase", "TransformerEncoder", "RNNEncoder", "CNNEncoder", 17 | "MeanEncoder", "str2enc", "HierarchicalTransformerEncoder"] 18 | -------------------------------------------------------------------------------- /onmt/encoders/audio_encoder.py: -------------------------------------------------------------------------------- 1 | """Audio encoder""" 2 | import math 3 | 4 | import torch.nn as nn 5 | 6 | from torch.nn.utils.rnn import pack_padded_sequence as pack 7 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 8 | 9 | from onmt.utils.rnn_factory import rnn_factory 10 | from onmt.encoders.encoder import EncoderBase 11 | 12 | 13 | class AudioEncoder(EncoderBase): 14 | """A simple encoder CNN -> RNN for audio input. 15 | 16 | Args: 17 | rnn_type (str): Type of RNN (e.g. GRU, LSTM, etc). 18 | enc_layers (int): Number of encoder layers. 19 | dec_layers (int): Number of decoder layers. 20 | brnn (bool): Bidirectional encoder. 21 | enc_rnn_size (int): Size of hidden states of the rnn. 22 | dec_rnn_size (int): Size of the decoder hidden states. 23 | enc_pooling (str): A comma separated list either of length 1 24 | or of length ``enc_layers`` specifying the pooling amount. 25 | dropout (float): dropout probablity. 26 | sample_rate (float): input spec 27 | window_size (int): input spec 28 | """ 29 | 30 | def __init__(self, rnn_type, enc_layers, dec_layers, brnn, 31 | enc_rnn_size, dec_rnn_size, enc_pooling, dropout, 32 | sample_rate, window_size): 33 | super(AudioEncoder, self).__init__() 34 | self.enc_layers = enc_layers 35 | self.rnn_type = rnn_type 36 | self.dec_layers = dec_layers 37 | num_directions = 2 if brnn else 1 38 | self.num_directions = num_directions 39 | assert enc_rnn_size % num_directions == 0 40 | enc_rnn_size_real = enc_rnn_size // num_directions 41 | assert dec_rnn_size % num_directions == 0 42 | self.dec_rnn_size = dec_rnn_size 43 | dec_rnn_size_real = dec_rnn_size // num_directions 44 | self.dec_rnn_size_real = dec_rnn_size_real 45 | self.dec_rnn_size = dec_rnn_size 46 | input_size = int(math.floor((sample_rate * window_size) / 2) + 1) 47 | enc_pooling = enc_pooling.split(',') 48 | assert len(enc_pooling) == enc_layers or len(enc_pooling) == 1 49 | if len(enc_pooling) == 1: 50 | enc_pooling = enc_pooling * enc_layers 51 | enc_pooling = [int(p) for p in enc_pooling] 52 | self.enc_pooling = enc_pooling 53 | 54 | if type(dropout) is not list: 55 | dropout = [dropout] 56 | if max(dropout) > 0: 57 | self.dropout = nn.Dropout(dropout[0]) 58 | else: 59 | self.dropout = None 60 | self.W = nn.Linear(enc_rnn_size, dec_rnn_size, bias=False) 61 | self.batchnorm_0 = nn.BatchNorm1d(enc_rnn_size, affine=True) 62 | self.rnn_0, self.no_pack_padded_seq = \ 63 | rnn_factory(rnn_type, 64 | input_size=input_size, 65 | hidden_size=enc_rnn_size_real, 66 | num_layers=1, 67 | dropout=dropout[0], 68 | bidirectional=brnn) 69 | self.pool_0 = nn.MaxPool1d(enc_pooling[0]) 70 | for l in range(enc_layers - 1): 71 | batchnorm = nn.BatchNorm1d(enc_rnn_size, affine=True) 72 | rnn, _ = \ 73 | rnn_factory(rnn_type, 74 | input_size=enc_rnn_size, 75 | hidden_size=enc_rnn_size_real, 76 | num_layers=1, 77 | dropout=dropout[0], 78 | bidirectional=brnn) 79 | setattr(self, 'rnn_%d' % (l + 1), rnn) 80 | setattr(self, 'pool_%d' % (l + 1), 81 | nn.MaxPool1d(enc_pooling[l + 1])) 82 | setattr(self, 'batchnorm_%d' % (l + 1), batchnorm) 83 | 84 | @classmethod 85 | def from_opt(cls, opt, embeddings=None): 86 | """Alternate constructor.""" 87 | if embeddings is not None: 88 | raise ValueError("Cannot use embeddings with AudioEncoder.") 89 | return cls( 90 | opt.rnn_type, 91 | opt.enc_layers, 92 | opt.dec_layers, 93 | opt.brnn, 94 | opt.enc_rnn_size, 95 | opt.dec_rnn_size, 96 | opt.audio_enc_pooling, 97 | opt.dropout, 98 | opt.sample_rate, 99 | opt.window_size) 100 | 101 | def forward(self, src, lengths=None): 102 | """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" 103 | batch_size, _, nfft, t = src.size() 104 | src = src.transpose(0, 1).transpose(0, 3).contiguous() \ 105 | .view(t, batch_size, nfft) 106 | orig_lengths = lengths 107 | lengths = lengths.view(-1).tolist() 108 | 109 | for l in range(self.enc_layers): 110 | rnn = getattr(self, 'rnn_%d' % l) 111 | pool = getattr(self, 'pool_%d' % l) 112 | batchnorm = getattr(self, 'batchnorm_%d' % l) 113 | stride = self.enc_pooling[l] 114 | packed_emb = pack(src, lengths) 115 | memory_bank, tmp = rnn(packed_emb) 116 | memory_bank = unpack(memory_bank)[0] 117 | t, _, _ = memory_bank.size() 118 | memory_bank = memory_bank.transpose(0, 2) 119 | memory_bank = pool(memory_bank) 120 | lengths = [int(math.floor((length - stride) / stride + 1)) 121 | for length in lengths] 122 | memory_bank = memory_bank.transpose(0, 2) 123 | src = memory_bank 124 | t, _, num_feat = src.size() 125 | src = batchnorm(src.contiguous().view(-1, num_feat)) 126 | src = src.view(t, -1, num_feat) 127 | if self.dropout and l + 1 != self.enc_layers: 128 | src = self.dropout(src) 129 | 130 | memory_bank = memory_bank.contiguous().view(-1, memory_bank.size(2)) 131 | memory_bank = self.W(memory_bank).view(-1, batch_size, 132 | self.dec_rnn_size) 133 | 134 | state = memory_bank.new_full((self.dec_layers * self.num_directions, 135 | batch_size, self.dec_rnn_size_real), 0) 136 | if self.rnn_type == 'LSTM': 137 | # The encoder hidden is (layers*directions) x batch x dim. 138 | encoder_final = (state, state) 139 | else: 140 | encoder_final = state 141 | return encoder_final, memory_bank, orig_lengths.new_tensor(lengths) 142 | 143 | def update_dropout(self, dropout): 144 | self.dropout.p = dropout 145 | for i in range(self.enc_layers - 1): 146 | getattr(self, 'rnn_%d' % i).dropout = dropout 147 | -------------------------------------------------------------------------------- /onmt/encoders/cnn_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch.nn as nn 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | from onmt.utils.cnn_factory import shape_transform, StackedCNN 8 | 9 | SCALE_WEIGHT = 0.5 ** 0.5 10 | 11 | 12 | class CNNEncoder(EncoderBase): 13 | """Encoder based on "Convolutional Sequence to Sequence Learning" 14 | :cite:`DBLP:journals/corr/GehringAGYD17`. 15 | """ 16 | 17 | def __init__(self, num_layers, hidden_size, 18 | cnn_kernel_width, dropout, embeddings): 19 | super(CNNEncoder, self).__init__() 20 | 21 | self.embeddings = embeddings 22 | input_size = embeddings.embedding_size 23 | self.linear = nn.Linear(input_size, hidden_size) 24 | self.cnn = StackedCNN(num_layers, hidden_size, 25 | cnn_kernel_width, dropout) 26 | 27 | @classmethod 28 | def from_opt(cls, opt, embeddings): 29 | """Alternate constructor.""" 30 | return cls( 31 | opt.enc_layers, 32 | opt.enc_rnn_size, 33 | opt.cnn_kernel_width, 34 | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, 35 | embeddings) 36 | 37 | def forward(self, input, lengths=None, hidden=None): 38 | """See :class:`onmt.modules.EncoderBase.forward()`""" 39 | self._check_args(input, lengths, hidden) 40 | 41 | emb = self.embeddings(input) 42 | # s_len, batch, emb_dim = emb.size() 43 | 44 | emb = emb.transpose(0, 1).contiguous() 45 | emb_reshape = emb.view(emb.size(0) * emb.size(1), -1) 46 | emb_remap = self.linear(emb_reshape) 47 | emb_remap = emb_remap.view(emb.size(0), emb.size(1), -1) 48 | emb_remap = shape_transform(emb_remap) 49 | out = self.cnn(emb_remap) 50 | 51 | return emb_remap.squeeze(3).transpose(0, 1).contiguous(), \ 52 | out.squeeze(3).transpose(0, 1).contiguous(), lengths 53 | 54 | def update_dropout(self, dropout): 55 | self.cnn.dropout.p = dropout 56 | -------------------------------------------------------------------------------- /onmt/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | """Base class for encoders and generic multi encoders.""" 2 | 3 | import torch.nn as nn 4 | 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class EncoderBase(nn.Module): 9 | """ 10 | Base encoder class. Specifies the interface used by different encoder types 11 | and required by :class:`onmt.Models.NMTModel`. 12 | 13 | .. mermaid:: 14 | 15 | graph BT 16 | A[Input] 17 | subgraph RNN 18 | C[Pos 1] 19 | D[Pos 2] 20 | E[Pos N] 21 | end 22 | F[Memory_Bank] 23 | G[Final] 24 | A-->C 25 | A-->D 26 | A-->E 27 | C-->F 28 | D-->F 29 | E-->F 30 | E-->G 31 | """ 32 | 33 | @classmethod 34 | def from_opt(cls, opt, embeddings=None): 35 | raise NotImplementedError 36 | 37 | def _check_args(self, src, lengths=None, hidden=None): 38 | n_batch = src.size(1) 39 | if lengths is not None: 40 | n_batch_, = lengths.size() 41 | aeq(n_batch, n_batch_) 42 | 43 | def forward(self, src, lengths=None): 44 | """ 45 | Args: 46 | src (LongTensor): 47 | padded sequences of sparse indices ``(src_len, batch, nfeat)`` 48 | lengths (LongTensor): length of each sequence ``(batch,)`` 49 | 50 | 51 | Returns: 52 | (FloatTensor, FloatTensor): 53 | 54 | * final encoder state, used to initialize decoder 55 | * memory bank for attention, ``(src_len, batch, hidden)`` 56 | """ 57 | 58 | raise NotImplementedError 59 | -------------------------------------------------------------------------------- /onmt/encoders/image_encoder.py: -------------------------------------------------------------------------------- 1 | """Image Encoder.""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | 6 | from onmt.encoders.encoder import EncoderBase 7 | 8 | 9 | class ImageEncoder(EncoderBase): 10 | """A simple encoder CNN -> RNN for image src. 11 | 12 | Args: 13 | num_layers (int): number of encoder layers. 14 | bidirectional (bool): bidirectional encoder. 15 | rnn_size (int): size of hidden states of the rnn. 16 | dropout (float): dropout probablity. 17 | """ 18 | 19 | def __init__(self, num_layers, bidirectional, rnn_size, dropout, 20 | image_chanel_size=3): 21 | super(ImageEncoder, self).__init__() 22 | self.num_layers = num_layers 23 | self.num_directions = 2 if bidirectional else 1 24 | self.hidden_size = rnn_size 25 | 26 | self.layer1 = nn.Conv2d(image_chanel_size, 64, kernel_size=(3, 3), 27 | padding=(1, 1), stride=(1, 1)) 28 | self.layer2 = nn.Conv2d(64, 128, kernel_size=(3, 3), 29 | padding=(1, 1), stride=(1, 1)) 30 | self.layer3 = nn.Conv2d(128, 256, kernel_size=(3, 3), 31 | padding=(1, 1), stride=(1, 1)) 32 | self.layer4 = nn.Conv2d(256, 256, kernel_size=(3, 3), 33 | padding=(1, 1), stride=(1, 1)) 34 | self.layer5 = nn.Conv2d(256, 512, kernel_size=(3, 3), 35 | padding=(1, 1), stride=(1, 1)) 36 | self.layer6 = nn.Conv2d(512, 512, kernel_size=(3, 3), 37 | padding=(1, 1), stride=(1, 1)) 38 | 39 | self.batch_norm1 = nn.BatchNorm2d(256) 40 | self.batch_norm2 = nn.BatchNorm2d(512) 41 | self.batch_norm3 = nn.BatchNorm2d(512) 42 | 43 | src_size = 512 44 | dropout = dropout[0] if type(dropout) is list else dropout 45 | self.rnn = nn.LSTM(src_size, int(rnn_size / self.num_directions), 46 | num_layers=num_layers, 47 | dropout=dropout, 48 | bidirectional=bidirectional) 49 | self.pos_lut = nn.Embedding(1000, src_size) 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings=None): 53 | """Alternate constructor.""" 54 | if embeddings is not None: 55 | raise ValueError("Cannot use embeddings with ImageEncoder.") 56 | # why is the model_opt.__dict__ check necessary? 57 | if "image_channel_size" not in opt.__dict__: 58 | image_channel_size = 3 59 | else: 60 | image_channel_size = opt.image_channel_size 61 | return cls( 62 | opt.enc_layers, 63 | opt.brnn, 64 | opt.enc_rnn_size, 65 | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, 66 | image_channel_size 67 | ) 68 | 69 | def load_pretrained_vectors(self, opt): 70 | """Pass in needed options only when modify function definition.""" 71 | pass 72 | 73 | def forward(self, src, lengths=None): 74 | """See :func:`onmt.encoders.encoder.EncoderBase.forward()`""" 75 | 76 | batch_size = src.size(0) 77 | # (batch_size, 64, imgH, imgW) 78 | # layer 1 79 | src = F.relu(self.layer1(src[:, :, :, :] - 0.5), True) 80 | 81 | # (batch_size, 64, imgH/2, imgW/2) 82 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 83 | 84 | # (batch_size, 128, imgH/2, imgW/2) 85 | # layer 2 86 | src = F.relu(self.layer2(src), True) 87 | 88 | # (batch_size, 128, imgH/2/2, imgW/2/2) 89 | src = F.max_pool2d(src, kernel_size=(2, 2), stride=(2, 2)) 90 | 91 | # (batch_size, 256, imgH/2/2, imgW/2/2) 92 | # layer 3 93 | # batch norm 1 94 | src = F.relu(self.batch_norm1(self.layer3(src)), True) 95 | 96 | # (batch_size, 256, imgH/2/2, imgW/2/2) 97 | # layer4 98 | src = F.relu(self.layer4(src), True) 99 | 100 | # (batch_size, 256, imgH/2/2/2, imgW/2/2) 101 | src = F.max_pool2d(src, kernel_size=(1, 2), stride=(1, 2)) 102 | 103 | # (batch_size, 512, imgH/2/2/2, imgW/2/2) 104 | # layer 5 105 | # batch norm 2 106 | src = F.relu(self.batch_norm2(self.layer5(src)), True) 107 | 108 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 109 | src = F.max_pool2d(src, kernel_size=(2, 1), stride=(2, 1)) 110 | 111 | # (batch_size, 512, imgH/2/2/2, imgW/2/2/2) 112 | src = F.relu(self.batch_norm3(self.layer6(src)), True) 113 | 114 | # # (batch_size, 512, H, W) 115 | all_outputs = [] 116 | for row in range(src.size(2)): 117 | inp = src[:, :, row, :].transpose(0, 2) \ 118 | .transpose(1, 2) 119 | row_vec = torch.Tensor(batch_size).type_as(inp.data) \ 120 | .long().fill_(row) 121 | pos_emb = self.pos_lut(row_vec) 122 | with_pos = torch.cat( 123 | (pos_emb.view(1, pos_emb.size(0), pos_emb.size(1)), inp), 0) 124 | outputs, hidden_t = self.rnn(with_pos) 125 | all_outputs.append(outputs) 126 | out = torch.cat(all_outputs, 0) 127 | 128 | return hidden_t, out, lengths 129 | 130 | def update_dropout(self, dropout): 131 | self.rnn.dropout = dropout 132 | -------------------------------------------------------------------------------- /onmt/encoders/mean_encoder.py: -------------------------------------------------------------------------------- 1 | """Define a minimal encoder.""" 2 | from onmt.encoders.encoder import EncoderBase 3 | from onmt.utils.misc import sequence_mask 4 | import torch 5 | 6 | 7 | class MeanEncoder(EncoderBase): 8 | """A trivial non-recurrent encoder. Simply applies mean pooling. 9 | 10 | Args: 11 | num_layers (int): number of replicated layers 12 | embeddings (onmt.modules.Embeddings): embedding module to use 13 | """ 14 | 15 | def __init__(self, num_layers, embeddings): 16 | super(MeanEncoder, self).__init__() 17 | self.num_layers = num_layers 18 | self.embeddings = embeddings 19 | 20 | @classmethod 21 | def from_opt(cls, opt, embeddings): 22 | """Alternate constructor.""" 23 | return cls( 24 | opt.enc_layers, 25 | embeddings) 26 | 27 | def forward(self, src, lengths=None): 28 | """See :func:`EncoderBase.forward()`""" 29 | self._check_args(src, lengths) 30 | 31 | emb = self.embeddings(src) 32 | _, batch, emb_dim = emb.size() 33 | 34 | if lengths is not None: 35 | # we avoid padding while mean pooling 36 | mask = sequence_mask(lengths).float() 37 | mask = mask / lengths.unsqueeze(1).float() 38 | mean = torch.bmm(mask.unsqueeze(1), emb.transpose(0, 1)).squeeze(1) 39 | else: 40 | mean = emb.mean(0) 41 | 42 | mean = mean.expand(self.num_layers, batch, emb_dim) 43 | memory_bank = emb 44 | encoder_final = (mean, mean) 45 | return encoder_final, memory_bank, lengths 46 | -------------------------------------------------------------------------------- /onmt/encoders/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | """Define RNN-based encoders.""" 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 7 | 8 | from onmt.encoders.encoder import EncoderBase 9 | from onmt.utils.rnn_factory import rnn_factory 10 | 11 | 12 | class RNNEncoder(EncoderBase): 13 | """ A generic recurrent neural network encoder. 14 | 15 | Args: 16 | rnn_type (str): 17 | style of recurrent unit to use, one of [RNN, LSTM, GRU, SRU] 18 | bidirectional (bool) : use a bidirectional RNN 19 | num_layers (int) : number of stacked layers 20 | hidden_size (int) : hidden size of each layer 21 | dropout (float) : dropout value for :class:`torch.nn.Dropout` 22 | embeddings (onmt.modules.Embeddings): embedding module to use 23 | """ 24 | 25 | def __init__(self, rnn_type, bidirectional, num_layers, 26 | hidden_size, dropout=0.0, embeddings=None, 27 | use_bridge=False): 28 | super(RNNEncoder, self).__init__() 29 | assert embeddings is not None 30 | 31 | num_directions = 2 if bidirectional else 1 32 | assert hidden_size % num_directions == 0 33 | hidden_size = hidden_size // num_directions 34 | self.embeddings = embeddings 35 | 36 | self.rnn, self.no_pack_padded_seq = \ 37 | rnn_factory(rnn_type, 38 | input_size=embeddings.embedding_size, 39 | hidden_size=hidden_size, 40 | num_layers=num_layers, 41 | dropout=dropout, 42 | bidirectional=bidirectional) 43 | 44 | # Initialize the bridge layer 45 | self.use_bridge = use_bridge 46 | if self.use_bridge: 47 | self._initialize_bridge(rnn_type, 48 | hidden_size, 49 | num_layers) 50 | 51 | @classmethod 52 | def from_opt(cls, opt, embeddings): 53 | """Alternate constructor.""" 54 | return cls( 55 | opt.rnn_type, 56 | opt.brnn, 57 | opt.enc_layers, 58 | opt.enc_rnn_size, 59 | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, 60 | embeddings, 61 | opt.bridge) 62 | 63 | def forward(self, src, lengths=None): 64 | """See :func:`EncoderBase.forward()`""" 65 | self._check_args(src, lengths) 66 | 67 | emb = self.embeddings(src) 68 | # s_len, batch, emb_dim = emb.size() 69 | 70 | packed_emb = emb 71 | if lengths is not None and not self.no_pack_padded_seq: 72 | # Lengths data is wrapped inside a Tensor. 73 | lengths_list = lengths.view(-1).tolist() 74 | packed_emb = pack(emb, lengths_list) 75 | 76 | memory_bank, encoder_final = self.rnn(packed_emb) 77 | 78 | if lengths is not None and not self.no_pack_padded_seq: 79 | memory_bank = unpack(memory_bank)[0] 80 | 81 | if self.use_bridge: 82 | encoder_final = self._bridge(encoder_final) 83 | return encoder_final, memory_bank, lengths 84 | 85 | def _initialize_bridge(self, rnn_type, 86 | hidden_size, 87 | num_layers): 88 | 89 | # LSTM has hidden and cell state, other only one 90 | number_of_states = 2 if rnn_type == "LSTM" else 1 91 | # Total number of states 92 | self.total_hidden_dim = hidden_size * num_layers 93 | 94 | # Build a linear layer for each 95 | self.bridge = nn.ModuleList([nn.Linear(self.total_hidden_dim, 96 | self.total_hidden_dim, 97 | bias=True) 98 | for _ in range(number_of_states)]) 99 | 100 | def _bridge(self, hidden): 101 | """Forward hidden state through bridge.""" 102 | def bottle_hidden(linear, states): 103 | """ 104 | Transform from 3D to 2D, apply linear and return initial size 105 | """ 106 | size = states.size() 107 | result = linear(states.view(-1, self.total_hidden_dim)) 108 | return F.relu(result).view(size) 109 | 110 | if isinstance(hidden, tuple): # LSTM 111 | outs = tuple([bottle_hidden(layer, hidden[ix]) 112 | for ix, layer in enumerate(self.bridge)]) 113 | else: 114 | outs = bottle_hidden(self.bridge[0], hidden) 115 | return outs 116 | 117 | def update_dropout(self, dropout): 118 | self.rnn.dropout = dropout 119 | -------------------------------------------------------------------------------- /onmt/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Attention is All You Need" 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | from onmt.encoders.encoder import EncoderBase 8 | from onmt.modules import MultiHeadedAttention 9 | from onmt.modules.position_ffn import PositionwiseFeedForward 10 | from onmt.utils.misc import sequence_mask 11 | 12 | 13 | class TransformerEncoderLayer(nn.Module): 14 | """ 15 | A single layer of the transformer encoder. 16 | 17 | Args: 18 | d_model (int): the dimension of keys/values/queries in 19 | MultiHeadedAttention, also the input size of 20 | the first-layer of the PositionwiseFeedForward. 21 | heads (int): the number of head for MultiHeadedAttention. 22 | d_ff (int): the second-layer of the PositionwiseFeedForward. 23 | dropout (float): dropout probability(0-1.0). 24 | """ 25 | 26 | def __init__(self, d_model, heads, d_ff, dropout, attention_dropout, 27 | max_relative_positions=0): 28 | super(TransformerEncoderLayer, self).__init__() 29 | 30 | self.self_attn = MultiHeadedAttention( 31 | heads, d_model, dropout=attention_dropout, 32 | max_relative_positions=max_relative_positions) 33 | self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout) 34 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 35 | self.dropout = nn.Dropout(dropout) 36 | 37 | def forward(self, inputs, mask): 38 | """ 39 | Args: 40 | inputs (FloatTensor): ``(batch_size, src_len, model_dim)`` 41 | mask (LongTensor): ``(batch_size, 1, src_len)`` 42 | 43 | Returns: 44 | (FloatTensor): 45 | 46 | * outputs ``(batch_size, src_len, model_dim)`` 47 | """ 48 | input_norm = self.layer_norm(inputs) 49 | context, _ = self.self_attn(input_norm, input_norm, input_norm, 50 | mask=mask, attn_type="self") 51 | out = self.dropout(context) + inputs 52 | return self.feed_forward(out) 53 | 54 | def update_dropout(self, dropout, attention_dropout): 55 | self.self_attn.update_dropout(attention_dropout) 56 | self.feed_forward.update_dropout(dropout) 57 | self.dropout.p = dropout 58 | 59 | 60 | class TransformerEncoder(EncoderBase): 61 | """The Transformer encoder from "Attention is All You Need" 62 | :cite:`DBLP:journals/corr/VaswaniSPUJGKP17` 63 | 64 | .. mermaid:: 65 | 66 | graph BT 67 | A[input] 68 | B[multi-head self-attn] 69 | C[feed forward] 70 | O[output] 71 | A --> B 72 | B --> C 73 | C --> O 74 | 75 | Args: 76 | num_layers (int): number of encoder layers 77 | d_model (int): size of the model 78 | heads (int): number of heads 79 | d_ff (int): size of the inner FF layer 80 | dropout (float): dropout parameters 81 | embeddings (onmt.modules.Embeddings): 82 | embeddings to use, should have positional encodings 83 | 84 | Returns: 85 | (torch.FloatTensor, torch.FloatTensor): 86 | 87 | * embeddings ``(src_len, batch_size, model_dim)`` 88 | * memory_bank ``(src_len, batch_size, model_dim)`` 89 | """ 90 | 91 | def __init__(self, num_layers, d_model, heads, d_ff, dropout, 92 | attention_dropout, embeddings, max_relative_positions): 93 | super(TransformerEncoder, self).__init__() 94 | 95 | self.embeddings = embeddings 96 | self.transformer = nn.ModuleList( 97 | [TransformerEncoderLayer( 98 | d_model, heads, d_ff, dropout, attention_dropout, 99 | max_relative_positions=max_relative_positions) 100 | for i in range(num_layers)]) 101 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 102 | 103 | @classmethod 104 | def from_opt(cls, opt, embeddings): 105 | """Alternate constructor.""" 106 | return cls( 107 | opt.enc_layers, 108 | opt.enc_rnn_size, 109 | opt.heads, 110 | opt.transformer_ff, 111 | opt.dropout[0] if type(opt.dropout) is list else opt.dropout, 112 | opt.attention_dropout[0] if type(opt.attention_dropout) 113 | is list else opt.attention_dropout, 114 | embeddings, 115 | opt.max_relative_positions) 116 | 117 | def forward(self, src, lengths=None): 118 | """See :func:`EncoderBase.forward()`""" 119 | self._check_args(src, lengths) 120 | 121 | emb = self.embeddings(src) 122 | 123 | out = emb.transpose(0, 1).contiguous() 124 | mask = ~sequence_mask(lengths).unsqueeze(1) 125 | # Run the forward pass of every layer of the tranformer. 126 | for layer in self.transformer: 127 | out = layer(out, mask) 128 | out = self.layer_norm(out) 129 | 130 | return emb, out.transpose(0, 1).contiguous(), lengths 131 | 132 | def update_dropout(self, dropout, attention_dropout): 133 | self.embeddings.update_dropout(dropout) 134 | for layer in self.transformer: 135 | layer.update_dropout(dropout, attention_dropout) 136 | -------------------------------------------------------------------------------- /onmt/inputters/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining inputters. 2 | 3 | Inputters implement the logic of transforming raw data to vectorized inputs, 4 | e.g., from a line of text to a sequence of embeddings. 5 | """ 6 | from onmt.inputters.inputter import \ 7 | load_old_vocab, get_fields, OrderedIterator, \ 8 | build_vocab, old_style_vocab, filter_example 9 | from onmt.inputters.dataset_base import Dataset 10 | from onmt.inputters.text_dataset import text_sort_key, TextDataReader 11 | from onmt.inputters.image_dataset import img_sort_key, ImageDataReader 12 | from onmt.inputters.audio_dataset import audio_sort_key, AudioDataReader 13 | from onmt.inputters.vec_dataset import vec_sort_key, VecDataReader 14 | from onmt.inputters.datareader_base import DataReaderBase 15 | 16 | 17 | str2reader = { 18 | "text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader, 19 | "vec": VecDataReader} 20 | str2sortkey = { 21 | 'text': text_sort_key, 'img': img_sort_key, 'audio': audio_sort_key, 22 | 'vec': vec_sort_key} 23 | 24 | 25 | __all__ = ['Dataset', 'load_old_vocab', 'get_fields', 'DataReaderBase', 26 | 'filter_example', 'old_style_vocab', 27 | 'build_vocab', 'OrderedIterator', 28 | 'text_sort_key', 'img_sort_key', 'audio_sort_key', 'vec_sort_key', 29 | 'TextDataReader', 'ImageDataReader', 'AudioDataReader', 30 | 'VecDataReader'] 31 | -------------------------------------------------------------------------------- /onmt/inputters/datareader_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | 4 | # several data readers need optional dependencies. There's no 5 | # appropriate builtin exception 6 | class MissingDependencyException(Exception): 7 | pass 8 | 9 | 10 | class DataReaderBase(object): 11 | """Read data from file system and yield as dicts. 12 | 13 | Raises: 14 | onmt.inputters.datareader_base.MissingDependencyException: A number 15 | of DataReaders need specific additional packages. 16 | If any are missing, this will be raised. 17 | """ 18 | 19 | @classmethod 20 | def from_opt(cls, opt): 21 | """Alternative constructor. 22 | 23 | Args: 24 | opt (argparse.Namespace): The parsed arguments. 25 | """ 26 | 27 | return cls() 28 | 29 | @classmethod 30 | def _read_file(cls, path): 31 | """Line-by-line read a file as bytes.""" 32 | with open(path, "rb") as f: 33 | for line in f: 34 | yield line 35 | 36 | @staticmethod 37 | def _raise_missing_dep(*missing_deps): 38 | """Raise missing dep exception with standard error message.""" 39 | raise MissingDependencyException( 40 | "Could not create reader. Be sure to install " 41 | "the following dependencies: " + ", ".join(missing_deps)) 42 | 43 | def read(self, data, side, src_dir): 44 | """Read data from file system and yield as dicts.""" 45 | raise NotImplementedError() 46 | -------------------------------------------------------------------------------- /onmt/inputters/dataset_base.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from itertools import chain, starmap 4 | from collections import Counter 5 | 6 | import torch 7 | from torchtext.data import Dataset as TorchtextDataset 8 | from torchtext.data import Example 9 | from torchtext.vocab import Vocab 10 | 11 | 12 | def _join_dicts(*args): 13 | """ 14 | Args: 15 | dictionaries with disjoint keys. 16 | 17 | Returns: 18 | a single dictionary that has the union of these keys. 19 | """ 20 | 21 | return dict(chain(*[d.items() for d in args])) 22 | 23 | 24 | def _dynamic_dict(example, src_field, tgt_field): 25 | """Create copy-vocab and numericalize with it. 26 | 27 | In-place adds ``"src_map"`` to ``example``. That is the copy-vocab 28 | numericalization of the tokenized ``example["src"]``. If ``example`` 29 | has a ``"tgt"`` key, adds ``"alignment"`` to example. That is the 30 | copy-vocab numericalization of the tokenized ``example["tgt"]``. The 31 | alignment has an initial and final UNK token to match the BOS and EOS 32 | tokens. 33 | 34 | Args: 35 | example (dict): An example dictionary with a ``"src"`` key and 36 | maybe a ``"tgt"`` key. (This argument changes in place!) 37 | src_field (torchtext.data.Field): Field object. 38 | tgt_field (torchtext.data.Field): Field object. 39 | 40 | Returns: 41 | torchtext.data.Vocab and ``example``, changed as described. 42 | """ 43 | 44 | src = src_field.tokenize(example["src"]) 45 | # make a small vocab containing just the tokens in the source sequence 46 | unk = src_field.unk_token 47 | pad = src_field.pad_token 48 | src_ex_vocab = Vocab(Counter(src), specials=[unk, pad]) 49 | unk_idx = src_ex_vocab.stoi[unk] 50 | # Map source tokens to indices in the dynamic dict. 51 | src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src]) 52 | example["src_map"] = src_map 53 | example["src_ex_vocab"] = src_ex_vocab 54 | 55 | if "tgt" in example: 56 | tgt = tgt_field.tokenize(example["tgt"]) 57 | mask = torch.LongTensor( 58 | [unk_idx] + [src_ex_vocab.stoi[w] for w in tgt] + [unk_idx]) 59 | example["alignment"] = mask 60 | return src_ex_vocab, example 61 | 62 | 63 | class Dataset(TorchtextDataset): 64 | """Contain data and process it. 65 | 66 | A dataset is an object that accepts sequences of raw data (sentence pairs 67 | in the case of machine translation) and fields which describe how this 68 | raw data should be processed to produce tensors. When a dataset is 69 | instantiated, it applies the fields' preprocessing pipeline (but not 70 | the bit that numericalizes it or turns it into batch tensors) to the raw 71 | data, producing a list of :class:`torchtext.data.Example` objects. 72 | torchtext's iterators then know how to use these examples to make batches. 73 | 74 | Args: 75 | fields (dict[str, Field]): a dict with the structure 76 | returned by :func:`onmt.inputters.get_fields()`. Usually 77 | that means the dataset side, ``"src"`` or ``"tgt"``. Keys match 78 | the keys of items yielded by the ``readers``, while values 79 | are lists of (name, Field) pairs. An attribute with this 80 | name will be created for each :class:`torchtext.data.Example` 81 | object and its value will be the result of applying the Field 82 | to the data that matches the key. The advantage of having 83 | sequences of fields for each piece of raw input is that it allows 84 | the dataset to store multiple "views" of each input, which allows 85 | for easy implementation of token-level features, mixed word- 86 | and character-level models, and so on. (See also 87 | :class:`onmt.inputters.TextMultiField`.) 88 | readers (Iterable[onmt.inputters.DataReaderBase]): Reader objects 89 | for disk-to-dict. The yielded dicts are then processed 90 | according to ``fields``. 91 | data (Iterable[Tuple[str, Any]]): (name, ``data_arg``) pairs 92 | where ``data_arg`` is passed to the ``read()`` method of the 93 | reader in ``readers`` at that position. (See the reader object for 94 | details on the ``Any`` type.) 95 | dirs (Iterable[str or NoneType]): A list of directories where 96 | data is contained. See the reader object for more details. 97 | sort_key (Callable[[torchtext.data.Example], Any]): A function 98 | for determining the value on which data is sorted (i.e. length). 99 | filter_pred (Callable[[torchtext.data.Example], bool]): A function 100 | that accepts Example objects and returns a boolean value 101 | indicating whether to include that example in the dataset. 102 | 103 | Attributes: 104 | src_vocabs (List[torchtext.data.Vocab]): Used with dynamic dict/copy 105 | attention. There is a very short vocab for each src example. 106 | It contains just the source words, e.g. so that the generator can 107 | predict to copy them. 108 | """ 109 | 110 | def __init__(self, fields, readers, data, dirs, sort_key, 111 | filter_pred=None): 112 | self.sort_key = sort_key 113 | can_copy = 'src_map' in fields and 'alignment' in fields 114 | 115 | read_iters = [r.read(dat[1], dat[0], dir_) for r, dat, dir_ 116 | in zip(readers, data, dirs)] 117 | 118 | # self.src_vocabs is used in collapse_copy_scores and Translator.py 119 | self.src_vocabs = [] 120 | examples = [] 121 | for ex_dict in starmap(_join_dicts, zip(*read_iters)): 122 | if can_copy: 123 | src_field = fields['src'] 124 | tgt_field = fields['tgt'] 125 | # this assumes src_field and tgt_field are both text 126 | src_ex_vocab, ex_dict = _dynamic_dict( 127 | ex_dict, src_field.base_field, tgt_field.base_field) 128 | self.src_vocabs.append(src_ex_vocab) 129 | ex_fields = {k: [(k, v)] for k, v in fields.items() if 130 | k in ex_dict} 131 | ex = Example.fromdict(ex_dict, ex_fields) 132 | examples.append(ex) 133 | 134 | # fields needs to have only keys that examples have as attrs 135 | fields = [] 136 | for _, nf_list in ex_fields.items(): 137 | assert len(nf_list) == 1 138 | fields.append(nf_list[0]) 139 | 140 | super(Dataset, self).__init__(examples, fields, filter_pred) 141 | 142 | def __getattr__(self, attr): 143 | # avoid infinite recursion when fields isn't defined 144 | if 'fields' not in vars(self): 145 | raise AttributeError 146 | if attr in self.fields: 147 | return (getattr(x, attr) for x in self.examples) 148 | else: 149 | raise AttributeError 150 | 151 | def save(self, path, remove_fields=True): 152 | if remove_fields: 153 | self.fields = [] 154 | torch.save(self, path) 155 | 156 | @staticmethod 157 | def config(fields): 158 | readers, data, dirs = [], [], [] 159 | for name, field in fields: 160 | if field["data"] is not None: 161 | readers.append(field["reader"]) 162 | data.append((name, field["data"])) 163 | dirs.append(field["dir"]) 164 | return readers, data, dirs 165 | -------------------------------------------------------------------------------- /onmt/inputters/image_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import torch 6 | from torchtext.data import Field 7 | 8 | from onmt.inputters.datareader_base import DataReaderBase 9 | 10 | # domain specific dependencies 11 | try: 12 | from PIL import Image 13 | from torchvision import transforms 14 | import cv2 15 | except ImportError: 16 | Image, transforms, cv2 = None, None, None 17 | 18 | 19 | class ImageDataReader(DataReaderBase): 20 | """Read image data from disk. 21 | 22 | Args: 23 | truncate (tuple[int] or NoneType): maximum img size. Use 24 | ``(0,0)`` or ``None`` for unlimited. 25 | channel_size (int): Number of channels per image. 26 | 27 | Raises: 28 | onmt.inputters.datareader_base.MissingDependencyException: If 29 | importing any of ``PIL``, ``torchvision``, or ``cv2`` fail. 30 | """ 31 | 32 | def __init__(self, truncate=None, channel_size=3): 33 | self._check_deps() 34 | self.truncate = truncate 35 | self.channel_size = channel_size 36 | 37 | @classmethod 38 | def from_opt(cls, opt): 39 | return cls(channel_size=opt.image_channel_size) 40 | 41 | @classmethod 42 | def _check_deps(cls): 43 | if any([Image is None, transforms is None, cv2 is None]): 44 | cls._raise_missing_dep( 45 | "PIL", "torchvision", "cv2") 46 | 47 | def read(self, images, side, img_dir=None): 48 | """Read data into dicts. 49 | 50 | Args: 51 | images (str or Iterable[str]): Sequence of image paths or 52 | path to file containing audio paths. 53 | In either case, the filenames may be relative to ``src_dir`` 54 | (default behavior) or absolute. 55 | side (str): Prefix used in return dict. Usually 56 | ``"src"`` or ``"tgt"``. 57 | img_dir (str): Location of source image files. See ``images``. 58 | 59 | Yields: 60 | a dictionary containing image data, path and index for each line. 61 | """ 62 | if isinstance(images, str): 63 | images = DataReaderBase._read_file(images) 64 | 65 | for i, filename in enumerate(images): 66 | filename = filename.decode("utf-8").strip() 67 | img_path = os.path.join(img_dir, filename) 68 | if not os.path.exists(img_path): 69 | img_path = filename 70 | 71 | assert os.path.exists(img_path), \ 72 | 'img path %s not found' % filename 73 | 74 | if self.channel_size == 1: 75 | img = transforms.ToTensor()( 76 | Image.fromarray(cv2.imread(img_path, 0))) 77 | else: 78 | img = transforms.ToTensor()(Image.open(img_path)) 79 | if self.truncate and self.truncate != (0, 0): 80 | if not (img.size(1) <= self.truncate[0] 81 | and img.size(2) <= self.truncate[1]): 82 | continue 83 | yield {side: img, side + '_path': filename, 'indices': i} 84 | 85 | 86 | def img_sort_key(ex): 87 | """Sort using the size of the image: (width, height).""" 88 | return ex.src.size(2), ex.src.size(1) 89 | 90 | 91 | def batch_img(data, vocab): 92 | """Pad and batch a sequence of images.""" 93 | c = data[0].size(0) 94 | h = max([t.size(1) for t in data]) 95 | w = max([t.size(2) for t in data]) 96 | imgs = torch.zeros(len(data), c, h, w).fill_(1) 97 | for i, img in enumerate(data): 98 | imgs[i, :, 0:img.size(1), 0:img.size(2)] = img 99 | return imgs 100 | 101 | 102 | def image_fields(**kwargs): 103 | img = Field( 104 | use_vocab=False, dtype=torch.float, 105 | postprocessing=batch_img, sequential=False) 106 | return img 107 | -------------------------------------------------------------------------------- /onmt/inputters/vec_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torchtext.data import Field 5 | 6 | from onmt.inputters.datareader_base import DataReaderBase 7 | 8 | try: 9 | import numpy as np 10 | except ImportError: 11 | np = None 12 | 13 | 14 | class VecDataReader(DataReaderBase): 15 | """Read feature vector data from disk. 16 | Raises: 17 | onmt.inputters.datareader_base.MissingDependencyException: If 18 | importing ``np`` fails. 19 | """ 20 | 21 | def __init__(self): 22 | self._check_deps() 23 | 24 | @classmethod 25 | def _check_deps(cls): 26 | if np is None: 27 | cls._raise_missing_dep("np") 28 | 29 | def read(self, vecs, side, vec_dir=None): 30 | """Read data into dicts. 31 | Args: 32 | vecs (str or Iterable[str]): Sequence of feature vector paths or 33 | path to file containing feature vector paths. 34 | In either case, the filenames may be relative to ``vec_dir`` 35 | (default behavior) or absolute. 36 | side (str): Prefix used in return dict. Usually 37 | ``"src"`` or ``"tgt"``. 38 | vec_dir (str): Location of source vectors. See ``vecs``. 39 | Yields: 40 | A dictionary containing feature vector data. 41 | """ 42 | 43 | if isinstance(vecs, str): 44 | vecs = DataReaderBase._read_file(vecs) 45 | 46 | for i, filename in enumerate(vecs): 47 | filename = filename.decode("utf-8").strip() 48 | vec_path = os.path.join(vec_dir, filename) 49 | if not os.path.exists(vec_path): 50 | vec_path = filename 51 | 52 | assert os.path.exists(vec_path), \ 53 | 'vec path %s not found' % filename 54 | 55 | vec = np.load(vec_path) 56 | yield {side: torch.from_numpy(vec), 57 | side + "_path": filename, "indices": i} 58 | 59 | 60 | def vec_sort_key(ex): 61 | """Sort using the length of the vector sequence.""" 62 | return ex.src.shape[0] 63 | 64 | 65 | class VecSeqField(Field): 66 | """Defines an vector datatype and instructions for converting to Tensor. 67 | See :class:`Fields` for attribute descriptions. 68 | """ 69 | 70 | def __init__(self, preprocessing=None, postprocessing=None, 71 | include_lengths=False, batch_first=False, pad_index=0, 72 | is_target=False): 73 | super(VecSeqField, self).__init__( 74 | sequential=True, use_vocab=False, init_token=None, 75 | eos_token=None, fix_length=False, dtype=torch.float, 76 | preprocessing=preprocessing, postprocessing=postprocessing, 77 | lower=False, tokenize=None, include_lengths=include_lengths, 78 | batch_first=batch_first, pad_token=pad_index, unk_token=None, 79 | pad_first=False, truncate_first=False, stop_words=None, 80 | is_target=is_target 81 | ) 82 | 83 | def pad(self, minibatch): 84 | """Pad a batch of examples to the length of the longest example. 85 | Args: 86 | minibatch (List[torch.FloatTensor]): A list of audio data, 87 | each having shape ``(len, n_feats, feat_dim)`` 88 | where len is variable. 89 | Returns: 90 | torch.FloatTensor or Tuple[torch.FloatTensor, List[int]]: The 91 | padded tensor of shape 92 | ``(batch_size, max_len, n_feats, feat_dim)``. 93 | and a list of the lengths if `self.include_lengths` is `True` 94 | else just returns the padded tensor. 95 | """ 96 | 97 | assert not self.pad_first and not self.truncate_first \ 98 | and not self.fix_length and self.sequential 99 | minibatch = list(minibatch) 100 | lengths = [x.size(0) for x in minibatch] 101 | max_len = max(lengths) 102 | nfeats = minibatch[0].size(1) 103 | feat_dim = minibatch[0].size(2) 104 | feats = torch.full((len(minibatch), max_len, nfeats, feat_dim), 105 | self.pad_token) 106 | for i, (feat, len_) in enumerate(zip(minibatch, lengths)): 107 | feats[i, 0:len_, :, :] = feat 108 | if self.include_lengths: 109 | return (feats, lengths) 110 | return feats 111 | 112 | def numericalize(self, arr, device=None): 113 | """Turn a batch of examples that use this field into a Variable. 114 | If the field has ``include_lengths=True``, a tensor of lengths will be 115 | included in the return value. 116 | Args: 117 | arr (torch.FloatTensor or Tuple(torch.FloatTensor, List[int])): 118 | List of tokenized and padded examples, or tuple of List of 119 | tokenized and padded examples and List of lengths of each 120 | example if self.include_lengths is True. 121 | device (str or torch.device): See `Field.numericalize`. 122 | """ 123 | 124 | assert self.use_vocab is False 125 | if self.include_lengths and not isinstance(arr, tuple): 126 | raise ValueError("Field has include_lengths set to True, but " 127 | "input data is not a tuple of " 128 | "(data batch, batch lengths).") 129 | if isinstance(arr, tuple): 130 | arr, lengths = arr 131 | lengths = torch.tensor(lengths, dtype=torch.int, device=device) 132 | arr = arr.to(device) 133 | 134 | if self.postprocessing is not None: 135 | arr = self.postprocessing(arr, None) 136 | 137 | if self.sequential and not self.batch_first: 138 | arr = arr.permute(1, 0, 2, 3) 139 | if self.sequential: 140 | arr = arr.contiguous() 141 | 142 | if self.include_lengths: 143 | return arr, lengths 144 | return arr 145 | 146 | 147 | def vec_fields(**kwargs): 148 | vec = VecSeqField(pad_index=0, include_lengths=True) 149 | return vec 150 | -------------------------------------------------------------------------------- /onmt/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining models.""" 2 | from onmt.models.model_saver import build_model_saver, ModelSaver 3 | from onmt.models.model import NMTModel 4 | 5 | __all__ = ["build_model_saver", "ModelSaver", "NMTModel"] 6 | -------------------------------------------------------------------------------- /onmt/models/model.py: -------------------------------------------------------------------------------- 1 | """ Onmt NMT Model base class definition """ 2 | import torch.nn as nn 3 | 4 | 5 | class NMTModel(nn.Module): 6 | """ 7 | Core trainable object in OpenNMT. Implements a trainable interface 8 | for a simple, generic encoder + decoder model. 9 | 10 | Args: 11 | encoder (onmt.encoders.EncoderBase): an encoder object 12 | decoder (onmt.decoders.DecoderBase): a decoder object 13 | """ 14 | 15 | def __init__(self, encoder, decoder): 16 | super(NMTModel, self).__init__() 17 | self.encoder = encoder 18 | self.decoder = decoder 19 | 20 | def forward(self, src, tgt, lengths, bptt=False, with_align=False): 21 | """Forward propagate a `src` and `tgt` pair for training. 22 | Possible initialized with a beginning decoder state. 23 | 24 | Args: 25 | src (Tensor): A source sequence passed to encoder. 26 | typically for inputs this will be a padded `LongTensor` 27 | of size ``(len, batch, features)``. However, may be an 28 | image or other generic input depending on encoder. 29 | tgt (LongTensor): A target sequence passed to decoder. 30 | Size ``(tgt_len, batch, features)``. 31 | lengths(LongTensor): The src lengths, pre-padding ``(batch,)``. 32 | bptt (Boolean): A flag indicating if truncated bptt is set. 33 | If reset then init_state 34 | with_align (Boolean): A flag indicating whether output alignment, 35 | Only valid for transformer decoder. 36 | 37 | Returns: 38 | (FloatTensor, dict[str, FloatTensor]): 39 | 40 | * decoder output ``(tgt_len, batch, hidden)`` 41 | * dictionary attention dists of ``(tgt_len, batch, src_len)`` 42 | """ 43 | dec_in = tgt[:-1] # exclude last target from inputs 44 | 45 | enc_state, memory_bank, lengths = self.encoder(src, lengths) 46 | 47 | if bptt is False: 48 | self.decoder.init_state(src, memory_bank, enc_state) 49 | dec_out, attns = self.decoder(dec_in, memory_bank, 50 | memory_lengths=lengths, 51 | with_align=with_align) 52 | return dec_out, attns 53 | 54 | def update_dropout(self, dropout): 55 | self.encoder.update_dropout(dropout) 56 | self.decoder.update_dropout(dropout) 57 | -------------------------------------------------------------------------------- /onmt/models/model_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from collections import deque 5 | from onmt.utils.logging import logger 6 | 7 | from copy import deepcopy 8 | 9 | 10 | def build_model_saver(model_opt, opt, model, fields, optim): 11 | model_saver = ModelSaver(opt.save_model, 12 | model, 13 | model_opt, 14 | fields, 15 | optim, 16 | opt.keep_checkpoint) 17 | return model_saver 18 | 19 | 20 | class ModelSaverBase(object): 21 | """Base class for model saving operations 22 | 23 | Inherited classes must implement private methods: 24 | * `_save` 25 | * `_rm_checkpoint 26 | """ 27 | 28 | def __init__(self, base_path, model, model_opt, fields, optim, 29 | keep_checkpoint=-1): 30 | self.base_path = base_path 31 | self.model = model 32 | self.model_opt = model_opt 33 | self.fields = fields 34 | self.optim = optim 35 | self.last_saved_step = None 36 | self.keep_checkpoint = keep_checkpoint 37 | if keep_checkpoint > 0: 38 | self.checkpoint_queue = deque([], maxlen=keep_checkpoint) 39 | 40 | def save(self, step, moving_average=None): 41 | """Main entry point for model saver 42 | 43 | It wraps the `_save` method with checks and apply `keep_checkpoint` 44 | related logic 45 | """ 46 | 47 | if self.keep_checkpoint == 0 or step == self.last_saved_step: 48 | return 49 | 50 | save_model = self.model 51 | if moving_average: 52 | model_params_data = [] 53 | for avg, param in zip(moving_average, save_model.parameters()): 54 | model_params_data.append(param.data) 55 | param.data = avg.data 56 | 57 | chkpt, chkpt_name = self._save(step, save_model) 58 | self.last_saved_step = step 59 | 60 | if moving_average: 61 | for param_data, param in zip(model_params_data, 62 | save_model.parameters()): 63 | param.data = param_data 64 | 65 | if self.keep_checkpoint > 0: 66 | if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen: 67 | todel = self.checkpoint_queue.popleft() 68 | self._rm_checkpoint(todel) 69 | self.checkpoint_queue.append(chkpt_name) 70 | 71 | def _save(self, step): 72 | """Save a resumable checkpoint. 73 | 74 | Args: 75 | step (int): step number 76 | 77 | Returns: 78 | (object, str): 79 | 80 | * checkpoint: the saved object 81 | * checkpoint_name: name (or path) of the saved checkpoint 82 | """ 83 | 84 | raise NotImplementedError() 85 | 86 | def _rm_checkpoint(self, name): 87 | """Remove a checkpoint 88 | 89 | Args: 90 | name(str): name that indentifies the checkpoint 91 | (it may be a filepath) 92 | """ 93 | 94 | raise NotImplementedError() 95 | 96 | 97 | class ModelSaver(ModelSaverBase): 98 | """Simple model saver to filesystem""" 99 | 100 | def _save(self, step, model): 101 | model_state_dict = model.state_dict() 102 | model_state_dict = {k: v for k, v in model_state_dict.items() 103 | if 'generator' not in k} 104 | generator_state_dict = model.generator.state_dict() 105 | 106 | # NOTE: We need to trim the vocab to remove any unk tokens that 107 | # were not originally here. 108 | 109 | vocab = deepcopy(self.fields) 110 | for side in ["src", "tgt"]: 111 | keys_to_pop = [] 112 | if hasattr(vocab[side], "fields"): 113 | unk_token = vocab[side].fields[0][1].vocab.itos[0] 114 | for key, value in vocab[side].fields[0][1].vocab.stoi.items(): 115 | if value == 0 and key != unk_token: 116 | keys_to_pop.append(key) 117 | for key in keys_to_pop: 118 | vocab[side].fields[0][1].vocab.stoi.pop(key, None) 119 | 120 | checkpoint = { 121 | 'model': model_state_dict, 122 | 'generator': generator_state_dict, 123 | 'vocab': vocab, 124 | 'opt': self.model_opt, 125 | 'optim': self.optim.state_dict(), 126 | } 127 | 128 | logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) 129 | checkpoint_path = '%s_step_%d.pt' % (self.base_path, step) 130 | torch.save(checkpoint, checkpoint_path) 131 | return checkpoint, checkpoint_path 132 | 133 | def _rm_checkpoint(self, name): 134 | os.remove(name) 135 | -------------------------------------------------------------------------------- /onmt/models/stacked_rnn.py: -------------------------------------------------------------------------------- 1 | """ Implementation of ONMT RNN for Input Feeding Decoding """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class StackedLSTM(nn.Module): 7 | """ 8 | Our own implementation of stacked LSTM. 9 | Needed for the decoder, because we do input feeding. 10 | """ 11 | 12 | def __init__(self, num_layers, input_size, rnn_size, dropout): 13 | super(StackedLSTM, self).__init__() 14 | self.dropout = nn.Dropout(dropout) 15 | self.num_layers = num_layers 16 | self.layers = nn.ModuleList() 17 | 18 | for _ in range(num_layers): 19 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 20 | input_size = rnn_size 21 | 22 | def forward(self, input_feed, hidden): 23 | h_0, c_0 = hidden 24 | h_1, c_1 = [], [] 25 | for i, layer in enumerate(self.layers): 26 | h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) 27 | input_feed = h_1_i 28 | if i + 1 != self.num_layers: 29 | input_feed = self.dropout(input_feed) 30 | h_1 += [h_1_i] 31 | c_1 += [c_1_i] 32 | 33 | h_1 = torch.stack(h_1) 34 | c_1 = torch.stack(c_1) 35 | 36 | return input_feed, (h_1, c_1) 37 | 38 | 39 | class StackedGRU(nn.Module): 40 | """ 41 | Our own implementation of stacked GRU. 42 | Needed for the decoder, because we do input feeding. 43 | """ 44 | 45 | def __init__(self, num_layers, input_size, rnn_size, dropout): 46 | super(StackedGRU, self).__init__() 47 | self.dropout = nn.Dropout(dropout) 48 | self.num_layers = num_layers 49 | self.layers = nn.ModuleList() 50 | 51 | for _ in range(num_layers): 52 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 53 | input_size = rnn_size 54 | 55 | def forward(self, input_feed, hidden): 56 | h_1 = [] 57 | for i, layer in enumerate(self.layers): 58 | h_1_i = layer(input_feed, hidden[0][i]) 59 | input_feed = h_1_i 60 | if i + 1 != self.num_layers: 61 | input_feed = self.dropout(input_feed) 62 | h_1 += [h_1_i] 63 | 64 | h_1 = torch.stack(h_1) 65 | return input_feed, (h_1,) 66 | -------------------------------------------------------------------------------- /onmt/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ Attention and normalization modules """ 2 | from onmt.modules.util_class import Elementwise 3 | from onmt.modules.gate import context_gate_factory, ContextGate 4 | from onmt.modules.global_attention import GlobalAttention 5 | from onmt.modules.hierarchical_attention import HierarchicalAttention 6 | from onmt.modules.conv_multi_step_attention import ConvMultiStepAttention 7 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss, \ 8 | CopyGeneratorLossCompute 9 | from onmt.modules.multi_headed_attn import MultiHeadedAttention 10 | #from onmt.modules.self_attention import MultiHeadSelfAttention 11 | from onmt.modules.embeddings import Embeddings, PositionalEncoding, \ 12 | VecEmbedding 13 | from onmt.modules.table_embeddings import TableEmbeddings 14 | from onmt.modules.weight_norm import WeightNormConv2d 15 | from onmt.modules.average_attn import AverageAttention 16 | from onmt.modules.glu import GatedLinear 17 | 18 | 19 | __all__ = ["Elementwise", "context_gate_factory", "ContextGate", 20 | "GlobalAttention", "ConvMultiStepAttention", "CopyGenerator", 21 | "CopyGeneratorLoss", "CopyGeneratorLossCompute", 22 | "MultiHeadedAttention", "Embeddings", "PositionalEncoding", 23 | "WeightNormConv2d", "AverageAttention", "VecEmbedding", 24 | "GatedLinear", "HierarchicalAttention", "TableEmbeddings"] 25 | -------------------------------------------------------------------------------- /onmt/modules/average_attn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Average Attention module.""" 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from onmt.modules.position_ffn import PositionwiseFeedForward 8 | 9 | 10 | class AverageAttention(nn.Module): 11 | """ 12 | Average Attention module from 13 | "Accelerating Neural Transformer via an Average Attention Network" 14 | :cite:`DBLP:journals/corr/abs-1805-00631`. 15 | 16 | Args: 17 | model_dim (int): the dimension of keys/values/queries, 18 | must be divisible by head_count 19 | dropout (float): dropout parameter 20 | """ 21 | 22 | def __init__(self, model_dim, dropout=0.1, aan_useffn=False): 23 | self.model_dim = model_dim 24 | self.aan_useffn = aan_useffn 25 | super(AverageAttention, self).__init__() 26 | if aan_useffn: 27 | self.average_layer = PositionwiseFeedForward(model_dim, model_dim, 28 | dropout) 29 | self.gating_layer = nn.Linear(model_dim * 2, model_dim * 2) 30 | 31 | def cumulative_average_mask(self, batch_size, inputs_len, device): 32 | """ 33 | Builds the mask to compute the cumulative average as described in 34 | :cite:`DBLP:journals/corr/abs-1805-00631` -- Figure 3 35 | 36 | Args: 37 | batch_size (int): batch size 38 | inputs_len (int): length of the inputs 39 | 40 | Returns: 41 | (FloatTensor): 42 | 43 | * A Tensor of shape ``(batch_size, input_len, input_len)`` 44 | """ 45 | 46 | triangle = torch.tril(torch.ones(inputs_len, inputs_len, 47 | dtype=torch.float, device=device)) 48 | weights = torch.ones(1, inputs_len, dtype=torch.float, device=device) \ 49 | / torch.arange(1, inputs_len + 1, dtype=torch.float, device=device) 50 | mask = triangle * weights.transpose(0, 1) 51 | 52 | return mask.unsqueeze(0).expand(batch_size, inputs_len, inputs_len) 53 | 54 | def cumulative_average(self, inputs, mask_or_step, 55 | layer_cache=None, step=None): 56 | """ 57 | Computes the cumulative average as described in 58 | :cite:`DBLP:journals/corr/abs-1805-00631` -- Equations (1) (5) (6) 59 | 60 | Args: 61 | inputs (FloatTensor): sequence to average 62 | ``(batch_size, input_len, dimension)`` 63 | mask_or_step: if cache is set, this is assumed 64 | to be the current step of the 65 | dynamic decoding. Otherwise, it is the mask matrix 66 | used to compute the cumulative average. 67 | layer_cache: a dictionary containing the cumulative average 68 | of the previous step. 69 | 70 | Returns: 71 | a tensor of the same shape and type as ``inputs``. 72 | """ 73 | 74 | if layer_cache is not None: 75 | step = mask_or_step 76 | average_attention = (inputs + step * 77 | layer_cache["prev_g"]) / (step + 1) 78 | layer_cache["prev_g"] = average_attention 79 | return average_attention 80 | else: 81 | mask = mask_or_step 82 | return torch.matmul(mask.to(inputs.dtype), inputs) 83 | 84 | def forward(self, inputs, mask=None, layer_cache=None, step=None): 85 | """ 86 | Args: 87 | inputs (FloatTensor): ``(batch_size, input_len, model_dim)`` 88 | 89 | Returns: 90 | (FloatTensor, FloatTensor): 91 | 92 | * gating_outputs ``(batch_size, input_len, model_dim)`` 93 | * average_outputs average attention 94 | ``(batch_size, input_len, model_dim)`` 95 | """ 96 | 97 | batch_size = inputs.size(0) 98 | inputs_len = inputs.size(1) 99 | average_outputs = self.cumulative_average( 100 | inputs, self.cumulative_average_mask(batch_size, 101 | inputs_len, inputs.device) 102 | if layer_cache is None else step, layer_cache=layer_cache) 103 | if self.aan_useffn: 104 | average_outputs = self.average_layer(average_outputs) 105 | gating_outputs = self.gating_layer(torch.cat((inputs, 106 | average_outputs), -1)) 107 | input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2) 108 | gating_outputs = torch.sigmoid(input_gate) * inputs + \ 109 | torch.sigmoid(forget_gate) * average_outputs 110 | 111 | return gating_outputs, average_outputs 112 | -------------------------------------------------------------------------------- /onmt/modules/conv_multi_step_attention.py: -------------------------------------------------------------------------------- 1 | """ Multi Step Attention for CNN """ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | SCALE_WEIGHT = 0.5 ** 0.5 9 | 10 | 11 | def seq_linear(linear, x): 12 | """ linear transform for 3-d tensor """ 13 | batch, hidden_size, length, _ = x.size() 14 | h = linear(torch.transpose(x, 1, 2).contiguous().view( 15 | batch * length, hidden_size)) 16 | return torch.transpose(h.view(batch, length, hidden_size, 1), 1, 2) 17 | 18 | 19 | class ConvMultiStepAttention(nn.Module): 20 | """ 21 | Conv attention takes a key matrix, a value matrix and a query vector. 22 | Attention weight is calculated by key matrix with the query vector 23 | and sum on the value matrix. And the same operation is applied 24 | in each decode conv layer. 25 | """ 26 | 27 | def __init__(self, input_size): 28 | super(ConvMultiStepAttention, self).__init__() 29 | self.linear_in = nn.Linear(input_size, input_size) 30 | self.mask = None 31 | 32 | def apply_mask(self, mask): 33 | """ Apply mask """ 34 | self.mask = mask 35 | 36 | def forward(self, base_target_emb, input_from_dec, encoder_out_top, 37 | encoder_out_combine): 38 | """ 39 | Args: 40 | base_target_emb: target emb tensor 41 | input_from_dec: output of decode conv 42 | encoder_out_top: the key matrix for calculation of attetion weight, 43 | which is the top output of encode conv 44 | encoder_out_combine: 45 | the value matrix for the attention-weighted sum, 46 | which is the combination of base emb and top output of encode 47 | """ 48 | 49 | # checks 50 | # batch, channel, height, width = base_target_emb.size() 51 | batch, _, height, _ = base_target_emb.size() 52 | # batch_, channel_, height_, width_ = input_from_dec.size() 53 | batch_, _, height_, _ = input_from_dec.size() 54 | aeq(batch, batch_) 55 | aeq(height, height_) 56 | 57 | # enc_batch, enc_channel, enc_height = encoder_out_top.size() 58 | enc_batch, _, enc_height = encoder_out_top.size() 59 | # enc_batch_, enc_channel_, enc_height_ = encoder_out_combine.size() 60 | enc_batch_, _, enc_height_ = encoder_out_combine.size() 61 | 62 | aeq(enc_batch, enc_batch_) 63 | aeq(enc_height, enc_height_) 64 | 65 | preatt = seq_linear(self.linear_in, input_from_dec) 66 | target = (base_target_emb + preatt) * SCALE_WEIGHT 67 | target = torch.squeeze(target, 3) 68 | target = torch.transpose(target, 1, 2) 69 | pre_attn = torch.bmm(target, encoder_out_top) 70 | 71 | if self.mask is not None: 72 | pre_attn.data.masked_fill_(self.mask, -float('inf')) 73 | 74 | attn = F.softmax(pre_attn, dim=2) 75 | 76 | context_output = torch.bmm( 77 | attn, torch.transpose(encoder_out_combine, 1, 2)) 78 | context_output = torch.transpose( 79 | torch.unsqueeze(context_output, 3), 1, 2) 80 | return context_output, attn 81 | -------------------------------------------------------------------------------- /onmt/modules/gate.py: -------------------------------------------------------------------------------- 1 | """ ContextGate module """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def context_gate_factory(gate_type, embeddings_size, decoder_size, 7 | attention_size, output_size): 8 | """Returns the correct ContextGate class""" 9 | 10 | gate_types = {'source': SourceContextGate, 11 | 'target': TargetContextGate, 12 | 'both': BothContextGate} 13 | 14 | assert gate_type in gate_types, "Not valid ContextGate type: {0}".format( 15 | gate_type) 16 | return gate_types[gate_type](embeddings_size, decoder_size, attention_size, 17 | output_size) 18 | 19 | 20 | class ContextGate(nn.Module): 21 | """ 22 | Context gate is a decoder module that takes as input the previous word 23 | embedding, the current decoder state and the attention state, and 24 | produces a gate. 25 | The gate can be used to select the input from the target side context 26 | (decoder state), from the source context (attention state) or both. 27 | """ 28 | 29 | def __init__(self, embeddings_size, decoder_size, 30 | attention_size, output_size): 31 | super(ContextGate, self).__init__() 32 | input_size = embeddings_size + decoder_size + attention_size 33 | self.gate = nn.Linear(input_size, output_size, bias=True) 34 | self.sig = nn.Sigmoid() 35 | self.source_proj = nn.Linear(attention_size, output_size) 36 | self.target_proj = nn.Linear(embeddings_size + decoder_size, 37 | output_size) 38 | 39 | def forward(self, prev_emb, dec_state, attn_state): 40 | input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) 41 | z = self.sig(self.gate(input_tensor)) 42 | proj_source = self.source_proj(attn_state) 43 | proj_target = self.target_proj( 44 | torch.cat((prev_emb, dec_state), dim=1)) 45 | return z, proj_source, proj_target 46 | 47 | 48 | class SourceContextGate(nn.Module): 49 | """Apply the context gate only to the source context""" 50 | 51 | def __init__(self, embeddings_size, decoder_size, 52 | attention_size, output_size): 53 | super(SourceContextGate, self).__init__() 54 | self.context_gate = ContextGate(embeddings_size, decoder_size, 55 | attention_size, output_size) 56 | self.tanh = nn.Tanh() 57 | 58 | def forward(self, prev_emb, dec_state, attn_state): 59 | z, source, target = self.context_gate( 60 | prev_emb, dec_state, attn_state) 61 | return self.tanh(target + z * source) 62 | 63 | 64 | class TargetContextGate(nn.Module): 65 | """Apply the context gate only to the target context""" 66 | 67 | def __init__(self, embeddings_size, decoder_size, 68 | attention_size, output_size): 69 | super(TargetContextGate, self).__init__() 70 | self.context_gate = ContextGate(embeddings_size, decoder_size, 71 | attention_size, output_size) 72 | self.tanh = nn.Tanh() 73 | 74 | def forward(self, prev_emb, dec_state, attn_state): 75 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 76 | return self.tanh(z * target + source) 77 | 78 | 79 | class BothContextGate(nn.Module): 80 | """Apply the context gate to both contexts""" 81 | 82 | def __init__(self, embeddings_size, decoder_size, 83 | attention_size, output_size): 84 | super(BothContextGate, self).__init__() 85 | self.context_gate = ContextGate(embeddings_size, decoder_size, 86 | attention_size, output_size) 87 | self.tanh = nn.Tanh() 88 | 89 | def forward(self, prev_emb, dec_state, attn_state): 90 | z, source, target = self.context_gate(prev_emb, dec_state, attn_state) 91 | return self.tanh((1. - z) * target + z * source) 92 | -------------------------------------------------------------------------------- /onmt/modules/glu.py: -------------------------------------------------------------------------------- 1 | """Comes directly from fairseq""" 2 | import torch, math 3 | 4 | 5 | class Downsample(torch.nn.Module): 6 | """ 7 | Selects every nth element along the last dim, where n is the index 8 | """ 9 | def __init__(self, in_dim, step): 10 | super().__init__() 11 | self._step = step 12 | self._in_dim = in_dim 13 | 14 | if in_dim % step != 0: 15 | raise ValueError('in_dim should be a multiple of step. ' 16 | f'Got {in_dim} and {step}.') 17 | self.index = torch.LongTensor(range(0, in_dim, step)) 18 | 19 | def forward(self, input): 20 | return input.index_select(dim=-1, index=self.index.to(input.device)) 21 | 22 | def extra_repr(self): 23 | return f'{self._in_dim}, {self._in_dim//self._step}' 24 | 25 | 26 | def Linear(in_features, out_features, dropout=0., bias=True): 27 | """Weight-normalized Linear layer (input: B x T x C)""" 28 | m = torch.nn.Linear(in_features, out_features, bias=bias) 29 | m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features)) 30 | m.bias.data.zero_() 31 | return torch.nn.utils.weight_norm(m) 32 | 33 | 34 | class GatedLinear(torch.nn.Module): 35 | def __init__(self, in_features, out_features, depth=2, 36 | downsample=0, dropout=0., bias=True): 37 | """ 38 | Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units. 39 | GLU units split the input in half to use one as values and one as gates: 40 | glu([a; b]) = a * sigmoid(b) 41 | """ 42 | super().__init__() 43 | 44 | self._num_layers = depth 45 | self._bias = bias 46 | self._dropout = dropout 47 | self._downsample = isinstance(downsample, int) and downsample > 0 48 | self.glu = torch.nn.GLU(dim=-1) 49 | 50 | # In order to halve the dims at each step and end on out_features 51 | # we need to start with out_feature * 2^depth and decrease the power 52 | # of 2 at each depth. 53 | if self._downsample: 54 | self.linear_in = torch.nn.Sequential( 55 | Downsample(in_features, downsample), 56 | Linear(in_features//downsample, out_features * pow(2, depth), dropout, bias) 57 | ) 58 | else: 59 | if in_features != out_features * pow(2, depth): 60 | raise ValueError('When not using downsampling, in_features should be ' 61 | 'equal to out_feature * 2^depth. ' 62 | f'Got {in_features} != {out_features} * 2^{depth}') 63 | 64 | self.linear_layers = torch.nn.ModuleList([ 65 | Linear(out_features * pow(2, depth - k), 66 | out_features * pow(2, depth - k), 67 | dropout, bias) 68 | for k in range(1, depth+1) 69 | ]) 70 | 71 | def forward(self, input): 72 | output = self.linear_in(input) if self._downsample else input 73 | for linear in self.linear_layers: 74 | output = linear(self.glu(output)) 75 | return output -------------------------------------------------------------------------------- /onmt/modules/position_ffn.py: -------------------------------------------------------------------------------- 1 | """Position feed-forward network from "Attention is All You Need".""" 2 | 3 | import torch.nn as nn 4 | 5 | 6 | class PositionwiseFeedForward(nn.Module): 7 | """ A two-layer Feed-Forward-Network with residual layer norm. 8 | 9 | Args: 10 | d_model (int): the size of input for the first-layer of the FFN. 11 | d_ff (int): the hidden layer size of the second-layer 12 | of the FNN. 13 | dropout (float): dropout probability in :math:`[0, 1)`. 14 | """ 15 | 16 | def __init__(self, d_model, d_ff, dropout=0.1): 17 | super(PositionwiseFeedForward, self).__init__() 18 | self.w_1 = nn.Linear(d_model, d_ff) 19 | self.w_2 = nn.Linear(d_ff, d_model) 20 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 21 | self.dropout_1 = nn.Dropout(dropout) 22 | self.relu = nn.ReLU() 23 | self.dropout_2 = nn.Dropout(dropout) 24 | 25 | def forward(self, x): 26 | """Layer definition. 27 | 28 | Args: 29 | x: ``(batch_size, input_len, model_dim)`` 30 | 31 | Returns: 32 | (FloatTensor): Output ``(batch_size, input_len, model_dim)``. 33 | """ 34 | 35 | inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x)))) 36 | output = self.dropout_2(self.w_2(inter)) 37 | return output + x 38 | 39 | def update_dropout(self, dropout): 40 | self.dropout_1.p = dropout 41 | self.dropout_2.p = dropout 42 | -------------------------------------------------------------------------------- /onmt/modules/self_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom reimplementation of torch.nn.MultiHeadAttention 3 | 4 | It's actually the same module, with more or less flewibility at times, 5 | and a more flexible use of the mask (different mask per element of the batch) 6 | """ 7 | from torch._jit_internal import weak_module, weak_script_method 8 | from torch.nn.init import constant_ 9 | from torch.nn.parameter import Parameter 10 | from torch.nn.init import xavier_uniform_ 11 | from torch.nn import functional as F 12 | from onmt.utils.misc import tile 13 | from onmt.modules import GatedLinear 14 | import torch 15 | 16 | 17 | @weak_module 18 | class MultiHeadSelfAttention(torch.nn.Module): 19 | """ 20 | if glu_depth is not zero, we use GatedLinear layers instead of regular layers. 21 | """ 22 | def __init__(self, embed_dim, num_heads, dropout=0., glu_depth=0, bias=True): 23 | super().__init__() 24 | self.embed_dim = embed_dim 25 | self.num_heads = num_heads 26 | self.dropout = dropout 27 | self.head_dim = embed_dim // num_heads 28 | msg = "embed_dim must be divisible by num_heads, got {} and {}" 29 | assert self.head_dim * num_heads == self.embed_dim, msg.format(embed_dim, num_heads) 30 | self.scaling = self.head_dim ** -0.5 31 | 32 | self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim)) 33 | if bias: 34 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim)) 35 | else: 36 | self.register_parameter('in_proj_bias', None) 37 | self.out_proj = torch.nn.Linear(embed_dim, embed_dim, bias=bias) 38 | 39 | # Gated Linear Unit 40 | self._use_glu = isinstance(glu_depth, int) and glu_depth > 0 41 | if self._use_glu: 42 | if not self.head_dim % pow(2, glu_depth) == 0: 43 | raise ValueError('When using GLU you need to use a head_dim that is ' 44 | 'a multiple of two to the power glu_depth. ' 45 | f'Got {self.head_dim} % 2^{glu_depth} != 0') 46 | glu_out_dim = self.head_dim // pow(2, glu_depth) 47 | self.key_glu = GatedLinear(self.head_dim, glu_out_dim, glu_depth) 48 | self.query_glu = GatedLinear(self.head_dim, glu_out_dim, glu_depth) 49 | 50 | self._reset_parameters() 51 | 52 | def _reset_parameters(self): 53 | xavier_uniform_(self.in_proj_weight[:self.embed_dim, :]) 54 | xavier_uniform_(self.in_proj_weight[self.embed_dim:(self.embed_dim * 2), :]) 55 | xavier_uniform_(self.in_proj_weight[(self.embed_dim * 2):, :]) 56 | 57 | xavier_uniform_(self.out_proj.weight) 58 | if self.in_proj_bias is not None: 59 | constant_(self.in_proj_bias, 0.) 60 | constant_(self.out_proj.bias, 0.) 61 | 62 | @weak_script_method 63 | def forward(self, input, attn_mask=None): 64 | """ 65 | Inputs of forward function 66 | input: [target length, batch size, embed dim] 67 | attn_mask [(batch size), sequence_length, sequence_length] 68 | 69 | Outputs of forward function 70 | attn_output: [target length, batch size, embed dim] 71 | attn_output_weights: [batch size, target length, sequence length] 72 | """ 73 | 74 | seq_len, bsz, embed_dim = input.size() 75 | assert embed_dim == self.embed_dim 76 | 77 | # self-attention 78 | q, k, v = F.linear(input, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) 79 | q *= self.scaling 80 | 81 | # Cut q, k, v in num_heads part 82 | q = q.contiguous().view(seq_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) 83 | k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 84 | v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) 85 | 86 | # Gated Linear Unit 87 | if self._use_glu: 88 | q = self.query_glu(q) 89 | k = self.key_glu(k) 90 | 91 | # batch matrix multply query against key 92 | # attn_output_weights is [bsz * num_heads, seq_len, seq_len] 93 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 94 | 95 | assert list(attn_output_weights.size()) == [bsz * self.num_heads, seq_len, seq_len] 96 | 97 | if attn_mask is not None: 98 | if attn_mask.dim() == 2: 99 | # We use the same mask for each item in the batch 100 | attn_mask = attn_mask.unsqueeze(0) 101 | elif attn_mask.dim() == 3: 102 | # Each item in the batch has its own mask. 103 | # We need to inflate the mask to go with all heads 104 | attn_mask = tile(attn_mask, count=self.num_heads, dim=0) 105 | else: 106 | # Don't known what we would be doing here... 107 | raise RuntimeError(f'Wrong mask dim: {attn_mask.dim()}') 108 | 109 | # The mask should be either 0 of -inf to go with softmax 110 | attn_output_weights += attn_mask 111 | 112 | attn_output_weights = F.softmax( 113 | attn_output_weights.float(), dim=-1, 114 | dtype=torch.float32 if attn_output_weights.dtype == torch.float16 else attn_output_weights.dtype) 115 | attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training) 116 | 117 | attn_output = torch.bmm(attn_output_weights, v) 118 | assert list(attn_output.size()) == [bsz * self.num_heads, seq_len, self.head_dim] 119 | attn_output = attn_output.transpose(0, 1).contiguous().view(seq_len, bsz, embed_dim) 120 | attn_output = self.out_proj(attn_output) 121 | 122 | # average attention weights over heads 123 | attn_output_weights = attn_output_weights.view(bsz, self.num_heads, seq_len, seq_len) 124 | attn_output_weights = attn_output_weights.sum(dim=1) / self.num_heads 125 | 126 | return attn_output, attn_output_weights -------------------------------------------------------------------------------- /onmt/modules/sparse_activations.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of sparsemax (Martins & Astudillo, 2016). See 3 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 4 | 5 | By Ben Peters and Vlad Niculae 6 | """ 7 | 8 | import torch 9 | from torch.autograd import Function 10 | import torch.nn as nn 11 | 12 | 13 | def _make_ix_like(input, dim=0): 14 | d = input.size(dim) 15 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 16 | view = [1] * input.dim() 17 | view[0] = -1 18 | return rho.view(view).transpose(0, dim) 19 | 20 | 21 | def _threshold_and_support(input, dim=0): 22 | """Sparsemax building block: compute the threshold 23 | 24 | Args: 25 | input: any dimension 26 | dim: dimension along which to apply the sparsemax 27 | 28 | Returns: 29 | the threshold value 30 | """ 31 | 32 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 33 | input_cumsum = input_srt.cumsum(dim) - 1 34 | rhos = _make_ix_like(input, dim) 35 | support = rhos * input_srt > input_cumsum 36 | 37 | support_size = support.sum(dim=dim).unsqueeze(dim) 38 | tau = input_cumsum.gather(dim, support_size - 1) 39 | tau /= support_size.to(input.dtype) 40 | return tau, support_size 41 | 42 | 43 | class SparsemaxFunction(Function): 44 | 45 | @staticmethod 46 | def forward(ctx, input, dim=0): 47 | """sparsemax: normalizing sparse transform (a la softmax) 48 | 49 | Parameters: 50 | input (Tensor): any shape 51 | dim: dimension along which to apply sparsemax 52 | 53 | Returns: 54 | output (Tensor): same shape as input 55 | """ 56 | ctx.dim = dim 57 | max_val, _ = input.max(dim=dim, keepdim=True) 58 | input -= max_val # same numerical stability trick as for softmax 59 | tau, supp_size = _threshold_and_support(input, dim=dim) 60 | output = torch.clamp(input - tau, min=0) 61 | ctx.save_for_backward(supp_size, output) 62 | return output 63 | 64 | @staticmethod 65 | def backward(ctx, grad_output): 66 | supp_size, output = ctx.saved_tensors 67 | dim = ctx.dim 68 | grad_input = grad_output.clone() 69 | grad_input[output == 0] = 0 70 | 71 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 72 | v_hat = v_hat.unsqueeze(dim) 73 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 74 | return grad_input, None 75 | 76 | 77 | sparsemax = SparsemaxFunction.apply 78 | 79 | 80 | class Sparsemax(nn.Module): 81 | 82 | def __init__(self, dim=0): 83 | self.dim = dim 84 | super(Sparsemax, self).__init__() 85 | 86 | def forward(self, input): 87 | return sparsemax(input, self.dim) 88 | 89 | 90 | class LogSparsemax(nn.Module): 91 | 92 | def __init__(self, dim=0): 93 | self.dim = dim 94 | super(LogSparsemax, self).__init__() 95 | 96 | def forward(self, input): 97 | return torch.log(sparsemax(input, self.dim)) 98 | -------------------------------------------------------------------------------- /onmt/modules/sparse_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | from onmt.modules.sparse_activations import _threshold_and_support 5 | from onmt.utils.misc import aeq 6 | 7 | 8 | class SparsemaxLossFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, input, target): 12 | """ 13 | input (FloatTensor): ``(n, num_classes)``. 14 | target (LongTensor): ``(n,)``, the indices of the target classes 15 | """ 16 | input_batch, classes = input.size() 17 | target_batch = target.size(0) 18 | aeq(input_batch, target_batch) 19 | 20 | z_k = input.gather(1, target.unsqueeze(1)).squeeze() 21 | tau_z, support_size = _threshold_and_support(input, dim=1) 22 | support = input > tau_z 23 | x = torch.where( 24 | support, input**2 - tau_z**2, 25 | torch.tensor(0.0, device=input.device) 26 | ).sum(dim=1) 27 | ctx.save_for_backward(input, target, tau_z) 28 | # clamping necessary because of numerical errors: loss should be lower 29 | # bounded by zero, but negative values near zero are possible without 30 | # the clamp 31 | return torch.clamp(x / 2 - z_k + 0.5, min=0.0) 32 | 33 | @staticmethod 34 | def backward(ctx, grad_output): 35 | input, target, tau_z = ctx.saved_tensors 36 | sparsemax_out = torch.clamp(input - tau_z, min=0) 37 | delta = torch.zeros_like(sparsemax_out) 38 | delta.scatter_(1, target.unsqueeze(1), 1) 39 | return sparsemax_out - delta, None 40 | 41 | 42 | sparsemax_loss = SparsemaxLossFunction.apply 43 | 44 | 45 | class SparsemaxLoss(nn.Module): 46 | """ 47 | An implementation of sparsemax loss, first proposed in 48 | :cite:`DBLP:journals/corr/MartinsA16`. If using 49 | a sparse output layer, it is not possible to use negative log likelihood 50 | because the loss is infinite in the case the target is assigned zero 51 | probability. Inputs to SparsemaxLoss are arbitrary dense real-valued 52 | vectors (like in nn.CrossEntropyLoss), not probability vectors (like in 53 | nn.NLLLoss). 54 | """ 55 | 56 | def __init__(self, weight=None, ignore_index=-100, 57 | reduction='elementwise_mean'): 58 | assert reduction in ['elementwise_mean', 'sum', 'none'] 59 | self.reduction = reduction 60 | self.weight = weight 61 | self.ignore_index = ignore_index 62 | super(SparsemaxLoss, self).__init__() 63 | 64 | def forward(self, input, target): 65 | loss = sparsemax_loss(input, target) 66 | if self.ignore_index >= 0: 67 | ignored_positions = target == self.ignore_index 68 | size = float((target.size(0) - ignored_positions.sum()).item()) 69 | loss.masked_fill_(ignored_positions, 0.0) 70 | else: 71 | size = float(target.size(0)) 72 | if self.reduction == 'sum': 73 | loss = loss.sum() 74 | elif self.reduction == 'elementwise_mean': 75 | loss = loss.sum() / size 76 | return loss 77 | -------------------------------------------------------------------------------- /onmt/modules/structured_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.cuda 4 | 5 | 6 | class MatrixTree(nn.Module): 7 | """Implementation of the matrix-tree theorem for computing marginals 8 | of non-projective dependency parsing. This attention layer is used 9 | in the paper "Learning Structured Text Representations" 10 | :cite:`DBLP:journals/corr/LiuL17d`. 11 | """ 12 | 13 | def __init__(self, eps=1e-5): 14 | self.eps = eps 15 | super(MatrixTree, self).__init__() 16 | 17 | def forward(self, input): 18 | laplacian = input.exp() + self.eps 19 | output = input.clone() 20 | for b in range(input.size(0)): 21 | lap = laplacian[b].masked_fill( 22 | torch.eye(input.size(1), device=input.device).ne(0), 0) 23 | lap = -lap + torch.diag(lap.sum(0)) 24 | # store roots on diagonal 25 | lap[0] = input[b].diag().exp() 26 | inv_laplacian = lap.inverse() 27 | 28 | factor = inv_laplacian.diag().unsqueeze(1)\ 29 | .expand_as(input[b]).transpose(0, 1) 30 | term1 = input[b].exp().mul(factor).clone() 31 | term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() 32 | term1[:, 0] = 0 33 | term2[0] = 0 34 | output[b] = term1 - term2 35 | roots_output = input[b].diag().exp().mul( 36 | inv_laplacian.transpose(0, 1)[0]) 37 | output[b] = output[b] + torch.diag(roots_output) 38 | return output 39 | -------------------------------------------------------------------------------- /onmt/modules/table_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class TableEmbeddings(torch.nn.Module): 5 | """ 6 | Now that I think about it, we can do more efficiently than rewritting the 7 | onmt module. I will in the future but for now this code works as is, 8 | so I won't chance breaking it! 9 | 10 | These embeddings follow the table structure: a table is an unordered set 11 | of tuple (pos, value) where pos can be viewed as column name. As 12 | such, TableEmbeddings' forward returns embeddings for pos and value. 13 | Furthermore, the value embedding can be merged with the pos embedding. 14 | 15 | Most argument names are not very fitting but stay the same 16 | as onmt.modules.Embeddings 17 | """ 18 | 19 | def __init__(self, 20 | word_vec_size, # dim of the value embeddings 21 | word_vocab_size, # size of the value vocabulary 22 | word_padding_idx, # idx of 23 | feat_vec_size, # dim of the pos embeddings 24 | feat_vec_exponent, # instead of feat_vec_size 25 | feat_vocab_size, # size of the pos vocabulary 26 | feat_padding_idx, # idx of 27 | merge="concat", # decide to merge the pos and value 28 | merge_activation='ReLU', # used if merge is mlp 29 | dropout=0, 30 | ent_idx=None): 31 | 32 | super().__init__() 33 | 34 | assert ent_idx is not None 35 | self.ent_idx = ent_idx 36 | 37 | self.word_padding_idx = word_padding_idx 38 | self.word_vec_size = word_vec_size 39 | 40 | if feat_vec_size < 0: 41 | if not 0 < feat_vec_exponent <= 1: 42 | raise ValueError('feat_vec_exponent should be between 0 and 1') 43 | feat_vec_size = int(feat_vocab_size ** feat_vec_exponent) 44 | 45 | self.value_embeddings = torch.nn.Embedding(word_vocab_size, 46 | word_vec_size, padding_idx=word_padding_idx) 47 | self.pos_embeddings = torch.nn.Embedding(feat_vocab_size, 48 | feat_vec_size, padding_idx=feat_padding_idx) 49 | 50 | self._merge = merge 51 | if merge is None: 52 | self.embedding_size = self.word_vec_size 53 | elif merge == 'concat': 54 | self.embedding_size = self.word_vec_size + self.feat_vec_size 55 | elif merge == 'sum': 56 | assert self.word_vec_size == self.feat_vec_size 57 | self.embedding_size = self.word_vec_size 58 | elif merge == 'mlp': 59 | self.embedding_size = self.word_vec_size 60 | val_dim = self.value_embeddings.embedding_dim 61 | pos_dim = self.pos_embeddings.embedding_dim 62 | in_dim = val_dim + pos_dim 63 | self.merge = torch.nn.Linear(in_dim, val_dim) 64 | 65 | if merge_activation is None: 66 | self.activation = None 67 | elif merge_activation == 'ReLU': 68 | self.activation = torch.nn.ReLU() 69 | elif merge_activation == 'Tanh': 70 | self.activation = torch.nn.Tanh() 71 | else: 72 | raise ValueError(f'Unknown activation {merge_activation}') 73 | else: 74 | raise ValueError('merge should be one of [None|concat|sum|mlp]') 75 | 76 | 77 | @property 78 | def word_lut(self): 79 | """Word look-up table.""" 80 | return self.value_embeddings 81 | 82 | def load_pretrained_vectors(self, emb_file): 83 | """ 84 | place holder for onmt compatibility 85 | """ 86 | if emb_file: 87 | raise NotImplementedError 88 | 89 | def forward(self, inputs): 90 | # unpack the inputs as cell values and pos (column name) 91 | values, pos = [item.squeeze(2) for item in inputs.split(1, dim=2)] 92 | 93 | # embed them separatly and maybe merge them 94 | values = self.value_embeddings(values) 95 | pos = self.pos_embeddings(pos) 96 | 97 | if self._merge is None: 98 | return values, pos 99 | if self._merge == 'sum': 100 | values = values + pos 101 | return values, pos 102 | 103 | values = torch.cat((values, pos), 2) 104 | if self._merge == 'concat': 105 | return values, pos 106 | if self._merge == 'mlp': 107 | values = self.merge(values) 108 | if self.activation: 109 | values = self.activation(values) 110 | return values, pos 111 | -------------------------------------------------------------------------------- /onmt/modules/util_class.py: -------------------------------------------------------------------------------- 1 | """ Misc classes """ 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | # At the moment this class is only used by embeddings.Embeddings look-up tables 7 | class Elementwise(nn.ModuleList): 8 | """ 9 | A simple network container. 10 | Parameters are a list of modules. 11 | Inputs are a 3d Tensor whose last dimension is the same length 12 | as the list. 13 | Outputs are the result of applying modules to inputs elementwise. 14 | An optional merge parameter allows the outputs to be reduced to a 15 | single Tensor. 16 | """ 17 | 18 | def __init__(self, merge=None, *args): 19 | assert merge in [None, 'first', 'concat', 'sum', 'mlp'] 20 | self.merge = merge 21 | super(Elementwise, self).__init__(*args) 22 | 23 | def forward(self, inputs): 24 | inputs_ = [feat.squeeze(2) for feat in inputs.split(1, dim=2)] 25 | assert len(self) == len(inputs_) 26 | outputs = [f(x) for f, x in zip(self, inputs_)] 27 | if self.merge == 'first': 28 | return outputs[0] 29 | elif self.merge == 'concat' or self.merge == 'mlp': 30 | return torch.cat(outputs, 2) 31 | elif self.merge == 'sum': 32 | return sum(outputs) 33 | else: 34 | return outputs 35 | 36 | 37 | class Cast(nn.Module): 38 | """ 39 | Basic layer that casts its input to a specific data type. The same tensor 40 | is returned if the data type is already correct. 41 | """ 42 | 43 | def __init__(self, dtype): 44 | super(Cast, self).__init__() 45 | self._dtype = dtype 46 | 47 | def forward(self, x): 48 | return x.to(self._dtype) 49 | -------------------------------------------------------------------------------- /onmt/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaijuML/data-to-text-hierarchical/da88d2d4491266fccc39ac1cc1fbb56bd7bbc30c/onmt/tests/__init__.py -------------------------------------------------------------------------------- /onmt/tests/rebuild_test_models.sh: -------------------------------------------------------------------------------- 1 | # # Retrain the models used for CI. 2 | # # Should be done rarely, indicates a major breaking change. 3 | my_python=python 4 | 5 | ############### TEST regular RNN choose either -rnn_type LSTM / GRU / SRU and set input_feed 0 for SRU 6 | if true; then 7 | rm data/*.pt 8 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 9 | 10 | $my_python train.py -data data/data -save_model tmp -world_size 1 -gpu_ranks 0 -rnn_size 256 -word_vec_size 256 -layers 1 -train_steps 10000 -optim adam -learning_rate 0.001 -rnn_type LSTM -input_feed 0 11 | #-truncated_decoder 5 12 | #-label_smoothing 0.1 13 | 14 | mv tmp*e10.pt onmt/tests/test_model.pt 15 | rm tmp*.pt 16 | fi 17 | # 18 | # 19 | ############### TEST CNN 20 | if false; then 21 | rm data/*.pt 22 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 23 | 24 | $my_python train.py -data data/data -save_model /tmp/tmp -world_size 1 -gpu_ranks 0 -rnn_size 256 -word_vec_size 256 -layers 2 -train_steps 10000 -optim adam -learning_rate 0.001 -encoder_type cnn -decoder_type cnn 25 | 26 | 27 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 28 | 29 | rm /tmp/tmp*.pt 30 | fi 31 | # 32 | ################# MORPH DATA 33 | if true; then 34 | rm data/morph/*.pt 35 | $my_python preprocess.py -train_src data/morph/src.train -train_tgt data/morph/tgt.train -valid_src data/morph/src.valid -valid_tgt data/morph/tgt.valid -save_data data/morph/data 36 | 37 | $my_python train.py -data data/morph/data -save_model tmp -world_size 1 -gpu_ranks 0 -rnn_size 400 -word_vec_size 100 -layers 1 -train_steps 8000 -optim adam -learning_rate 0.001 38 | 39 | 40 | mv tmp*e8.pt onmt/tests/test_model2.pt 41 | 42 | rm tmp*.pt 43 | fi 44 | ############### TEST TRANSFORMER 45 | if false; then 46 | rm data/*.pt 47 | $my_python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data data/data -src_vocab_size 1000 -tgt_vocab_size 1000 -share_vocab 48 | 49 | 50 | $my_python train.py -data data/data -save_model /tmp/tmp -batch_type tokens -batch_size 1024 -accum_count 4 \ 51 | -layers 4 -rnn_size 256 -word_vec_size 256 -encoder_type transformer -decoder_type transformer -share_embedding \ 52 | -train_steps 10000 -world_size 1 -gpu_ranks 0 -max_generator_batches 4 -dropout 0.1 -normalization tokens \ 53 | -max_grad_norm 0 -optim adam -decay_method noam -learning_rate 2 -label_smoothing 0.1 \ 54 | -position_encoding -param_init 0 -warmup_steps 100 -param_init_glorot -adam_beta2 0.998 55 | # 56 | mv /tmp/tmp*e10.pt onmt/tests/test_model.pt 57 | rm /tmp/tmp*.pt 58 | fi 59 | # 60 | if false; then 61 | $my_python translate.py -gpu 0 -model onmt/tests/test_model.pt \ 62 | -src data/src-val.txt -output onmt/tests/output_hyp.txt -beam 5 -batch_size 16 63 | 64 | fi 65 | 66 | 67 | -------------------------------------------------------------------------------- /onmt/tests/test_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Here come the tests for attention types and their compatibility 3 | """ 4 | import unittest 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | import onmt 9 | 10 | 11 | class TestAttention(unittest.TestCase): 12 | 13 | def test_masked_global_attention(self): 14 | 15 | source_lengths = torch.IntTensor([7, 3, 5, 2]) 16 | # illegal_weights_mask = torch.ByteTensor([ 17 | # [0, 0, 0, 0, 0, 0, 0], 18 | # [0, 0, 0, 1, 1, 1, 1], 19 | # [0, 0, 0, 0, 0, 1, 1], 20 | # [0, 0, 1, 1, 1, 1, 1]]) 21 | 22 | batch_size = source_lengths.size(0) 23 | dim = 20 24 | 25 | memory_bank = Variable(torch.randn(batch_size, 26 | source_lengths.max(), dim)) 27 | hidden = Variable(torch.randn(batch_size, dim)) 28 | 29 | attn = onmt.modules.GlobalAttention(dim) 30 | 31 | _, alignments = attn(hidden, memory_bank, 32 | memory_lengths=source_lengths) 33 | # TODO: fix for pytorch 0.3 34 | # illegal_weights = alignments.masked_select(illegal_weights_mask) 35 | 36 | # self.assertEqual(0.0, illegal_weights.data.sum()) 37 | -------------------------------------------------------------------------------- /onmt/tests/test_copy_generator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.modules.copy_generator import CopyGenerator, CopyGeneratorLoss 3 | 4 | import itertools 5 | from copy import deepcopy 6 | 7 | import torch 8 | from torch.nn.functional import softmax 9 | 10 | from onmt.tests.utils_for_tests import product_dict 11 | 12 | 13 | class TestCopyGenerator(unittest.TestCase): 14 | INIT_CASES = list(product_dict( 15 | input_size=[172], 16 | output_size=[319], 17 | pad_idx=[0, 39], 18 | )) 19 | PARAMS = list(product_dict( 20 | batch_size=[1, 14], 21 | max_seq_len=[23], 22 | tgt_max_len=[50], 23 | n_extra_words=[107] 24 | )) 25 | 26 | @classmethod 27 | def dummy_inputs(cls, params, init_case): 28 | hidden = torch.randn((params["batch_size"] * params["tgt_max_len"], 29 | init_case["input_size"])) 30 | attn = torch.randn((params["batch_size"] * params["tgt_max_len"], 31 | params["max_seq_len"])) 32 | src_map = torch.randn((params["max_seq_len"], params["batch_size"], 33 | params["n_extra_words"])) 34 | return hidden, attn, src_map 35 | 36 | @classmethod 37 | def expected_shape(cls, params, init_case): 38 | return params["tgt_max_len"] * params["batch_size"], \ 39 | init_case["output_size"] + params["n_extra_words"] 40 | 41 | def test_copy_gen_forward_shape(self): 42 | for params, init_case in itertools.product( 43 | self.PARAMS, self.INIT_CASES): 44 | cgen = CopyGenerator(**init_case) 45 | dummy_in = self.dummy_inputs(params, init_case) 46 | res = cgen(*dummy_in) 47 | expected_shape = self.expected_shape(params, init_case) 48 | self.assertEqual(res.shape, expected_shape, init_case.__str__()) 49 | 50 | def test_copy_gen_outp_has_no_prob_of_pad(self): 51 | for params, init_case in itertools.product( 52 | self.PARAMS, self.INIT_CASES): 53 | cgen = CopyGenerator(**init_case) 54 | dummy_in = self.dummy_inputs(params, init_case) 55 | res = cgen(*dummy_in) 56 | self.assertTrue( 57 | res[:, init_case["pad_idx"]].allclose(torch.tensor(0.0))) 58 | 59 | def test_copy_gen_trainable_params_update(self): 60 | for params, init_case in itertools.product( 61 | self.PARAMS, self.INIT_CASES): 62 | cgen = CopyGenerator(**init_case) 63 | trainable_params = {n: p for n, p in cgen.named_parameters() 64 | if p.requires_grad} 65 | assert len(trainable_params) > 0 # sanity check 66 | old_weights = deepcopy(trainable_params) 67 | dummy_in = self.dummy_inputs(params, init_case) 68 | res = cgen(*dummy_in) 69 | pretend_loss = res.sum() 70 | pretend_loss.backward() 71 | dummy_optim = torch.optim.SGD(trainable_params.values(), 1) 72 | dummy_optim.step() 73 | for param_name in old_weights.keys(): 74 | self.assertTrue( 75 | trainable_params[param_name] 76 | .ne(old_weights[param_name]).any(), 77 | param_name + " " + init_case.__str__()) 78 | 79 | 80 | class TestCopyGeneratorLoss(unittest.TestCase): 81 | INIT_CASES = list(product_dict( 82 | vocab_size=[172], 83 | unk_index=[0, 39], 84 | ignore_index=[1, 17], # pad idx 85 | force_copy=[True, False] 86 | )) 87 | PARAMS = list(product_dict( 88 | batch_size=[1, 14], 89 | tgt_max_len=[50], 90 | n_extra_words=[107] 91 | )) 92 | 93 | @classmethod 94 | def dummy_inputs(cls, params, init_case): 95 | n_unique_src_words = 13 96 | scores = torch.randn((params["batch_size"] * params["tgt_max_len"], 97 | init_case["vocab_size"] + n_unique_src_words)) 98 | scores = softmax(scores, dim=1) 99 | align = torch.randint(0, n_unique_src_words, 100 | (params["batch_size"] * params["tgt_max_len"],)) 101 | target = torch.randint(0, init_case["vocab_size"], 102 | (params["batch_size"] * params["tgt_max_len"],)) 103 | target[0] = init_case["unk_index"] 104 | target[1] = init_case["ignore_index"] 105 | return scores, align, target 106 | 107 | @classmethod 108 | def expected_shape(cls, params, init_case): 109 | return (params["batch_size"] * params["tgt_max_len"],) 110 | 111 | def test_copy_loss_forward_shape(self): 112 | for params, init_case in itertools.product( 113 | self.PARAMS, self.INIT_CASES): 114 | loss = CopyGeneratorLoss(**init_case) 115 | dummy_in = self.dummy_inputs(params, init_case) 116 | res = loss(*dummy_in) 117 | expected_shape = self.expected_shape(params, init_case) 118 | self.assertEqual(res.shape, expected_shape, init_case.__str__()) 119 | 120 | def test_copy_loss_ignore_index_is_ignored(self): 121 | for params, init_case in itertools.product( 122 | self.PARAMS, self.INIT_CASES): 123 | loss = CopyGeneratorLoss(**init_case) 124 | scores, align, target = self.dummy_inputs(params, init_case) 125 | res = loss(scores, align, target) 126 | should_be_ignored = (target == init_case["ignore_index"]).nonzero() 127 | assert len(should_be_ignored) > 0 # otherwise not testing anything 128 | self.assertTrue(res[should_be_ignored].allclose(torch.tensor(0.0))) 129 | 130 | def test_copy_loss_output_range_is_positive(self): 131 | for params, init_case in itertools.product( 132 | self.PARAMS, self.INIT_CASES): 133 | loss = CopyGeneratorLoss(**init_case) 134 | dummy_in = self.dummy_inputs(params, init_case) 135 | res = loss(*dummy_in) 136 | self.assertTrue((res >= 0).all()) 137 | -------------------------------------------------------------------------------- /onmt/tests/test_embeddings.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.modules.embeddings import Embeddings 3 | 4 | import itertools 5 | from copy import deepcopy 6 | 7 | import torch 8 | 9 | from onmt.tests.utils_for_tests import product_dict 10 | 11 | 12 | class TestEmbeddings(unittest.TestCase): 13 | INIT_CASES = list(product_dict( 14 | word_vec_size=[172], 15 | word_vocab_size=[319], 16 | word_padding_idx=[17], 17 | position_encoding=[False, True], 18 | feat_merge=["first", "concat", "sum", "mlp"], 19 | feat_vec_exponent=[-1, 1.1, 0.7], 20 | feat_vec_size=[0, 200], 21 | feat_padding_idx=[[], [29], [0, 1]], 22 | feat_vocab_sizes=[[], [39], [401, 39]], 23 | dropout=[0, 0.5], 24 | fix_word_vecs=[False, True] 25 | )) 26 | PARAMS = list(product_dict( 27 | batch_size=[1, 14], 28 | max_seq_len=[23] 29 | )) 30 | 31 | @classmethod 32 | def case_is_degenerate(cls, case): 33 | no_feats = len(case["feat_vocab_sizes"]) == 0 34 | if case["feat_merge"] != "first" and no_feats: 35 | return True 36 | if case["feat_merge"] == "first" and not no_feats: 37 | return True 38 | if case["feat_merge"] == "concat" and case["feat_vec_exponent"] != -1: 39 | return True 40 | if no_feats and case["feat_vec_exponent"] != -1: 41 | return True 42 | if len(case["feat_vocab_sizes"]) != len(case["feat_padding_idx"]): 43 | return True 44 | if case["feat_vec_size"] == 0 and case["feat_vec_exponent"] <= 0: 45 | return True 46 | if case["feat_merge"] == "sum": 47 | if case["feat_vec_exponent"] != -1: 48 | return True 49 | if case["feat_vec_size"] != 0: 50 | return True 51 | if case["feat_vec_size"] != 0 and case["feat_vec_exponent"] != -1: 52 | return True 53 | return False 54 | 55 | @classmethod 56 | def cases(cls): 57 | for case in cls.INIT_CASES: 58 | if not cls.case_is_degenerate(case): 59 | yield case 60 | 61 | @classmethod 62 | def dummy_inputs(cls, params, init_case): 63 | max_seq_len = params["max_seq_len"] 64 | batch_size = params["batch_size"] 65 | fv_sizes = init_case["feat_vocab_sizes"] 66 | n_words = init_case["word_vocab_size"] 67 | voc_sizes = [n_words] + fv_sizes 68 | pad_idxs = [init_case["word_padding_idx"]] + \ 69 | init_case["feat_padding_idx"] 70 | lengths = torch.randint(0, max_seq_len, (batch_size,)) 71 | lengths[0] = max_seq_len 72 | inps = torch.empty((max_seq_len, batch_size, len(voc_sizes)), 73 | dtype=torch.long) 74 | for f, (voc_size, pad_idx) in enumerate(zip(voc_sizes, pad_idxs)): 75 | for b, len_ in enumerate(lengths): 76 | inps[:len_, b, f] = torch.randint(0, voc_size-1, (len_,)) 77 | inps[len_:, b, f] = pad_idx 78 | return inps 79 | 80 | @classmethod 81 | def expected_shape(cls, params, init_case): 82 | wvs = init_case["word_vec_size"] 83 | fvs = init_case["feat_vec_size"] 84 | nf = len(init_case["feat_vocab_sizes"]) 85 | size = wvs 86 | if init_case["feat_merge"] not in {"sum", "mlp"}: 87 | size += nf * fvs 88 | return params["max_seq_len"], params["batch_size"], size 89 | 90 | def test_embeddings_forward_shape(self): 91 | for params, init_case in itertools.product(self.PARAMS, self.cases()): 92 | emb = Embeddings(**init_case) 93 | dummy_in = self.dummy_inputs(params, init_case) 94 | res = emb(dummy_in) 95 | expected_shape = self.expected_shape(params, init_case) 96 | self.assertEqual(res.shape, expected_shape, init_case.__str__()) 97 | 98 | def test_embeddings_trainable_params(self): 99 | for params, init_case in itertools.product(self.PARAMS, 100 | self.cases()): 101 | emb = Embeddings(**init_case) 102 | trainable_params = {n: p for n, p in emb.named_parameters() 103 | if p.requires_grad} 104 | # first check there's nothing unexpectedly not trainable 105 | for key in emb.state_dict(): 106 | if key not in trainable_params: 107 | if key.endswith("emb_luts.0.weight") and \ 108 | init_case["fix_word_vecs"]: 109 | # ok: word embeddings shouldn't be trainable 110 | # if word vecs are fixed 111 | continue 112 | if key.endswith(".pe.pe"): 113 | # ok: positional encodings shouldn't be trainable 114 | assert init_case["position_encoding"] 115 | continue 116 | else: 117 | self.fail("Param {:s} is unexpectedly not " 118 | "trainable.".format(key)) 119 | # then check nothing unexpectedly trainable 120 | if init_case["fix_word_vecs"]: 121 | self.assertFalse( 122 | any(trainable_param.endswith("emb_luts.0.weight") 123 | for trainable_param in trainable_params), 124 | "Word embedding is trainable but word vecs are fixed.") 125 | if init_case["position_encoding"]: 126 | self.assertFalse( 127 | any(trainable_p.endswith(".pe.pe") 128 | for trainable_p in trainable_params), 129 | "Positional encoding is trainable.") 130 | 131 | def test_embeddings_trainable_params_update(self): 132 | for params, init_case in itertools.product(self.PARAMS, self.cases()): 133 | emb = Embeddings(**init_case) 134 | trainable_params = {n: p for n, p in emb.named_parameters() 135 | if p.requires_grad} 136 | if len(trainable_params) > 0: 137 | old_weights = deepcopy(trainable_params) 138 | dummy_in = self.dummy_inputs(params, init_case) 139 | res = emb(dummy_in) 140 | pretend_loss = res.sum() 141 | pretend_loss.backward() 142 | dummy_optim = torch.optim.SGD(trainable_params.values(), 1) 143 | dummy_optim.step() 144 | for param_name in old_weights.keys(): 145 | self.assertTrue( 146 | trainable_params[param_name] 147 | .ne(old_weights[param_name]).any(), 148 | param_name + " " + init_case.__str__()) 149 | -------------------------------------------------------------------------------- /onmt/tests/test_image_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.inputters.image_dataset import ImageDataReader 3 | 4 | import os 5 | import shutil 6 | 7 | import cv2 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class TestImageDataReader(unittest.TestCase): 13 | # this test touches the file system, so it could be considered an 14 | # integration test 15 | _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) 16 | _IMG_DATA_DIRNAME = "test_image_data" 17 | _IMG_DATA_DIR = os.path.join(_THIS_DIR, _IMG_DATA_DIRNAME) 18 | _IMG_DATA_FMT = "test_img_{:d}.png" 19 | _IMG_DATA_PATH_FMT = os.path.join(_IMG_DATA_DIR, _IMG_DATA_FMT) 20 | 21 | _IMG_LIST_DIR = "test_image_filenames" 22 | # file to hold full paths to image data 23 | _IMG_LIST_PATHS_FNAME = "test_files.txt" 24 | _IMG_LIST_PATHS_PATH = os.path.join( 25 | _IMG_LIST_DIR, _IMG_LIST_PATHS_FNAME) 26 | # file to hold image paths relative to _IMG_DATA_DIR (i.e. file names) 27 | _IMG_LIST_FNAMES_FNAME = "test_fnames.txt" 28 | _IMG_LIST_FNAMES_PATH = os.path.join( 29 | _IMG_LIST_DIR, _IMG_LIST_FNAMES_FNAME) 30 | 31 | # it's ok if non-image files co-exist with image files in the data dir 32 | _JUNK_FILE = os.path.join( 33 | _IMG_DATA_DIR, "this_is_junk.txt") 34 | 35 | _N_EXAMPLES = 20 36 | _N_CHANNELS = 3 37 | 38 | @classmethod 39 | def setUpClass(cls): 40 | if not os.path.exists(cls._IMG_DATA_DIR): 41 | os.makedirs(cls._IMG_DATA_DIR) 42 | if not os.path.exists(cls._IMG_LIST_DIR): 43 | os.makedirs(cls._IMG_LIST_DIR) 44 | 45 | with open(cls._JUNK_FILE, "w") as f: 46 | f.write("this is some garbage\nShould have no impact.") 47 | 48 | with open(cls._IMG_LIST_PATHS_PATH, "w") as f_list_fnames, \ 49 | open(cls._IMG_LIST_FNAMES_PATH, "w") as f_list_paths: 50 | cls.n_rows = torch.randint(30, 314, (cls._N_EXAMPLES,)) 51 | cls.n_cols = torch.randint(30, 314, (cls._N_EXAMPLES,)) 52 | for i in range(cls._N_EXAMPLES): 53 | img = np.random.randint( 54 | 0, 255, (cls.n_rows[i], cls.n_cols[i], cls._N_CHANNELS)) 55 | f_path = cls._IMG_DATA_PATH_FMT.format(i) 56 | cv2.imwrite(f_path, img) 57 | f_name_short = cls._IMG_DATA_FMT.format(i) 58 | f_list_fnames.write(f_name_short + "\n") 59 | f_list_paths.write(f_path + "\n") 60 | 61 | @classmethod 62 | def tearDownClass(cls): 63 | shutil.rmtree(cls._IMG_DATA_DIR) 64 | shutil.rmtree(cls._IMG_LIST_DIR) 65 | 66 | def test_read_from_dir_and_data_file_containing_filenames(self): 67 | rdr = ImageDataReader(channel_size=self._N_CHANNELS) 68 | i = 0 # initialize since there's a sanity check on i 69 | for i, img in enumerate(rdr.read( 70 | self._IMG_LIST_FNAMES_PATH, "src", self._IMG_DATA_DIR)): 71 | self.assertEqual( 72 | img["src"].shape, 73 | (self._N_CHANNELS, self.n_rows[i], self.n_cols[i])) 74 | self.assertEqual(img["src_path"], 75 | self._IMG_DATA_PATH_FMT.format(i)) 76 | self.assertGreater(i, 0, "No image data was read.") 77 | 78 | def test_read_from_dir_and_data_file_containing_paths(self): 79 | rdr = ImageDataReader(channel_size=self._N_CHANNELS) 80 | i = 0 # initialize since there's a sanity check on i 81 | for i, img in enumerate(rdr.read( 82 | self._IMG_LIST_PATHS_PATH, "src", self._IMG_DATA_DIR)): 83 | self.assertEqual( 84 | img["src"].shape, 85 | (self._N_CHANNELS, self.n_rows[i], self.n_cols[i])) 86 | self.assertEqual(img["src_path"], 87 | self._IMG_DATA_FMT.format(i)) 88 | self.assertGreater(i, 0, "No image data was read.") 89 | 90 | 91 | class TestImageDataReader1Channel(TestImageDataReader): 92 | _N_CHANNELS = 1 93 | -------------------------------------------------------------------------------- /onmt/tests/test_model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaijuML/data-to-text-hierarchical/da88d2d4491266fccc39ac1cc1fbb56bd7bbc30c/onmt/tests/test_model.pt -------------------------------------------------------------------------------- /onmt/tests/test_model2.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaijuML/data-to-text-hierarchical/da88d2d4491266fccc39ac1cc1fbb56bd7bbc30c/onmt/tests/test_model2.pt -------------------------------------------------------------------------------- /onmt/tests/test_preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from __future__ import print_function 4 | 5 | import configargparse 6 | import copy 7 | import unittest 8 | import glob 9 | import os 10 | import codecs 11 | 12 | import onmt 13 | import onmt.inputters 14 | import onmt.opts 15 | import onmt.bin.preprocess as preprocess 16 | 17 | 18 | parser = configargparse.ArgumentParser(description='preprocess.py') 19 | onmt.opts.preprocess_opts(parser) 20 | 21 | SAVE_DATA_PREFIX = 'data/test_preprocess' 22 | 23 | default_opts = [ 24 | '-data_type', 'text', 25 | '-train_src', 'data/src-train.txt', 26 | '-train_tgt', 'data/tgt-train.txt', 27 | '-valid_src', 'data/src-val.txt', 28 | '-valid_tgt', 'data/tgt-val.txt', 29 | '-save_data', SAVE_DATA_PREFIX 30 | ] 31 | 32 | opt = parser.parse_known_args(default_opts)[0] 33 | 34 | 35 | class TestData(unittest.TestCase): 36 | def __init__(self, *args, **kwargs): 37 | super(TestData, self).__init__(*args, **kwargs) 38 | self.opt = opt 39 | 40 | def dataset_build(self, opt): 41 | fields = onmt.inputters.get_fields("text", 0, 0) 42 | 43 | if hasattr(opt, 'src_vocab') and len(opt.src_vocab) > 0: 44 | with codecs.open(opt.src_vocab, 'w', 'utf-8') as f: 45 | f.write('a\nb\nc\nd\ne\nf\n') 46 | if hasattr(opt, 'tgt_vocab') and len(opt.tgt_vocab) > 0: 47 | with codecs.open(opt.tgt_vocab, 'w', 'utf-8') as f: 48 | f.write('a\nb\nc\nd\ne\nf\n') 49 | 50 | src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt) 51 | tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt) 52 | align_reader = onmt.inputters.str2reader["text"].from_opt(opt) 53 | preprocess.build_save_dataset( 54 | 'train', fields, src_reader, tgt_reader, align_reader, opt) 55 | 56 | preprocess.build_save_dataset( 57 | 'valid', fields, src_reader, tgt_reader, align_reader, opt) 58 | 59 | # Remove the generated *pt files. 60 | for pt in glob.glob(SAVE_DATA_PREFIX + '*.pt'): 61 | os.remove(pt) 62 | if hasattr(opt, 'src_vocab') and os.path.exists(opt.src_vocab): 63 | os.remove(opt.src_vocab) 64 | if hasattr(opt, 'tgt_vocab') and os.path.exists(opt.tgt_vocab): 65 | os.remove(opt.tgt_vocab) 66 | 67 | 68 | def _add_test(param_setting, methodname): 69 | """ 70 | Adds a Test to TestData according to settings 71 | 72 | Args: 73 | param_setting: list of tuples of (param, setting) 74 | methodname: name of the method that gets called 75 | """ 76 | 77 | def test_method(self): 78 | if param_setting: 79 | opt = copy.deepcopy(self.opt) 80 | for param, setting in param_setting: 81 | setattr(opt, param, setting) 82 | else: 83 | opt = self.opt 84 | getattr(self, methodname)(opt) 85 | if param_setting: 86 | name = 'test_' + methodname + "_" + "_".join( 87 | str(param_setting).split()) 88 | else: 89 | name = 'test_' + methodname + '_standard' 90 | setattr(TestData, name, test_method) 91 | test_method.__name__ = name 92 | 93 | 94 | test_databuild = [[], 95 | [('src_vocab_size', 1), 96 | ('tgt_vocab_size', 1)], 97 | [('src_vocab_size', 10000), 98 | ('tgt_vocab_size', 10000)], 99 | [('src_seq_len', 1)], 100 | [('src_seq_len', 5000)], 101 | [('src_seq_length_trunc', 1)], 102 | [('src_seq_length_trunc', 5000)], 103 | [('tgt_seq_len', 1)], 104 | [('tgt_seq_len', 5000)], 105 | [('tgt_seq_length_trunc', 1)], 106 | [('tgt_seq_length_trunc', 5000)], 107 | [('shuffle', 0)], 108 | [('lower', True)], 109 | [('dynamic_dict', True)], 110 | [('share_vocab', True)], 111 | [('dynamic_dict', True), 112 | ('share_vocab', True)], 113 | [('dynamic_dict', True), 114 | ('shard_size', 500000)], 115 | [('src_vocab', '/tmp/src_vocab.txt'), 116 | ('tgt_vocab', '/tmp/tgt_vocab.txt')], 117 | ] 118 | 119 | for p in test_databuild: 120 | _add_test(p, 'dataset_build') 121 | 122 | # Test image preprocessing 123 | test_databuild = [[], 124 | [('tgt_vocab_size', 1)], 125 | [('tgt_vocab_size', 10000)], 126 | [('tgt_seq_len', 1)], 127 | [('tgt_seq_len', 5000)], 128 | [('tgt_seq_length_trunc', 1)], 129 | [('tgt_seq_length_trunc', 5000)], 130 | [('shuffle', 0)], 131 | [('lower', True)], 132 | [('shard_size', 5)], 133 | [('shard_size', 50)], 134 | [('tgt_vocab', '/tmp/tgt_vocab.txt')], 135 | ] 136 | test_databuild_common = [('data_type', 'img'), 137 | ('src_dir', '/tmp/im2text/images'), 138 | ('train_src', ['/tmp/im2text/src-train-head.txt']), 139 | ('train_tgt', ['/tmp/im2text/tgt-train-head.txt']), 140 | ('valid_src', '/tmp/im2text/src-val-head.txt'), 141 | ('valid_tgt', '/tmp/im2text/tgt-val-head.txt'), 142 | ] 143 | for p in test_databuild: 144 | _add_test(p + test_databuild_common, 'dataset_build') 145 | 146 | # Test audio preprocessing 147 | test_databuild = [[], 148 | [('tgt_vocab_size', 1)], 149 | [('tgt_vocab_size', 10000)], 150 | [('src_seq_len', 1)], 151 | [('src_seq_len', 5000)], 152 | [('src_seq_length_trunc', 3200)], 153 | [('src_seq_length_trunc', 5000)], 154 | [('tgt_seq_len', 1)], 155 | [('tgt_seq_len', 5000)], 156 | [('tgt_seq_length_trunc', 1)], 157 | [('tgt_seq_length_trunc', 5000)], 158 | [('shuffle', 0)], 159 | [('lower', True)], 160 | [('shard_size', 5)], 161 | [('shard_size', 50)], 162 | [('tgt_vocab', '/tmp/tgt_vocab.txt')], 163 | ] 164 | test_databuild_common = [('data_type', 'audio'), 165 | ('src_dir', '/tmp/speech/an4_dataset'), 166 | ('train_src', ['/tmp/speech/src-train-head.txt']), 167 | ('train_tgt', ['/tmp/speech/tgt-train-head.txt']), 168 | ('valid_src', '/tmp/speech/src-val-head.txt'), 169 | ('valid_tgt', '/tmp/speech/tgt-val-head.txt'), 170 | ('sample_rate', 16000), 171 | ('window_size', 0.04), 172 | ('window_stride', 0.02), 173 | ('window', 'hamming'), 174 | ] 175 | for p in test_databuild: 176 | _add_test(p + test_databuild_common, 'dataset_build') 177 | -------------------------------------------------------------------------------- /onmt/tests/test_simple.py: -------------------------------------------------------------------------------- 1 | import onmt 2 | 3 | 4 | def test_load(): 5 | onmt 6 | pass 7 | -------------------------------------------------------------------------------- /onmt/tests/test_structured_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from onmt.modules.structured_attention import MatrixTree 3 | 4 | import torch 5 | 6 | 7 | class TestStructuredAttention(unittest.TestCase): 8 | def test_matrix_tree_marg_pdfs_sum_to_1(self): 9 | dtree = MatrixTree() 10 | q = torch.rand(1, 5, 5) 11 | marg = dtree.forward(q) 12 | self.assertTrue( 13 | marg.sum(1).allclose(torch.tensor(1.0))) 14 | -------------------------------------------------------------------------------- /onmt/tests/utils_for_tests.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | 4 | def product_dict(**kwargs): 5 | keys = kwargs.keys() 6 | vals = kwargs.values() 7 | for instance in itertools.product(*vals): 8 | yield dict(zip(keys, instance)) 9 | -------------------------------------------------------------------------------- /onmt/train_single.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Training on a single process.""" 3 | import os 4 | 5 | import torch 6 | 7 | from onmt.inputters.inputter import build_dataset_iter, \ 8 | load_old_vocab, old_style_vocab, build_dataset_iter_multiple 9 | from onmt.model_builder import build_model 10 | from onmt.utils.optimizers import Optimizer 11 | from onmt.utils.misc import set_random_seed 12 | from onmt.trainer import build_trainer 13 | from onmt.models import build_model_saver 14 | from onmt.utils.logging import init_logger, logger 15 | from onmt.utils.parse import ArgumentParser 16 | 17 | 18 | def _check_save_model_path(opt): 19 | save_model_path = os.path.abspath(opt.save_model) 20 | model_dirname = os.path.dirname(save_model_path) 21 | if not os.path.exists(model_dirname): 22 | os.makedirs(model_dirname) 23 | 24 | 25 | def _tally_parameters(model): 26 | enc = 0 27 | dec = 0 28 | for name, param in model.named_parameters(): 29 | if 'encoder' in name: 30 | enc += param.nelement() 31 | else: 32 | dec += param.nelement() 33 | return enc + dec, enc, dec 34 | 35 | 36 | def configure_process(opt, device_id): 37 | if device_id >= 0: 38 | torch.cuda.set_device(device_id) 39 | set_random_seed(opt.seed, device_id >= 0) 40 | 41 | 42 | def main(opt, device_id, batch_queue=None, semaphore=None): 43 | # NOTE: It's important that ``opt`` has been validated and updated 44 | # at this point. 45 | configure_process(opt, device_id) 46 | init_logger(opt.log_file) 47 | assert len(opt.accum_count) == len(opt.accum_steps), \ 48 | 'Number of accum_count values must match number of accum_steps' 49 | # Load checkpoint if we resume from a previous training. 50 | if opt.train_from: 51 | logger.info('Loading checkpoint from %s' % opt.train_from) 52 | checkpoint = torch.load(opt.train_from, 53 | map_location=lambda storage, loc: storage) 54 | model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) 55 | ArgumentParser.update_model_opts(model_opt) 56 | ArgumentParser.validate_model_opts(model_opt) 57 | logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) 58 | vocab = checkpoint['vocab'] 59 | else: 60 | checkpoint = None 61 | model_opt = opt 62 | vocab = torch.load(opt.data + '.vocab.pt') 63 | 64 | # check for code where vocab is saved instead of fields 65 | # (in the future this will be done in a smarter way) 66 | if old_style_vocab(vocab): 67 | fields = load_old_vocab( 68 | vocab, opt.model_type, dynamic_dict=opt.copy_attn) 69 | else: 70 | fields = vocab 71 | 72 | # Report src and tgt vocab sizes, including for features 73 | for side in ['src', 'tgt']: 74 | f = fields[side] 75 | try: 76 | f_iter = iter(f) 77 | except TypeError: 78 | f_iter = [(side, f)] 79 | for sn, sf in f_iter: 80 | if sf.use_vocab: 81 | logger.info(' * %s vocab size = %d' % (sn, len(sf.vocab))) 82 | 83 | # Build model. 84 | model = build_model(model_opt, opt, fields, checkpoint) 85 | n_params, enc, dec = _tally_parameters(model) 86 | logger.info('encoder: %d' % enc) 87 | logger.info('decoder: %d' % dec) 88 | logger.info('* number of parameters: %d' % n_params) 89 | _check_save_model_path(opt) 90 | 91 | # Build optimizer. 92 | optim = Optimizer.from_opt(model, opt, checkpoint=checkpoint) 93 | 94 | # Build model saver 95 | model_saver = build_model_saver(model_opt, opt, model, fields, optim) 96 | 97 | trainer = build_trainer( 98 | opt, device_id, model, fields, optim, model_saver=model_saver) 99 | 100 | if batch_queue is None: 101 | if len(opt.data_ids) > 1: 102 | train_shards = [] 103 | for train_id in opt.data_ids: 104 | shard_base = "train_" + train_id 105 | train_shards.append(shard_base) 106 | train_iter = build_dataset_iter_multiple(train_shards, fields, opt) 107 | else: 108 | if opt.data_ids[0] is not None: 109 | shard_base = "train_" + opt.data_ids[0] 110 | else: 111 | shard_base = "train" 112 | train_iter = build_dataset_iter(shard_base, fields, opt) 113 | 114 | else: 115 | assert semaphore is not None, \ 116 | "Using batch_queue requires semaphore as well" 117 | 118 | def _train_iter(): 119 | while True: 120 | batch = batch_queue.get() 121 | semaphore.release() 122 | yield batch 123 | 124 | train_iter = _train_iter() 125 | 126 | valid_iter = build_dataset_iter( 127 | "valid", fields, opt, is_train=False) 128 | 129 | if len(opt.gpu_ranks): 130 | logger.info('Starting training on GPU: %s' % opt.gpu_ranks) 131 | else: 132 | logger.info('Starting training on CPU, could be very slow') 133 | train_steps = opt.train_steps 134 | if opt.single_pass and train_steps > 0: 135 | logger.warning("Option single_pass is enabled, ignoring train_steps.") 136 | train_steps = 0 137 | 138 | trainer.train( 139 | train_iter, 140 | train_steps, 141 | save_checkpoint_steps=opt.save_checkpoint_steps, 142 | valid_iter=valid_iter, 143 | valid_steps=opt.valid_steps) 144 | 145 | if trainer.report_manager.tensorboard_writer is not None: 146 | trainer.report_manager.tensorboard_writer.close() 147 | -------------------------------------------------------------------------------- /onmt/translate/__init__.py: -------------------------------------------------------------------------------- 1 | """ Modules for translation """ 2 | from onmt.translate.translator import Translator 3 | from onmt.translate.translation import Translation, TranslationBuilder 4 | from onmt.translate.beam_search import BeamSearch, GNMTGlobalScorer 5 | from onmt.translate.decode_strategy import DecodeStrategy 6 | from onmt.translate.greedy_search import GreedySearch 7 | from onmt.translate.penalties import PenaltyBuilder 8 | from onmt.translate.translation_server import TranslationServer, \ 9 | ServerModelError 10 | 11 | __all__ = ['Translator', 'Translation', 'BeamSearch', 12 | 'GNMTGlobalScorer', 'TranslationBuilder', 13 | 'PenaltyBuilder', 'TranslationServer', 'ServerModelError', 14 | "DecodeStrategy", "GreedySearch"] 15 | -------------------------------------------------------------------------------- /onmt/translate/penalties.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | 4 | 5 | class PenaltyBuilder(object): 6 | """Returns the Length and Coverage Penalty function for Beam Search. 7 | 8 | Args: 9 | length_pen (str): option name of length pen 10 | cov_pen (str): option name of cov pen 11 | 12 | Attributes: 13 | has_cov_pen (bool): Whether coverage penalty is None (applying it 14 | is a no-op). Note that the converse isn't true. Setting beta 15 | to 0 should force coverage length to be a no-op. 16 | has_len_pen (bool): Whether length penalty is None (applying it 17 | is a no-op). Note that the converse isn't true. Setting alpha 18 | to 1 should force length penalty to be a no-op. 19 | coverage_penalty (callable[[FloatTensor, float], FloatTensor]): 20 | Calculates the coverage penalty. 21 | length_penalty (callable[[int, float], float]): Calculates 22 | the length penalty. 23 | """ 24 | 25 | def __init__(self, cov_pen, length_pen): 26 | self.has_cov_pen = not self._pen_is_none(cov_pen) 27 | self.coverage_penalty = self._coverage_penalty(cov_pen) 28 | self.has_len_pen = not self._pen_is_none(length_pen) 29 | self.length_penalty = self._length_penalty(length_pen) 30 | 31 | @staticmethod 32 | def _pen_is_none(pen): 33 | return pen == "none" or pen is None 34 | 35 | def _coverage_penalty(self, cov_pen): 36 | if cov_pen == "wu": 37 | return self.coverage_wu 38 | elif cov_pen == "summary": 39 | return self.coverage_summary 40 | elif self._pen_is_none(cov_pen): 41 | return self.coverage_none 42 | else: 43 | raise NotImplementedError("No '{:s}' coverage penalty.".format( 44 | cov_pen)) 45 | 46 | def _length_penalty(self, length_pen): 47 | if length_pen == "wu": 48 | return self.length_wu 49 | elif length_pen == "avg": 50 | return self.length_average 51 | elif self._pen_is_none(length_pen): 52 | return self.length_none 53 | else: 54 | raise NotImplementedError("No '{:s}' length penalty.".format( 55 | length_pen)) 56 | 57 | # Below are all the different penalty terms implemented so far. 58 | # Subtract coverage penalty from topk log probs. 59 | # Divide topk log probs by length penalty. 60 | 61 | def coverage_wu(self, cov, beta=0.): 62 | """GNMT coverage re-ranking score. 63 | 64 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 65 | ``cov`` is expected to be sized ``(*, seq_len)``, where ``*`` is 66 | probably ``batch_size x beam_size`` but could be several 67 | dimensions like ``(batch_size, beam_size)``. If ``cov`` is attention, 68 | then the ``seq_len`` axis probably sums to (almost) 1. 69 | """ 70 | 71 | penalty = -torch.min(cov, cov.clone().fill_(1.0)).log().sum(-1) 72 | return beta * penalty 73 | 74 | def coverage_summary(self, cov, beta=0.): 75 | """Our summary penalty.""" 76 | penalty = torch.max(cov, cov.clone().fill_(1.0)).sum(-1) 77 | penalty -= cov.size(-1) 78 | return beta * penalty 79 | 80 | def coverage_none(self, cov, beta=0.): 81 | """Returns zero as penalty""" 82 | none = torch.zeros((1,), device=cov.device, 83 | dtype=torch.float) 84 | if cov.dim() == 3: 85 | none = none.unsqueeze(0) 86 | return none 87 | 88 | def length_wu(self, cur_len, alpha=0.): 89 | """GNMT length re-ranking score. 90 | 91 | See "Google's Neural Machine Translation System" :cite:`wu2016google`. 92 | """ 93 | 94 | return ((5 + cur_len) / 6.0) ** alpha 95 | 96 | def length_average(self, cur_len, alpha=0.): 97 | """Returns the current sequence length.""" 98 | return cur_len 99 | 100 | def length_none(self, cur_len, alpha=0.): 101 | """Returns unmodified scores.""" 102 | return 1.0 103 | -------------------------------------------------------------------------------- /onmt/translate/process_zh.py: -------------------------------------------------------------------------------- 1 | from pyhanlp import HanLP 2 | from snownlp import SnowNLP 3 | import pkuseg 4 | 5 | 6 | # Chinese segmentation 7 | def zh_segmentator(line): 8 | return " ".join(pkuseg.pkuseg().cut(line)) 9 | 10 | 11 | # Chinese simplify -> Chinese traditional standard 12 | def zh_traditional_standard(line): 13 | return HanLP.convertToTraditionalChinese(line) 14 | 15 | 16 | # Chinese simplify -> Chinese traditional (HongKong) 17 | def zh_traditional_hk(line): 18 | return HanLP.s2hk(line) 19 | 20 | 21 | # Chinese simplify -> Chinese traditional (Taiwan) 22 | def zh_traditional_tw(line): 23 | return HanLP.s2tw(line) 24 | 25 | 26 | # Chinese traditional -> Chinese simplify (v1) 27 | def zh_simplify(line): 28 | return HanLP.convertToSimplifiedChinese(line) 29 | 30 | 31 | # Chinese traditional -> Chinese simplify (v2) 32 | def zh_simplify_v2(line): 33 | return SnowNLP(line).han 34 | -------------------------------------------------------------------------------- /onmt/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module defining various utilities.""" 2 | from onmt.utils.misc import split_corpus, aeq, use_gpu, set_random_seed 3 | from onmt.utils.alignment import make_batch_align_matrix 4 | from onmt.utils.report_manager import ReportMgr, build_report_manager 5 | from onmt.utils.statistics import Statistics 6 | from onmt.utils.optimizers import MultipleOptimizer, \ 7 | Optimizer, AdaFactor 8 | from onmt.utils.earlystopping import EarlyStopping, scorers_from_opts 9 | 10 | __all__ = ["split_corpus", "aeq", "use_gpu", "set_random_seed", "ReportMgr", 11 | "build_report_manager", "Statistics", 12 | "MultipleOptimizer", "Optimizer", "AdaFactor", "EarlyStopping", 13 | "scorers_from_opts", "make_batch_align_matrix"] 14 | -------------------------------------------------------------------------------- /onmt/utils/alignment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from itertools import accumulate 5 | 6 | 7 | def make_batch_align_matrix(index_tensor, size=None, normalize=False): 8 | """ 9 | Convert a sparse index_tensor into a batch of alignment matrix, 10 | with row normalize to the sum of 1 if set normalize. 11 | 12 | Args: 13 | index_tensor (LongTensor): ``(N, 3)`` of [batch_id, tgt_id, src_id] 14 | size (List[int]): Size of the sparse tensor. 15 | normalize (bool): if normalize the 2nd dim of resulting tensor. 16 | """ 17 | n_fill, device = index_tensor.size(0), index_tensor.device 18 | value_tensor = torch.ones([n_fill], dtype=torch.float) 19 | dense_tensor = torch.sparse_coo_tensor( 20 | index_tensor.t(), value_tensor, size=size, device=device).to_dense() 21 | if normalize: 22 | row_sum = dense_tensor.sum(-1, keepdim=True) # sum by row(tgt) 23 | # threshold on 1 to avoid div by 0 24 | torch.nn.functional.threshold(row_sum, 1, 1, inplace=True) 25 | dense_tensor.div_(row_sum) 26 | return dense_tensor 27 | 28 | 29 | def extract_alignment(align_matrix, tgt_mask, src_lens, n_best): 30 | """ 31 | Extract a batched align_matrix into its src indice alignment lists, 32 | with tgt_mask to filter out invalid tgt position as EOS/PAD. 33 | BOS already excluded from tgt_mask in order to match prediction. 34 | 35 | Args: 36 | align_matrix (Tensor): ``(B, tgt_len, src_len)``, 37 | attention head normalized by Softmax(dim=-1) 38 | tgt_mask (BoolTensor): ``(B, tgt_len)``, True for EOS, PAD. 39 | src_lens (LongTensor): ``(B,)``, containing valid src length 40 | n_best (int): a value indicating number of parallel translation. 41 | * B: denote flattened batch as B = batch_size * n_best. 42 | 43 | Returns: 44 | alignments (List[List[FloatTensor]]): ``(batch_size, n_best,)``, 45 | containing valid alignment matrix for each translation. 46 | """ 47 | batch_size_n_best = align_matrix.size(0) 48 | assert batch_size_n_best % n_best == 0 49 | 50 | alignments = [[] for _ in range(batch_size_n_best // n_best)] 51 | 52 | # treat alignment matrix one by one as each have different lengths 53 | for i, (am_b, tgt_mask_b, src_len) in enumerate( 54 | zip(align_matrix, tgt_mask, src_lens)): 55 | valid_tgt = ~tgt_mask_b 56 | valid_tgt_len = valid_tgt.sum() 57 | # get valid alignment (sub-matrix from full paded aligment matrix) 58 | am_valid_tgt = am_b.masked_select(valid_tgt.unsqueeze(-1)) \ 59 | .view(valid_tgt_len, -1) 60 | valid_alignment = am_valid_tgt[:, :src_len] # only keep valid src 61 | alignments[i // n_best].append(valid_alignment) 62 | 63 | return alignments 64 | 65 | 66 | def build_align_pharaoh(valid_alignment): 67 | """Convert valid alignment matrix to i-j Pharaoh format.(0 indexed)""" 68 | align_pairs = [] 69 | tgt_align_src_id = valid_alignment.argmax(dim=-1) 70 | 71 | for tgt_id, src_id in enumerate(tgt_align_src_id.tolist()): 72 | align_pairs.append(str(src_id) + "-" + str(tgt_id)) 73 | align_pairs.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id 74 | align_pairs.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id 75 | return align_pairs 76 | 77 | 78 | def to_word_align(src, tgt, subword_align, mode): 79 | """Convert subword alignment to word alignment. 80 | 81 | Args: 82 | src (string): tokenized sentence in source language. 83 | tgt (string): tokenized sentence in target language. 84 | subword_align (string): align_pharaoh correspond to src-tgt. 85 | mode (string): tokenization mode used by src and tgt, 86 | choose from ["joiner", "spacer"]. 87 | 88 | Returns: 89 | word_align (string): converted alignments correspand to 90 | detokenized src-tgt. 91 | """ 92 | src, tgt = src.strip().split(), tgt.strip().split() 93 | subword_align = {(int(a), int(b)) for a, b in (x.split("-") 94 | for x in subword_align.split())} 95 | if mode == 'joiner': 96 | src_map = subword_map_by_joiner(src, marker='■') 97 | tgt_map = subword_map_by_joiner(tgt, marker='■') 98 | elif mode == 'spacer': 99 | src_map = subword_map_by_spacer(src, marker='▁') 100 | tgt_map = subword_map_by_spacer(tgt, marker='▁') 101 | else: 102 | raise ValueError("Invalid value for argument mode!") 103 | word_align = list({"{}-{}".format(src_map[a], tgt_map[b]) 104 | for a, b in subword_align}) 105 | word_align.sort(key=lambda x: int(x.split('-')[-1])) # sort by tgt_id 106 | word_align.sort(key=lambda x: int(x.split('-')[0])) # sort by src_id 107 | return " ".join(word_align) 108 | 109 | 110 | def subword_map_by_joiner(subwords, marker='■'): 111 | """Return word id for each subword token (annotate by joiner).""" 112 | flags = [0] * len(subwords) 113 | for i, tok in enumerate(subwords): 114 | if tok.endswith(marker): 115 | flags[i] = 1 116 | if tok.startswith(marker): 117 | assert i >= 1 and flags[i-1] != 1, \ 118 | "Sentence `{}` not correct!".format(" ".join(subwords)) 119 | flags[i-1] = 1 120 | marker_acc = list(accumulate([0] + flags[:-1])) 121 | word_group = [(i - maker_sofar) for i, maker_sofar 122 | in enumerate(marker_acc)] 123 | return word_group 124 | 125 | 126 | def subword_map_by_spacer(subwords, marker='▁'): 127 | """Return word id for each subword token (annotate by spacer).""" 128 | word_group = list(accumulate([int(marker in x) for x in subwords])) 129 | if word_group[0] == 1: # when dummy prefix is set 130 | word_group = [item - 1 for item in word_group] 131 | return word_group 132 | -------------------------------------------------------------------------------- /onmt/utils/cnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of "Convolutional Sequence to Sequence Learning" 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | 8 | import onmt.modules 9 | 10 | SCALE_WEIGHT = 0.5 ** 0.5 11 | 12 | 13 | def shape_transform(x): 14 | """ Tranform the size of the tensors to fit for conv input. """ 15 | return torch.unsqueeze(torch.transpose(x, 1, 2), 3) 16 | 17 | 18 | class GatedConv(nn.Module): 19 | """ Gated convolution for CNN class """ 20 | 21 | def __init__(self, input_size, width=3, dropout=0.2, nopad=False): 22 | super(GatedConv, self).__init__() 23 | self.conv = onmt.modules.WeightNormConv2d( 24 | input_size, 2 * input_size, kernel_size=(width, 1), stride=(1, 1), 25 | padding=(width // 2 * (1 - nopad), 0)) 26 | init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout))**0.5) 27 | self.dropout = nn.Dropout(dropout) 28 | 29 | def forward(self, x_var): 30 | x_var = self.dropout(x_var) 31 | x_var = self.conv(x_var) 32 | out, gate = x_var.split(int(x_var.size(1) / 2), 1) 33 | out = out * torch.sigmoid(gate) 34 | return out 35 | 36 | 37 | class StackedCNN(nn.Module): 38 | """ Stacked CNN class """ 39 | 40 | def __init__(self, num_layers, input_size, cnn_kernel_width=3, 41 | dropout=0.2): 42 | super(StackedCNN, self).__init__() 43 | self.dropout = dropout 44 | self.num_layers = num_layers 45 | self.layers = nn.ModuleList() 46 | for _ in range(num_layers): 47 | self.layers.append( 48 | GatedConv(input_size, cnn_kernel_width, dropout)) 49 | 50 | def forward(self, x): 51 | for conv in self.layers: 52 | x = x + conv(x) 53 | x *= SCALE_WEIGHT 54 | return x 55 | -------------------------------------------------------------------------------- /onmt/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ Pytorch Distributed utils 2 | This piece of code was heavily inspired by the equivalent of Fairseq-py 3 | https://github.com/pytorch/fairseq 4 | """ 5 | 6 | 7 | from __future__ import print_function 8 | 9 | import math 10 | import pickle 11 | import torch.distributed 12 | 13 | from onmt.utils.logging import logger 14 | 15 | 16 | def is_master(opt, device_id): 17 | return opt.gpu_ranks[device_id] == 0 18 | 19 | 20 | def multi_init(opt, device_id): 21 | dist_init_method = 'tcp://{master_ip}:{master_port}'.format( 22 | master_ip=opt.master_ip, 23 | master_port=opt.master_port) 24 | dist_world_size = opt.world_size 25 | torch.distributed.init_process_group( 26 | backend=opt.gpu_backend, init_method=dist_init_method, 27 | world_size=dist_world_size, rank=opt.gpu_ranks[device_id]) 28 | gpu_rank = torch.distributed.get_rank() 29 | if not is_master(opt, device_id): 30 | logger.disabled = True 31 | 32 | return gpu_rank 33 | 34 | 35 | def all_reduce_and_rescale_tensors(tensors, rescale_denom, 36 | buffer_size=10485760): 37 | """All-reduce and rescale tensors in chunks of the specified size. 38 | 39 | Args: 40 | tensors: list of Tensors to all-reduce 41 | rescale_denom: denominator for rescaling summed Tensors 42 | buffer_size: all-reduce chunk size in bytes 43 | """ 44 | # buffer size in bytes, determine equiv. # of elements based on data type 45 | buffer_t = tensors[0].new( 46 | math.ceil(buffer_size / tensors[0].element_size())).zero_() 47 | buffer = [] 48 | 49 | def all_reduce_buffer(): 50 | # copy tensors into buffer_t 51 | offset = 0 52 | for t in buffer: 53 | numel = t.numel() 54 | buffer_t[offset:offset+numel].copy_(t.view(-1)) 55 | offset += numel 56 | 57 | # all-reduce and rescale 58 | torch.distributed.all_reduce(buffer_t[:offset]) 59 | buffer_t.div_(rescale_denom) 60 | 61 | # copy all-reduced buffer back into tensors 62 | offset = 0 63 | for t in buffer: 64 | numel = t.numel() 65 | t.view(-1).copy_(buffer_t[offset:offset+numel]) 66 | offset += numel 67 | 68 | filled = 0 69 | for t in tensors: 70 | sz = t.numel() * t.element_size() 71 | if sz > buffer_size: 72 | # tensor is bigger than buffer, all-reduce and rescale directly 73 | torch.distributed.all_reduce(t) 74 | t.div_(rescale_denom) 75 | elif filled + sz > buffer_size: 76 | # buffer is full, all-reduce and replace buffer with grad 77 | all_reduce_buffer() 78 | buffer = [t] 79 | filled = sz 80 | else: 81 | # add tensor to buffer 82 | buffer.append(t) 83 | filled += sz 84 | 85 | if len(buffer) > 0: 86 | all_reduce_buffer() 87 | 88 | 89 | def all_gather_list(data, max_size=4096): 90 | """Gathers arbitrary data from all nodes into a list.""" 91 | world_size = torch.distributed.get_world_size() 92 | if not hasattr(all_gather_list, '_in_buffer') or \ 93 | max_size != all_gather_list._in_buffer.size(): 94 | all_gather_list._in_buffer = torch.cuda.ByteTensor(max_size) 95 | all_gather_list._out_buffers = [ 96 | torch.cuda.ByteTensor(max_size) 97 | for i in range(world_size) 98 | ] 99 | in_buffer = all_gather_list._in_buffer 100 | out_buffers = all_gather_list._out_buffers 101 | 102 | enc = pickle.dumps(data) 103 | enc_size = len(enc) 104 | if enc_size + 2 > max_size: 105 | raise ValueError( 106 | 'encoded data exceeds max_size: {}'.format(enc_size + 2)) 107 | assert max_size < 255*256 108 | in_buffer[0] = enc_size // 255 # this encoding works for max_size < 65k 109 | in_buffer[1] = enc_size % 255 110 | in_buffer[2:enc_size+2] = torch.ByteTensor(list(enc)) 111 | 112 | torch.distributed.all_gather(out_buffers, in_buffer.cuda()) 113 | 114 | results = [] 115 | for i in range(world_size): 116 | out_buffer = out_buffers[i] 117 | size = (255 * out_buffer[0].item()) + out_buffer[1].item() 118 | 119 | bytes_list = bytes(out_buffer[2:size+2].tolist()) 120 | result = pickle.loads(bytes_list) 121 | results.append(result) 122 | return results 123 | -------------------------------------------------------------------------------- /onmt/utils/earlystopping.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Enum 3 | from onmt.utils.logging import logger 4 | 5 | 6 | class PatienceEnum(Enum): 7 | IMPROVING = 0 8 | DECREASING = 1 9 | STOPPED = 2 10 | 11 | 12 | class Scorer(object): 13 | def __init__(self, best_score, name): 14 | self.best_score = best_score 15 | self.name = name 16 | 17 | def is_improving(self, stats): 18 | raise NotImplementedError() 19 | 20 | def is_decreasing(self, stats): 21 | raise NotImplementedError() 22 | 23 | def update(self, stats): 24 | self.best_score = self._caller(stats) 25 | 26 | def __call__(self, stats, **kwargs): 27 | return self._caller(stats) 28 | 29 | def _caller(self, stats): 30 | raise NotImplementedError() 31 | 32 | 33 | class PPLScorer(Scorer): 34 | 35 | def __init__(self): 36 | super(PPLScorer, self).__init__(float("inf"), "ppl") 37 | 38 | def is_improving(self, stats): 39 | return stats.ppl() < self.best_score 40 | 41 | def is_decreasing(self, stats): 42 | return stats.ppl() > self.best_score 43 | 44 | def _caller(self, stats): 45 | return stats.ppl() 46 | 47 | 48 | class AccuracyScorer(Scorer): 49 | 50 | def __init__(self): 51 | super(AccuracyScorer, self).__init__(float("-inf"), "acc") 52 | 53 | def is_improving(self, stats): 54 | return stats.accuracy() > self.best_score 55 | 56 | def is_decreasing(self, stats): 57 | return stats.accuracy() < self.best_score 58 | 59 | def _caller(self, stats): 60 | return stats.accuracy() 61 | 62 | 63 | DEFAULT_SCORERS = [PPLScorer(), AccuracyScorer()] 64 | 65 | 66 | SCORER_BUILDER = { 67 | "ppl": PPLScorer, 68 | "accuracy": AccuracyScorer 69 | } 70 | 71 | 72 | def scorers_from_opts(opt): 73 | if opt.early_stopping_criteria is None: 74 | return DEFAULT_SCORERS 75 | else: 76 | scorers = [] 77 | for criterion in set(opt.early_stopping_criteria): 78 | assert criterion in SCORER_BUILDER.keys(), \ 79 | "Criterion {} not found".format(criterion) 80 | scorers.append(SCORER_BUILDER[criterion]()) 81 | return scorers 82 | 83 | 84 | class EarlyStopping(object): 85 | 86 | def __init__(self, tolerance, scorers=DEFAULT_SCORERS): 87 | """ 88 | Callable class to keep track of early stopping. 89 | 90 | Args: 91 | tolerance(int): number of validation steps without improving 92 | scorer(fn): list of scorers to validate performance on dev 93 | """ 94 | 95 | self.tolerance = tolerance 96 | self.stalled_tolerance = self.tolerance 97 | self.current_tolerance = self.tolerance 98 | self.early_stopping_scorers = scorers 99 | self.status = PatienceEnum.IMPROVING 100 | self.current_step_best = 0 101 | 102 | def __call__(self, valid_stats, step): 103 | """ 104 | Update the internal state of early stopping mechanism, whether to 105 | continue training or stop the train procedure. 106 | 107 | Checks whether the scores from all pre-chosen scorers improved. If 108 | every metric improve, then the status is switched to improving and the 109 | tolerance is reset. If every metric deteriorate, then the status is 110 | switched to decreasing and the tolerance is also decreased; if the 111 | tolerance reaches 0, then the status is changed to stopped. 112 | Finally, if some improved and others not, then it's considered stalled; 113 | after tolerance number of stalled, the status is switched to stopped. 114 | 115 | :param valid_stats: Statistics of dev set 116 | """ 117 | 118 | if self.status == PatienceEnum.STOPPED: 119 | # Don't do anything 120 | return 121 | 122 | if all([scorer.is_improving(valid_stats) for scorer 123 | in self.early_stopping_scorers]): 124 | self._update_increasing(valid_stats, step) 125 | 126 | elif all([scorer.is_decreasing(valid_stats) for scorer 127 | in self.early_stopping_scorers]): 128 | self._update_decreasing() 129 | 130 | else: 131 | self._update_stalled() 132 | 133 | def _update_stalled(self): 134 | self.stalled_tolerance -= 1 135 | 136 | logger.info( 137 | "Stalled patience: {}/{}".format(self.stalled_tolerance, 138 | self.tolerance)) 139 | 140 | if self.stalled_tolerance == 0: 141 | logger.info( 142 | "Training finished after stalled validations. Early Stop!" 143 | ) 144 | self._log_best_step() 145 | 146 | self._decreasing_or_stopped_status_update(self.stalled_tolerance) 147 | 148 | def _update_increasing(self, valid_stats, step): 149 | self.current_step_best = step 150 | for scorer in self.early_stopping_scorers: 151 | logger.info( 152 | "Model is improving {}: {:g} --> {:g}.".format( 153 | scorer.name, scorer.best_score, scorer(valid_stats)) 154 | ) 155 | # Update best score of each criteria 156 | scorer.update(valid_stats) 157 | 158 | # Reset tolerance 159 | self.current_tolerance = self.tolerance 160 | self.stalled_tolerance = self.tolerance 161 | 162 | # Update current status 163 | self.status = PatienceEnum.IMPROVING 164 | 165 | def _update_decreasing(self): 166 | # Decrease tolerance 167 | self.current_tolerance -= 1 168 | 169 | # Log 170 | logger.info( 171 | "Decreasing patience: {}/{}".format(self.current_tolerance, 172 | self.tolerance) 173 | ) 174 | # Log 175 | if self.current_tolerance == 0: 176 | logger.info("Training finished after not improving. Early Stop!") 177 | self._log_best_step() 178 | 179 | self._decreasing_or_stopped_status_update(self.current_tolerance) 180 | 181 | def _log_best_step(self): 182 | logger.info("Best model found at step {}".format( 183 | self.current_step_best)) 184 | 185 | def _decreasing_or_stopped_status_update(self, tolerance): 186 | self.status = PatienceEnum.DECREASING \ 187 | if tolerance > 0 \ 188 | else PatienceEnum.STOPPED 189 | 190 | def is_improving(self): 191 | return self.status == PatienceEnum.IMPROVING 192 | 193 | def has_stopped(self): 194 | return self.status == PatienceEnum.STOPPED 195 | -------------------------------------------------------------------------------- /onmt/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import absolute_import 3 | 4 | import logging 5 | from logging.handlers import RotatingFileHandler 6 | logger = logging.getLogger() 7 | 8 | 9 | def init_logger(log_file=None, log_file_level=logging.NOTSET): 10 | log_format = logging.Formatter("[%(asctime)s %(levelname)s] %(message)s") 11 | logger = logging.getLogger() 12 | logger.setLevel(logging.INFO) 13 | 14 | console_handler = logging.StreamHandler() 15 | console_handler.setFormatter(log_format) 16 | logger.handlers = [console_handler] 17 | 18 | if log_file and log_file != '': 19 | file_handler = RotatingFileHandler( 20 | log_file, maxBytes=1000, backupCount=10) 21 | file_handler.setLevel(log_file_level) 22 | file_handler.setFormatter(log_format) 23 | logger.addHandler(file_handler) 24 | 25 | return logger 26 | -------------------------------------------------------------------------------- /onmt/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import random 5 | import inspect 6 | from itertools import islice, repeat 7 | import os 8 | 9 | 10 | def split_corpus(path, shard_size, default=None): 11 | """yield a `list` containing `shard_size` line of `path`, 12 | or repeatly generate `default` if `path` is None. 13 | """ 14 | if path is not None: 15 | return _split_corpus(path, shard_size) 16 | else: 17 | return repeat(default) 18 | 19 | 20 | def _split_corpus(path, shard_size): 21 | """Yield a `list` containing `shard_size` line of `path`. 22 | """ 23 | with open(path, "rb") as f: 24 | if shard_size <= 0: 25 | yield f.readlines() 26 | else: 27 | while True: 28 | shard = list(islice(f, shard_size)) 29 | if not shard: 30 | break 31 | yield shard 32 | 33 | 34 | def aeq(*args): 35 | """ 36 | Assert all arguments have the same value 37 | """ 38 | arguments = (arg for arg in args) 39 | first = next(arguments) 40 | assert all(arg == first for arg in arguments), \ 41 | "Not all arguments have the same value: " + str(args) 42 | 43 | 44 | def sequence_mask(lengths, max_len=None): 45 | """ 46 | Creates a boolean mask from sequence lengths. 47 | """ 48 | batch_size = lengths.numel() 49 | max_len = max_len or lengths.max() 50 | return (torch.arange(0, max_len, device=lengths.device) 51 | .type_as(lengths) 52 | .repeat(batch_size, 1) 53 | .lt(lengths.unsqueeze(1))) 54 | 55 | 56 | def tile(x, count, dim=0): 57 | """ 58 | Tiles x on dimension dim count times. 59 | """ 60 | perm = list(range(len(x.size()))) 61 | if dim != 0: 62 | perm[0], perm[dim] = perm[dim], perm[0] 63 | x = x.permute(perm).contiguous() 64 | out_size = list(x.size()) 65 | out_size[0] *= count 66 | batch = x.size(0) 67 | x = x.view(batch, -1) \ 68 | .transpose(0, 1) \ 69 | .repeat(count, 1) \ 70 | .transpose(0, 1) \ 71 | .contiguous() \ 72 | .view(*out_size) 73 | if dim != 0: 74 | x = x.permute(perm).contiguous() 75 | return x 76 | 77 | 78 | def use_gpu(opt): 79 | """ 80 | Creates a boolean if gpu used 81 | """ 82 | return (hasattr(opt, 'gpu_ranks') and len(opt.gpu_ranks) > 0) or \ 83 | (hasattr(opt, 'gpu') and opt.gpu > -1) 84 | 85 | 86 | def set_random_seed(seed, is_cuda): 87 | """Sets the random seed.""" 88 | if seed > 0: 89 | torch.manual_seed(seed) 90 | # this one is needed for torchtext random call (shuffled iterator) 91 | # in multi gpu it ensures datasets are read in the same order 92 | random.seed(seed) 93 | # some cudnn methods can be random even after fixing the seed 94 | # unless you tell it to be deterministic 95 | torch.backends.cudnn.deterministic = True 96 | 97 | if is_cuda and seed > 0: 98 | # These ensure same initialization in multi gpu mode 99 | torch.cuda.manual_seed(seed) 100 | 101 | 102 | def generate_relative_positions_matrix(length, max_relative_positions, 103 | cache=False): 104 | """Generate the clipped relative positions matrix 105 | for a given length and maximum relative positions""" 106 | if cache: 107 | distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0) 108 | else: 109 | range_vec = torch.arange(length) 110 | range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1) 111 | distance_mat = range_mat - range_mat.transpose(0, 1) 112 | distance_mat_clipped = torch.clamp(distance_mat, 113 | min=-max_relative_positions, 114 | max=max_relative_positions) 115 | # Shift values to be >= 0 116 | final_mat = distance_mat_clipped + max_relative_positions 117 | return final_mat 118 | 119 | 120 | def relative_matmul(x, z, transpose): 121 | """Helper function for relative positions attention.""" 122 | batch_size = x.shape[0] 123 | heads = x.shape[1] 124 | length = x.shape[2] 125 | x_t = x.permute(2, 0, 1, 3) 126 | x_t_r = x_t.reshape(length, heads * batch_size, -1) 127 | if transpose: 128 | z_t = z.transpose(1, 2) 129 | x_tz_matmul = torch.matmul(x_t_r, z_t) 130 | else: 131 | x_tz_matmul = torch.matmul(x_t_r, z) 132 | x_tz_matmul_r = x_tz_matmul.reshape(length, batch_size, heads, -1) 133 | x_tz_matmul_r_t = x_tz_matmul_r.permute(1, 2, 0, 3) 134 | return x_tz_matmul_r_t 135 | 136 | 137 | def fn_args(fun): 138 | """Returns the list of function arguments name.""" 139 | return inspect.getfullargspec(fun).args 140 | 141 | 142 | def nwise(iterable, n=2): 143 | iterables = tee(iterable, n) 144 | [next(iterables[i]) for i in range(n) for j in range(i)] 145 | return zip(*iterables) 146 | 147 | 148 | def report_matrix(row_label, column_label, matrix): 149 | header_format = "{:>10.10} " + "{:>10.7} " * len(row_label) 150 | row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label) 151 | output = header_format.format("", *row_label) + '\n' 152 | for word, row in zip(column_label, matrix): 153 | max_index = row.index(max(row)) 154 | row_format = row_format.replace( 155 | "{:>10.7f} ", "{:*>10.7f} ", max_index + 1) 156 | row_format = row_format.replace( 157 | "{:*>10.7f} ", "{:>10.7f} ", max_index) 158 | output += row_format.format(word, *row) + '\n' 159 | row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label) 160 | return output 161 | 162 | 163 | def check_model_config(model_config, root): 164 | # we need to check the model path + any tokenizer path 165 | for model in model_config["models"]: 166 | model_path = os.path.join(root, model) 167 | if not os.path.exists(model_path): 168 | raise FileNotFoundError( 169 | "{} from model {} does not exist".format( 170 | model_path, model_config["id"])) 171 | if "tokenizer" in model_config.keys(): 172 | if "params" in model_config["tokenizer"].keys(): 173 | for k, v in model_config["tokenizer"]["params"].items(): 174 | if k.endswith("path"): 175 | tok_path = os.path.join(root, v) 176 | if not os.path.exists(tok_path): 177 | raise FileNotFoundError( 178 | "{} from model {} does not exist".format( 179 | tok_path, model_config["id"])) 180 | -------------------------------------------------------------------------------- /onmt/utils/report_manager.py: -------------------------------------------------------------------------------- 1 | """ Report manager utility """ 2 | from __future__ import print_function 3 | import time 4 | from datetime import datetime 5 | 6 | import onmt 7 | 8 | from onmt.utils.logging import logger 9 | 10 | 11 | def build_report_manager(opt, gpu_rank): 12 | if opt.tensorboard and gpu_rank == 0: 13 | from torch.utils.tensorboard import SummaryWriter 14 | tensorboard_log_dir = opt.tensorboard_log_dir 15 | 16 | if not opt.train_from: 17 | tensorboard_log_dir += datetime.now().strftime("/%b-%d_%H-%M-%S") 18 | 19 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") 20 | else: 21 | writer = None 22 | 23 | report_mgr = ReportMgr(opt.report_every, start_time=-1, 24 | tensorboard_writer=writer) 25 | return report_mgr 26 | 27 | 28 | class ReportMgrBase(object): 29 | """ 30 | Report Manager Base class 31 | Inherited classes should override: 32 | * `_report_training` 33 | * `_report_step` 34 | """ 35 | 36 | def __init__(self, report_every, start_time=-1.): 37 | """ 38 | Args: 39 | report_every(int): Report status every this many sentences 40 | start_time(float): manually set report start time. Negative values 41 | means that you will need to set it later or use `start()` 42 | """ 43 | self.report_every = report_every 44 | self.start_time = start_time 45 | 46 | def start(self): 47 | self.start_time = time.time() 48 | 49 | def log(self, *args, **kwargs): 50 | logger.info(*args, **kwargs) 51 | 52 | def report_training(self, step, num_steps, learning_rate, 53 | report_stats, multigpu=False): 54 | """ 55 | This is the user-defined batch-level traing progress 56 | report function. 57 | 58 | Args: 59 | step(int): current step count. 60 | num_steps(int): total number of batches. 61 | learning_rate(float): current learning rate. 62 | report_stats(Statistics): old Statistics instance. 63 | Returns: 64 | report_stats(Statistics): updated Statistics instance. 65 | """ 66 | if self.start_time < 0: 67 | raise ValueError("""ReportMgr needs to be started 68 | (set 'start_time' or use 'start()'""") 69 | 70 | if step % self.report_every == 0: 71 | if multigpu: 72 | report_stats = \ 73 | onmt.utils.Statistics.all_gather_stats(report_stats) 74 | self._report_training( 75 | step, num_steps, learning_rate, report_stats) 76 | return onmt.utils.Statistics() 77 | else: 78 | return report_stats 79 | 80 | def _report_training(self, *args, **kwargs): 81 | """ To be overridden """ 82 | raise NotImplementedError() 83 | 84 | def report_step(self, lr, step, train_stats=None, valid_stats=None): 85 | """ 86 | Report stats of a step 87 | 88 | Args: 89 | train_stats(Statistics): training stats 90 | valid_stats(Statistics): validation stats 91 | lr(float): current learning rate 92 | """ 93 | self._report_step( 94 | lr, step, train_stats=train_stats, valid_stats=valid_stats) 95 | 96 | def _report_step(self, *args, **kwargs): 97 | raise NotImplementedError() 98 | 99 | 100 | class ReportMgr(ReportMgrBase): 101 | def __init__(self, report_every, start_time=-1., tensorboard_writer=None): 102 | """ 103 | A report manager that writes statistics on standard output as well as 104 | (optionally) TensorBoard 105 | 106 | Args: 107 | report_every(int): Report status every this many sentences 108 | tensorboard_writer(:obj:`tensorboard.SummaryWriter`): 109 | The TensorBoard Summary writer to use or None 110 | """ 111 | super(ReportMgr, self).__init__(report_every, start_time) 112 | self.tensorboard_writer = tensorboard_writer 113 | 114 | def maybe_log_tensorboard(self, stats, prefix, learning_rate, step): 115 | if self.tensorboard_writer is not None: 116 | stats.log_tensorboard( 117 | prefix, self.tensorboard_writer, learning_rate, step) 118 | 119 | def _report_training(self, step, num_steps, learning_rate, 120 | report_stats): 121 | """ 122 | See base class method `ReportMgrBase.report_training`. 123 | """ 124 | report_stats.output(step, num_steps, 125 | learning_rate, self.start_time) 126 | 127 | self.maybe_log_tensorboard(report_stats, 128 | "progress", 129 | learning_rate, 130 | step) 131 | report_stats = onmt.utils.Statistics() 132 | 133 | return report_stats 134 | 135 | def _report_step(self, lr, step, train_stats=None, valid_stats=None): 136 | """ 137 | See base class method `ReportMgrBase.report_step`. 138 | """ 139 | if train_stats is not None: 140 | self.log('Train perplexity: %g' % train_stats.ppl()) 141 | self.log('Train accuracy: %g' % train_stats.accuracy()) 142 | 143 | self.maybe_log_tensorboard(train_stats, 144 | "train", 145 | lr, 146 | step) 147 | 148 | if valid_stats is not None: 149 | self.log('Validation perplexity: %g' % valid_stats.ppl()) 150 | self.log('Validation accuracy: %g' % valid_stats.accuracy()) 151 | 152 | self.maybe_log_tensorboard(valid_stats, 153 | "valid", 154 | lr, 155 | step) 156 | -------------------------------------------------------------------------------- /onmt/utils/rnn_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | RNN tools 3 | """ 4 | import torch.nn as nn 5 | import onmt.models 6 | 7 | 8 | def rnn_factory(rnn_type, **kwargs): 9 | """ rnn factory, Use pytorch version when available. """ 10 | no_pack_padded_seq = False 11 | if rnn_type == "SRU": 12 | # SRU doesn't support PackedSequence. 13 | no_pack_padded_seq = True 14 | rnn = onmt.models.sru.SRU(**kwargs) 15 | else: 16 | rnn = getattr(nn, rnn_type)(**kwargs) 17 | return rnn, no_pack_padded_seq 18 | -------------------------------------------------------------------------------- /onmt/utils/statistics.py: -------------------------------------------------------------------------------- 1 | """ Statistics calculation utility """ 2 | from __future__ import division 3 | import time 4 | import math 5 | import sys 6 | 7 | from onmt.utils.logging import logger 8 | 9 | 10 | class Statistics(object): 11 | """ 12 | Accumulator for loss statistics. 13 | Currently calculates: 14 | 15 | * accuracy 16 | * perplexity 17 | * elapsed time 18 | """ 19 | 20 | def __init__(self, loss=0, n_words=0, n_correct=0): 21 | self.loss = loss 22 | self.n_words = n_words 23 | self.n_correct = n_correct 24 | self.n_src_words = 0 25 | self.start_time = time.time() 26 | 27 | @staticmethod 28 | def all_gather_stats(stat, max_size=4096): 29 | """ 30 | Gather a `Statistics` object accross multiple process/nodes 31 | 32 | Args: 33 | stat(:obj:Statistics): the statistics object to gather 34 | accross all processes/nodes 35 | max_size(int): max buffer size to use 36 | 37 | Returns: 38 | `Statistics`, the update stats object 39 | """ 40 | stats = Statistics.all_gather_stats_list([stat], max_size=max_size) 41 | return stats[0] 42 | 43 | @staticmethod 44 | def all_gather_stats_list(stat_list, max_size=4096): 45 | """ 46 | Gather a `Statistics` list accross all processes/nodes 47 | 48 | Args: 49 | stat_list(list([`Statistics`])): list of statistics objects to 50 | gather accross all processes/nodes 51 | max_size(int): max buffer size to use 52 | 53 | Returns: 54 | our_stats(list([`Statistics`])): list of updated stats 55 | """ 56 | from torch.distributed import get_rank 57 | from onmt.utils.distributed import all_gather_list 58 | 59 | # Get a list of world_size lists with len(stat_list) Statistics objects 60 | all_stats = all_gather_list(stat_list, max_size=max_size) 61 | 62 | our_rank = get_rank() 63 | our_stats = all_stats[our_rank] 64 | for other_rank, stats in enumerate(all_stats): 65 | if other_rank == our_rank: 66 | continue 67 | for i, stat in enumerate(stats): 68 | our_stats[i].update(stat, update_n_src_words=True) 69 | return our_stats 70 | 71 | def update(self, stat, update_n_src_words=False): 72 | """ 73 | Update statistics by suming values with another `Statistics` object 74 | 75 | Args: 76 | stat: another statistic object 77 | update_n_src_words(bool): whether to update (sum) `n_src_words` 78 | or not 79 | 80 | """ 81 | self.loss += stat.loss 82 | self.n_words += stat.n_words 83 | self.n_correct += stat.n_correct 84 | 85 | if update_n_src_words: 86 | self.n_src_words += stat.n_src_words 87 | 88 | def accuracy(self): 89 | """ compute accuracy """ 90 | return 100 * (self.n_correct / self.n_words) 91 | 92 | def xent(self): 93 | """ compute cross entropy """ 94 | return self.loss / self.n_words 95 | 96 | def ppl(self): 97 | """ compute perplexity """ 98 | return math.exp(min(self.loss / self.n_words, 100)) 99 | 100 | def elapsed_time(self): 101 | """ compute elapsed time """ 102 | return time.time() - self.start_time 103 | 104 | def output(self, step, num_steps, learning_rate, start): 105 | """Write out statistics to stdout. 106 | 107 | Args: 108 | step (int): current step 109 | n_batch (int): total batches 110 | start (int): start time of step. 111 | """ 112 | t = self.elapsed_time() 113 | step_fmt = "%2d" % step 114 | if num_steps > 0: 115 | step_fmt = "%s/%5d" % (step_fmt, num_steps) 116 | logger.info( 117 | ("Step %s; acc: %6.2f; ppl: %5.2f; xent: %4.2f; " + 118 | "lr: %7.5f; %3.0f/%3.0f tok/s; %6.0f sec") 119 | % (step_fmt, 120 | self.accuracy(), 121 | self.ppl(), 122 | self.xent(), 123 | learning_rate, 124 | self.n_src_words / (t + 1e-5), 125 | self.n_words / (t + 1e-5), 126 | time.time() - start)) 127 | sys.stdout.flush() 128 | 129 | def log_tensorboard(self, prefix, writer, learning_rate, step): 130 | """ display statistics to tensorboard """ 131 | t = self.elapsed_time() 132 | writer.add_scalar(prefix + "/xent", self.xent(), step) 133 | writer.add_scalar(prefix + "/ppl", self.ppl(), step) 134 | writer.add_scalar(prefix + "/accuracy", self.accuracy(), step) 135 | writer.add_scalar(prefix + "/tgtper", self.n_words / t, step) 136 | writer.add_scalar(prefix + "/lr", learning_rate, step) 137 | -------------------------------------------------------------------------------- /outputs.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KaijuML/data-to-text-hierarchical/da88d2d4491266fccc39ac1cc1fbb56bd7bbc30c/outputs.zip -------------------------------------------------------------------------------- /preprocess.cfg: -------------------------------------------------------------------------------- 1 | # 2 | train_src: "data/train_input.txt" 3 | train_tgt: "data/train_output.txt" 4 | valid_src: "data/valid_input.txt" 5 | valid_tgt: "data/valid_output.txt" 6 | save_data: "experiments/exp-1/data/data" 7 | src_vocab_size: 50000 8 | tgt_vocab_size: 50000 9 | src_seq_length: 1000 # we do not truncate the source table 10 | tgt_seq_length: 1000 # we do not truncate the target sentences 11 | 12 | dynamic_dict: true # Dynamic dict is used by the copy mechanism 13 | share_vocab: true # SRC and TGT are both written in english 14 | 15 | lower: true 16 | 17 | log_file: 'experiments/exp-1/data/preprocess-log.txt' 18 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from onmt.bin.preprocess import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==19.3.0 2 | backcall==0.1.0 3 | bleach==3.3.0 4 | certifi==2019.11.28 5 | cffi==1.13.2 6 | chardet==3.0.4 7 | ConfigArgParse==0.14.0 8 | cycler==0.10.0 9 | decorator==4.4.1 10 | defusedxml==0.6.0 11 | entrypoints==0.3 12 | idna==2.8 13 | importlib-metadata==1.3.0 14 | ipykernel==5.1.3 15 | ipython==7.10.2 16 | ipython-genutils==0.2.0 17 | ipywidgets==7.5.1 18 | jedi==0.15.1 19 | Jinja2==2.11.3 20 | json5==0.8.5 21 | jsonschema==3.2.0 22 | jupyter==1.0.0 23 | jupyter-client==5.3.4 24 | jupyter-console==6.0.0 25 | jupyter-core==4.6.1 26 | jupyterlab==1.2.4 27 | jupyterlab-server==1.0.6 28 | kiwisolver==1.1.0 29 | lab==5.1 30 | MarkupSafe==1.1.1 31 | matplotlib==3.1.2 32 | mistune==0.8.4 33 | more-itertools==8.0.2 34 | nbconvert==5.6.1 35 | nbformat==4.4.0 36 | notebook==6.1.5 37 | numpy==1.17.4 38 | pandocfilters==1.4.2 39 | parso==0.5.2 40 | pexpect==4.7.0 41 | pickleshare==0.7.5 42 | prometheus-client==0.7.1 43 | prompt-toolkit==2.0.10 44 | ptyprocess==0.6.0 45 | pycparser==2.19 46 | Pygments==2.7.4 47 | pyparsing==2.4.5 48 | pyrsistent==0.15.6 49 | python-dateutil==2.8.1 50 | PyYAML==5.4 51 | pyzmq==18.1.1 52 | qtconsole==4.6.0 53 | requests==2.22.0 54 | Send2Trash==1.5.0 55 | simplejson==3.17.0 56 | six==1.13.0 57 | terminado==0.8.3 58 | testpath==0.4.4 59 | torch==1.1.0 60 | torchtext==0.4.0 61 | tornado==6.0.3 62 | tqdm==4.40.2 63 | traitlets==4.3.3 64 | urllib3==1.25.7 65 | wcwidth==0.1.7 66 | webencodings==0.5.1 67 | widgetsnbextension==3.5.1 68 | zipp==0.6.0 69 | -------------------------------------------------------------------------------- /train.cfg: -------------------------------------------------------------------------------- 1 | # Model/Embeddings 2 | word_vec_size: 600 # Word embedding size for src and tgt 3 | share_embeddings: True # Share embeddings from src and tgt 4 | 5 | # Model/Embedding Features 6 | feat_vec_size: 20 # Attribute embedding size 7 | feat_merge: mlp # Merge action for incorporating feature embeddings [concat|sum|mlp] 8 | feat_merge_activation: ReLU 9 | 10 | 11 | # Model Structure 12 | model_type: table # Type of source model to use [text|table|img|audio] 13 | model_dtype: fp32 14 | encoder_type: htransformer # Type of encoder [rnn|brnn|transformer|htransformer|cnn] 15 | decoder_type: hrnn # Type of decoder [rnn|transformer|cnn|hrnn] 16 | param_init: 0.1 # Uniform distribution with support (-param_init, +param_init) 17 | 18 | # We put sizes we wish to change manually at -1 19 | layers: -1 20 | enc_layers: -1 21 | heads: -1 22 | glu_depth: -1 23 | 24 | # Encoder sizes 25 | transformer_ff: 1024 # Size of hidden transformer feed-forward 26 | units_layers: 2 27 | chunks_layers: 2 28 | units_head: 2 29 | chunks_head: 2 30 | units_glu_depth: 1 31 | chunks_glu_depth: 1 32 | 33 | # Decoder sizes 34 | dec_layers: 2 35 | rnn_size: 600 36 | input_feed: 1 37 | bridge: True 38 | rnn_type: LSTM 39 | 40 | 41 | # Model/Attention 42 | global_attention: general # Type of attn to use [dot|general|mlp|none] 43 | global_attention_function: softmax # [softmax|sparsemax] 44 | self_attn_type: scaled-dot # self attn type in transformer [scaled-dot|average] 45 | generator_function: softmax 46 | use_pos: True # whether using attributes in attention layers 47 | 48 | # Model/Copy 49 | copy_attn: True 50 | reuse_copy_attn: True # Reuse standard attention for copy 51 | copy_attn_force: True # When available, train to copy 52 | 53 | 54 | # Files and logs 55 | data: experiments/exp-1/data/data # path to datafile from preprocess.py 56 | save_model: experiments/exp-1/models/model # path to store checkpoints 57 | log_file: experiments/exp-1/train-log.txt 58 | 59 | report_every: 50 # log current loss every X steps 60 | save_checkpoint_steps: 500 # save a cp every X steps 61 | 62 | 63 | # Gpu related: 64 | gpu_ranks: [0] # ids of gpus to use 65 | world_size: 1 # total number of distributed processes 66 | gpu_backend: nccl # type of torch distributed backend 67 | gpu_verbose_level: 0 68 | master_ip: localhost 69 | master_port: 10000 70 | seed: 123 71 | 72 | 73 | # Optimization & training 74 | batch_size: 32 75 | batch_type: sents 76 | normalization: sents 77 | accum_count: [2] # Update weights every X batches 78 | accum_steps: [0] # steps at which accum counts value changes 79 | valid_steps: 500 # run models on validation set every X steps 80 | train_steps: 30000 81 | optim: adam 82 | max_grad_norm: 5 83 | dropout: .5 84 | adam_beta1: 0.9 85 | adam_beta2: 0.999 86 | label_smoothing: 0.0 87 | average_decay: 0 88 | average_every: 1 89 | 90 | # Learning rate 91 | learning_rate: 0.001 92 | learning_rate_decay: 0.5 # lr *= lr_decay 93 | start_decay_step: 5000 94 | decay_steps: 10000 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from onmt.bin.train import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | -------------------------------------------------------------------------------- /translate.cfg: -------------------------------------------------------------------------------- 1 | data_type: text 2 | src: data/test_input.txt 3 | output: experiments/exp-1/gens/test/predictions_30K.txt 4 | model: experiments/exp-1/models/model_step_30000.pt 5 | dynamic_dict: true 6 | seed: 829 # default random seed of OpenNMT 7 | beam_size: 10 8 | min_length: 275 9 | max_length: 600 10 | block_ngram_repeat: 10 11 | log_file: experiments/exp-1/translation-log.txt 12 | batch_size: 64 13 | gpu: 0 14 | -------------------------------------------------------------------------------- /translate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from onmt.bin.translate import main 3 | 4 | 5 | if __name__ == "__main__": 6 | main() 7 | --------------------------------------------------------------------------------