├── signjoey ├── __init__.py ├── __pycache__ │ ├── Bemb.cpython-36.pyc │ ├── data.cpython-36.pyc │ ├── loss.cpython-36.pyc │ ├── batch.cpython-36.pyc │ ├── layers.cpython-36.pyc │ ├── model.cpython-36.pyc │ ├── search.cpython-36.pyc │ ├── utils.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── __main__.cpython-36.pyc │ ├── attention.cpython-36.pyc │ ├── builders.cpython-36.pyc │ ├── dataset.cpython-36.pyc │ ├── decoders.cpython-36.pyc │ ├── encoders.cpython-36.pyc │ ├── helpers.cpython-36.pyc │ ├── metrics.cpython-36.pyc │ ├── training.cpython-36.pyc │ ├── compression.cpython-36.pyc │ ├── embeddings.cpython-36.pyc │ ├── prediction.cpython-36.pyc │ ├── vocabulary.cpython-36.pyc │ ├── initialization.cpython-36.pyc │ ├── local_attention.cpython-36.pyc │ ├── EnsembleTransformer.cpython-36.pyc │ └── transformer_layers.cpython-36.pyc ├── external_metrics │ ├── __pycache__ │ │ ├── sacrebleu.cpython-36.pyc │ │ ├── sacrebleu.cpython-37.pyc │ │ ├── mscoco_rouge.cpython-36.pyc │ │ └── mscoco_rouge.cpython-37.pyc │ └── mscoco_rouge.py ├── phoenix_utils │ ├── __pycache__ │ │ ├── phoenix_cleanup.cpython-36.pyc │ │ └── phoenix_cleanup.cpython-37.pyc │ └── phoenix_cleanup.py ├── __main__.py ├── dataset.py ├── loss.py ├── EnsembleTransformer.py ├── compression.py ├── batch.py ├── utils.py ├── local_attention.py ├── initialization.py ├── attention.py ├── helpers.py ├── metrics.py ├── vocabulary.py ├── data.py ├── embeddings.py ├── encoders.py ├── builders.py ├── transformer_layers.py └── search.py ├── base_annotations ├── phoenix14t.dev ├── phoenix14t.test └── phoenix14t.train ├── data ├── download.sh └── gls.vocab ├── requirements.txt ├── README.md └── configs ├── extest.yaml ├── example.yaml ├── example2.yaml └── exampleEnsemble.yaml /signjoey/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /base_annotations/phoenix14t.dev: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/base_annotations/phoenix14t.dev -------------------------------------------------------------------------------- /base_annotations/phoenix14t.test: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/base_annotations/phoenix14t.test -------------------------------------------------------------------------------- /base_annotations/phoenix14t.train: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/base_annotations/phoenix14t.train -------------------------------------------------------------------------------- /signjoey/__pycache__/Bemb.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/Bemb.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/data.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/data.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/batch.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/layers.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/search.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/search.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/__main__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/__main__.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/builders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/builders.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/decoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/decoders.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/encoders.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/encoders.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/helpers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/helpers.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/training.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/training.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/compression.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/compression.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/embeddings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/embeddings.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/prediction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/prediction.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/vocabulary.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/vocabulary.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/initialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/initialization.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/local_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/local_attention.cpython-36.pyc -------------------------------------------------------------------------------- /data/download.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | wget "http://cihancamgoz.com/files/cvpr2020/phoenix14t.pami0.train" 3 | wget "http://cihancamgoz.com/files/cvpr2020/phoenix14t.pami0.dev" 4 | wget "http://cihancamgoz.com/files/cvpr2020/phoenix14t.pami0.test" 5 | -------------------------------------------------------------------------------- /signjoey/__pycache__/EnsembleTransformer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/EnsembleTransformer.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/__pycache__/transformer_layers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/__pycache__/transformer_layers.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/external_metrics/__pycache__/sacrebleu.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/external_metrics/__pycache__/sacrebleu.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/external_metrics/__pycache__/sacrebleu.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/external_metrics/__pycache__/sacrebleu.cpython-37.pyc -------------------------------------------------------------------------------- /signjoey/external_metrics/__pycache__/mscoco_rouge.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/external_metrics/__pycache__/mscoco_rouge.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/external_metrics/__pycache__/mscoco_rouge.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/external_metrics/__pycache__/mscoco_rouge.cpython-37.pyc -------------------------------------------------------------------------------- /signjoey/phoenix_utils/__pycache__/phoenix_cleanup.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/phoenix_utils/__pycache__/phoenix_cleanup.cpython-36.pyc -------------------------------------------------------------------------------- /signjoey/phoenix_utils/__pycache__/phoenix_cleanup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avoskou/Stochastic-Transformer-Networks-with-Linear-Competing-Units-Application-to-end-to-end-SL-Translatio/HEAD/signjoey/phoenix_utils/__pycache__/phoenix_cleanup.cpython-37.pyc -------------------------------------------------------------------------------- /signjoey/__main__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import sys 5 | from signjoey.training import train 6 | from signjoey.prediction import test 7 | 8 | sys.path.append("/vol/research/extol/personal/cihan/code/SignJoey") 9 | 10 | 11 | def main(): 12 | ap = argparse.ArgumentParser("Joey NMT") 13 | 14 | ap.add_argument("mode", choices=["train", "test"], help="train a model or test") 15 | 16 | ap.add_argument("config_path", type=str, help="path to YAML config file") 17 | 18 | ap.add_argument("--ckpt", type=str, help="checkpoint for prediction") 19 | 20 | ap.add_argument( 21 | "--output_path", type=str, help="path for saving translation output" 22 | ) 23 | ap.add_argument("--gpu_id", type=str, default="0", help="gpu to run your job on") 24 | args = ap.parse_args() 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 27 | 28 | if args.mode == "train": 29 | train(cfg_file=args.config_path) 30 | elif args.mode == "test": 31 | test(cfg_file=args.config_path, ckpt=args.ckpt, output_path=args.output_path) 32 | else: 33 | raise ValueError("Unknown mode") 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | asn1crypto==1.2.0 3 | astor==0.8.1 4 | attrs==19.3.0 5 | cachetools==4.0.0 6 | certifi==2019.11.28 7 | cffi==1.13.2 8 | chardet==3.0.4 9 | #conda==4.8.0 10 | #conda-package-handling==1.6.0 11 | cryptography==2.8 12 | cycler==0.10.0 13 | gast==0.2.2 14 | google-auth==1.10.2 15 | google-auth-oauthlib==0.4.1 16 | google-pasta==0.1.8 17 | grpcio==1.26.0 18 | h5py==2.10.0 19 | idna==2.8 20 | importlib-metadata==1.5.0 21 | joblib==0.14.1 22 | Keras-Applications==1.0.8 23 | Keras-Preprocessing==1.1.0 24 | keyboard==0.13.5 25 | kiwisolver==1.1.0 26 | Markdown==3.1.1 27 | matplotlib==3.1.2 28 | more-itertools==8.2.0 29 | natsort==7.0.1 30 | numpy==1.18.1 31 | oauthlib==3.1.0 32 | opencv-python==4.2.0.32 33 | opt-einsum==3.1.0 34 | packaging==20.3 35 | pandas==1.0.3 36 | Pillow==7.0.0 37 | plotly==4.5.0 38 | pluggy==0.13.1 39 | portalocker==1.5.2 40 | protobuf==3.11.2 41 | py==1.8.1 42 | pyasn1==0.4.8 43 | pyasn1-modules==0.2.8 44 | pycosat==0.6.3 45 | pycparser==2.19 46 | pyOpenSSL==19.1.0 47 | pyparsing==2.4.6 48 | PySocks==1.7.1 49 | pytest==5.4.0 50 | python-dateutil==2.8.1 51 | pytz==2019.3 52 | PyYAML==5.3 53 | requests==2.22.0 54 | requests-oauthlib==1.3.0 55 | retrying==1.3.3 56 | rsa==4.0 57 | ruamel-yaml==0.15.87 58 | sacrebleu==1.4.4 59 | scikit-learn==0.22.2.post1 60 | scipy==1.4.1 61 | sentencepiece==0.1.85 62 | six==1.13.0 63 | tensorboard==2.1.0 64 | tensorflow==2.1.0 65 | tensorflow-estimator==2.1.0 66 | termcolor==1.1.0 67 | torch==1.7.0 68 | torchtext==0.8.0 69 | torchvision==0.8.1 70 | tqdm==4.40.2 71 | typing==3.7.4.1 72 | urllib3==1.25.7 73 | #warmup-scheduler==0.1.1 74 | wcwidth==0.1.8 75 | Werkzeug==0.16.0 76 | wrapt==1.11.2 77 | xmltodict==0.12.0 78 | zipp==3.1.0 79 | 80 | -------------------------------------------------------------------------------- /signjoey/external_metrics/mscoco_rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : mscoco_rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | 11 | def my_lcs(string, sub): 12 | """ 13 | Calculates longest common subsequence for a pair of tokenized strings 14 | :param string : list of str : tokens from a string split using whitespace 15 | :param sub : list of str : shorter string, also split using whitespace 16 | :returns: length (list of int): length of the longest common subsequence between the two strings 17 | 18 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 19 | """ 20 | if len(string) < len(sub): 21 | sub, string = string, sub 22 | 23 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 24 | 25 | for j in range(1, len(sub) + 1): 26 | for i in range(1, len(string) + 1): 27 | if string[i - 1] == sub[j - 1]: 28 | lengths[i][j] = lengths[i - 1][j - 1] + 1 29 | else: 30 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 31 | 32 | return lengths[len(string)][len(sub)] 33 | 34 | 35 | def calc_score(hypotheses, references, beta=1.2): 36 | """ 37 | Compute ROUGE-L score given one candidate and references for an image 38 | :param hypotheses: str : candidate sentence to be evaluated 39 | :param references: list of str : COCO reference sentences for the particular image to be evaluated 40 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 41 | """ 42 | assert len(hypotheses) == 1 43 | assert len(references) > 0 44 | prec = [] 45 | rec = [] 46 | 47 | # split into tokens 48 | token_c = hypotheses[0].split(" ") 49 | 50 | for reference in references: 51 | # split into tokens 52 | token_r = reference.split(" ") 53 | # compute the longest common subsequence 54 | lcs = my_lcs(token_r, token_c) 55 | prec.append(lcs / float(len(token_c))) 56 | rec.append(lcs / float(len(token_r))) 57 | 58 | prec_max = max(prec) 59 | rec_max = max(rec) 60 | 61 | if prec_max != 0 and rec_max != 0: 62 | score = ((1 + beta ** 2) * prec_max * rec_max) / float( 63 | rec_max + beta ** 2 * prec_max 64 | ) 65 | else: 66 | score = 0.0 67 | return score 68 | -------------------------------------------------------------------------------- /signjoey/dataset.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Data module 4 | """ 5 | from torchtext import data 6 | from torchtext.data import Field, RawField 7 | from typing import List, Tuple 8 | import pickle 9 | import gzip 10 | import torch 11 | 12 | 13 | def load_dataset_file(filename): 14 | with gzip.open(filename, "rb") as f: 15 | loaded_object = pickle.load(f) 16 | return loaded_object 17 | 18 | 19 | class SignTranslationDataset(data.Dataset): 20 | """Defines a dataset for machine translation.""" 21 | 22 | @staticmethod 23 | def sort_key(ex): 24 | return data.interleave_keys(len(ex.sgn), len(ex.txt)) 25 | 26 | def __init__( 27 | self, 28 | path: str, 29 | fields: Tuple[RawField, RawField, Field, Field, Field], 30 | **kwargs 31 | ): 32 | """Create a SignTranslationDataset given paths and fields. 33 | 34 | Arguments: 35 | path: Common prefix of paths to the data files for both languages. 36 | exts: A tuple containing the extension to path for each language. 37 | fields: A tuple containing the fields that will be used for data 38 | in each language. 39 | Remaining keyword arguments: Passed to the constructor of 40 | data.Dataset. 41 | """ 42 | if not isinstance(fields[0], (tuple, list)): 43 | fields = [ 44 | ("sequence", fields[0]), 45 | ("signer", fields[1]), 46 | ("sgn", fields[2]), 47 | ("gls", fields[3]), 48 | ("txt", fields[4]), 49 | ] 50 | 51 | if not isinstance(path, list): 52 | path = [path] 53 | 54 | samples = {} 55 | for annotation_file in path: 56 | tmp = load_dataset_file(annotation_file) 57 | for s in tmp: 58 | seq_id = s["name"] 59 | if seq_id in samples: 60 | assert samples[seq_id]["name"] == s["name"] 61 | assert samples[seq_id]["signer"] == s["signer"] 62 | assert samples[seq_id]["gloss"] == s["gloss"] 63 | assert samples[seq_id]["text"] == s["text"] 64 | samples[seq_id]["sign"] = torch.cat( 65 | [samples[seq_id]["sign"], s["sign"]], axis=1 66 | ) 67 | else: 68 | samples[seq_id] = { 69 | "name": s["name"], 70 | "signer": s["signer"], 71 | "gloss": s["gloss"], 72 | "text": s["text"], 73 | "sign": s["sign"], 74 | } 75 | 76 | examples = [] 77 | for s in samples: 78 | sample = samples[s] 79 | examples.append( 80 | data.Example.fromlist( 81 | [ 82 | sample["name"], 83 | sample["signer"], 84 | # This is for numerical stability 85 | sample["sign"] + 1e-8, 86 | sample["gloss"].strip(), 87 | sample["text"].strip(), 88 | ], 89 | fields, 90 | ) 91 | ) 92 | super().__init__(examples, fields, **kwargs) 93 | -------------------------------------------------------------------------------- /signjoey/loss.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Module to implement training loss 4 | """ 5 | 6 | import torch 7 | from torch import nn, Tensor 8 | from torch.autograd import Variable 9 | 10 | 11 | class XentLoss(nn.Module): 12 | """ 13 | Cross-Entropy Loss with optional label smoothing 14 | """ 15 | 16 | def __init__(self, pad_index: int, smoothing: float = 0.0): 17 | super(XentLoss, self).__init__() 18 | self.smoothing = smoothing 19 | self.pad_index = pad_index 20 | if self.smoothing <= 0.0: 21 | # standard xent loss 22 | self.criterion = nn.NLLLoss(ignore_index=self.pad_index, reduction="sum") 23 | else: 24 | # custom label-smoothed loss, computed with KL divergence loss 25 | self.criterion = nn.KLDivLoss(reduction="sum") 26 | 27 | def _smooth_targets(self, targets: Tensor, vocab_size: int): 28 | """ 29 | Smooth target distribution. All non-reference words get uniform 30 | probability mass according to "smoothing". 31 | 32 | :param targets: target indices, batch*seq_len 33 | :param vocab_size: size of the output vocabulary 34 | :return: smoothed target distributions, batch*seq_len x vocab_size 35 | """ 36 | # batch*seq_len x vocab_size 37 | smooth_dist = targets.new_zeros((targets.size(0), vocab_size)).float() 38 | # fill distribution uniformly with smoothing 39 | smooth_dist.fill_(self.smoothing / (vocab_size - 2)) 40 | # assign true label the probability of 1-smoothing ("confidence") 41 | smooth_dist.scatter_(1, targets.unsqueeze(1).data, 1.0 - self.smoothing) 42 | # give padding probability of 0 everywhere 43 | smooth_dist[:, self.pad_index] = 0 44 | # masking out padding area (sum of probabilities for padding area = 0) 45 | padding_positions = torch.nonzero(targets.data == self.pad_index) 46 | # pylint: disable=len-as-condition 47 | if len(padding_positions) > 0: 48 | smooth_dist.index_fill_(0, padding_positions.squeeze(), 0.0) 49 | return Variable(smooth_dist, requires_grad=False) 50 | 51 | # pylint: disable=arguments-differ 52 | def forward(self, log_probs, targets): 53 | """ 54 | Compute the cross-entropy between logits and targets. 55 | 56 | If label smoothing is used, target distributions are not one-hot, but 57 | "1-smoothing" for the correct target token and the rest of the 58 | probability mass is uniformly spread across the other tokens. 59 | 60 | :param log_probs: log probabilities as predicted by model 61 | :param targets: target indices 62 | :return: 63 | """ 64 | if self.smoothing > 0: 65 | targets = self._smooth_targets( 66 | targets=targets.contiguous().view(-1), vocab_size=log_probs.size(-1) 67 | ) 68 | # targets: distributions with batch*seq_len x vocab_size 69 | assert ( 70 | log_probs.contiguous().view(-1, log_probs.size(-1)).shape 71 | == targets.shape 72 | ) 73 | else: 74 | # targets: indices with batch*seq_len 75 | targets = targets.contiguous().view(-1) 76 | loss = self.criterion( 77 | log_probs.contiguous().view(-1, log_probs.size(-1)), targets 78 | ) 79 | return loss 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stochastic Transformer Networks with Linear Competing Units: Application to end-to-end Sign Language Translation 2 | 3 | 4 | Automating sign language translation (SLT) is a challenging 5 | real-world application. Despite its societal importance, 6 | though, research progress in the field remains rather 7 | poor. Crucially, existing methods that yield viable performance 8 | necessitate the availability of laborious to obtain 9 | gloss sequence groundtruth. In this paper, we attenuate 10 | this need, by introducing an end-to-end SLT model that does 11 | not entail explicit use of glosses; the model only needs text 12 | groundtruth. This is in stark contrast to existing end-to- 13 | end models that use gloss sequence groundtruth, either in 14 | the form of a modality that is recognized at an intermedi- 15 | ate model stage, or in the form of a parallel output process, 16 | jointly trained with the SLT model. Our approach constitutes 17 | a Transformer network with a novel type of layers that 18 | combines: (i) local winner-takes-all (LWTA) layers with 19 | stochastic winner sampling, instead of conventional ReLU 20 | layers, (ii) stochastic weights with posterior distributions 21 | estimated via variational inference, and (iii) a weight com- 22 | pression technique at inference time that exploits estimated 23 | posterior variance to perform massive, almost lossless com- 24 | pression. We demonstrate that our approach can reach the 25 | currently best reported BLEU-4 score on the PHOENIX 26 | 2014T benchmark, but without making use of glosses for 27 | model training, and with a memory footprint reduced by 28 | more than 70% 29 | 30 | The code is based on: 31 | 1. Sign Language Transformers: Joint End-to-end Sign Language Recognition and Translation. 32 | 2. Joey NMT (https://github.com/joeynmt/joeynmt) 33 | 3. Nonparametric Bayesian Deep Networks with Local Competition 34 | 4. Bayesian Compression for Deep Learning 35 | 36 | 37 | ## Reference 38 | 39 | Please cite : 40 | 41 | @inproceedings{voskou2021stochastic, 42 | title={Stochastic transformer networks with linear competing units: Application to end-to-end sl translation}, 43 | author={Voskou, Andreas and Panousis, Konstantinos P and Kosmopoulos, Dimitrios and Metaxas, Dimitris N and Chatzis, Sotirios}, 44 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 45 | pages={11946--11955}, 46 | year={2021} 47 | } 48 | 49 | 50 | ## Requirements 51 | * python 3.6+ 52 | * Download the feature files using the `data/download.sh` script. 53 | * Install required packages using the `requirements.txt` file. 54 | 55 | `pip install -r requirements.txt` 56 | 57 | Tested on a single GPU (not tested on CPU or multiple GPUs ). 58 | 59 | 60 | 61 | ## Usage 62 | 63 | To train a model: 64 | 65 | `python -m signjoey train configs/example.yaml` 66 | 67 | To test an excisting model: 68 | 69 | `python -m signjoey test configs/example.yaml` 70 | 71 | 72 | 73 | Note that the default data directory is `./data`. If you download them to somewhere else, you need to update the `data_path` parameters in your config file. 74 | 75 | 76 | 77 | ## Acknowledgement 78 | This research was partially supported by the ResearchPromotion Foundation of Cyprus, through the grant: INTERNATIONAL/USA/0118/0037 (Dimitrios Kosmopoulos, Dimitris Metaxas), and the European Union’s Horizon2020 research and innovation program, under grant agreement No 872139, project aiD (Andreas Voskou, Sotirios Chatzis) 79 | -------------------------------------------------------------------------------- /configs/extest.yaml: -------------------------------------------------------------------------------- 1 | name: example_experiment 2 | data: 3 | data_path: ./data/ 4 | version: phoenix_2014_trans 5 | sgn: sign 6 | txt: text 7 | gls: gloss 8 | train: phoenix14t.pami0.train 9 | dev: phoenix14t.pami0.dev 10 | test: phoenix14t.pami0.test 11 | feature_size: 1024 12 | level: word 13 | txt_lowercase: true 14 | max_sent_length: 400 15 | random_train_subset: -1 16 | random_dev_subset: -1 17 | gls_vocab: ./data/gls.vocab 18 | txt_vocab: ./data/txt.vocab 19 | batch_size: 32 20 | testing: 21 | translation_beam_sizes: 22 | - 1 23 | - 2 24 | - 3 25 | - 4 26 | - 5 27 | - 6 28 | translation_beam_alphas: 29 | - -1 30 | - 0 31 | - 1 32 | - 2 33 | training: 34 | # load_model: "./SavedModels/example/13400.ckpt" 35 | reset_best_ckpt: true 36 | reset_scheduler: true 37 | reset_optimizer: true 38 | random_seed: 42 39 | model_dir: "./SavedModels/example5" 40 | recognition_loss_weight: 0 41 | translation_loss_weight: 1.0 42 | kl_weight: 1 43 | eval_metric: bleu 44 | optimizer: adam 45 | learning_rate: 0.001 46 | batch_size: 32 47 | eval_batch_size: 32 48 | num_valid_log: 5 49 | epochs: 500 50 | early_stopping_metric: eval_metric 51 | batch_type: sentence 52 | translation_normalization: batch 53 | eval_recognition_beam_size: 1 54 | eval_translation_beam_size: 1 55 | eval_translation_beam_alpha: 0 56 | overwrite: true 57 | shuffle: true 58 | use_cuda: true 59 | translation_max_output_length: 30 60 | keep_last_ckpts: 1 61 | batch_multiplier: 1 62 | logging_freq: 20 63 | validation_freq: 80 64 | betas: 65 | - 0.9 66 | - 0.998 67 | scheduling: plateau 68 | learning_rate_min: 0.00001 69 | patience: 5 70 | decrease_factor: 0.8 71 | label_smoothing: 0.0 72 | model: 73 | gloss_input: false 74 | initializer: xavier 75 | bias_initializer: zeros 76 | init_gain: 1.0 77 | embed_initializer: xavier 78 | embed_init_gain: 1.0 79 | tied_softmax: false 80 | simplified_inference: True 81 | inference_sample_size: 4 82 | encoder: 83 | skip_encoder: false 84 | type: transformer 85 | bayesian_attention: true 86 | bayesian_feedforward: true 87 | ibp: false 88 | activation: lwta 89 | lwta_competitors: 4 90 | num_layers: 2 91 | num_heads: 8 92 | embeddings: 93 | embedding_dim: 512 94 | scale: false 95 | bayesian: true 96 | ibp: false 97 | dropout: 0.2 98 | norm_type: batch 99 | activation_type: lwta 100 | lwta_competitors: 4 101 | hidden_size: 512 102 | ff_size: 2048 103 | dropout: 0.2 104 | decoder: 105 | type: transformer 106 | num_layers: 2 107 | num_heads: 8 108 | bayesian_attention: true 109 | bayesian_feedforward: true 110 | bayesian_output: true 111 | ibp: false 112 | activation: lwta 113 | lwta_competitors: 4 114 | embeddings: 115 | embedding_dim: 512 116 | scale: False 117 | bayesian: False 118 | dropout: 0.2 119 | norm_type: batch 120 | hidden_size: 512 121 | ff_size: 2048 122 | dropout: 0.2 123 | -------------------------------------------------------------------------------- /configs/example.yaml: -------------------------------------------------------------------------------- 1 | name: example_experiment 2 | data: 3 | data_path: ./data/ 4 | version: phoenix_2014_trans 5 | sgn: sign 6 | txt: text 7 | gls: gloss 8 | train: phoenix14t.pami0.train 9 | dev: phoenix14t.pami0.dev 10 | test: phoenix14t.pami0.test 11 | feature_size: 1024 12 | level: word 13 | txt_lowercase: true 14 | max_sent_length: 400 15 | random_train_subset: -1 16 | random_dev_subset: -1 17 | gls_vocab: ./data/gls.vocab 18 | txt_vocab: ./data/txt.vocab 19 | batch_size: 32 20 | testing: 21 | translation_beam_sizes: 22 | - 1 23 | - 2 24 | - 3 25 | - 4 26 | - 5 27 | - 6 28 | translation_beam_alphas: 29 | - -1 30 | - 0 31 | - 1 32 | - 2 33 | - 3 34 | training: 35 | # load_model: "./SavedModels/example/13400.ckpt" 36 | reset_best_ckpt: true 37 | reset_scheduler: true 38 | reset_optimizer: true 39 | random_seed: 44 40 | model_dir: "./SavedModels/example22" 41 | recognition_loss_weight: 0 42 | translation_loss_weight: 1.0 43 | kl_weight: 1 44 | eval_metric: bleu 45 | optimizer: adam 46 | learning_rate: 0.001 47 | batch_size: 32 48 | eval_batch_size: 32 49 | num_valid_log: 5 50 | epochs: 500 51 | early_stopping_metric: eval_metric 52 | batch_type: sentence 53 | translation_normalization: batch 54 | eval_recognition_beam_size: 1 55 | eval_translation_beam_size: 1 56 | eval_translation_beam_alpha: 0 57 | overwrite: true 58 | shuffle: true 59 | use_cuda: true 60 | translation_max_output_length: 30 61 | keep_last_ckpts: 1 62 | batch_multiplier: 1 63 | logging_freq: 20 64 | validation_freq: 80 65 | betas: 66 | - 0.9 67 | - 0.998 68 | scheduling: plateau 69 | learning_rate_min: 0.00001 70 | patience: 6 71 | decrease_factor: 0.8 72 | label_smoothing: 0.0 73 | model: 74 | gloss_input: false 75 | initializer: xavier 76 | bias_initializer: zeros 77 | init_gain: 1.0 78 | embed_initializer: xavier 79 | embed_init_gain: 1.0 80 | tied_softmax: false 81 | simplified_inference: true 82 | inference_sample_size: 4 83 | encoder: 84 | skip_encoder: false 85 | type: transformer 86 | bayesian_attention: true 87 | bayesian_feedforward: true 88 | ibp: false 89 | activation: lwta 90 | lwta_competitors: 4 91 | num_layers: 2 92 | num_heads: 8 93 | embeddings: 94 | embedding_dim: 512 95 | scale: false 96 | bayesian: true 97 | ibp: false 98 | dropout: 0.2 99 | norm_type: batch 100 | activation_type: lwta 101 | lwta_competitors: 4 102 | hidden_size: 512 103 | ff_size: 2048 104 | dropout: 0.2 105 | decoder: 106 | type: transformer 107 | num_layers: 2 108 | num_heads: 8 109 | bayesian_attention: true 110 | bayesian_feedforward: true 111 | bayesian_output: true 112 | ibp: false 113 | activation: lwta 114 | lwta_competitors: 4 115 | embeddings: 116 | embedding_dim: 512 117 | scale: False 118 | bayesian: False 119 | dropout: 0.2 120 | norm_type: batch 121 | hidden_size: 512 122 | ff_size: 2048 123 | dropout: 0.2 124 | -------------------------------------------------------------------------------- /configs/example2.yaml: -------------------------------------------------------------------------------- 1 | name: example_experiment 2 | data: 3 | data_path: ./data/ 4 | version: phoenix_2014_trans 5 | sgn: sign 6 | txt: text 7 | gls: gloss 8 | train: phoenix14t.pami0.train 9 | dev: phoenix14t.pami0.dev 10 | test: phoenix14t.pami0.test 11 | feature_size: 1024 12 | level: word 13 | txt_lowercase: true 14 | max_sent_length: 400 15 | random_train_subset: -1 16 | random_dev_subset: -1 17 | gls_vocab: ./data/gls.vocab 18 | txt_vocab: ./data/txt.vocab 19 | batch_size: 32 20 | testing: 21 | translation_beam_sizes: 22 | - 1 23 | - 2 24 | - 3 25 | - 4 26 | - 5 27 | - 6 28 | translation_beam_alphas: 29 | - -1 30 | - 0 31 | - 1 32 | - 2 33 | - 3 34 | training: 35 | # load_model: "./SavedModels/example11/19920.ckpt" 36 | reset_best_ckpt: true 37 | reset_scheduler: true 38 | reset_optimizer: true 39 | random_seed: 54 40 | model_dir: "./SavedModels/example23" 41 | recognition_loss_weight: 0 42 | translation_loss_weight: 1.0 43 | kl_weight: 1 44 | eval_metric: bleu 45 | optimizer: adam 46 | learning_rate: 0.001 47 | batch_size: 32 48 | eval_batch_size: 32 49 | num_valid_log: 5 50 | epochs: 500 51 | early_stopping_metric: eval_metric 52 | batch_type: sentence 53 | translation_normalization: batch 54 | eval_recognition_beam_size: 1 55 | eval_translation_beam_size: 1 56 | eval_translation_beam_alpha: 0 57 | overwrite: true 58 | shuffle: true 59 | use_cuda: true 60 | translation_max_output_length: 30 61 | keep_last_ckpts: 1 62 | batch_multiplier: 1 63 | logging_freq: 20 64 | validation_freq: 80 65 | betas: 66 | - 0.9 67 | - 0.998 68 | scheduling: plateau 69 | learning_rate_min: 0.00001 70 | patience: 5 71 | decrease_factor: 0.8 72 | label_smoothing: 0.0 73 | model: 74 | gloss_input: false 75 | initializer: xavier 76 | bias_initializer: zeros 77 | init_gain: 1.0 78 | embed_initializer: xavier 79 | embed_init_gain: 1.0 80 | tied_softmax: false 81 | simplified_inference: true 82 | inference_sample_size: 4 83 | encoder: 84 | skip_encoder: false 85 | type: transformer 86 | bayesian_attention: true 87 | bayesian_feedforward: true 88 | ibp: false 89 | activation: lwta 90 | lwta_competitors: 4 91 | num_layers: 2 92 | num_heads: 8 93 | embeddings: 94 | embedding_dim: 512 95 | scale: false 96 | bayesian: true 97 | ibp: false 98 | dropout: 0.2 99 | norm_type: batch 100 | activation_type: lwta 101 | lwta_competitors: 4 102 | hidden_size: 512 103 | ff_size: 2048 104 | dropout: 0.2 105 | decoder: 106 | type: transformer 107 | num_layers: 2 108 | num_heads: 8 109 | bayesian_attention: true 110 | bayesian_feedforward: true 111 | bayesian_output: true 112 | ibp: false 113 | activation: lwta 114 | lwta_competitors: 4 115 | embeddings: 116 | embedding_dim: 512 117 | scale: False 118 | bayesian: False 119 | dropout: 0.2 120 | norm_type: batch 121 | hidden_size: 512 122 | ff_size: 2048 123 | dropout: 0.2 124 | -------------------------------------------------------------------------------- /configs/exampleEnsemble.yaml: -------------------------------------------------------------------------------- 1 | name: example_experiment 2 | data: 3 | data_path: ./data/ 4 | version: phoenix_2014_trans 5 | sgn: sign 6 | txt: text 7 | gls: gloss 8 | train: phoenix14t.pami0.train 9 | dev: phoenix14t.pami0.dev 10 | test: phoenix14t.pami0.test 11 | feature_size: 1024 12 | level: word 13 | txt_lowercase: true 14 | max_sent_length: 400 15 | random_train_subset: -1 16 | random_dev_subset: -1 17 | gls_vocab: ./data/gls.vocab 18 | txt_vocab: ./data/txt.vocab 19 | batch_size: 32 20 | testing: 21 | translation_beam_sizes: 22 | - 1 23 | - 2 24 | - 3 25 | - 4 26 | - 5 27 | - 6 28 | translation_beam_alphas: 29 | - -1 30 | - 0 31 | - 1 32 | - 2 33 | - 3 34 | training: 35 | # load_model: "./SavedModels/example/13400.ckpt" 36 | reset_best_ckpt: true 37 | reset_scheduler: true 38 | reset_optimizer: true 39 | random_seed: 44 40 | model_dir: 41 | - "./SavedModels/example22" 42 | - "./SavedModels/example21" 43 | - "./SavedModels/example20" 44 | - "./SavedModels/example19" 45 | - "./SavedModels/example18" 46 | recognition_loss_weight: 0 47 | translation_loss_weight: 1.0 48 | kl_weight: 1 49 | eval_metric: bleu 50 | optimizer: adam 51 | learning_rate: 0.001 52 | batch_size: 32 53 | eval_batch_size: 32 54 | num_valid_log: 5 55 | epochs: 500 56 | early_stopping_metric: eval_metric 57 | batch_type: sentence 58 | translation_normalization: batch 59 | eval_recognition_beam_size: 1 60 | eval_translation_beam_size: 1 61 | eval_translation_beam_alpha: 0 62 | overwrite: true 63 | shuffle: true 64 | use_cuda: true 65 | translation_max_output_length: 30 66 | keep_last_ckpts: 1 67 | batch_multiplier: 1 68 | logging_freq: 20 69 | validation_freq: 80 70 | betas: 71 | - 0.9 72 | - 0.998 73 | scheduling: plateau 74 | learning_rate_min: 0.00001 75 | patience: 6 76 | decrease_factor: 0.8 77 | label_smoothing: 0.0 78 | model: 79 | gloss_input: false 80 | initializer: xavier 81 | bias_initializer: zeros 82 | init_gain: 1.0 83 | embed_initializer: xavier 84 | embed_init_gain: 1.0 85 | tied_softmax: false 86 | simplified_inference: true 87 | inference_sample_size: 2 88 | encoder: 89 | skip_encoder: false 90 | type: transformer 91 | bayesian_attention: true 92 | bayesian_feedforward: true 93 | ibp: false 94 | activation: lwta 95 | lwta_competitors: 4 96 | num_layers: 2 97 | num_heads: 8 98 | embeddings: 99 | embedding_dim: 512 100 | scale: false 101 | bayesian: true 102 | ibp: false 103 | dropout: 0.2 104 | norm_type: batch 105 | activation_type: lwta 106 | lwta_competitors: 4 107 | hidden_size: 512 108 | ff_size: 2048 109 | dropout: 0.2 110 | decoder: 111 | type: transformer 112 | num_layers: 2 113 | num_heads: 8 114 | bayesian_attention: true 115 | bayesian_feedforward: true 116 | bayesian_output: true 117 | ibp: false 118 | activation: lwta 119 | lwta_competitors: 4 120 | embeddings: 121 | embedding_dim: 512 122 | scale: False 123 | bayesian: False 124 | dropout: 0.2 125 | norm_type: batch 126 | hidden_size: 512 127 | ff_size: 2048 128 | dropout: 0.2 129 | -------------------------------------------------------------------------------- /signjoey/phoenix_utils/phoenix_cleanup.py: -------------------------------------------------------------------------------- 1 | from itertools import groupby 2 | import re 3 | 4 | 5 | def clean_phoenix_2014(prediction): 6 | # TODO (Cihan): Python version of the evaluation script provided 7 | # by the phoenix2014 dataset (not phoenix2014t). This should work 8 | # as intended but further tests are required to make sure it is 9 | # consistent with the bash/sed based clean up script. 10 | 11 | prediction = prediction.strip() 12 | prediction = re.sub(r"loc-", "", prediction) 13 | prediction = re.sub(r"cl-", "", prediction) 14 | prediction = re.sub(r"qu-", "", prediction) 15 | prediction = re.sub(r"poss-", "", prediction) 16 | prediction = re.sub(r"lh-", "", prediction) 17 | prediction = re.sub(r"S0NNE", "SONNE", prediction) 18 | prediction = re.sub(r"HABEN2", "HABEN", prediction) 19 | prediction = re.sub(r"__EMOTION__", "", prediction) 20 | prediction = re.sub(r"__PU__", "", prediction) 21 | prediction = re.sub(r"__LEFTHAND__", "", prediction) 22 | prediction = re.sub(r"WIE AUSSEHEN", "WIE-AUSSEHEN", prediction) 23 | prediction = re.sub(r"ZEIGEN ", "ZEIGEN-BILDSCHIRM ", prediction) 24 | prediction = re.sub(r"ZEIGEN$", "ZEIGEN-BILDSCHIRM", prediction) 25 | prediction = re.sub(r"^([A-Z]) ([A-Z][+ ])", r"\1+\2", prediction) 26 | prediction = re.sub(r"[ +]([A-Z]) ([A-Z]) ", r" \1+\2 ", prediction) 27 | prediction = re.sub(r"([ +][A-Z]) ([A-Z][ +])", r"\1+\2", prediction) 28 | prediction = re.sub(r"([ +][A-Z]) ([A-Z][ +])", r"\1+\2", prediction) 29 | prediction = re.sub(r"([ +][A-Z]) ([A-Z][ +])", r"\1+\2", prediction) 30 | prediction = re.sub(r"([ +]SCH) ([A-Z][ +])", r"\1+\2", prediction) 31 | prediction = re.sub(r"([ +]NN) ([A-Z][ +])", r"\1+\2", prediction) 32 | prediction = re.sub(r"([ +][A-Z]) (NN[ +])", r"\1+\2", prediction) 33 | prediction = re.sub(r"([ +][A-Z]) ([A-Z])$", r"\1+\2", prediction) 34 | prediction = re.sub(r"([A-Z][A-Z])RAUM", r"\1", prediction) 35 | prediction = re.sub(r"-PLUSPLUS", "", prediction) 36 | prediction = re.sub(r" +", " ", prediction) 37 | prediction = re.sub(r"(?bhnij', q, self.weights.type(q.dtype)) * self.scale 72 | return shift(emb) 73 | 74 | # main class 75 | 76 | class LocalAttention(nn.Module): 77 | def __init__(self, window_size, causal = False, look_backward = 1, look_forward = None, dropout = 0., shared_qk = False, rel_pos_emb_config = None, autopad = False): 78 | super().__init__() 79 | look_forward = default(look_forward, 0 if causal else 1) 80 | assert not (causal and look_forward > 0), 'you cannot look forward if causal' 81 | linear=nn.Linear 82 | self.window_size = window_size 83 | self.causal = causal 84 | self.look_backward = look_backward 85 | self.look_forward = look_forward 86 | self.autopad = autopad 87 | 88 | self.dropout = nn.Dropout(dropout) 89 | 90 | self.shared_qk = shared_qk 91 | 92 | self.rel_pos = None 93 | if rel_pos_emb_config is not None: 94 | dim_head, heads = rel_pos_emb_config 95 | rel_pos_length = window_size * (1 + look_forward + look_backward) 96 | self.heads = heads 97 | self.rel_pos = RelativePositionalEmbedding(dim_head, heads, rel_pos_length) 98 | self.k_layer = linear(512, 512) 99 | 100 | self.v_layer = linear(512, 512) 101 | self.q_layer = linear(512, 512) 102 | 103 | self.output_layer=nn.Linear(512,512) 104 | def forward(self, q, k, v, input_mask = None): 105 | 106 | k = self.k_layer(k) 107 | v = self.v_layer(v) 108 | q = self.q_layer(q) 109 | shape = q.shape 110 | 111 | merge_into_batch = lambda t: t.reshape(-1, *t.shape[-2:]) 112 | q, k, v = map(merge_into_batch, (q, k, v)) 113 | 114 | if self.autopad: 115 | orig_t = q.shape[1] 116 | q, k, v = map(lambda t: pad_to_multiple(t, self.window_size, dim = -2), (q, k, v)) 117 | 118 | window_size, causal, look_backward, look_forward, shared_qk = self.window_size, self.causal, self.look_backward, self.look_forward, self.shared_qk 119 | b, t, e, device, dtype = *q.shape, q.device, q.dtype 120 | assert (t % window_size) == 0, f'sequence length {t} must be divisible by window size {window_size} for local attention' 121 | 122 | windows = t // window_size 123 | 124 | if shared_qk: 125 | k = F.normalize(k, 2, dim=-1).type_as(q) 126 | 127 | ticker = torch.arange(t, device=device, dtype=dtype)[None, :] 128 | b_t = ticker.reshape(1, windows, window_size) 129 | 130 | bucket_fn = lambda t: t.reshape(b, windows, window_size, -1) 131 | bq, bk, bv = map(bucket_fn, (q, k, v)) 132 | 133 | look_around_kwargs = {'backward': look_backward, 'forward': look_forward} 134 | bk = look_around(bk, **look_around_kwargs) 135 | bv = look_around(bv, **look_around_kwargs) 136 | 137 | bq_t = b_t 138 | bq_k = look_around(b_t, **look_around_kwargs) 139 | 140 | dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (e ** -0.5) 141 | 142 | if self.rel_pos is not None: 143 | rel_attn = self.rel_pos(bq.view(-1, self.heads, *bq.shape[1:])).reshape_as(dots) 144 | dots = dots + rel_attn 145 | 146 | mask_value = max_neg_value(dots) 147 | 148 | if shared_qk: 149 | mask = bq_t[:, :, :, None] == bq_k[:, :, None, :] 150 | dots.masked_fill_(mask, TOKEN_SELF_ATTN_VALUE) 151 | del mask 152 | 153 | if causal: 154 | mask = bq_t[:, :, :, None] < bq_k[:, :, None, :] 155 | dots.masked_fill_(mask, mask_value) 156 | del mask 157 | 158 | mask = bq_k[:, :, None, :] == -1 159 | dots.masked_fill_(mask, mask_value) 160 | del mask 161 | 162 | if input_mask is not None: 163 | h = b // input_mask.shape[0] 164 | if self.autopad: 165 | input_mask = pad_to_multiple(input_mask, window_size, dim=-1, value=False) 166 | input_mask = input_mask.reshape(-1, windows, window_size) 167 | mq = mk = input_mask 168 | mk = look_around(mk, pad_value=False, **look_around_kwargs) 169 | mask = (mq[:, :, :, None] * mk[:, :, None, :]) 170 | mask = merge_dims(0, 1, expand_dim(mask, 1, h)) 171 | dots.masked_fill_(~mask, mask_value) 172 | del mask 173 | 174 | attn = dots.softmax(dim=-1) 175 | attn = self.dropout(attn) 176 | 177 | out = torch.einsum('bhij,bhje->bhie', attn, bv) 178 | out = out.reshape(-1, t, e) 179 | 180 | if self.autopad: 181 | out = out[:, :orig_t, :] 182 | out = self.output_layer(out) 183 | return out.reshape(*shape) 184 | -------------------------------------------------------------------------------- /signjoey/initialization.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | """ 4 | Implements custom initialization 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch import Tensor 12 | from torch.nn.init import _calculate_fan_in_and_fan_out 13 | 14 | 15 | def orthogonal_rnn_init_(cell: nn.RNNBase, gain: float = 1.0): 16 | """ 17 | Orthogonal initialization of recurrent weights 18 | RNN parameters contain 3 or 4 matrices in one parameter, so we slice it. 19 | """ 20 | with torch.no_grad(): 21 | for _, hh, _, _ in cell.all_weights: 22 | for i in range(0, hh.size(0), cell.hidden_size): 23 | nn.init.orthogonal_(hh.data[i : i + cell.hidden_size], gain=gain) 24 | 25 | 26 | def lstm_forget_gate_init_(cell: nn.RNNBase, value: float = 1.0) -> None: 27 | """ 28 | Initialize LSTM forget gates with `value`. 29 | 30 | :param cell: LSTM cell 31 | :param value: initial value, default: 1 32 | """ 33 | with torch.no_grad(): 34 | for _, _, ih_b, hh_b in cell.all_weights: 35 | l = len(ih_b) 36 | ih_b.data[l // 4 : l // 2].fill_(value) 37 | hh_b.data[l // 4 : l // 2].fill_(value) 38 | 39 | 40 | def xavier_uniform_n_(w: Tensor, gain: float = 1.0, n: int = 4) -> None: 41 | """ 42 | Xavier initializer for parameters that combine multiple matrices in one 43 | parameter for efficiency. This is e.g. used for GRU and LSTM parameters, 44 | where e.g. all gates are computed at the same time by 1 big matrix. 45 | 46 | :param w: parameter 47 | :param gain: default 1 48 | :param n: default 4 49 | """ 50 | with torch.no_grad(): 51 | fan_in, fan_out = _calculate_fan_in_and_fan_out(w) 52 | assert fan_out % n == 0, "fan_out should be divisible by n" 53 | fan_out //= n 54 | std = gain * math.sqrt(2.0 / (fan_in + fan_out)) 55 | a = math.sqrt(3.0) * std 56 | nn.init.uniform_(w, -a, a) 57 | 58 | 59 | # pylint: disable=too-many-branches 60 | def initialize_model(model: nn.Module, cfg: dict, txt_padding_idx: int) -> None: 61 | """ 62 | This initializes a model based on the provided config. 63 | 64 | All initializer configuration is part of the `model` section of the 65 | configuration file. 66 | For an example, see e.g. `https://github.com/joeynmt/joeynmt/ 67 | blob/master/configs/iwslt_envi_xnmt.yaml#L47` 68 | 69 | The main initializer is set using the `initializer` key. 70 | Possible values are `xavier`, `uniform`, `normal` or `zeros`. 71 | (`xavier` is the default). 72 | 73 | When an initializer is set to `uniform`, then `init_weight` sets the 74 | range for the values (-init_weight, init_weight). 75 | 76 | When an initializer is set to `normal`, then `init_weight` sets the 77 | standard deviation for the weights (with mean 0). 78 | 79 | The word embedding initializer is set using `embed_initializer` and takes 80 | the same values. The default is `normal` with `embed_init_weight = 0.01`. 81 | 82 | Biases are initialized separately using `bias_initializer`. 83 | The default is `zeros`, but you can use the same initializers as 84 | the main initializer. 85 | 86 | Set `init_rnn_orthogonal` to True if you want RNN orthogonal initialization 87 | (for recurrent matrices). Default is False. 88 | 89 | `lstm_forget_gate` controls how the LSTM forget gate is initialized. 90 | Default is `1`. 91 | 92 | :param model: model to initialize 93 | :param cfg: the model configuration 94 | :param txt_padding_idx: index of spoken language text padding token 95 | """ 96 | 97 | # defaults: xavier, embeddings: normal 0.01, biases: zeros, no orthogonal 98 | gain = float(cfg.get("init_gain", 1.0)) # for xavier 99 | init = cfg.get("initializer", "xavier") 100 | init_weight = float(cfg.get("init_weight", 0.01)) 101 | 102 | embed_init = cfg.get("embed_initializer", "normal") 103 | embed_init_weight = float(cfg.get("embed_init_weight", 0.01)) 104 | embed_gain = float(cfg.get("embed_init_gain", 1.0)) # for xavier 105 | 106 | bias_init = cfg.get("bias_initializer", "zeros") 107 | bias_init_weight = float(cfg.get("bias_init_weight", 0.01)) 108 | 109 | # pylint: disable=unnecessary-lambda, no-else-return 110 | def _parse_init(s, scale, _gain): 111 | scale = float(scale) 112 | assert scale > 0.0, "incorrect init_weight" 113 | if s.lower() == "xavier": 114 | return lambda p: nn.init.xavier_uniform_(p, gain=_gain) 115 | elif s.lower() == "uniform": 116 | return lambda p: nn.init.uniform_(p, a=-scale, b=scale) 117 | elif s.lower() == "normal": 118 | return lambda p: nn.init.normal_(p, mean=0.0, std=scale) 119 | elif s.lower() == "zeros": 120 | return lambda p: nn.init.zeros_(p) 121 | else: 122 | raise ValueError("unknown initializer") 123 | 124 | init_fn_ = _parse_init(init, init_weight, gain) 125 | try: 126 | embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain) 127 | except: 128 | #BayesianEmbedding Init external 129 | pass 130 | bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain) 131 | 132 | with torch.no_grad(): 133 | for name, p in model.named_parameters(): 134 | 135 | if "txt_embed" in name: 136 | try: 137 | if "lut" in name: 138 | embed_init_fn_(p) 139 | except: 140 | #BayesianEmbedding Init external 141 | pass 142 | 143 | elif "bias" in name: 144 | bias_init_fn_(p) 145 | 146 | elif len(p.size()) > 1: 147 | 148 | # RNNs combine multiple matrices is one, which messes up 149 | # xavier initialization 150 | if init == "xavier" and "rnn" in name: 151 | n = 1 152 | if "encoder" in name: 153 | n = 4 if isinstance(model.encoder.rnn, nn.LSTM) else 3 154 | elif "decoder" in name: 155 | n = 4 if isinstance(model.decoder.rnn, nn.LSTM) else 3 156 | xavier_uniform_n_(p.data, gain=gain, n=n) 157 | else: 158 | init_fn_(p) 159 | 160 | # zero out paddings 161 | if model.txt_embed is not None: 162 | model.txt_embed.lut.weight.data[txt_padding_idx].zero_() 163 | 164 | orthogonal = cfg.get("init_rnn_orthogonal", False) 165 | lstm_forget_gate = cfg.get("lstm_forget_gate", 1.0) 166 | 167 | # encoder rnn orthogonal initialization & LSTM forget gate 168 | if hasattr(model.encoder, "rnn"): 169 | 170 | if orthogonal: 171 | orthogonal_rnn_init_(model.encoder.rnn) 172 | 173 | if isinstance(model.encoder.rnn, nn.LSTM): 174 | lstm_forget_gate_init_(model.encoder.rnn, lstm_forget_gate) 175 | 176 | # decoder rnn orthogonal initialization & LSTM forget gate 177 | if hasattr(model.decoder, "rnn"): 178 | 179 | if orthogonal: 180 | orthogonal_rnn_init_(model.decoder.rnn) 181 | 182 | if isinstance(model.decoder.rnn, nn.LSTM): 183 | lstm_forget_gate_init_(model.decoder.rnn, lstm_forget_gate) 184 | -------------------------------------------------------------------------------- /signjoey/attention.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Attention modules 4 | """ 5 | 6 | import torch 7 | from torch import Tensor 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class AttentionMechanism(nn.Module): 13 | """ 14 | Base attention class 15 | """ 16 | 17 | def forward(self, *inputs): 18 | raise NotImplementedError("Implement this.") 19 | 20 | 21 | class BahdanauAttention(AttentionMechanism): 22 | """ 23 | Implements Bahdanau (MLP) attention 24 | 25 | Section A.1.2 in https://arxiv.org/pdf/1409.0473.pdf. 26 | """ 27 | 28 | def __init__(self, hidden_size=1, key_size=1, query_size=1): 29 | """ 30 | Creates attention mechanism. 31 | 32 | :param hidden_size: size of the projection for query and key 33 | :param key_size: size of the attention input keys 34 | :param query_size: size of the query 35 | """ 36 | 37 | super(BahdanauAttention, self).__init__() 38 | 39 | self.key_layer = nn.Linear(key_size, hidden_size, bias=False) 40 | self.query_layer = nn.Linear(query_size, hidden_size, bias=False) 41 | self.energy_layer = nn.Linear(hidden_size, 1, bias=False) 42 | 43 | self.proj_keys = None # to store projected keys 44 | self.proj_query = None # projected query 45 | 46 | # pylint: disable=arguments-differ 47 | def forward(self, query: Tensor = None, mask: Tensor = None, values: Tensor = None): 48 | """ 49 | Bahdanau MLP attention forward pass. 50 | 51 | :param query: the item (decoder state) to compare with the keys/memory, 52 | shape (batch_size, 1, decoder.hidden_size) 53 | :param mask: mask out keys position (0 in invalid positions, 1 else), 54 | shape (batch_size, 1, sgn_length) 55 | :param values: values (encoder states), 56 | shape (batch_size, sgn_length, encoder.hidden_size) 57 | :return: context vector of shape (batch_size, 1, value_size), 58 | attention probabilities of shape (batch_size, 1, sgn_length) 59 | """ 60 | self._check_input_shapes_forward(query=query, mask=mask, values=values) 61 | 62 | assert mask is not None, "mask is required" 63 | assert self.proj_keys is not None, "projection keys have to get pre-computed" 64 | 65 | # We first project the query (the decoder state). 66 | # The projected keys (the encoder states) were already pre-computated. 67 | self.compute_proj_query(query) 68 | 69 | # Calculate scores. 70 | # proj_keys: batch x sgn_len x hidden_size 71 | # proj_query: batch x 1 x hidden_size 72 | scores = self.energy_layer(torch.tanh(self.proj_query + self.proj_keys)) 73 | # scores: batch x sgn_len x 1 74 | 75 | scores = scores.squeeze(2).unsqueeze(1) 76 | # scores: batch x 1 x time 77 | 78 | # mask out invalid positions by filling the masked out parts with -inf 79 | scores = torch.where(mask, scores, scores.new_full([1], float("-inf"))) 80 | 81 | # turn scores to probabilities 82 | alphas = F.softmax(scores, dim=-1) # batch x 1 x time 83 | 84 | # the context vector is the weighted sum of the values 85 | context = alphas @ values # batch x 1 x value_size 86 | 87 | return context, alphas 88 | 89 | def compute_proj_keys(self, keys: Tensor): 90 | """ 91 | Compute the projection of the keys. 92 | Is efficient if pre-computed before receiving individual queries. 93 | 94 | :param keys: 95 | :return: 96 | """ 97 | self.proj_keys = self.key_layer(keys) 98 | 99 | def compute_proj_query(self, query: Tensor): 100 | """ 101 | Compute the projection of the query. 102 | 103 | :param query: 104 | :return: 105 | """ 106 | self.proj_query = self.query_layer(query) 107 | 108 | def _check_input_shapes_forward( 109 | self, query: torch.Tensor, mask: torch.Tensor, values: torch.Tensor 110 | ): 111 | """ 112 | Make sure that inputs to `self.forward` are of correct shape. 113 | Same input semantics as for `self.forward`. 114 | 115 | :param query: 116 | :param mask: 117 | :param values: 118 | :return: 119 | """ 120 | assert query.shape[0] == values.shape[0] == mask.shape[0] 121 | assert query.shape[1] == 1 == mask.shape[1] 122 | assert query.shape[2] == self.query_layer.in_features 123 | assert values.shape[2] == self.key_layer.in_features 124 | assert mask.shape[2] == values.shape[1] 125 | 126 | def __repr__(self): 127 | return "BahdanauAttention" 128 | 129 | 130 | class LuongAttention(AttentionMechanism): 131 | """ 132 | Implements Luong (bilinear / multiplicative) attention. 133 | 134 | Eq. 8 ("general") in http://aclweb.org/anthology/D15-1166. 135 | """ 136 | 137 | def __init__(self, hidden_size: int = 1, key_size: int = 1): 138 | """ 139 | Creates attention mechanism. 140 | 141 | :param hidden_size: size of the key projection layer, has to be equal 142 | to decoder hidden size 143 | :param key_size: size of the attention input keys 144 | """ 145 | 146 | super(LuongAttention, self).__init__() 147 | self.key_layer = nn.Linear( 148 | in_features=key_size, out_features=hidden_size, bias=False 149 | ) 150 | self.proj_keys = None # projected keys 151 | 152 | # pylint: disable=arguments-differ 153 | def forward( 154 | self, 155 | query: torch.Tensor = None, 156 | mask: torch.Tensor = None, 157 | values: torch.Tensor = None, 158 | ): 159 | """ 160 | Luong (multiplicative / bilinear) attention forward pass. 161 | Computes context vectors and attention scores for a given query and 162 | all masked values and returns them. 163 | 164 | :param query: the item (decoder state) to compare with the keys/memory, 165 | shape (batch_size, 1, decoder.hidden_size) 166 | :param mask: mask out keys position (0 in invalid positions, 1 else), 167 | shape (batch_size, 1, sgn_length) 168 | :param values: values (encoder states), 169 | shape (batch_size, sgn_length, encoder.hidden_size) 170 | :return: context vector of shape (batch_size, 1, value_size), 171 | attention probabilities of shape (batch_size, 1, sgn_length) 172 | """ 173 | self._check_input_shapes_forward(query=query, mask=mask, values=values) 174 | 175 | assert self.proj_keys is not None, "projection keys have to get pre-computed" 176 | assert mask is not None, "mask is required" 177 | 178 | # scores: batch_size x 1 x sgn_length 179 | scores = query @ self.proj_keys.transpose(1, 2) 180 | 181 | # mask out invalid positions by filling the masked out parts with -inf 182 | scores = torch.where(mask, scores, scores.new_full([1], float("-inf"))) 183 | 184 | # turn scores to probabilities 185 | alphas = F.softmax(scores, dim=-1) # batch x 1 x sgn_len 186 | 187 | # the context vector is the weighted sum of the values 188 | context = alphas @ values # batch x 1 x values_size 189 | 190 | return context, alphas 191 | 192 | def compute_proj_keys(self, keys: Tensor): 193 | """ 194 | Compute the projection of the keys and assign them to `self.proj_keys`. 195 | This pre-computation is efficiently done for all keys 196 | before receiving individual queries. 197 | 198 | :param keys: shape (batch_size, sgn_length, encoder.hidden_size) 199 | """ 200 | # proj_keys: batch x sgn_len x hidden_size 201 | self.proj_keys = self.key_layer(keys) 202 | 203 | def _check_input_shapes_forward( 204 | self, query: torch.Tensor, mask: torch.Tensor, values: torch.Tensor 205 | ): 206 | """ 207 | Make sure that inputs to `self.forward` are of correct shape. 208 | Same input semantics as for `self.forward`. 209 | 210 | :param query: 211 | :param mask: 212 | :param values: 213 | :return: 214 | """ 215 | assert query.shape[0] == values.shape[0] == mask.shape[0] 216 | assert query.shape[1] == 1 == mask.shape[1] 217 | assert query.shape[2] == self.key_layer.out_features 218 | assert values.shape[2] == self.key_layer.in_features 219 | assert mask.shape[2] == values.shape[1] 220 | 221 | def __repr__(self): 222 | return "LuongAttention" 223 | -------------------------------------------------------------------------------- /signjoey/helpers.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Collection of helper functions 4 | """ 5 | import copy 6 | import glob 7 | import os 8 | import os.path 9 | import errno 10 | import shutil 11 | import random 12 | import logging 13 | from sys import platform 14 | from logging import Logger 15 | from typing import Callable, Optional 16 | import numpy as np 17 | 18 | import torch 19 | from torch import nn, Tensor 20 | from torchtext.data import Dataset 21 | import yaml 22 | from signjoey.vocabulary import GlossVocabulary, TextVocabulary 23 | 24 | 25 | def make_model_dir(model_dir: str, overwrite: bool = False) -> str: 26 | """ 27 | Create a new directory for the model. 28 | 29 | :param model_dir: path to model directory 30 | :param overwrite: whether to overwrite an existing directory 31 | :return: path to model directory 32 | """ 33 | if os.path.isdir(model_dir): 34 | if not overwrite: 35 | raise FileExistsError("Model directory exists and overwriting is disabled.") 36 | # delete previous directory to start with empty dir again 37 | shutil.rmtree(model_dir) 38 | os.makedirs(model_dir) 39 | return model_dir 40 | 41 | 42 | def make_logger(model_dir: str, log_file: str = "train.log") -> Logger: 43 | """ 44 | Create a logger for logging the training process. 45 | 46 | :param model_dir: path to logging directory 47 | :param log_file: path to logging file 48 | :return: logger object 49 | """ 50 | logger = logging.getLogger(__name__) 51 | if not logger.handlers: 52 | logger.setLevel(level=logging.DEBUG) 53 | fh = logging.FileHandler("{}/{}".format(model_dir, log_file)) 54 | fh.setLevel(level=logging.DEBUG) 55 | logger.addHandler(fh) 56 | formatter = logging.Formatter("%(asctime)s %(message)s") 57 | fh.setFormatter(formatter) 58 | if platform == "linux": 59 | sh = logging.StreamHandler() 60 | sh.setLevel(logging.INFO) 61 | sh.setFormatter(formatter) 62 | logging.getLogger("").addHandler(sh) 63 | logger.info("Hello! This is Joey-NMT.") 64 | return logger 65 | 66 | 67 | def log_cfg(cfg: dict, logger: Logger, prefix: str = "cfg"): 68 | """ 69 | Write configuration to log. 70 | 71 | :param cfg: configuration to log 72 | :param logger: logger that defines where log is written to 73 | :param prefix: prefix for logging 74 | """ 75 | for k, v in cfg.items(): 76 | if isinstance(v, dict): 77 | p = ".".join([prefix, k]) 78 | log_cfg(v, logger, prefix=p) 79 | else: 80 | p = ".".join([prefix, k]) 81 | logger.info("{:34s} : {}".format(p, v)) 82 | 83 | 84 | def clones(module: nn.Module, n: int) -> nn.ModuleList: 85 | """ 86 | Produce N identical layers. Transformer helper function. 87 | 88 | :param module: the module to clone 89 | :param n: clone this many times 90 | :return cloned modules 91 | """ 92 | return nn.ModuleList([copy.deepcopy(module) for _ in range(n)]) 93 | 94 | 95 | def subsequent_mask(size: int) -> Tensor: 96 | """ 97 | Mask out subsequent positions (to prevent attending to future positions) 98 | Transformer helper function. 99 | 100 | :param size: size of mask (2nd and 3rd dim) 101 | :return: Tensor with 0s and 1s of shape (1, size, size) 102 | """ 103 | mask = np.triu(np.ones((1, size, size)), k=1).astype("uint8") 104 | return torch.from_numpy(mask) == 0 105 | 106 | 107 | def set_seed(seed: int): 108 | """ 109 | Set the random seed for modules torch, numpy and random. 110 | 111 | :param seed: random seed 112 | """ 113 | torch.manual_seed(seed) 114 | np.random.seed(seed) 115 | random.seed(seed) 116 | 117 | 118 | def log_data_info( 119 | train_data: Dataset, 120 | valid_data: Dataset, 121 | test_data: Dataset, 122 | gls_vocab: GlossVocabulary, 123 | txt_vocab: TextVocabulary, 124 | logging_function: Callable[[str], None], 125 | ): 126 | """ 127 | Log statistics of data and vocabulary. 128 | 129 | :param train_data: 130 | :param valid_data: 131 | :param test_data: 132 | :param gls_vocab: 133 | :param txt_vocab: 134 | :param logging_function: 135 | """ 136 | logging_function( 137 | "Data set sizes: \n\ttrain {:d},\n\tvalid {:d},\n\ttest {:d}".format( 138 | len(train_data), 139 | len(valid_data), 140 | len(test_data) if test_data is not None else 0, 141 | ) 142 | ) 143 | 144 | logging_function( 145 | "First training example:\n\t[GLS] {}\n\t[TXT] {}".format( 146 | " ".join(vars(train_data[0])["gls"]), " ".join(vars(train_data[0])["txt"]) 147 | ) 148 | ) 149 | 150 | logging_function( 151 | "First 10 words (gls): {}".format( 152 | " ".join("(%d) %s" % (i, t) for i, t in enumerate(gls_vocab.itos[:10])) 153 | ) 154 | ) 155 | logging_function( 156 | "First 10 words (txt): {}".format( 157 | " ".join("(%d) %s" % (i, t) for i, t in enumerate(txt_vocab.itos[:10])) 158 | ) 159 | ) 160 | 161 | logging_function("Number of unique glosses (types): {}".format(len(gls_vocab))) 162 | logging_function("Number of unique words (types): {}".format(len(txt_vocab))) 163 | 164 | 165 | def load_config(path="configs/default.yaml") -> dict: 166 | """ 167 | Loads and parses a YAML configuration file. 168 | 169 | :param path: path to YAML configuration file 170 | :return: configuration dictionary 171 | """ 172 | with open(path, "r", encoding="utf-8") as ymlfile: 173 | cfg = yaml.safe_load(ymlfile) 174 | return cfg 175 | 176 | 177 | def bpe_postprocess(string) -> str: 178 | """ 179 | Post-processor for BPE output. Recombines BPE-split tokens. 180 | 181 | :param string: 182 | :return: post-processed string 183 | """ 184 | return string.replace("@@ ", "") 185 | 186 | 187 | def get_latest_checkpoint(ckpt_dir: str) -> Optional[str]: 188 | """ 189 | Returns the latest checkpoint (by time) from the given directory. 190 | If there is no checkpoint in this directory, returns None 191 | 192 | :param ckpt_dir: 193 | :return: latest checkpoint file 194 | """ 195 | list_of_files = glob.glob("{}/*.ckpt".format(ckpt_dir)) 196 | latest_checkpoint = None 197 | if list_of_files: 198 | latest_checkpoint = max(list_of_files, key=os.path.getctime) 199 | return latest_checkpoint 200 | 201 | 202 | def load_checkpoint(path: str, use_cuda: bool = True) -> dict: 203 | """ 204 | Load model from saved checkpoint. 205 | 206 | :param path: path to checkpoint 207 | :param use_cuda: using cuda or not 208 | :return: checkpoint (dict) 209 | """ 210 | assert os.path.isfile(path), "Checkpoint %s not found" % path 211 | checkpoint = torch.load(path, map_location="cuda" if use_cuda else "cpu") 212 | return checkpoint 213 | 214 | 215 | # from onmt 216 | def tile(x: Tensor, count: int, dim=0) -> Tensor: 217 | """ 218 | Tiles x on dimension dim count times. From OpenNMT. Used for beam search. 219 | 220 | :param x: tensor to tile 221 | :param count: number of tiles 222 | :param dim: dimension along which the tensor is tiled 223 | :return: tiled tensor 224 | """ 225 | if isinstance(x, tuple): 226 | h, c = x 227 | return tile(h, count, dim=dim), tile(c, count, dim=dim) 228 | 229 | perm = list(range(len(x.size()))) 230 | if dim != 0: 231 | perm[0], perm[dim] = perm[dim], perm[0] 232 | x = x.permute(perm).contiguous() 233 | out_size = list(x.size()) 234 | out_size[0] *= count 235 | batch = x.size(0) 236 | x = ( 237 | x.view(batch, -1) 238 | .transpose(0, 1) 239 | .repeat(count, 1) 240 | .transpose(0, 1) 241 | .contiguous() 242 | .view(*out_size) 243 | ) 244 | if dim != 0: 245 | x = x.permute(perm).contiguous() 246 | return x 247 | 248 | 249 | def freeze_params(module: nn.Module): 250 | """ 251 | Freeze the parameters of this module, 252 | i.e. do not update them during training 253 | 254 | :param module: freeze parameters of this module 255 | """ 256 | for _, p in module.named_parameters(): 257 | p.requires_grad = False 258 | 259 | 260 | def symlink_update(target, link_name): 261 | try: 262 | os.symlink(target, link_name) 263 | except FileExistsError as e: 264 | if e.errno == errno.EEXIST: 265 | os.remove(link_name) 266 | os.symlink(target, link_name) 267 | else: 268 | raise e 269 | -------------------------------------------------------------------------------- /signjoey/metrics.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | This module holds various MT evaluation metrics. 4 | """ 5 | 6 | from signjoey.external_metrics import sacrebleu 7 | from signjoey.external_metrics import mscoco_rouge 8 | import numpy as np 9 | 10 | WER_COST_DEL = 3 11 | WER_COST_INS = 3 12 | WER_COST_SUB = 4 13 | 14 | 15 | def chrf(references, hypotheses): 16 | """ 17 | Character F-score from sacrebleu 18 | 19 | :param hypotheses: list of hypotheses (strings) 20 | :param references: list of references (strings) 21 | :return: 22 | """ 23 | return ( 24 | sacrebleu.corpus_chrf(hypotheses=hypotheses, references=references).score * 100 25 | ) 26 | 27 | 28 | def bleu(references, hypotheses): 29 | """ 30 | Raw corpus BLEU from sacrebleu (without tokenization) 31 | 32 | :param hypotheses: list of hypotheses (strings) 33 | :param references: list of references (strings) 34 | :return: 35 | """ 36 | bleu_scores = sacrebleu.raw_corpus_bleu( 37 | sys_stream=hypotheses, ref_streams=[references] 38 | ).scores 39 | scores = {} 40 | for n in range(len(bleu_scores)): 41 | scores["bleu" + str(n + 1)] = bleu_scores[n] 42 | return scores 43 | 44 | 45 | def token_accuracy(references, hypotheses, level="word"): 46 | """ 47 | Compute the accuracy of hypothesis tokens: correct tokens / all tokens 48 | Tokens are correct if they appear in the same position in the reference. 49 | 50 | :param hypotheses: list of hypotheses (strings) 51 | :param references: list of references (strings) 52 | :param level: segmentation level, either "word", "bpe", or "char" 53 | :return: 54 | """ 55 | correct_tokens = 0 56 | all_tokens = 0 57 | split_char = " " if level in ["word", "bpe"] else "" 58 | assert len(hypotheses) == len(references) 59 | for hyp, ref in zip(hypotheses, references): 60 | all_tokens += len(hyp) 61 | for h_i, r_i in zip(hyp.split(split_char), ref.split(split_char)): 62 | # min(len(h), len(r)) tokens considered 63 | if h_i == r_i: 64 | correct_tokens += 1 65 | return (correct_tokens / all_tokens) * 100 if all_tokens > 0 else 0.0 66 | 67 | 68 | def sequence_accuracy(references, hypotheses): 69 | """ 70 | Compute the accuracy of hypothesis tokens: correct tokens / all tokens 71 | Tokens are correct if they appear in the same position in the reference. 72 | 73 | :param hypotheses: list of hypotheses (strings) 74 | :param references: list of references (strings) 75 | :return: 76 | """ 77 | assert len(hypotheses) == len(references) 78 | correct_sequences = sum( 79 | [1 for (hyp, ref) in zip(hypotheses, references) if hyp == ref] 80 | ) 81 | return (correct_sequences / len(hypotheses)) * 100 if hypotheses else 0.0 82 | 83 | 84 | def rouge(references, hypotheses): 85 | rouge_score = 0 86 | n_seq = len(hypotheses) 87 | 88 | for h, r in zip(hypotheses, references): 89 | rouge_score += mscoco_rouge.calc_score(hypotheses=[h], references=[r]) / n_seq 90 | 91 | return rouge_score * 100 92 | 93 | 94 | def wer_list(references, hypotheses): 95 | total_error = total_del = total_ins = total_sub = total_ref_len = 0 96 | 97 | for r, h in zip(references, hypotheses): 98 | res = wer_single(r=r, h=h) 99 | total_error += res["num_err"] 100 | total_del += res["num_del"] 101 | total_ins += res["num_ins"] 102 | total_sub += res["num_sub"] 103 | total_ref_len += res["num_ref"] 104 | 105 | wer = (total_error / total_ref_len) * 100 106 | del_rate = (total_del / total_ref_len) * 100 107 | ins_rate = (total_ins / total_ref_len) * 100 108 | sub_rate = (total_sub / total_ref_len) * 100 109 | 110 | return { 111 | "wer": wer, 112 | "del_rate": del_rate, 113 | "ins_rate": ins_rate, 114 | "sub_rate": sub_rate, 115 | } 116 | 117 | 118 | def wer_single(r, h): 119 | r = r.strip().split() 120 | h = h.strip().split() 121 | edit_distance_matrix = edit_distance(r=r, h=h) 122 | alignment, alignment_out = get_alignment(r=r, h=h, d=edit_distance_matrix) 123 | 124 | num_cor = np.sum([s == "C" for s in alignment]) 125 | num_del = np.sum([s == "D" for s in alignment]) 126 | num_ins = np.sum([s == "I" for s in alignment]) 127 | num_sub = np.sum([s == "S" for s in alignment]) 128 | num_err = num_del + num_ins + num_sub 129 | num_ref = len(r) 130 | 131 | return { 132 | "alignment": alignment, 133 | "alignment_out": alignment_out, 134 | "num_cor": num_cor, 135 | "num_del": num_del, 136 | "num_ins": num_ins, 137 | "num_sub": num_sub, 138 | "num_err": num_err, 139 | "num_ref": num_ref, 140 | } 141 | 142 | 143 | def edit_distance(r, h): 144 | """ 145 | Original Code from https://github.com/zszyellow/WER-in-python/blob/master/wer.py 146 | This function is to calculate the edit distance of reference sentence and the hypothesis sentence. 147 | Main algorithm used is dynamic programming. 148 | Attributes: 149 | r -> the list of words produced by splitting reference sentence. 150 | h -> the list of words produced by splitting hypothesis sentence. 151 | """ 152 | d = np.zeros((len(r) + 1) * (len(h) + 1), dtype=np.uint8).reshape( 153 | (len(r) + 1, len(h) + 1) 154 | ) 155 | for i in range(len(r) + 1): 156 | for j in range(len(h) + 1): 157 | if i == 0: 158 | # d[0][j] = j 159 | d[0][j] = j * WER_COST_INS 160 | elif j == 0: 161 | d[i][0] = i * WER_COST_DEL 162 | for i in range(1, len(r) + 1): 163 | for j in range(1, len(h) + 1): 164 | if r[i - 1] == h[j - 1]: 165 | d[i][j] = d[i - 1][j - 1] 166 | else: 167 | substitute = d[i - 1][j - 1] + WER_COST_SUB 168 | insert = d[i][j - 1] + WER_COST_INS 169 | delete = d[i - 1][j] + WER_COST_DEL 170 | d[i][j] = min(substitute, insert, delete) 171 | return d 172 | 173 | 174 | def get_alignment(r, h, d): 175 | """ 176 | Original Code from https://github.com/zszyellow/WER-in-python/blob/master/wer.py 177 | This function is to get the list of steps in the process of dynamic programming. 178 | Attributes: 179 | r -> the list of words produced by splitting reference sentence. 180 | h -> the list of words produced by splitting hypothesis sentence. 181 | d -> the matrix built when calculating the editing distance of h and r. 182 | """ 183 | x = len(r) 184 | y = len(h) 185 | max_len = 3 * (x + y) 186 | 187 | alignlist = [] 188 | align_ref = "" 189 | align_hyp = "" 190 | alignment = "" 191 | 192 | while True: 193 | if (x <= 0 and y <= 0) or (len(alignlist) > max_len): 194 | break 195 | elif x >= 1 and y >= 1 and d[x][y] == d[x - 1][y - 1] and r[x - 1] == h[y - 1]: 196 | align_hyp = " " + h[y - 1] + align_hyp 197 | align_ref = " " + r[x - 1] + align_ref 198 | alignment = " " * (len(r[x - 1]) + 1) + alignment 199 | alignlist.append("C") 200 | x = max(x - 1, 0) 201 | y = max(y - 1, 0) 202 | elif x >= 1 and y >= 1 and d[x][y] == d[x - 1][y - 1] + WER_COST_SUB: 203 | ml = max(len(h[y - 1]), len(r[x - 1])) 204 | align_hyp = " " + h[y - 1].ljust(ml) + align_hyp 205 | align_ref = " " + r[x - 1].ljust(ml) + align_ref 206 | alignment = " " + "S" + " " * (ml - 1) + alignment 207 | alignlist.append("S") 208 | x = max(x - 1, 0) 209 | y = max(y - 1, 0) 210 | elif y >= 1 and d[x][y] == d[x][y - 1] + WER_COST_INS: 211 | align_hyp = " " + h[y - 1] + align_hyp 212 | align_ref = " " + "*" * len(h[y - 1]) + align_ref 213 | alignment = " " + "I" + " " * (len(h[y - 1]) - 1) + alignment 214 | alignlist.append("I") 215 | x = max(x, 0) 216 | y = max(y - 1, 0) 217 | else: 218 | align_hyp = " " + "*" * len(r[x - 1]) + align_hyp 219 | align_ref = " " + r[x - 1] + align_ref 220 | alignment = " " + "D" + " " * (len(r[x - 1]) - 1) + alignment 221 | alignlist.append("D") 222 | x = max(x - 1, 0) 223 | y = max(y, 0) 224 | 225 | align_ref = align_ref[1:] 226 | align_hyp = align_hyp[1:] 227 | alignment = alignment[1:] 228 | 229 | return ( 230 | alignlist[::-1], 231 | {"align_ref": align_ref, "align_hyp": align_hyp, "alignment": alignment}, 232 | ) 233 | -------------------------------------------------------------------------------- /signjoey/vocabulary.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import numpy as np 3 | 4 | from collections import defaultdict, Counter 5 | from typing import List 6 | from torchtext.data import Dataset 7 | 8 | SIL_TOKEN = "" 9 | UNK_TOKEN = "" 10 | PAD_TOKEN = "" 11 | BOS_TOKEN = "" 12 | EOS_TOKEN = "" 13 | 14 | 15 | class Vocabulary: 16 | """ Vocabulary represents mapping between tokens and indices. """ 17 | 18 | def __init__(self): 19 | # don't rename stoi and itos since needed for torchtext 20 | # warning: stoi grows with unknown tokens, don't use for saving or size 21 | self.specials = [] 22 | self.itos = [] 23 | self.stoi = None 24 | self.DEFAULT_UNK_ID = None 25 | 26 | def _from_list(self, tokens: List[str] = None): 27 | """ 28 | Make vocabulary from list of tokens. 29 | Tokens are assumed to be unique and pre-selected. 30 | Special symbols are added if not in list. 31 | 32 | :param tokens: list of tokens 33 | """ 34 | self.add_tokens(tokens=self.specials + tokens) 35 | assert len(self.stoi) == len(self.itos) 36 | 37 | def _from_file(self, file: str): 38 | """ 39 | Make vocabulary from contents of file. 40 | File format: token with index i is in line i. 41 | 42 | :param file: path to file where the vocabulary is loaded from 43 | """ 44 | tokens = [] 45 | with open(file, "r", encoding="utf-8") as open_file: 46 | for line in open_file: 47 | tokens.append(line.strip("\n")) 48 | self._from_list(tokens) 49 | 50 | def __str__(self) -> str: 51 | return self.stoi.__str__() 52 | 53 | def to_file(self, file: str): 54 | """ 55 | Save the vocabulary to a file, by writing token with index i in line i. 56 | 57 | :param file: path to file where the vocabulary is written 58 | """ 59 | with open(file, "w", encoding="utf-8") as open_file: 60 | for t in self.itos: 61 | open_file.write("{}\n".format(t)) 62 | 63 | def add_tokens(self, tokens: List[str]): 64 | """ 65 | Add list of tokens to vocabulary 66 | 67 | :param tokens: list of tokens to add to the vocabulary 68 | """ 69 | for t in tokens: 70 | new_index = len(self.itos) 71 | # add to vocab if not already there 72 | if t not in self.itos: 73 | self.itos.append(t) 74 | self.stoi[t] = new_index 75 | 76 | def is_unk(self, token: str) -> bool: 77 | """ 78 | Check whether a token is covered by the vocabulary 79 | 80 | :param token: 81 | :return: True if covered, False otherwise 82 | """ 83 | return self.stoi[token] == self.DEFAULT_UNK_ID() 84 | 85 | def __len__(self) -> int: 86 | return len(self.itos) 87 | 88 | 89 | class TextVocabulary(Vocabulary): 90 | def __init__(self, tokens: List[str] = None, file: str = None): 91 | """ 92 | Create vocabulary from list of tokens or file. 93 | 94 | Special tokens are added if not already in file or list. 95 | File format: token with index i is in line i. 96 | 97 | :param tokens: list of tokens 98 | :param file: file to load vocabulary from 99 | """ 100 | super().__init__() 101 | self.specials = [UNK_TOKEN, PAD_TOKEN, BOS_TOKEN, EOS_TOKEN] 102 | self.DEFAULT_UNK_ID = lambda: 0 103 | self.stoi = defaultdict(self.DEFAULT_UNK_ID) 104 | 105 | if tokens is not None: 106 | self._from_list(tokens) 107 | elif file is not None: 108 | self._from_file(file) 109 | 110 | def array_to_sentence(self, array: np.array, cut_at_eos=True) -> List[str]: 111 | """ 112 | Converts an array of IDs to a sentence, optionally cutting the result 113 | off at the end-of-sequence token. 114 | 115 | :param array: 1D array containing indices 116 | :param cut_at_eos: cut the decoded sentences at the first 117 | :return: list of strings (tokens) 118 | """ 119 | sentence = [] 120 | for i in array: 121 | s = self.itos[i] 122 | if cut_at_eos and s == EOS_TOKEN: 123 | break 124 | sentence.append(s) 125 | return sentence 126 | 127 | def arrays_to_sentences(self, arrays: np.array, cut_at_eos=True) -> List[List[str]]: 128 | """ 129 | Convert multiple arrays containing sequences of token IDs to their 130 | sentences, optionally cutting them off at the end-of-sequence token. 131 | 132 | :param arrays: 2D array containing indices 133 | :param cut_at_eos: cut the decoded sentences at the first 134 | :return: list of list of strings (tokens) 135 | """ 136 | sentences = [] 137 | for array in arrays: 138 | sentences.append(self.array_to_sentence(array=array, cut_at_eos=cut_at_eos)) 139 | return sentences 140 | 141 | 142 | class GlossVocabulary(Vocabulary): 143 | def __init__(self, tokens: List[str] = None, file: str = None): 144 | """ 145 | Create vocabulary from list of tokens or file. 146 | 147 | Special tokens are added if not already in file or list. 148 | File format: token with index i is in line i. 149 | 150 | :param tokens: list of tokens 151 | :param file: file to load vocabulary from 152 | """ 153 | super().__init__() 154 | self.specials = [SIL_TOKEN, UNK_TOKEN, PAD_TOKEN] 155 | self.DEFAULT_UNK_ID = lambda: 1 156 | self.stoi = defaultdict(self.DEFAULT_UNK_ID) 157 | 158 | if tokens is not None: 159 | self._from_list(tokens) 160 | elif file is not None: 161 | self._from_file(file) 162 | 163 | # TODO (Cihan): This bit is hardcoded so that the silence token 164 | # is the first label to be able to do CTC calculations (decoding etc.) 165 | # Might fix in the future. 166 | assert self.stoi[SIL_TOKEN] == 0 167 | 168 | def arrays_to_sentences(self, arrays: np.array) -> List[List[str]]: 169 | gloss_sequences = [] 170 | for array in arrays: 171 | sequence = [] 172 | for i in array: 173 | sequence.append(self.itos[i]) 174 | gloss_sequences.append(sequence) 175 | return gloss_sequences 176 | 177 | 178 | def filter_min(counter: Counter, minimum_freq: int): 179 | """ Filter counter by min frequency """ 180 | filtered_counter = Counter({t: c for t, c in counter.items() if c >= minimum_freq}) 181 | return filtered_counter 182 | 183 | 184 | def sort_and_cut(counter: Counter, limit: int): 185 | """ Cut counter to most frequent, 186 | sorted numerically and alphabetically""" 187 | # sort by frequency, then alphabetically 188 | tokens_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0]) 189 | tokens_and_frequencies.sort(key=lambda tup: tup[1], reverse=True) 190 | vocab_tokens = [i[0] for i in tokens_and_frequencies[:limit]] 191 | return vocab_tokens 192 | 193 | 194 | def build_vocab( 195 | field: str, max_size: int, min_freq: int, dataset: Dataset, vocab_file: str = None 196 | ) -> Vocabulary: 197 | """ 198 | Builds vocabulary for a torchtext `field` from given`dataset` or 199 | `vocab_file`. 200 | 201 | :param field: attribute e.g. "src" 202 | :param max_size: maximum size of vocabulary 203 | :param min_freq: minimum frequency for an item to be included 204 | :param dataset: dataset to load data for field from 205 | :param vocab_file: file to store the vocabulary, 206 | if not None, load vocabulary from here 207 | :return: Vocabulary created from either `dataset` or `vocab_file` 208 | """ 209 | 210 | if vocab_file is not None: 211 | # load it from file 212 | if field == "gls": 213 | vocab = GlossVocabulary(file=vocab_file) 214 | elif field == "txt": 215 | vocab = TextVocabulary(file=vocab_file) 216 | else: 217 | raise ValueError("Unknown vocabulary type") 218 | else: 219 | tokens = [] 220 | for i in dataset.examples: 221 | if field == "gls": 222 | tokens.extend(i.gls) 223 | elif field == "txt": 224 | tokens.extend(i.txt) 225 | else: 226 | raise ValueError("Unknown field type") 227 | 228 | counter = Counter(tokens) 229 | if min_freq > -1: 230 | counter = filter_min(counter, min_freq) 231 | vocab_tokens = sort_and_cut(counter, max_size) 232 | assert len(vocab_tokens) <= max_size 233 | 234 | if field == "gls": 235 | vocab = GlossVocabulary(tokens=vocab_tokens) 236 | elif field == "txt": 237 | vocab = TextVocabulary(tokens=vocab_tokens) 238 | else: 239 | raise ValueError("Unknown vocabulary type") 240 | 241 | assert len(vocab) <= max_size + len(vocab.specials) 242 | assert vocab.itos[vocab.DEFAULT_UNK_ID()] == UNK_TOKEN 243 | 244 | for i, s in enumerate(vocab.specials): 245 | if i != vocab.DEFAULT_UNK_ID(): 246 | assert not vocab.is_unk(s) 247 | 248 | return vocab 249 | -------------------------------------------------------------------------------- /signjoey/data.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Data module 4 | """ 5 | import os 6 | import sys 7 | import random 8 | 9 | import torch 10 | from torchtext import data 11 | from torchtext.data import Dataset, Iterator 12 | import socket 13 | from signjoey.dataset import SignTranslationDataset 14 | from signjoey.vocabulary import ( 15 | build_vocab, 16 | Vocabulary, 17 | UNK_TOKEN, 18 | EOS_TOKEN, 19 | BOS_TOKEN, 20 | PAD_TOKEN, 21 | ) 22 | 23 | 24 | def load_data(data_cfg: dict) -> (Dataset, Dataset, Dataset, Vocabulary, Vocabulary): 25 | """ 26 | Load train, dev and optionally test data as specified in configuration. 27 | Vocabularies are created from the training set with a limit of `voc_limit` 28 | tokens and a minimum token frequency of `voc_min_freq` 29 | (specified in the configuration dictionary). 30 | 31 | The training data is filtered to include sentences up to `max_sent_length` 32 | on source and target side. 33 | 34 | If you set ``random_train_subset``, a random selection of this size is used 35 | from the training set instead of the full training set. 36 | 37 | If you set ``random_dev_subset``, a random selection of this size is used 38 | from the dev development instead of the full development set. 39 | 40 | :param data_cfg: configuration dictionary for data 41 | ("data" part of configuration file) 42 | :return: 43 | - train_data: training dataset 44 | - dev_data: development dataset 45 | - test_data: test dataset if given, otherwise None 46 | - gls_vocab: gloss vocabulary extracted from training data 47 | - txt_vocab: spoken text vocabulary extracted from training data 48 | """ 49 | 50 | data_path = data_cfg.get("data_path", "./data") 51 | 52 | if isinstance(data_cfg["train"], list): 53 | train_paths = [os.path.join(data_path, x) for x in data_cfg["train"]] 54 | dev_paths = [os.path.join(data_path, x) for x in data_cfg["dev"]] 55 | test_paths = [os.path.join(data_path, x) for x in data_cfg["test"]] 56 | pad_feature_size = sum(data_cfg["feature_size"]) 57 | 58 | else: 59 | train_paths = os.path.join(data_path, data_cfg["train"]) 60 | dev_paths = os.path.join(data_path, data_cfg["dev"]) 61 | test_paths = os.path.join(data_path, data_cfg["test"]) 62 | pad_feature_size = data_cfg["feature_size"] 63 | 64 | level = data_cfg["level"] 65 | txt_lowercase = data_cfg["txt_lowercase"] 66 | max_sent_length = data_cfg["max_sent_length"] 67 | 68 | def tokenize_text(text): 69 | if level == "char": 70 | return list(text) 71 | else: 72 | return text.split() 73 | 74 | def tokenize_features(features): 75 | ft_list = torch.split(features, 1, dim=0) 76 | return [ft.squeeze() for ft in ft_list] 77 | 78 | # NOTE (Cihan): The something was necessary to match the function signature. 79 | def stack_features(features, something): 80 | return torch.stack([torch.stack(ft, dim=0) for ft in features], dim=0) 81 | 82 | sequence_field = data.RawField() 83 | signer_field = data.RawField() 84 | 85 | sgn_field = data.Field( 86 | use_vocab=False, 87 | init_token=None, 88 | dtype=torch.float32, 89 | preprocessing=tokenize_features, 90 | tokenize=lambda features: features, # TODO (Cihan): is this necessary? 91 | batch_first=True, 92 | include_lengths=True, 93 | postprocessing=stack_features, 94 | pad_token=torch.zeros((pad_feature_size,)), 95 | ) 96 | 97 | gls_field = data.Field( 98 | pad_token=PAD_TOKEN, 99 | tokenize=tokenize_text, 100 | batch_first=True, 101 | lower=False, 102 | include_lengths=True, 103 | ) 104 | 105 | txt_field = data.Field( 106 | init_token=BOS_TOKEN, 107 | eos_token=EOS_TOKEN, 108 | pad_token=PAD_TOKEN, 109 | tokenize=tokenize_text, 110 | unk_token=UNK_TOKEN, 111 | batch_first=True, 112 | lower=txt_lowercase, 113 | include_lengths=True, 114 | ) 115 | 116 | train_data = SignTranslationDataset( 117 | path=train_paths, 118 | fields=(sequence_field, signer_field, sgn_field, gls_field, txt_field), 119 | filter_pred=lambda x: len(vars(x)["sgn"]) <= max_sent_length 120 | and len(vars(x)["txt"]) <= max_sent_length, 121 | ) 122 | 123 | gls_max_size = data_cfg.get("gls_voc_limit", sys.maxsize) 124 | gls_min_freq = data_cfg.get("gls_voc_min_freq", 1) 125 | txt_max_size = data_cfg.get("txt_voc_limit", sys.maxsize) 126 | txt_min_freq = data_cfg.get("txt_voc_min_freq", 1) 127 | 128 | gls_vocab_file = data_cfg.get("gls_vocab", None) 129 | txt_vocab_file = data_cfg.get("txt_vocab", None) 130 | 131 | gls_vocab = build_vocab( 132 | field="gls", 133 | min_freq=gls_min_freq, 134 | max_size=gls_max_size, 135 | dataset=train_data, 136 | vocab_file=gls_vocab_file, 137 | ) 138 | txt_vocab = build_vocab( 139 | field="txt", 140 | min_freq=txt_min_freq, 141 | max_size=txt_max_size, 142 | dataset=train_data, 143 | vocab_file=txt_vocab_file, 144 | ) 145 | random_train_subset = data_cfg.get("random_train_subset", -1) 146 | if random_train_subset > -1: 147 | # select this many training examples randomly and discard the rest 148 | keep_ratio = random_train_subset / len(train_data) 149 | keep, _ = train_data.split( 150 | split_ratio=[keep_ratio, 1 - keep_ratio], random_state=random.getstate() 151 | ) 152 | train_data = keep 153 | 154 | dev_data = SignTranslationDataset( 155 | path=dev_paths, 156 | fields=(sequence_field, signer_field, sgn_field, gls_field, txt_field), 157 | ) 158 | random_dev_subset = data_cfg.get("random_dev_subset", -1) 159 | if random_dev_subset > -1: 160 | # select this many development examples randomly and discard the rest 161 | keep_ratio = random_dev_subset / len(dev_data) 162 | keep, _ = dev_data.split( 163 | split_ratio=[keep_ratio, 1 - keep_ratio], random_state=random.getstate() 164 | ) 165 | dev_data = keep 166 | 167 | # check if target exists 168 | test_data = SignTranslationDataset( 169 | path=test_paths, 170 | fields=(sequence_field, signer_field, sgn_field, gls_field, txt_field), 171 | ) 172 | 173 | gls_field.vocab = gls_vocab 174 | txt_field.vocab = txt_vocab 175 | return train_data, dev_data, test_data, gls_vocab, txt_vocab 176 | 177 | 178 | # TODO (Cihan): I don't like this use of globals. 179 | # Need to find a more elegant solution for this it at some point. 180 | # pylint: disable=global-at-module-level 181 | global max_sgn_in_batch, max_gls_in_batch, max_txt_in_batch 182 | 183 | 184 | # pylint: disable=unused-argument,global-variable-undefined 185 | def token_batch_size_fn(new, count, sofar): 186 | """Compute batch size based on number of tokens (+padding)""" 187 | global max_sgn_in_batch, max_gls_in_batch, max_txt_in_batch 188 | if count == 1: 189 | max_sgn_in_batch = 0 190 | max_gls_in_batch = 0 191 | max_txt_in_batch = 0 192 | max_sgn_in_batch = max(max_sgn_in_batch, len(new.sgn)) 193 | max_gls_in_batch = max(max_gls_in_batch, len(new.gls)) 194 | max_txt_in_batch = max(max_txt_in_batch, len(new.txt) + 2) 195 | sgn_elements = count * max_sgn_in_batch 196 | gls_elements = count * max_gls_in_batch 197 | txt_elements = count * max_txt_in_batch 198 | return max(sgn_elements, gls_elements, txt_elements) 199 | 200 | 201 | def make_data_iter( 202 | dataset: Dataset, 203 | batch_size: int, 204 | batch_type: str = "sentence", 205 | train: bool = False, 206 | shuffle: bool = False, 207 | ) -> Iterator: 208 | """ 209 | Returns a torchtext iterator for a torchtext dataset. 210 | 211 | :param dataset: torchtext dataset containing sgn and optionally txt 212 | :param batch_size: size of the batches the iterator prepares 213 | :param batch_type: measure batch size by sentence count or by token count 214 | :param train: whether it's training time, when turned off, 215 | bucketing, sorting within batches and shuffling is disabled 216 | :param shuffle: whether to shuffle the data before each epoch 217 | (no effect if set to True for testing) 218 | :return: torchtext iterator 219 | """ 220 | 221 | batch_size_fn = token_batch_size_fn if batch_type == "token" else None 222 | 223 | if train: 224 | # optionally shuffle and sort during training 225 | data_iter = data.BucketIterator( 226 | repeat=False, 227 | sort=False, 228 | dataset=dataset, 229 | batch_size=batch_size, 230 | batch_size_fn=batch_size_fn, 231 | train=True, 232 | sort_within_batch=True, 233 | sort_key=lambda x: len(x.sgn), 234 | shuffle=shuffle, 235 | ) 236 | else: 237 | # don't sort/shuffle for validation/inference 238 | data_iter = data.BucketIterator( 239 | repeat=False, 240 | dataset=dataset, 241 | batch_size=batch_size, 242 | batch_size_fn=batch_size_fn, 243 | train=False, 244 | sort=False, 245 | ) 246 | 247 | return data_iter 248 | -------------------------------------------------------------------------------- /signjoey/embeddings.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch import nn, Tensor 5 | import torch.nn.functional as F 6 | from signjoey.helpers import freeze_params 7 | from signjoey.layers import DenseBayesian,EmbeddingBayesian 8 | 9 | 10 | def get_activation(activation_type): 11 | if activation_type == "relu": 12 | return nn.ReLU() 13 | elif activation_type == "relu6": 14 | return nn.ReLU6() 15 | elif activation_type == "prelu": 16 | return nn.PReLU() 17 | elif activation_type == "selu": 18 | return nn.SELU() 19 | elif activation_type == "celu": 20 | return nn.CELU() 21 | elif activation_type == "gelu": 22 | return nn.GELU() 23 | elif activation_type == "sigmoid": 24 | return nn.Sigmoid() 25 | elif activation_type == "softplus": 26 | return nn.Softplus() 27 | elif activation_type == "softshrink": 28 | return nn.Softshrink() 29 | elif activation_type == "softsign": 30 | return nn.Softsign() 31 | elif activation_type == "tanh": 32 | return nn.Tanh() 33 | elif activation_type == "tanhshrink": 34 | return nn.Tanhshrink() 35 | else: 36 | raise ValueError("Unknown activation type {}".format(activation_type)) 37 | 38 | 39 | class MaskedNorm(nn.Module): 40 | """ 41 | Original Code from: 42 | https://discuss.pytorch.org/t/batchnorm-for-different-sized-samples-in-batch/44251/8 43 | """ 44 | 45 | def __init__(self, norm_type, num_groups, num_features): 46 | super().__init__() 47 | self.norm_type = norm_type 48 | if self.norm_type == "batch": 49 | self.norm = nn.BatchNorm1d(num_features=num_features) 50 | elif self.norm_type == "group": 51 | self.norm = nn.GroupNorm(num_groups=num_groups, num_channels=num_features) 52 | elif self.norm_type == "layer": 53 | self.norm = nn.LayerNorm(normalized_shape=num_features) 54 | else: 55 | raise ValueError("Unsupported Normalization Layer") 56 | 57 | self.num_features = num_features 58 | 59 | def forward(self, x: Tensor, mask: Tensor): 60 | if self.training: 61 | reshaped = x.reshape([-1, self.num_features]) 62 | reshaped_mask = mask.reshape([-1, 1]) > 0 63 | selected = torch.masked_select(reshaped, reshaped_mask).reshape( 64 | [-1, self.num_features] 65 | ) 66 | batch_normed = self.norm(selected) 67 | scattered = reshaped.masked_scatter(reshaped_mask, batch_normed) 68 | return scattered.reshape([x.shape[0], -1, self.num_features]) 69 | else: 70 | reshaped = x.reshape([-1, self.num_features]) 71 | batched_normed = self.norm(reshaped) 72 | return batched_normed.reshape([x.shape[0], -1, self.num_features]) 73 | 74 | 75 | 76 | 77 | class Embeddings(nn.Module): 78 | 79 | """ 80 | Simple embeddings class 81 | """ 82 | 83 | # pylint: disable=unused-argument 84 | def __init__( 85 | self, 86 | embedding_dim: int = 64, 87 | num_heads: int = 8, 88 | scale: bool = False, 89 | scale_factor: float = None, 90 | norm_type: str = None, 91 | activation_type: str = 'relu', 92 | lwta_competitors: int = 4, 93 | vocab_size: int = 0, 94 | padding_idx: int = 1, 95 | freeze: bool = False, 96 | bayesian : bool = False, 97 | inference_sample_size : int = 1, 98 | **kwargs 99 | ): 100 | """ 101 | Create new embeddings for the vocabulary. 102 | Use scaling for the Transformer. 103 | 104 | :param embedding_dim: 105 | :param scale: 106 | :param vocab_size: 107 | :param padding_idx: 108 | :param freeze: freeze the embeddings during training 109 | """ 110 | super().__init__() 111 | 112 | self.bayesian=bayesian 113 | self.embedding_dim = embedding_dim 114 | self.vocab_size = vocab_size 115 | if bayesian: 116 | self.inference_sample_size=inference_sample_size 117 | self.lut = EmbeddingBayesian(vocab_size, self.embedding_dim, padding_idx=padding_idx, 118 | input_features=vocab_size, output_features=self.embedding_dim, competitors=4, 119 | activation='lwta',kl_w=0.1) 120 | else: 121 | self.inference_sample_size=1 122 | self.lut = nn.Embedding(vocab_size, self.embedding_dim, padding_idx=padding_idx) 123 | 124 | 125 | self.norm_type = norm_type 126 | if self.norm_type: 127 | self.norm = MaskedNorm( 128 | norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim 129 | ) 130 | 131 | self.activation_type = activation_type 132 | if self.activation_type and not self.bayesian : 133 | self.activation = get_activation(activation_type) 134 | 135 | self.scale = scale 136 | if self.scale: 137 | if scale_factor: 138 | self.scale_factor = scale_factor 139 | else: 140 | self.scale_factor = math.sqrt(self.embedding_dim) 141 | 142 | if freeze: 143 | freeze_params(self) 144 | 145 | # pylint: disable=arguments-differ 146 | def forward_(self, x: Tensor, mask: Tensor = None) -> Tensor: 147 | """ 148 | Perform lookup for input `x` in the embedding table. 149 | 150 | :param mask: token masks 151 | :param x: index in the vocabulary 152 | :return: embedded representation for `x` 153 | """ 154 | 155 | x = self.lut(x) 156 | 157 | if self.norm_type: 158 | x = self.norm(x, mask) 159 | 160 | if self.activation_type and not self.bayesian: 161 | x = self.activation(x) 162 | 163 | if self.scale: 164 | return x * self.scale_factor 165 | else: 166 | return x 167 | 168 | def __repr__(self): 169 | return "%s(embedding_dim=%d, vocab_size=%d)" % ( 170 | self.__class__.__name__, 171 | self.embedding_dim, 172 | self.vocab_size, 173 | ) 174 | # pylint: disable=arguments-differ 175 | def forward( self, x: Tensor, mask: Tensor = None) -> Tensor: 176 | if self.training : 177 | return self.forward_(x,mask) 178 | else: 179 | 180 | out=[] 181 | #for i in range(self.inference_sample_size): 182 | for i in range(self.inference_sample_size): 183 | 184 | x_= self.forward_(x,mask) 185 | 186 | out.append(torch.unsqueeze(x_,-1)) 187 | out=torch.cat(out,-1) 188 | 189 | 190 | 191 | return out 192 | 193 | 194 | class SpatialEmbeddings(nn.Module): 195 | 196 | 197 | # pylint: disable=unused-argument 198 | def __init__( 199 | self, 200 | embedding_dim: int, 201 | input_size: int, 202 | num_heads: int, 203 | freeze: bool = False, 204 | norm_type: str = None, 205 | activation_type: str = None, 206 | lwta_competitors: int = 4, 207 | scale: bool = False, 208 | scale_factor: float = None, 209 | bayesian : bool = False, 210 | ibp : bool = False, 211 | inference_sample_size : int = 1, 212 | **kwargs 213 | ): 214 | """ 215 | Create new embeddings for the vocabulary. 216 | Use scaling for the Transformer. 217 | 218 | :param embedding_dim: 219 | :param input_size: 220 | :param freeze: freeze the embeddings during training 221 | """ 222 | super().__init__() 223 | 224 | 225 | self.embedding_dim = embedding_dim 226 | self.input_size = input_size 227 | self.bayesian=bayesian 228 | if self.bayesian: 229 | self.inference_sample_size=inference_sample_size 230 | else: 231 | self.inference_sample_size=1 232 | if bayesian: 233 | self.ln = DenseBayesian(self.input_size, self.embedding_dim, competitors =lwta_competitors , 234 | activation = activation_type, prior_mean=0, prior_scale=1. , kl_w=0.1, ibp = ibp) 235 | 236 | else: 237 | self.ln = nn.Linear(self.input_size, self.embedding_dim) 238 | 239 | self.norm_type = norm_type 240 | if self.norm_type: 241 | self.norm = MaskedNorm( 242 | norm_type=norm_type, num_groups=num_heads, num_features=embedding_dim 243 | ) 244 | 245 | self.activation_type = activation_type 246 | if bayesian: 247 | self.activation_type = False 248 | else: 249 | self.activation_type = activation_type 250 | if self.activation_type: 251 | self.activation = get_activation(activation_type) 252 | 253 | self.scale = scale 254 | if self.scale: 255 | if scale_factor: 256 | self.scale_factor = scale_factor 257 | else: 258 | self.scale_factor = math.sqrt(self.embedding_dim) 259 | 260 | if freeze: 261 | freeze_params(self) 262 | 263 | 264 | # pylint: disable=arguments-differ 265 | def forward_(self, x: Tensor, mask: Tensor) -> Tensor: 266 | """ 267 | :param mask: frame masks 268 | :param x: input frame features 269 | :return: embedded representation for `x` 270 | """ 271 | 272 | #x = self.ln(x.transpose(-1,-2)).transpose(-1,-2) 273 | x=self.ln(x) 274 | 275 | if self.norm_type: 276 | x = self.norm(x, mask) 277 | 278 | if self.activation_type and (not self.bayesian): 279 | x = self.activation(x) 280 | 281 | if self.scale: 282 | return x * self.scale_factor 283 | else: 284 | return x 285 | 286 | # pylint: disable=arguments-differ 287 | def forward( self, x: Tensor, mask: Tensor) -> Tensor: 288 | if self.training : 289 | return self.forward_(x,mask) 290 | else: 291 | 292 | out=[] 293 | #for i in range(self.inference_sample_size): 294 | for i in range(self.inference_sample_size): 295 | 296 | x_= self.forward_(x,mask) 297 | 298 | out.append(torch.unsqueeze(x_,-1)) 299 | out=torch.cat(out,-1) 300 | 301 | 302 | 303 | return out 304 | 305 | def __repr__(self): 306 | return "%s(embedding_dim=%d, input_size=%d)" % ( 307 | self.__class__.__name__, 308 | self.embedding_dim, 309 | self.input_size, 310 | ) 311 | -------------------------------------------------------------------------------- /signjoey/encoders.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | 9 | from signjoey.helpers import freeze_params 10 | from signjoey.transformer_layers import TransformerEncoderLayer, PositionalEncoding 11 | from signjoey.layers import DenseBayesian 12 | 13 | # pylint: disable=abstract-method 14 | class Encoder(nn.Module): 15 | """ 16 | Base encoder class 17 | """ 18 | 19 | @property 20 | def output_size(self): 21 | """ 22 | Return the output size 23 | 24 | :return: 25 | """ 26 | return self._output_size 27 | 28 | 29 | class RecurrentEncoder(Encoder): 30 | """Encodes a sequence of word embeddings""" 31 | 32 | # pylint: disable=unused-argument 33 | def __init__( 34 | self, 35 | rnn_type: str = "gru", 36 | hidden_size: int = 1, 37 | emb_size: int = 1, 38 | num_layers: int = 1, 39 | dropout: float = 0.0, 40 | emb_dropout: float = 0.0, 41 | bidirectional: bool = True, 42 | freeze: bool = False, 43 | **kwargs 44 | ) -> None: 45 | """ 46 | Create a new recurrent encoder. 47 | 48 | :param rnn_type: RNN type: `gru` or `lstm`. 49 | :param hidden_size: Size of each RNN. 50 | :param emb_size: Size of the word embeddings. 51 | :param num_layers: Number of encoder RNN layers. 52 | :param dropout: Is applied between RNN layers. 53 | :param emb_dropout: Is applied to the RNN input (word embeddings). 54 | :param bidirectional: Use a bi-directional RNN. 55 | :param freeze: freeze the parameters of the encoder during training 56 | :param kwargs: 57 | """ 58 | 59 | super(RecurrentEncoder, self).__init__() 60 | 61 | self.emb_dropout = torch.nn.Dropout(p=emb_dropout, inplace=False) 62 | self.type = rnn_type 63 | self.emb_size = emb_size 64 | 65 | rnn = nn.GRU if rnn_type == "gru" else nn.LSTM 66 | 67 | self.rnn = rnn( 68 | emb_size, 69 | hidden_size, 70 | num_layers, 71 | batch_first=True, 72 | bidirectional=bidirectional, 73 | dropout=dropout if num_layers > 1 else 0.0, 74 | ) 75 | 76 | self._output_size = 2 * hidden_size if bidirectional else hidden_size 77 | 78 | if freeze: 79 | freeze_params(self) 80 | 81 | # pylint: disable=invalid-name, unused-argument 82 | def _check_shapes_input_forward( 83 | self, embed_src: Tensor, src_length: Tensor, mask: Tensor 84 | ) -> None: 85 | """ 86 | Make sure the shape of the inputs to `self.forward` are correct. 87 | Same input semantics as `self.forward`. 88 | 89 | :param embed_src: embedded source tokens 90 | :param src_length: source length 91 | :param mask: source mask 92 | """ 93 | assert embed_src.shape[0] == src_length.shape[0] 94 | assert embed_src.shape[2] == self.emb_size 95 | # assert mask.shape == embed_src.shape 96 | assert len(src_length.shape) == 1 97 | 98 | # pylint: disable=arguments-differ 99 | def forward( 100 | self, embed_src: Tensor, src_length: Tensor, mask: Tensor 101 | ) -> (Tensor, Tensor): 102 | """ 103 | Applies a bidirectional RNN to sequence of embeddings x. 104 | The input mini-batch x needs to be sorted by src length. 105 | x and mask should have the same dimensions [batch, time, dim]. 106 | 107 | :param embed_src: embedded src inputs, 108 | shape (batch_size, src_len, embed_size) 109 | :param src_length: length of src inputs 110 | (counting tokens before padding), shape (batch_size) 111 | :param mask: indicates padding areas (zeros where padding), shape 112 | (batch_size, src_len, embed_size) 113 | :return: 114 | - output: hidden states with 115 | shape (batch_size, max_length, directions*hidden), 116 | - hidden_concat: last hidden state with 117 | shape (batch_size, directions*hidden) 118 | """ 119 | self._check_shapes_input_forward( 120 | embed_src=embed_src, src_length=src_length, mask=mask 121 | ) 122 | 123 | # apply dropout to the rnn input 124 | embed_src = self.emb_dropout(embed_src) 125 | 126 | packed = pack_padded_sequence(embed_src, src_length, batch_first=True) 127 | output, hidden = self.rnn(packed) 128 | 129 | # pylint: disable=unused-variable 130 | if isinstance(hidden, tuple): 131 | hidden, memory_cell = hidden 132 | 133 | output, _ = pad_packed_sequence(output, batch_first=True) 134 | # hidden: dir*layers x batch x hidden 135 | # output: batch x max_length x directions*hidden 136 | batch_size = hidden.size()[1] 137 | # separate final hidden states by layer and direction 138 | hidden_layerwise = hidden.view( 139 | self.rnn.num_layers, 140 | 2 if self.rnn.bidirectional else 1, 141 | batch_size, 142 | self.rnn.hidden_size, 143 | ) 144 | # final_layers: layers x directions x batch x hidden 145 | 146 | # concatenate the final states of the last layer for each directions 147 | # thanks to pack_padded_sequence final states don't include padding 148 | fwd_hidden_last = hidden_layerwise[-1:, 0] 149 | bwd_hidden_last = hidden_layerwise[-1:, 1] 150 | 151 | # only feed the final state of the top-most layer to the decoder 152 | # pylint: disable=no-member 153 | hidden_concat = torch.cat([fwd_hidden_last, bwd_hidden_last], dim=2).squeeze(0) 154 | # final: batch x directions*hidden 155 | return output, hidden_concat 156 | 157 | def __repr__(self): 158 | return "%s(%r)" % (self.__class__.__name__, self.rnn) 159 | 160 | 161 | 162 | class TransformerEncoder(Encoder): 163 | """ 164 | Transformer Encoder 165 | """ 166 | 167 | # pylint: disable=unused-argument 168 | def __init__( 169 | self, 170 | hidden_size: int = 512, 171 | ff_size: int = 2048, 172 | num_layers: int = 8, 173 | num_heads: int = 4, 174 | dropout: float = 0.2, 175 | emb_dropout: float = 0.2, 176 | freeze: bool = False, 177 | skip_encoder: bool = False, 178 | bayesian_attention: bool = False, 179 | bayesian_feedforward: bool = False, 180 | ibp: bool = False, 181 | inference_sample_size: int = 10, 182 | activation: str ='relu', 183 | lwta_competitors: int = 4, 184 | **kwargs 185 | ): 186 | """ 187 | Initializes the Transformer. 188 | :param hidden_size: hidden size and size of embeddings 189 | :param ff_size: position-wise feed-forward layer size. 190 | (Typically this is 2*hidden_size.) 191 | :param num_layers: number of layers 192 | :param num_heads: number of heads for multi-headed attention 193 | :param dropout: dropout probability for Transformer layers 194 | :param emb_dropout: Is applied to the input (word embeddings). 195 | :param freeze: freeze the parameters of the encoder during training 196 | :param bayesian_attention: using or not gaussian weights 197 | :param bayesian_feedforward: using or not gaussian weights 198 | :param ibp: using or not ibp method 199 | :param inference_sample_size: sample for bayesian averaging 200 | :param lwta_competitors: the size of LWTA's competion block (U parameter) 201 | :param kwargs: 202 | """ 203 | super(TransformerEncoder, self).__init__() 204 | if skip_encoder: 205 | bayesian_attention=False 206 | bayesian_feedforward=False 207 | 208 | self.bayesian = bayesian_attention or bayesian_feedforward 209 | if self.bayesian: 210 | self.inference_sample_size=inference_sample_size 211 | else : 212 | self.inference_sample_size=1 213 | 214 | self.lwta_competitors=lwta_competitors 215 | 216 | # build all (num_layers) layers 217 | self.layers = nn.ModuleList( 218 | [ 219 | TransformerEncoderLayer( 220 | size=hidden_size, 221 | ff_size=ff_size, 222 | num_heads=num_heads, 223 | dropout=dropout, 224 | 225 | bayesian_attention=bayesian_attention, 226 | bayesian_feedforward=bayesian_feedforward, 227 | ibp=ibp, 228 | activation=activation, 229 | lwta_competitors=lwta_competitors 230 | ) 231 | for _ in range(num_layers) 232 | ] 233 | ) 234 | 235 | self.skip_encoder=skip_encoder 236 | self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-6) 237 | self.pe = PositionalEncoding(hidden_size) 238 | 239 | self.emb_dropout = nn.Dropout(p=emb_dropout) 240 | 241 | self._output_size = hidden_size 242 | 243 | if freeze: 244 | freeze_params(self) 245 | 246 | 247 | # pylint: disable=arguments-differ 248 | def forward( 249 | self, embed_src: Tensor, src_length: Tensor, mask: Tensor 250 | ) -> (Tensor, Tensor): 251 | if self.training: 252 | return self.forward_(embed_src,src_length,mask) 253 | else: 254 | 255 | out=[] 256 | embed_s=embed_src.shape[-1] 257 | inference_sample_size= max(self.inference_sample_size,embed_s) 258 | for i in range(inference_sample_size): 259 | 260 | x_, _= self.forward_(embed_src[...,i%embed_s],src_length,mask) 261 | 262 | out.append(torch.unsqueeze(x_,-1)) 263 | out=torch.cat(out,-1) 264 | 265 | 266 | 267 | return out, None 268 | 269 | #hidden forward (single run) 270 | def forward_( 271 | self, embed_src: Tensor, src_length: Tensor, mask: Tensor 272 | ) -> (Tensor, Tensor): 273 | """ 274 | Pass the input (and mask) through each layer in turn. 275 | Applies a Transformer encoder to sequence of embeddings x. 276 | The input mini-batch x needs to be sorted by src length. 277 | x and mask should have the same dimensions [batch, time, dim]. 278 | 279 | :param embed_src: embedded src inputs, 280 | shape (batch_size, src_len, embed_size) 281 | :param src_length: length of src inputs 282 | (counting tokens before padding), shape (batch_size) 283 | :param mask: indicates padding areas (zeros where padding), shape 284 | (batch_size, src_len, embed_size) 285 | :return: 286 | - output: hidden states with 287 | shape (batch_size, max_length, directions*hidden), 288 | - hidden_concat: last hidden state with 289 | shape (batch_size, directions*hidden) 290 | """ 291 | 292 | x = embed_src 293 | x = self.pe(x) # add position encoding to word embeddings 294 | x = self.emb_dropout(x) 295 | if not self.skip_encoder: 296 | for layer in self.layers: 297 | x = layer(x,mask ) 298 | 299 | 300 | x=self.layer_norm(x) 301 | return x, None 302 | 303 | def __repr__(self): 304 | return "%s(num_layers=%r, num_heads=%r)" % ( 305 | self.__class__.__name__, 306 | len(self.layers), 307 | self.layers[0].src_src_att.num_heads, 308 | ) 309 | -------------------------------------------------------------------------------- /signjoey/builders.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | Collection of builder functions 4 | """ 5 | from typing import Callable, Optional, Generator 6 | 7 | import torch 8 | from torch import nn 9 | 10 | # Learning Rate Scheduler 11 | from torch.optim import lr_scheduler 12 | 13 | # Optimization Algorithms 14 | from torch.optim import Optimizer 15 | 16 | 17 | def build_gradient_clipper(config: dict) -> Optional[Callable]: 18 | """ 19 | Define the function for gradient clipping as specified in configuration. 20 | If not specified, returns None. 21 | 22 | Current options: 23 | - "clip_grad_val": clip the gradients if they exceed this value, 24 | see `torch.nn.utils.clip_grad_value_` 25 | - "clip_grad_norm": clip the gradients if their norm exceeds this value, 26 | see `torch.nn.utils.clip_grad_norm_` 27 | 28 | :param config: dictionary with training configurations 29 | :return: clipping function (in-place) or None if no gradient clipping 30 | """ 31 | clip_grad_fun = None 32 | if "clip_grad_val" in config.keys(): 33 | clip_value = config["clip_grad_val"] 34 | clip_grad_fun = lambda params: nn.utils.clip_grad_value_( 35 | parameters=params, clip_value=clip_value 36 | ) 37 | elif "clip_grad_norm" in config.keys(): 38 | max_norm = config["clip_grad_norm"] 39 | clip_grad_fun = lambda params: nn.utils.clip_grad_norm_( 40 | parameters=params, max_norm=max_norm 41 | ) 42 | 43 | if "clip_grad_val" in config.keys() and "clip_grad_norm" in config.keys(): 44 | raise ValueError("You can only specify either clip_grad_val or clip_grad_norm.") 45 | 46 | return clip_grad_fun 47 | 48 | 49 | def build_optimizer(config: dict, parameters) -> Optimizer: 50 | """ 51 | Create an optimizer for the given parameters as specified in config. 52 | 53 | Except for the weight decay and initial learning rate, 54 | default optimizer settings are used. 55 | 56 | Currently supported configuration settings for "optimizer": 57 | - "sgd" (default): see `torch.optim.SGD` 58 | - "adam": see `torch.optim.adam` 59 | - "adagrad": see `torch.optim.adagrad` 60 | - "adadelta": see `torch.optim.adadelta` 61 | - "rmsprop": see `torch.optim.RMSprop` 62 | 63 | The initial learning rate is set according to "learning_rate" in the config. 64 | The weight decay is set according to "weight_decay" in the config. 65 | If they are not specified, the initial learning rate is set to 3.0e-4, the 66 | weight decay to 0. 67 | 68 | Note that the scheduler state is saved in the checkpoint, so if you load 69 | a model for further training you have to use the same type of scheduler. 70 | 71 | :param config: configuration dictionary 72 | :param parameters: 73 | :return: optimizer 74 | """ 75 | optimizer_name = config.get("optimizer", "radam").lower() 76 | learning_rate = config.get("learning_rate", 3.0e-4) 77 | weight_decay = config.get("weight_decay", 0) 78 | eps = config.get("eps", 1.0e-8) 79 | 80 | # Adam based optimizers 81 | betas = config.get("betas", (0.9, 0.999)) 82 | amsgrad = config.get("amsgrad", False) 83 | 84 | if optimizer_name == "adam": 85 | return torch.optim.Adam( 86 | params=parameters, 87 | lr=learning_rate, 88 | betas=betas, 89 | eps=eps, 90 | weight_decay=weight_decay, 91 | amsgrad=amsgrad, 92 | ) 93 | elif optimizer_name == "adamw": 94 | return torch.optim.Adam( 95 | params=parameters, 96 | lr=learning_rate, 97 | betas=betas, 98 | eps=eps, 99 | weight_decay=weight_decay, 100 | amsgrad=amsgrad, 101 | ) 102 | elif optimizer_name == "adagrad": 103 | return torch.optim.Adagrad( 104 | params=parameters, 105 | lr=learning_rate, 106 | lr_decay=config.get("lr_decay", 0), 107 | weight_decay=weight_decay, 108 | eps=eps, 109 | ) 110 | elif optimizer_name == "adadelta": 111 | return torch.optim.Adadelta( 112 | params=parameters, 113 | rho=config.get("rho", 0.9), 114 | eps=eps, 115 | lr=learning_rate, 116 | weight_decay=weight_decay, 117 | ) 118 | elif optimizer_name == "rmsprop": 119 | return torch.optim.RMSprop( 120 | params=parameters, 121 | lr=learning_rate, 122 | momentum=config.get("momentum", 0), 123 | alpha=config.get("alpha", 0.99), 124 | eps=eps, 125 | weight_decay=weight_decay, 126 | ) 127 | elif optimizer_name == "sgd": 128 | return torch.optim.SGD( 129 | params=parameters, 130 | lr=learning_rate, 131 | momentum=config.get("momentum", 0), 132 | weight_decay=weight_decay, 133 | ) 134 | else: 135 | raise ValueError("Unknown optimizer {}.".format(optimizer_name)) 136 | 137 | 138 | def build_scheduler( 139 | config: dict, optimizer: Optimizer, scheduler_mode: str, hidden_size: int = 0 140 | ) -> (Optional[lr_scheduler._LRScheduler], Optional[str]): 141 | """ 142 | Create a learning rate scheduler if specified in config and 143 | determine when a scheduler step should be executed. 144 | 145 | Current options: 146 | - "plateau": see `torch.optim.lr_scheduler.ReduceLROnPlateau` 147 | - "decaying": see `torch.optim.lr_scheduler.StepLR` 148 | - "exponential": see `torch.optim.lr_scheduler.ExponentialLR` 149 | - "noam": see `joeynmt.transformer.NoamScheduler` 150 | 151 | If no scheduler is specified, returns (None, None) which will result in 152 | a constant learning rate. 153 | 154 | :param config: training configuration 155 | :param optimizer: optimizer for the scheduler, determines the set of 156 | parameters which the scheduler sets the learning rate for 157 | :param scheduler_mode: "min" or "max", depending on whether the validation 158 | score should be minimized or maximized. 159 | Only relevant for "plateau". 160 | :param hidden_size: encoder hidden size (required for NoamScheduler) 161 | :return: 162 | - scheduler: scheduler object, 163 | - scheduler_step_at: either "validation" or "epoch" 164 | """ 165 | scheduler_name = config["scheduling"].lower() 166 | 167 | if scheduler_name == "plateau": 168 | # learning rate scheduler 169 | return ( 170 | lr_scheduler.ReduceLROnPlateau( 171 | optimizer=optimizer, 172 | mode=scheduler_mode, 173 | verbose=False, 174 | threshold_mode="abs", 175 | factor=config.get("decrease_factor", 0.1), 176 | patience=config.get("patience", 10), 177 | ), 178 | "validation", 179 | ) 180 | elif scheduler_name == "cosineannealing": 181 | return ( 182 | lr_scheduler.CosineAnnealingLR( 183 | optimizer=optimizer, 184 | eta_min=config.get("eta_min", 0), 185 | T_max=config.get("t_max", 20), 186 | ), 187 | "epoch", 188 | ) 189 | elif scheduler_name == "cosineannealingwarmrestarts": 190 | return ( 191 | lr_scheduler.CosineAnnealingWarmRestarts( 192 | optimizer=optimizer, 193 | T_0=config.get("t_init", 10), 194 | T_mult=config.get("t_mult", 2), 195 | ), 196 | "step", 197 | ) 198 | elif scheduler_name == "decaying": 199 | return ( 200 | lr_scheduler.StepLR( 201 | optimizer=optimizer, step_size=config.get("decaying_step_size", 1) 202 | ), 203 | "epoch", 204 | ) 205 | elif scheduler_name == "exponential": 206 | return ( 207 | lr_scheduler.ExponentialLR( 208 | optimizer=optimizer, gamma=config.get("decrease_factor", 0.99) 209 | ), 210 | "epoch", 211 | ) 212 | elif scheduler_name == "noam": 213 | factor = config.get("learning_rate_factor", 1) 214 | warmup = config.get("learning_rate_warmup", 4000) 215 | return ( 216 | NoamScheduler( 217 | hidden_size=hidden_size, 218 | factor=factor, 219 | warmup=warmup, 220 | optimizer=optimizer, 221 | ), 222 | "step", 223 | ) 224 | elif scheduler_name == "warmupexponentialdecay": 225 | min_rate = config.get("learning_rate_min", 1.0e-5) 226 | decay_rate = config.get("learning_rate_decay", 0.1) 227 | warmup = config.get("learning_rate_warmup", 4000) 228 | peak_rate = config.get("learning_rate_peak", 1.0e-3) 229 | decay_length = config.get("learning_rate_decay_length", 10000) 230 | return ( 231 | WarmupExponentialDecayScheduler( 232 | min_rate=min_rate, 233 | decay_rate=decay_rate, 234 | warmup=warmup, 235 | optimizer=optimizer, 236 | peak_rate=peak_rate, 237 | decay_length=decay_length, 238 | ), 239 | "step", 240 | ) 241 | else: 242 | raise ValueError("Unknown learning scheduler {}.".format(scheduler_name)) 243 | 244 | 245 | class NoamScheduler: 246 | """ 247 | The Noam learning rate scheduler used in "Attention is all you need" 248 | See Eq. 3 in https://arxiv.org/pdf/1706.03762.pdf 249 | """ 250 | 251 | def __init__( 252 | self, 253 | hidden_size: int, 254 | optimizer: torch.optim.Optimizer, 255 | factor: float = 1, 256 | warmup: int = 4000, 257 | ): 258 | """ 259 | Warm-up, followed by learning rate decay. 260 | :param hidden_size: 261 | :param optimizer: 262 | :param factor: decay factor 263 | :param warmup: number of warmup steps 264 | """ 265 | self.optimizer = optimizer 266 | self._step = 0 267 | self.warmup = warmup 268 | self.factor = factor 269 | self.hidden_size = hidden_size 270 | self._rate = 0 271 | 272 | def step(self): 273 | """Update parameters and rate""" 274 | self._step += 1 275 | rate = self._compute_rate() 276 | for p in self.optimizer.param_groups: 277 | p["lr"] = rate 278 | self._rate = rate 279 | 280 | def _compute_rate(self): 281 | """Implement `lrate` above""" 282 | step = self._step 283 | return self.factor * ( 284 | self.hidden_size ** (-0.5) 285 | * min(step ** (-0.5), step * self.warmup ** (-1.5)) 286 | ) 287 | 288 | # pylint: disable=no-self-use 289 | def state_dict(self): 290 | return None 291 | 292 | 293 | class WarmupExponentialDecayScheduler: 294 | """ 295 | A learning rate scheduler similar to Noam, but modified: 296 | Keep the warm up period but make it so that the decay rate can be tuneable. 297 | The decay is exponential up to a given minimum rate. 298 | """ 299 | 300 | def __init__( 301 | self, 302 | optimizer: torch.optim.Optimizer, 303 | peak_rate: float = 1.0e-3, 304 | decay_length: int = 10000, 305 | warmup: int = 4000, 306 | decay_rate: float = 0.5, 307 | min_rate: float = 1.0e-5, 308 | ): 309 | """ 310 | Warm-up, followed by exponential learning rate decay. 311 | :param peak_rate: maximum learning rate at peak after warmup 312 | :param optimizer: 313 | :param decay_length: decay length after warmup 314 | :param decay_rate: decay rate after warmup 315 | :param warmup: number of warmup steps 316 | :param min_rate: minimum learning rate 317 | """ 318 | self.optimizer = optimizer 319 | self._step = 0 320 | self.warmup = warmup 321 | self.decay_length = decay_length 322 | self.peak_rate = peak_rate 323 | self._rate = 0 324 | self.decay_rate = decay_rate 325 | self.min_rate = min_rate 326 | 327 | def step(self): 328 | """Update parameters and rate""" 329 | self._step += 1 330 | rate = self._compute_rate() 331 | for p in self.optimizer.param_groups: 332 | p["lr"] = rate 333 | self._rate = rate 334 | 335 | def _compute_rate(self): 336 | """Implement `lrate` above""" 337 | step = self._step 338 | warmup = self.warmup 339 | 340 | if step < warmup: 341 | rate = step * self.peak_rate / warmup 342 | else: 343 | exponent = (step - warmup) / self.decay_length 344 | rate = self.peak_rate * (self.decay_rate ** exponent) 345 | return max(rate, self.min_rate) 346 | 347 | # pylint: disable=no-self-use 348 | def state_dict(self): 349 | return None 350 | -------------------------------------------------------------------------------- /signjoey/transformer_layers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from torch import Tensor 9 | from signjoey.layers import DenseBayesian 10 | import pandas as pd 11 | import numpy as np 12 | from signjoey.local_attention import LocalAttention 13 | 14 | 15 | 16 | class MultiHeadedAttention(nn.Module): 17 | """ 18 | Multi-Head Attention module from "Attention is All You Need" 19 | 20 | Implementation modified from OpenNMT-py. 21 | https://github.com/OpenNMT/OpenNMT-py 22 | """ 23 | kls=0 24 | def __init__(self, num_heads: int, size: int, dropout: float = 0.1,bayesian=False,ibp=False,sizek=None,scale_out=1.0): 25 | """ 26 | Create a multi-headed attention layer. 27 | :param num_heads: the number of heads 28 | :param size: model size (must be divisible by num_heads) 29 | :param dropout: probability of dropping a unit 30 | """ 31 | super(MultiHeadedAttention, self).__init__() 32 | linear=nn.Linear 33 | if sizek==None: 34 | sizek=size 35 | 36 | assert size % num_heads == 0 37 | self.ran=False 38 | self.head_size = head_size = size // num_heads 39 | self.model_size = size 40 | self.num_heads = num_heads 41 | print(size) 42 | self.k_layer = linear(sizek, num_heads * head_size) 43 | 44 | self.v_layer = linear(sizek, num_heads * head_size) 45 | self.q_layer = linear(size, num_heads * head_size) 46 | 47 | self.output_layer=nn.Linear(size,size) 48 | if bayesian: 49 | self.k_layer = DenseBayesian(input_features=size, output_features=num_heads * head_size, 50 | competitors = 1, activation = 'linear',prior_mean=0, prior_scale=1. , ibp = ibp,name='atte_k') 51 | 52 | self.v_layer = DenseBayesian(input_features=size, output_features=num_heads * head_size, 53 | competitors = 1, activation = 'linear',prior_mean=0, prior_scale=1. , ibp = ibp,name='atte_v') 54 | 55 | self.q_layer = DenseBayesian(input_features=size, output_features=num_heads * head_size, 56 | competitors = 1, activation = 'linear',prior_mean=0, prior_scale=1. , ibp = ibp,name='atte_q') 57 | 58 | self.output_layer = DenseBayesian(input_features=size, output_features=size, competitors = 1,activation = 'linear', 59 | prior_mean=0, prior_scale=1. , ibp = ibp,name='atte_o',scale_out=scale_out) 60 | 61 | self.softmax = nn.Softmax(dim=-1) 62 | self.dropout = nn.Dropout(dropout) 63 | self.printcounter=0 64 | self.pe=PositionalEncoding(size) 65 | def forward(self, k: Tensor, v: Tensor, q: Tensor, mask: Tensor = None): 66 | """ 67 | Computes multi-headed attention. 68 | 69 | :param k: keys [B, M, D] with M being the sentence length. 70 | :param v: values [B, M, D] 71 | :param q: query [B, M, D] 72 | :param mask: optional mask [B, 1, M] 73 | :return: 74 | """ 75 | batch_size = k.size(0) 76 | num_heads = self.num_heads 77 | 78 | 79 | k = self.k_layer(k) 80 | v = self.v_layer(v) 81 | q = self.q_layer(q) 82 | 83 | 84 | # reshape q, k, v for our computation to [batch_size, num_heads, ..] 85 | k = k.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) 86 | v = v.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) 87 | q = q.view(batch_size, -1, num_heads, self.head_size).transpose(1, 2) 88 | 89 | # compute scores 90 | q = q / math.sqrt(self.head_size) 91 | 92 | scores = torch.matmul(q, k.transpose(2, 3)) 93 | 94 | 95 | # apply the mask (if we have one) 96 | # we add a dimension for the heads to it below: [B, 1, 1, M] 97 | if mask is not None: 98 | scores = scores.masked_fill(~mask.unsqueeze(1), float("-inf")) 99 | 100 | # apply attention dropout and compute context vectors. 101 | attention = self.softmax(scores) 102 | 103 | # MultiHeadedAttention.kls+= torch.mean(torch.mean(-torch.log(attention))) 104 | attention = self.dropout(attention) 105 | 106 | # get context vector (select values with attention) and reshape 107 | # back to [B, M, D] 108 | context = torch.matmul(attention, v) 109 | context = ( 110 | context.transpose(1, 2) 111 | .contiguous() 112 | .view(batch_size, -1, num_heads * self.head_size) 113 | ) 114 | # output=context 115 | output = self.output_layer(context) 116 | 117 | return output 118 | 119 | 120 | # pylint: disable=arguments-differ 121 | class PositionwiseFeedForward(nn.Module): 122 | """ 123 | Position-wise Feed-forward layer 124 | Projects to ff_size and then back down to input_size. 125 | """ 126 | 127 | def __init__(self, input_size, ff_size, dropout=0.1,bayesian=False,ibp=False,activation='relu',lwta_competitors=4,scale_out=0.2): 128 | """ 129 | Initializes position-wise feed-forward layer. 130 | :param input_size: dimensionality of the input. 131 | :param ff_size: dimensionality of intermediate representation 132 | :param dropout: 133 | """ 134 | super(PositionwiseFeedForward, self).__init__() 135 | linear=nn.Linear 136 | self.ibp=ibp 137 | self.layer_norm = nn.LayerNorm(input_size, eps=1e-6) 138 | 139 | if bayesian: 140 | 141 | self.pwff_layer = nn.Sequential( 142 | DenseBayesian(input_size , ff_size, 143 | competitors = lwta_competitors, activation = activation, prior_mean=0, prior_scale=1. , kl_w=0.01,ibp = ibp), 144 | nn.Dropout(dropout), 145 | 146 | DenseBayesian(input_features=ff_size, output_features=input_size, 147 | competitors = 1, activation = 'linear',prior_mean=0, prior_scale=1. , ibp = ibp,out_w=False,scale_out=scale_out), 148 | 149 | 150 | nn.Dropout(dropout), 151 | ) 152 | else: 153 | self.pwff_layer = nn.Sequential( 154 | nn.Linear(input_size, ff_size), 155 | nn.ReLU(), 156 | nn.Dropout(dropout), 157 | nn.Linear(ff_size, input_size), 158 | nn.Dropout(dropout), 159 | ) 160 | 161 | def forward(self, x): 162 | x_norm = self.layer_norm(x) 163 | return (self.pwff_layer(x_norm)) + x 164 | 165 | 166 | # pylint: disable=arguments-differ 167 | class PositionalEncoding(nn.Module): 168 | """ 169 | Pre-compute position encodings (PE). 170 | In forward pass, this adds the position-encodings to the 171 | input for as many time steps as necessary. 172 | 173 | Implementation based on OpenNMT-py. 174 | https://github.com/OpenNMT/OpenNMT-py 175 | """ 176 | 177 | def __init__(self, size: int = 0, max_len: int = 5000): 178 | """ 179 | Positional Encoding with maximum length max_len 180 | :param size: 181 | :param max_len: 182 | :param dropout: 183 | """ 184 | if size % 2 != 0: 185 | raise ValueError( 186 | "Cannot use sin/cos positional encoding with " 187 | "odd dim (got dim={:d})".format(size) 188 | ) 189 | pe = torch.zeros(max_len, size) 190 | position = torch.arange(0, max_len).unsqueeze(1) 191 | div_term = torch.exp( 192 | (torch.arange(0, size, 2, dtype=torch.float) * -(math.log(10000.0) / size)) 193 | ) 194 | pe[:, 0::2] = torch.sin(position.float() * div_term) 195 | pe[:, 1::2] = torch.cos(position.float() * div_term) 196 | pe = pe.unsqueeze(0) # shape: [1, size, max_len] 197 | super(PositionalEncoding, self).__init__() 198 | self.register_buffer("pe", pe) 199 | self.dim = size 200 | 201 | def forward(self, emb): 202 | """Embed inputs. 203 | Args: 204 | emb (FloatTensor): Sequence of word vectors 205 | ``(seq_len, batch_size, self.dim)`` 206 | """ 207 | # Add position encodings 208 | return emb + self.pe[:, : emb.size(1)] 209 | 210 | 211 | class TransformerEncoderLayer(nn.Module): 212 | """ 213 | One Transformer encoder layer has a Multi-head attention layer plus 214 | a position-wise feed-forward layer. 215 | """ 216 | 217 | def __init__( 218 | self, size: int = 0, ff_size: int = 0, num_heads: int = 0, dropout: float = 0.1, 219 | bayesian_attention=False,bayesian_feedforward=False,ibp=False,activation='relu',lwta_competitors=4 220 | ): 221 | """ 222 | A single Transformer layer. 223 | :param size: 224 | :param ff_size: 225 | :param num_heads: 226 | :param dropout: 227 | """ 228 | super(TransformerEncoderLayer, self).__init__() 229 | 230 | self.layer_norm = nn.LayerNorm(size, eps=1e-6) 231 | self.src_src_att = MultiHeadedAttention(num_heads, size, dropout=dropout, 232 | bayesian=bayesian_attention,ibp=ibp,scale_out=(0.125)) 233 | self.src_src_att.ran=True 234 | self.feed_forward = PositionwiseFeedForward( input_size=size, ff_size=ff_size, dropout=dropout, 235 | bayesian=bayesian_feedforward,ibp=ibp,activation=activation,lwta_competitors=lwta_competitors,scale_out=(0.2) 236 | ) 237 | self.dropout = nn.Dropout(dropout) 238 | self.size = size 239 | 240 | # pylint: disable=arguments-differ 241 | def forward(self, x: Tensor, mask: Tensor) -> Tensor: 242 | """ 243 | Forward pass for a single transformer encoder layer. 244 | First applies layer norm, then self attention, 245 | then dropout with residual connection (adding the input to the result), 246 | and then a position-wise feed-forward layer. 247 | 248 | :param x: layer input 249 | :param mask: input mask 250 | :return: output tensor 251 | """ 252 | x_norm = self.layer_norm(x) 253 | 254 | h = self.src_src_att(x_norm, x_norm, x_norm, mask) 255 | 256 | h = self.dropout(h) + x 257 | o = self.feed_forward(h) 258 | return o 259 | 260 | 261 | 262 | class TransformerDecoderLayer(nn.Module): 263 | """ 264 | Transformer decoder layer. 265 | 266 | Consists of self-attention, source-attention, and feed-forward. 267 | """ 268 | 269 | def __init__( 270 | self, size: int = 0, ff_size: int = 0, num_heads: int = 0, dropout: float = 0.1, 271 | bayesian_attention=False,bayesian_feedforward=False,ibp=False,activation='relu',lwta_competitors=4 272 | ): 273 | """ 274 | Represents a single Transformer decoder layer. 275 | 276 | It attends to the source representation and the previous decoder states. 277 | 278 | :param size: model dimensionality 279 | :param ff_size: size of the feed-forward intermediate layer 280 | :param num_heads: number of heads 281 | :param dropout: dropout to apply to input 282 | """ 283 | super(TransformerDecoderLayer, self).__init__() 284 | self.size = size 285 | 286 | self.trg_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout, 287 | bayesian=bayesian_attention,ibp=ibp) 288 | 289 | 290 | self.src_trg_att = MultiHeadedAttention(num_heads, size, dropout=dropout, 291 | bayesian=bayesian_attention,ibp=ibp,sizek=size,scale_out=1.0) 292 | self.feed_forward = PositionwiseFeedForward( input_size=size, ff_size=ff_size, dropout=dropout, 293 | bayesian=bayesian_feedforward,ibp=ibp,activation=activation,lwta_competitors=lwta_competitors 294 | ) 295 | 296 | 297 | self.x_layer_norm = nn.LayerNorm(size, eps=1e-6) 298 | self.dec_layer_norm = nn.LayerNorm(size, eps=1e-6) 299 | 300 | 301 | self.dropout = nn.Dropout(dropout) 302 | 303 | # pylint: disable=arguments-differ 304 | def forward( 305 | self, 306 | x: Tensor = None, 307 | memory: Tensor = None, 308 | src_mask: Tensor = None, 309 | trg_mask: Tensor = None, 310 | ) -> Tensor: 311 | """ 312 | Forward pass of a single Transformer decoder layer. 313 | 314 | :param x: inputs 315 | :param memory: source representations 316 | :param src_mask: source mask 317 | :param trg_mask: target mask (so as to not condition on future steps) 318 | :return: output tensor 319 | """ 320 | # decoder/target self-attention 321 | x_norm = self.x_layer_norm(x) 322 | 323 | h1 = self.trg_trg_att(x_norm, x_norm, x_norm, mask=trg_mask) 324 | h1 = (self.dropout(h1)) + x 325 | 326 | # source-target attention 327 | h1_norm = self.dec_layer_norm(h1) 328 | 329 | 330 | 331 | h2 = self.src_trg_att(memory, memory, h1_norm, mask=src_mask) 332 | 333 | 334 | # final position-wise feed-forward layer 335 | o = self.feed_forward(self.dropout(h2) + h1) 336 | 337 | return o 338 | -------------------------------------------------------------------------------- /data/gls.vocab: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | REGEN 5 | REGION 6 | IX 7 | KOMMEN 8 | MORGEN 9 | NORD 10 | SONNE 11 | WOLKE 12 | GRAD 13 | NACHT 14 | SUED 15 | KOENNEN 16 | SCHNEE 17 | AUCH 18 | BISSCHEN 19 | MEHR 20 | HEUTE 21 | BIS 22 | GEWITTER 23 | WETTER 24 | WIND 25 | WEHEN 26 | ZWANZIG 27 | OST 28 | DANN 29 | WEST 30 | SCHAUER 31 | MOEGLICH 32 | ABEND 33 | FREUNDLICH 34 | NEBEL 35 | BERG 36 | NORDWEST 37 | JETZT 38 | ABER 39 | STURM 40 | TAG 41 | TEIL 42 | WIE-AUSSEHEN 43 | TIEF 44 | FLUSS 45 | MINUS 46 | SUEDOST 47 | MITTE 48 | WECHSELHAFT 49 | KLAR 50 | SCHWACH 51 | SONNTAG 52 | FREITAG 53 | KUEHL 54 | VIEL 55 | TROCKEN 56 | SAMSTAG 57 | HOCH 58 | STARK 59 | MAESSIG 60 | BESONDERS 61 | DONNERSTAG 62 | SONST 63 | MEISTENS 64 | IN-KOMMEND 65 | BLEIBEN 66 | UND 67 | KALT 68 | MITTWOCH 69 | MONTAG 70 | LAND 71 | TEMPERATUR 72 | NORDOST 73 | WARM 74 | NUR 75 | SUEDWEST 76 | ALPEN 77 | FUENF 78 | DEUTSCH 79 | GUT 80 | DIENSTAG 81 | MITTAG 82 | EINS 83 | KUESTE 84 | SCHOEN 85 | FROST 86 | ZWEI 87 | DREI 88 | MILD 89 | HIMMEL 90 | SEHEN 91 | IM-VERLAUF 92 | VIER 93 | HAUPTSAECHLICH 94 | LUFT 95 | SIEBEN 96 | DABEI 97 | FRISCH 98 | NOCH 99 | SCHNEIEN 100 | SCHON 101 | VERSCHWINDEN 102 | ZEHN 103 | ACHT 104 | NACH 105 | SECHS 106 | ANFANG 107 | MAXIMAL 108 | GLATT 109 | ENORM 110 | ZWISCHEN 111 | NULL 112 | DREISSIG 113 | PLUS 114 | VIERZEHN 115 | NEUN 116 | UEBERWIEGEND 117 | MAL 118 | STERN 119 | LIEB 120 | FUENFZEHN 121 | WENN 122 | ZWOELF 123 | DREIZEHN 124 | TEILWEISE 125 | ZUSCHAUER 126 | WIEDER 127 | BEWOELKT 128 | MEER 129 | WAHRSCHEINLICH 130 | BAYERN 131 | DRUCK 132 | SIEBZEHN 133 | WENIG 134 | LANG 135 | ODER 136 | DESHALB 137 | ELF 138 | SECHSZEHN 139 | EUROPA 140 | ZEIGEN-BILDSCHIRM 141 | WEITER 142 | poss-EUCH 143 | ACHTZEHN 144 | ORT 145 | SPAETER 146 | WOCHENENDE 147 | DAZU 148 | SEE 149 | HABEN 150 | RUHIG 151 | FRUEH 152 | NEU 153 | AB 154 | SO 155 | DURCHGEHEND 156 | LEICHT 157 | WALD 158 | WIE 159 | neg-HABEN 160 | NAECHSTE 161 | UNGEFAEHR 162 | BESSER 163 | DEUTSCHLAND 164 | MACHEN 165 | VORSICHT 166 | ORKAN 167 | NEUNZEHN 168 | STEIGEN 169 | UNWETTER 170 | ZONE 171 | HEISS 172 | WARNUNG 173 | MIT 174 | SINKEN 175 | VOR 176 | WIE-IMMER 177 | AUFLOESEN 178 | SPEZIELL 179 | UEBER 180 | BODEN 181 | NAH 182 | HIER 183 | GEFRIEREN 184 | GLEICH 185 | MANCHMAL 186 | WOCHE 187 | OKTOBER 188 | WUENSCHEN 189 | poss-SEIN 190 | NOVEMBER 191 | TATSAECHLICH 192 | DARUM 193 | SOMMER 194 | VORAUS 195 | DIENST 196 | DOCH 197 | GEFAHR 198 | SKANDINAVIEN 199 | JULI 200 | WINTER 201 | FEBRUAR 202 | TSCHUESS 203 | VERAENDERN 204 | ES-BEDEUTET 205 | SCHEINEN 206 | UMWANDELN 207 | BEGRUESSEN 208 | DEZEMBER 209 | HABEN2 210 | INFORMIEREN 211 | JANUAR 212 | MAI 213 | METER 214 | REST 215 | SEPTEMBER 216 | UNTERSCHIED 217 | BEDEUTEN 218 | FEUCHT 219 | KAUM 220 | LANGSAM 221 | APRIL 222 | AUGUST 223 | SCHWER 224 | HERBST 225 | OFT 226 | VON 227 | GRUND 228 | TRUEB 229 | GERADE 230 | ICH 231 | WENIGER 232 | ZUERST 233 | VIELLEICHT 234 | DAS-IST-ES 235 | ALLGAEU 236 | EIS 237 | JUNI 238 | NACHMITTAG 239 | SACHSEN 240 | TAGSUEBER 241 | UEBERALL 242 | EINFLUSS 243 | FRANKREICH 244 | IM-MOMENT 245 | KRAEFTIG 246 | SCHWARZ 247 | DU 248 | negalp-AUCH 249 | BLITZ 250 | BRAND 251 | GRENZE 252 | HEFTIG 253 | SAUER 254 | UNTER 255 | BEISPIEL 256 | DIESE 257 | MAERZ 258 | NASS 259 | SCHOTTLAND 260 | STUNDE 261 | DURCH 262 | L 263 | MISCHUNG 264 | SIEBTE 265 | ZEIT 266 | ANDERE 267 | AUFZIEHEN 268 | ENGLAND 269 | ERSTE 270 | HOEHE 271 | NICHT 272 | RUEGEN 273 | TAL 274 | THUERINGEN 275 | UNTEN 276 | ANGENEHM 277 | BURG 278 | EINIGE 279 | EINIGERMASSEN 280 | HERZ 281 | NOCHEINMAL 282 | PASSEN 283 | SAGEN 284 | AEHNLICH 285 | DONNER 286 | GANZTAGS 287 | HAGEL 288 | POLEN 289 | WERT 290 | B 291 | GRAUPEL 292 | MITTEILEN 293 | RECHNEN 294 | VORAUSSAGE 295 | WAS 296 | negalp-KEIN 297 | ACH 298 | FRUEHLING 299 | INSGESAMT 300 | SELTEN 301 | UEBERSCHWEMMUNG 302 | VERSCHIEDEN 303 | VOGEL 304 | ZWEITE 305 | AUFLOCKERUNG 306 | DASSELBE 307 | NICHTS 308 | SECHSTE 309 | SUPER 310 | VIERTE 311 | WARUM 312 | WO 313 | ACHTE 314 | BERLIN 315 | DAZWISCHEN 316 | E 317 | KEIN 318 | KURZ 319 | LOCKER 320 | LOS 321 | SCHNELL 322 | TROPFEN 323 | VORBEI 324 | ACHTUNG 325 | BADEN 326 | BISHER 327 | ES-GIBT 328 | FAST 329 | FUENFTE 330 | LETZTE 331 | LOCH 332 | MUESSEN 333 | SCHWUEL 334 | TAUEN 335 | TROTZDEM 336 | ZEHNTE 337 | ALS 338 | DAENEMARK 339 | DAUER 340 | DRITTE 341 | FUER 342 | KOELN 343 | NIEDER 344 | NOCH-NICHT 345 | RICHTUNG 346 | RUSSLAND 347 | STRASSE 348 | STROEMEN 349 | UND-DANN 350 | VERRINGERN 351 | VORHER 352 | neg-VIEL 353 | poss-BEI-UNS 354 | ALLE 355 | DA 356 | DANACH 357 | DICK 358 | EIN-PAAR-TAGE 359 | ERSTMAL 360 | FLACH 361 | HEUTE-NACHT 362 | HUNDERT 363 | IHR 364 | ITALIEN 365 | JA 366 | LITER 367 | MERKEN 368 | MORGENS 369 | NEUNTE 370 | NICHT-MEHR 371 | NORMAL 372 | OBEN 373 | SCHAFFEN 374 | SCHAUEN 375 | SCHLESWIG 376 | SECHSHUNDERT 377 | UMKEHREN 378 | WECHSEL 379 | AM-TAG 380 | BIS-JETZT 381 | EIFEL 382 | FUENFHUNDERT 383 | HOLSTEIN 384 | JAHR 385 | NEIN 386 | SOLL 387 | STELLENWEISE 388 | UM 389 | VERBREITEN 390 | VORTEIL 391 | WIRBEL 392 | WOHER 393 | AM 394 | AUF 395 | BRAUCHEN 396 | EIN-PAAR 397 | GEWESEN 398 | GRAU 399 | GRIECHENLAND 400 | LEIDER 401 | TYPISCH 402 | VORMITTAG 403 | WASSER 404 | WIRKLICH 405 | neg-REGEN 406 | BEREICH 407 | BEWEGEN 408 | BLAU 409 | F 410 | FUEHLEN 411 | GESTERN 412 | GROSS 413 | GROSSBRITANNIEN 414 | NAEHERN 415 | OBER 416 | RISIKO 417 | SCHLUSS 418 | SELBE 419 | SIE 420 | SPANIEN 421 | SYLT 422 | VERKEHR 423 | ZENTIMETER 424 | ZURUECK 425 | neg-KAUM 426 | ANKOMMEN 427 | AUF-JEDEN-FALL 428 | AUFTAUCHEN 429 | E+R+Z 430 | FLAECHENDECKEND 431 | HAELFTE 432 | HOCHWASSER 433 | INSEL 434 | ISLAND 435 | J+U+L+I 436 | NIEDERSACHSEN 437 | NORWEGEN 438 | POSITION 439 | PUNKT 440 | SCHLIMM 441 | SCHWEDEN 442 | SPAET 443 | STABIL 444 | STEIN 445 | UEBERMORGEN 446 | UHR 447 | VIERHUNDERT 448 | VIERZIG 449 | WEIHNACHTEN 450 | neg-NICHTS 451 | ABWECHSELN 452 | AUSSEHEN 453 | C+M 454 | CHANCE 455 | DICHT 456 | DREIHUNDERT 457 | ENTSCHULDIGUNG 458 | ETWAS 459 | FELD 460 | FOLGE 461 | GEBEN 462 | HARMLOS 463 | IMMER 464 | OSTERN 465 | PASSIEREN 466 | QUADRATMETER 467 | RHEIN 468 | ROT 469 | S 470 | SCHWIERIG 471 | SIEBENHUNDERT 472 | TAUSEND 473 | TUERKEI 474 | W 475 | ZWISCHEN-NULL 476 | ZWOELFTE 477 | poss-MEIN 478 | A 479 | ACHTHUNDERT 480 | ANDERS 481 | AUFPASSEN 482 | AUTO 483 | BEKANNTGEBEN 484 | BELGIEN 485 | BESTIMMT 486 | BLOCKIEREN 487 | BOEE 488 | BRANDENBURG 489 | BRINGEN 490 | DREHEN 491 | ERST 492 | FAHREN 493 | FALLEN 494 | GEMUETLICH 495 | GENAU 496 | GLEICH-BLEIBEN 497 | GLUECK 498 | H 499 | HART 500 | HAUPT 501 | IN 502 | KILOMETER 503 | KLEIN 504 | KNAPP 505 | NACH-HAUSE 506 | NEUNHUNDERT 507 | NIESELREGEN 508 | NORDSEE 509 | OBWOHL 510 | PRO 511 | QUELL 512 | RUND-UM-DIE-UHR 513 | S+H 514 | V 515 | WAR 516 | WIR 517 | WIRBELSTURM 518 | ZENTRUM 519 | ZWEI-TAG 520 | neg-FROST 521 | neg-KLAR 522 | neg-TROCKEN 523 | AACHEN 524 | ACHTZIG 525 | ALSO 526 | AN 527 | ANGEMESSEN 528 | AUSRICHTEN 529 | AUSWAEHLEN 530 | BALD 531 | BEKOMMEN 532 | BITTE 533 | BODENSEE 534 | BRITANNIEN 535 | CHAOS 536 | DAMEN 537 | DAUERND 538 | DUENN 539 | ELFTE 540 | ENDE 541 | ENDLICH 542 | ERWARTEN 543 | FERTIG 544 | FINNLAND 545 | FOEHN 546 | FRANKFURT 547 | FREI 548 | FUENFZIG 549 | GEBIRGE 550 | GEWOHNT 551 | HAAR 552 | HAFTEN 553 | HALB 554 | HALLO 555 | HAMBURG 556 | HERREN 557 | HESSEN 558 | HOEHER 559 | HOFFEN 560 | HOLLAND 561 | IN-BESTIMMT-ZEIT 562 | INTERNET 563 | IRLAND 564 | JEDEN-TAG 565 | K 566 | KUEHLER 567 | LAGE 568 | LEUTE 569 | LIEGEN 570 | MOEGEN 571 | NORDRHEIN-WESTFALEN 572 | OESTERREICH 573 | PAAR 574 | PFALZ 575 | PFINGSTEN 576 | PLOETZLICH 577 | PROBLEM 578 | PROZENT 579 | SCHADEN 580 | SIEBZIG 581 | SITUATION 582 | SITZ 583 | STRAHLEN 584 | TAUGEN 585 | TEXT 586 | TJA 587 | VERGLEICH 588 | VERMEIDEN 589 | VERSCHIEBEN 590 | VORBEREITEN 591 | WAHR 592 | WEG 593 | WEIT 594 | WER 595 | WICHTIG 596 | WIE-GEBLIEBEN 597 | WUERTTEMBERG 598 | Z+D+F 599 | ZIEHEN 600 | ZU-ENDE 601 | ZUSAMMENHANG 602 | neg-GRAD 603 | neg-KALT 604 | neg-NEBEL 605 | negalp-KOENNEN 606 | ABSINKEN 607 | AKTIV 608 | AKTUELL 609 | ALT 610 | AM-KUESTE 611 | AUFKLAREN 612 | AUFKOMMEN 613 | AUS 614 | AUSHALTEN 615 | AUSNAHME 616 | AUSSERGEWOEHNLICH 617 | BAUER 618 | BAUM 619 | BEISEITE 620 | BEOBACHTEN 621 | BERUHIGEN 622 | BLUETE 623 | BRAUN 624 | BREMEN 625 | DREI-MONATE 626 | DRESDEN 627 | DRUCKFLAECHE 628 | E+R 629 | EISEN 630 | ENTSPANNT 631 | ENTWICKELN 632 | ERDRUTSCH 633 | ERFAHREN 634 | ERFURT 635 | ERZ 636 | FLOCKEN 637 | FRAGEZEICHEN 638 | FREUEN 639 | FUER-ALLE 640 | GEHEN 641 | GEHOERT 642 | GEHT-SO 643 | GELB 644 | GEMISCHT 645 | GENUG 646 | GESAMT 647 | GLATTEIS 648 | GOLD 649 | GROB 650 | GUT-ABEND 651 | HEILIG 652 | HOEREN 653 | IN-DIESE-WOCHE 654 | KIEL 655 | KLAPPEN 656 | KOERPER 657 | MAINZ 658 | MOMENT 659 | MOND 660 | MORGEN-FRUEH 661 | MUENCHEN 662 | MUENSTER 663 | NORDPOL 664 | OB 665 | ORANGE 666 | P 667 | PAUSE 668 | POMMERN 669 | POSITIV 670 | RAND 671 | RAUM 672 | REIF 673 | REIN 674 | REKORD 675 | RHEINLAND 676 | RHEINLAND-PFALZ 677 | RICHTIG 678 | RODELN 679 | SAARLAND 680 | SCH 681 | SCH+H 682 | SCHAU-MAL 683 | SCHLAF 684 | SCHLECHT 685 | SCHLECHTER 686 | SCHON-WIEDER 687 | SEHR 688 | SKI 689 | SO-BLEIBEN 690 | SPAZIEREN 691 | SPITZE 692 | SPUEREN 693 | STAU 694 | STEIGEN-RUNTER 695 | STOCKEN 696 | STOERUNG 697 | STRENG 698 | T 699 | T-SHIRT 700 | TOLL 701 | TRENNEN 702 | TSCHECHIEN 703 | TUN 704 | UNGEMUETLICH 705 | UNSICHER 706 | VERLAUFEN 707 | WEIN 708 | WIRTSCHAFT 709 | WISSEN 710 | WOHNEN 711 | WUERZ 712 | ZU 713 | ZUFRIEDEN 714 | ZUG 715 | ZUSAMMEN 716 | ZWEIFEL 717 | neg-EINFLUSS 718 | neg-FUEHLEN 719 | neg-GEWITTER 720 | neg-HOEHE 721 | neg-IMMER 722 | neg-KOMMEN 723 | neg-MEHR 724 | neg-NEIN 725 | neg-NICHT-MEHR 726 | neg-NORD 727 | neg-REGION 728 | neg-SCHLIMM 729 | neg-SCHOEN 730 | neg-SEHEN 731 | neg-SONNE 732 | neg-WARTEN 733 | neg-WOLKE 734 | A+Z 735 | AB-JETZT 736 | ABFALLEN 737 | ABKUEHLEN 738 | ABSCHIED 739 | ABSCHNITT 740 | AENDERN 741 | AFRIKA 742 | ALLGEMEIN 743 | ALPENRAND 744 | ALPENTAL 745 | AM-MEER 746 | AM-RAND 747 | AMERIKA 748 | ANDERE-MOEGLICHKEIT 749 | ANGST 750 | ANHALT 751 | ANSAMMELN 752 | ANSCHAUEN 753 | ARM 754 | ATLANTIK 755 | AUFBLUEHEN 756 | AUFEINANDERTREFFEN 757 | AUFFUELLEN 758 | AUFHEITERN 759 | AUFHOEREN 760 | AUFTEILEN 761 | AUSEINANDER 762 | AUSSICHT 763 | AUTOMATISCH 764 | BADEN-WUERTTEMBERG 765 | BEDECKT 766 | BEDINGUNGEN 767 | BEGINN 768 | BEIDE 769 | BELAESTIGUNG 770 | BERGAB 771 | BERGAUF 772 | BERUF 773 | BESPRECHEN 774 | BESTE 775 | BETROFFEN 776 | BETT 777 | BIS-MITTE 778 | BIS-MORGEN 779 | BLATT 780 | BLEIBEN-GLEICH 781 | BLUMEN 782 | BROCKEN 783 | BRUCKBERG 784 | BUNT 785 | C 786 | D+A+M+I+A+N 787 | D+E 788 | DAFUER 789 | DANEBEN 790 | DARAUF 791 | DARUNTER 792 | DAS-WAR-ES 793 | DEMNAECHST 794 | DENKEN 795 | DIESMAL 796 | DOCH-SONST-NOCH 797 | DRAUSSEN 798 | DREIMAL 799 | DUESSELDORF 800 | DUMM 801 | DUNST 802 | DURCHEINANDER 803 | DURCHSCHNITT 804 | E+Z 805 | EBEN 806 | ECHT 807 | EIGENTLICH 808 | EIN 809 | EIN-JAHR 810 | EIN-WOCHE 811 | EINFACH 812 | EINSCHRAENKEN 813 | EINZELN 814 | EMPFINDLICH 815 | ENTFERNT 816 | ENTHALTEN 817 | ERHOEHEN 818 | ERLEICHERT 819 | ERNTE 820 | ERSCHROCKEN 821 | EUCH 822 | EWIG 823 | EXTREM 824 | F+E+H+M+E+R 825 | F+E+R+Z 826 | FACH 827 | FEHLT 828 | FEIER 829 | FEST 830 | FLIESSEN 831 | FLUT 832 | FREIZEIT 833 | FRONT 834 | FUENF-TAGE 835 | FUENF-UHR 836 | FUER-UNS 837 | G+U+NN+E+R 838 | GARTEN 839 | GENIESSEN 840 | GESCHWINDIGKEIT 841 | GETRENNT 842 | GIPFEL 843 | GLAUBEN 844 | GLEICH-WIE 845 | GLITZERN 846 | GOTT 847 | GRUEN 848 | HALTEN 849 | HANNOVER 850 | HAVEN 851 | HEINS 852 | HELL 853 | HERAB 854 | HERVORRAGEND 855 | HEUTE-ABEND 856 | HEUTE-MITTAG 857 | HINDERNIS 858 | HOLEN 859 | HUND 860 | HUT 861 | I+S 862 | IM-LAUFE 863 | INNERHALB 864 | INTERESSANT 865 | IRGENDWANN 866 | IRGENDWO 867 | J+L+I 868 | K+R+E+T+A 869 | KALENDER 870 | KALTFRONT 871 | KANADA 872 | KANAL 873 | KAPUTTGEGANGEN 874 | KARFREITAG 875 | KARTE 876 | KLEINIGKEIT 877 | KNOSPE-ABFALLEN 878 | KOBLENZ 879 | KOMMA 880 | KOMPLETT 881 | KONSTANT 882 | KORB 883 | KRATZEN 884 | KRISE 885 | KROATIEN 886 | KUCHEN 887 | KURVE 888 | LAERM 889 | LAHM 890 | LANDSCHAFT 891 | LAUFEN 892 | LAUSITZ 893 | LEBEN 894 | LEIPZIG 895 | LESEN 896 | LETZTE-WOCHE 897 | LICHT 898 | LUECKE 899 | M 900 | M+A+R+T+I+N 901 | M+L 902 | M+M 903 | M+R+Z 904 | MARKT 905 | MASCHINE 906 | MATSCH 907 | MECKLENBURG 908 | MECKLENBURG-VORPOMMERN 909 | MEHR-WENIG 910 | MEHRMALS 911 | MEINEN 912 | MERKWUERDIG 913 | MESSEN 914 | MINDESTENS 915 | MITBEKOMMEN 916 | MITEILEN 917 | MITNEHMEN 918 | MITZIEHEN 919 | MM 920 | MOEGLICHKEIT 921 | MONAT 922 | MUND 923 | NAECHSTE-WOCHE 924 | NATUR 925 | NEBEN 926 | NEUNZEHNTE 927 | NEUNZIG 928 | NIEDERUNG 929 | NIEDRIG 930 | NOCH-MEHR 931 | NORDRHEIN 932 | NRW 933 | NUMMER 934 | OHNE 935 | ORIENTIEREN 936 | OSTBAYERN 937 | P+O+P 938 | PFEIL 939 | PFLANZE 940 | PORTUGAL 941 | PUENKTLICH 942 | PULLOVER 943 | QUADRAT 944 | RAUSFALLEN 945 | REDUZIEREN 946 | ROSE 947 | ROSTOCK 948 | RUECKEN 949 | RUHRGEBIET 950 | RUMAENIEN 951 | S+H+M+V 952 | S+Y+L+T 953 | S0NNE 954 | SACKGASSE 955 | SAND 956 | SCH+E+W+E+B+Y 957 | SCHAETZEN 958 | SCHIRM 959 | SCHLAGSAHNE 960 | SCHMELZEN 961 | SCHNEE-AUF-BERG 962 | SCHNEEVERWEHUNG 963 | SCHRANK 964 | SCHUETZEN 965 | SCHULD 966 | SCHWITZEN 967 | SECHZIG 968 | SEI-DANK 969 | SEIT 970 | SEITE 971 | SICHER 972 | SIEBEN-WOCHE 973 | SLOWAKEI 974 | SONNE-SCHEINEN 975 | SONNENUNTERGANG 976 | SOWIESO 977 | SPAETESTEN 978 | SPORT 979 | SPRIESSEN 980 | STADT 981 | STAMM 982 | START 983 | STEHEN 984 | STEIGEN-OBEN 985 | STIMMT 986 | STREIFEN 987 | STROM 988 | STUTTGART 989 | SUCHEN 990 | T+V+U 991 | TANKEN 992 | TAU 993 | THEMA 994 | TIEFDRUCKZONE 995 | TRINKEN 996 | TROPISCH 997 | U 998 | UEBER-UNTER 999 | UEBERFLUTUNG 1000 | UMSTAENDLICH 1001 | UMSTELLEN 1002 | UND-SO-WEITER 1003 | UNGARN 1004 | UNSER 1005 | UNTERNEHMEN 1006 | UNWAHRSCHEINLICH 1007 | URLAUB 1008 | V+O+I+G+T 1009 | VERANTWORTLICH 1010 | VERBINDEN 1011 | VERDICHTEN 1012 | VEREINZELT 1013 | VERSPAETET 1014 | VERSUCHEN 1015 | VERTEILEN 1016 | VERTREIBEN 1017 | VERWOEHNT 1018 | VIDEO 1019 | VOLL 1020 | VOR-ALLEM 1021 | VOR-LETZTEN-TAGEN 1022 | VORDERSCHEIBE 1023 | VORPOMMERN 1024 | VORSTELLEN 1025 | VORUEBERGEHEND 1026 | W+E 1027 | WACHSEN 1028 | WANN 1029 | WARTEN 1030 | WASCH 1031 | WASSER-STEIGEN 1032 | WEIBER 1033 | WEIL 1034 | WEIT-SEHEN 1035 | WERDEN 1036 | WESER 1037 | WIE-LANG 1038 | WIEDER-ZURUECK 1039 | WIESE 1040 | WIEVIEL 1041 | WIRKEN 1042 | WUNDERBAR 1043 | WUNDERSCHOEN 1044 | Y+I+Y 1045 | Z+D+E 1046 | Z+Y+P 1047 | ZAHL 1048 | ZEHN-STUNDEN 1049 | ZEITSKALA 1050 | ZOOM 1051 | ZU-HAUSE 1052 | ZU-TUN 1053 | ZUSAMMENSTOSS 1054 | ZUSAMMENTREFFEN 1055 | neg-ACHTZEHN 1056 | neg-BEDEUTEN 1057 | neg-DEUTSCH 1058 | neg-FUENF 1059 | neg-GEMUETLICH 1060 | neg-GENUG 1061 | neg-GLEICH 1062 | neg-HART 1063 | neg-HEISS 1064 | neg-IN-DIESE-WOCHE 1065 | neg-MEISTENS 1066 | neg-MILD 1067 | neg-NACHT 1068 | neg-NAJA 1069 | neg-NOCH-NICHT 1070 | neg-NULL 1071 | neg-NUR 1072 | neg-RICHTIG 1073 | neg-SCHEINEN 1074 | neg-SCHLECHT 1075 | neg-SCHNEE 1076 | neg-SELTEN 1077 | neg-SPUEREN 1078 | neg-STARK 1079 | neg-TEIL 1080 | neg-THEMA 1081 | neg-VON 1082 | neg-WARM 1083 | neg-ZU-WARM 1084 | negalp-BRAUCHEN 1085 | negalp-GIBT 1086 | negalp-MUSS 1087 | negalp-PASSEN 1088 | -------------------------------------------------------------------------------- /signjoey/search.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | import numpy as np 6 | 7 | from signjoey.decoders import Decoder, TransformerDecoder 8 | from signjoey.EnsembleTransformer import EnsembleTransformerDecoder 9 | from signjoey.embeddings import Embeddings 10 | from signjoey.helpers import tile 11 | from scipy import stats 12 | import pandas as pd 13 | __all__ = ["greedy", "transformer_greedy", "beam_search"] 14 | 15 | 16 | def greedy( 17 | src_mask: Tensor, 18 | embed: Embeddings, 19 | bos_index: int, 20 | eos_index: int, 21 | max_output_length: int, 22 | decoder: Decoder, 23 | encoder_output: Tensor, 24 | encoder_hidden: Tensor, 25 | ) -> (np.array, np.array): 26 | """ 27 | Greedy decoding. Select the token word highest probability at each time 28 | step. This function is a wrapper that calls recurrent_greedy for 29 | recurrent decoders and transformer_greedy for transformer decoders. 30 | 31 | :param src_mask: mask for source inputs, 0 for positions after 32 | :param embed: target embedding 33 | :param bos_index: index of in the vocabulary 34 | :param eos_index: index of in the vocabulary 35 | :param max_output_length: maximum length for the hypotheses 36 | :param decoder: decoder to use for greedy decoding 37 | :param encoder_output: encoder hidden states for attention 38 | :param encoder_hidden: encoder last state for decoder initialization 39 | :return: 40 | """ 41 | 42 | if isinstance(decoder, TransformerDecoder) or isinstance(decoder, EnsembleTransformerDecoder) : 43 | # Transformer greedy decoding 44 | greedy_fun = transformer_greedy 45 | else: 46 | # Recurrent greedy decoding 47 | greedy_fun = recurrent_greedy 48 | 49 | return greedy_fun( 50 | src_mask=src_mask, 51 | embed=embed, 52 | bos_index=bos_index, 53 | eos_index=eos_index, 54 | max_output_length=max_output_length, 55 | decoder=decoder, 56 | encoder_output=encoder_output, 57 | encoder_hidden=encoder_hidden, 58 | ) 59 | 60 | 61 | def recurrent_greedy( 62 | src_mask: Tensor, 63 | embed: Embeddings, 64 | bos_index: int, 65 | eos_index: int, 66 | max_output_length: int, 67 | decoder: Decoder, 68 | encoder_output: Tensor, 69 | encoder_hidden: Tensor, 70 | ) -> (np.array, np.array): 71 | """ 72 | Greedy decoding: in each step, choose the word that gets highest score. 73 | Version for recurrent decoder. 74 | 75 | :param src_mask: mask for source inputs, 0 for positions after 76 | :param embed: target embedding 77 | :param bos_index: index of in the vocabulary 78 | :param eos_index: index of in the vocabulary 79 | :param max_output_length: maximum length for the hypotheses 80 | :param decoder: decoder to use for greedy decoding 81 | :param encoder_output: encoder hidden states for attention 82 | :param encoder_hidden: encoder last state for decoder initialization 83 | :return: 84 | - stacked_output: output hypotheses (2d array of indices), 85 | - stacked_attention_scores: attention scores (3d array) 86 | """ 87 | batch_size = src_mask.size(0) 88 | prev_y = src_mask.new_full( 89 | size=[batch_size, 1], fill_value=bos_index, dtype=torch.long 90 | ) 91 | output = [] 92 | attention_scores = [] 93 | hidden = None 94 | prev_att_vector = None 95 | finished = src_mask.new_zeros((batch_size, 1)).byte() 96 | 97 | # pylint: disable=unused-variable 98 | for t in range(max_output_length): 99 | # decode one single step 100 | logits, hidden, att_probs, prev_att_vector = decoder( 101 | encoder_output=encoder_output, 102 | encoder_hidden=encoder_hidden, 103 | src_mask=src_mask, 104 | trg_embed=embed(prev_y), 105 | hidden=hidden, 106 | prev_att_vector=prev_att_vector, 107 | unroll_steps=1, 108 | ) 109 | # logits: batch x time=1 x vocab (logits) 110 | 111 | # greedy decoding: choose arg max over vocabulary in each step 112 | next_word = torch.argmax(logits, dim=-1) # batch x time=1 113 | output.append(next_word.squeeze(1).detach().cpu().numpy()) 114 | prev_y = next_word 115 | attention_scores.append(att_probs.squeeze(1).detach().cpu().numpy()) 116 | # batch, max_src_lengths 117 | # check if previous symbol was 118 | is_eos = torch.eq(next_word, eos_index) 119 | finished += is_eos 120 | # stop predicting if reached for all elements in batch 121 | if (finished >= 1).sum() == batch_size: 122 | break 123 | 124 | stacked_output = np.stack(output, axis=1) # batch, time 125 | stacked_attention_scores = np.stack(attention_scores, axis=1) 126 | return stacked_output, stacked_attention_scores 127 | 128 | 129 | # pylint: disable=unused-argument 130 | def transformer_greedy( 131 | src_mask: Tensor, 132 | embed: Embeddings, 133 | bos_index: int, 134 | eos_index: int, 135 | max_output_length: int, 136 | decoder: Decoder, 137 | encoder_output: Tensor, 138 | encoder_hidden: Tensor, 139 | ) -> (np.array, None): 140 | """ 141 | Special greedy function for transformer, since it works differently. 142 | The transformer remembers all previous states and attends to them. 143 | 144 | :param src_mask: mask for source inputs, 0 for positions after 145 | :param embed: target embedding layer 146 | :param bos_index: index of in the vocabulary 147 | :param eos_index: index of in the vocabulary 148 | :param max_output_length: maximum length for the hypotheses 149 | :param decoder: decoder to use for greedy decoding 150 | :param encoder_output: encoder hidden states for attention 151 | :param encoder_hidden: encoder final state (unused in Transformer) 152 | :return: 153 | - stacked_output: output hypotheses (2d array of indices), 154 | - stacked_attention_scores: attention scores (3d array) 155 | """ 156 | 157 | batch_size = src_mask.size(0) 158 | 159 | # start with BOS-symbol for each sentence in the batch 160 | ys = encoder_output.new_full([batch_size, 1], bos_index, dtype=torch.long) 161 | 162 | # a subsequent mask is intersected with this in decoder forward pass 163 | trg_mask = src_mask.new_ones([1, 1, 1]) 164 | finished = src_mask.new_zeros((batch_size)).byte() 165 | 166 | for _ in range(max_output_length): 167 | 168 | trg_embed = embed(ys) # embed the previous tokens 169 | 170 | # pylint: disable=unused-variable 171 | with torch.no_grad(): 172 | logits, out, _, _ = decoder( 173 | trg_embed=trg_embed, 174 | encoder_output=encoder_output, 175 | encoder_hidden=None, 176 | src_mask=src_mask, 177 | unroll_steps=None, 178 | hidden=None, 179 | trg_mask=trg_mask, 180 | ) 181 | 182 | logits = logits[:, -1] 183 | _, next_word = torch.max(logits, dim=1) 184 | next_word = next_word.data 185 | ys = torch.cat([ys, next_word.unsqueeze(-1)], dim=1) 186 | 187 | # check if previous symbol was 188 | is_eos = torch.eq(next_word, eos_index) 189 | finished += is_eos 190 | # stop predicting if reached for all elements in batch 191 | if (finished >= 1).sum() == batch_size: 192 | break 193 | 194 | ys = ys[:, 1:] # remove BOS-symbol 195 | return ys.detach().cpu().numpy(), None 196 | 197 | 198 | # pylint: disable=too-many-statements,too-many-branches 199 | def beam_search( 200 | decoder: Decoder, 201 | size: int, 202 | bos_index: int, 203 | eos_index: int, 204 | pad_index: int, 205 | encoder_output: Tensor, 206 | encoder_hidden: Tensor, 207 | src_mask: Tensor, 208 | max_output_length: int, 209 | alpha: float, 210 | embed: Embeddings, 211 | n_best: int = 1, 212 | ) -> (np.array, np.array): 213 | """ 214 | Beam search with size k. 215 | Inspired by OpenNMT-py, adapted for Transformer. 216 | 217 | In each decoding step, find the k most likely partial hypotheses. 218 | 219 | :param decoder: 220 | :param size: size of the beam 221 | :param bos_index: 222 | :param eos_index: 223 | :param pad_index: 224 | :param encoder_output: 225 | :param encoder_hidden: 226 | :param src_mask: 227 | :param max_output_length: 228 | :param alpha: `alpha` factor for length penalty 229 | :param embed: 230 | :param n_best: return this many hypotheses, <= beam (currently only 1) 231 | :return: 232 | - stacked_output: output hypotheses (2d array of indices), 233 | - stacked_attention_scores: attention scores (3d array) 234 | """ 235 | assert size > 0, "Beam size must be >0." 236 | assert n_best <= size, "Can only return {} best hypotheses.".format(size) 237 | 238 | # init 239 | transformer = isinstance(decoder, TransformerDecoder) or isinstance(decoder, EnsembleTransformerDecoder) 240 | batch_size = src_mask.size(0) 241 | att_vectors = None # not used for Transformer 242 | 243 | # Recurrent models only: initialize RNN hidden state 244 | # pylint: disable=protected-access 245 | if not transformer: 246 | hidden = decoder._init_hidden(encoder_hidden) 247 | else: 248 | hidden = None 249 | 250 | # tile encoder states and decoder initial states beam_size times 251 | if hidden is not None: 252 | hidden = tile(hidden, size, dim=1) # layers x batch*k x dec_hidden_size 253 | 254 | encoder_output = tile( 255 | encoder_output.contiguous(), size, dim=0 256 | ) # batch*k x src_len x enc_hidden_size 257 | src_mask = tile(src_mask, size, dim=0) # batch*k x 1 x src_len 258 | 259 | # Transformer only: create target mask 260 | if transformer: 261 | trg_mask = src_mask.new_ones([1, 1, 1]) # transformer only 262 | else: 263 | trg_mask = None 264 | 265 | # numbering elements in the batch 266 | batch_offset = torch.arange( 267 | batch_size, dtype=torch.long, device=encoder_output.device 268 | ) 269 | 270 | # numbering elements in the extended batch, i.e. beam size copies of each 271 | # batch element 272 | beam_offset = torch.arange( 273 | 0, batch_size * size, step=size, dtype=torch.long, device=encoder_output.device 274 | ) 275 | 276 | # keeps track of the top beam size hypotheses to expand for each element 277 | # in the batch to be further decoded (that are still "alive") 278 | alive_seq = torch.full( 279 | [batch_size * size, 1], 280 | bos_index, 281 | dtype=torch.long, 282 | device=encoder_output.device, 283 | ) 284 | 285 | # Give full probability to the first beam on the first step. 286 | topk_log_probs = torch.zeros(batch_size, size, device=encoder_output.device) 287 | topk_log_probs[:, 1:] = float("-inf") 288 | 289 | # Structure that holds finished hypotheses. 290 | hypotheses = [[] for _ in range(batch_size)] 291 | 292 | results = { 293 | "predictions": [[] for _ in range(batch_size)], 294 | "scores": [[] for _ in range(batch_size)], 295 | "gold_score": [0] * batch_size, 296 | } 297 | 298 | for step in range(max_output_length): 299 | 300 | # This decides which part of the predicted sentence we feed to the 301 | # decoder to make the next prediction. 302 | # For Transformer, we feed the complete predicted sentence so far. 303 | # For Recurrent models, only feed the previous target word prediction 304 | if transformer: # Transformer 305 | decoder_input = alive_seq # complete prediction so far 306 | else: # Recurrent 307 | decoder_input = alive_seq[:, -1].view(-1, 1) # only the last word 308 | 309 | # expand current hypotheses 310 | # decode one single step 311 | # logits: logits for final softmax 312 | # pylint: disable=unused-variable 313 | trg_embed = embed(decoder_input) 314 | logits, hidden, att_scores, att_vectors = decoder( 315 | encoder_output=encoder_output, 316 | encoder_hidden=encoder_hidden, 317 | src_mask=src_mask, 318 | trg_embed=trg_embed, 319 | hidden=hidden, 320 | prev_att_vector=att_vectors, 321 | unroll_steps=1, 322 | trg_mask=trg_mask, # subsequent mask for Transformer only 323 | ) 324 | 325 | # For the Transformer we made predictions for all time steps up to 326 | # this point, so we only want to know about the last time step. 327 | if transformer: 328 | logits = logits[:, -1] # keep only the last time step 329 | hidden = None # we don't need to keep it for transformer 330 | 331 | # batch*k x trg_vocab 332 | log_probs = F.log_softmax(logits, dim=-1).squeeze(1) 333 | 334 | # multiply probs by the beam probability (=add logprobs) 335 | log_probs += topk_log_probs.view(-1).unsqueeze(1) 336 | curr_scores = log_probs.clone() 337 | 338 | # compute length penalty 339 | if alpha > -1: 340 | length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha 341 | curr_scores /= length_penalty 342 | 343 | # flatten log_probs into a list of possibilities 344 | curr_scores = curr_scores.reshape(-1, size * decoder.output_size) 345 | 346 | # pick currently best top k hypotheses (flattened order) 347 | topk_scores, topk_ids = curr_scores.topk(size, dim=-1) 348 | 349 | if alpha > -1: 350 | # recover original log probs 351 | topk_log_probs = topk_scores * length_penalty 352 | else: 353 | topk_log_probs = topk_scores.clone() 354 | 355 | # reconstruct beam origin and true word ids from flattened order 356 | topk_beam_index = topk_ids.floor_divide(decoder.output_size) 357 | topk_ids = topk_ids.fmod(decoder.output_size) 358 | 359 | # map beam_index to batch_index in the flat representation 360 | batch_index = topk_beam_index + beam_offset[ 361 | : topk_beam_index.size(0) 362 | ].unsqueeze(1) 363 | select_indices = batch_index.view(-1) 364 | 365 | # append latest prediction 366 | alive_seq = torch.cat( 367 | [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1 368 | ) # batch_size*k x hyp_len 369 | 370 | is_finished = topk_ids.eq(eos_index) 371 | if step + 1 == max_output_length: 372 | is_finished.fill_(True) 373 | # end condition is whether the top beam is finished 374 | end_condition = is_finished[:, 0].eq(True) 375 | 376 | # save finished hypotheses 377 | if is_finished.any(): 378 | predictions = alive_seq.view(-1, size, alive_seq.size(-1)) 379 | for i in range(is_finished.size(0)): 380 | b = batch_offset[i] 381 | if end_condition[i]: 382 | is_finished[i].fill_(True) 383 | finished_hyp = is_finished[i].nonzero().view(-1) 384 | # store finished hypotheses for this batch 385 | for j in finished_hyp: 386 | # Check if the prediction has more than one EOS. 387 | # If it has more than one EOS, it means that the prediction should have already 388 | # been added to the hypotheses, so you don't have to add them again. 389 | if (predictions[i, j, 1:] == eos_index).nonzero().numel() < 2: 390 | hypotheses[b].append( 391 | ( 392 | topk_scores[i, j], 393 | predictions[i, j, 1:], 394 | ) # ignore start_token 395 | ) 396 | # if the batch reached the end, save the n_best hypotheses 397 | if end_condition[i]: 398 | best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True) 399 | for n, (score, pred) in enumerate(best_hyp): 400 | if n >= n_best: 401 | break 402 | results["scores"][b].append(score) 403 | results["predictions"][b].append(pred) 404 | non_finished = end_condition.eq(False).nonzero().view(-1) 405 | # if all sentences are translated, no need to go further 406 | # pylint: disable=len-as-condition 407 | if len(non_finished) == 0: 408 | break 409 | # remove finished batches for the next step 410 | topk_log_probs = topk_log_probs.index_select(0, non_finished) 411 | batch_index = batch_index.index_select(0, non_finished) 412 | batch_offset = batch_offset.index_select(0, non_finished) 413 | alive_seq = predictions.index_select(0, non_finished).view( 414 | -1, alive_seq.size(-1) 415 | ) 416 | 417 | # reorder indices, outputs and masks 418 | select_indices = batch_index.view(-1) 419 | encoder_output = encoder_output.index_select(0, select_indices) 420 | src_mask = src_mask.index_select(0, select_indices) 421 | 422 | if hidden is not None and not transformer: 423 | if isinstance(hidden, tuple): 424 | # for LSTMs, states are tuples of tensors 425 | h, c = hidden 426 | h = h.index_select(1, select_indices) 427 | c = c.index_select(1, select_indices) 428 | hidden = (h, c) 429 | else: 430 | # for GRUs, states are single tensors 431 | hidden = hidden.index_select(1, select_indices) 432 | 433 | if att_vectors is not None: 434 | att_vectors = att_vectors.index_select(0, select_indices) 435 | 436 | def pad_and_stack_hyps(hyps, pad_value): 437 | filled = ( 438 | np.ones((len(hyps), max([h.shape[0] for h in hyps])), dtype=int) * pad_value 439 | ) 440 | for j, h in enumerate(hyps): 441 | for k, i in enumerate(h): 442 | filled[j, k] = i 443 | return filled 444 | 445 | # from results to stacked outputs 446 | assert n_best == 1 447 | # only works for n_best=1 for now 448 | 449 | final_outputs = pad_and_stack_hyps( 450 | [r[0].cpu().numpy() for r in results["predictions"]], pad_value=pad_index 451 | ) 452 | 453 | return final_outputs, None 454 | 455 | 456 | --------------------------------------------------------------------------------