├── .gitmodules ├── cress ├── __init__.py ├── test_scripts │ ├── avg_epoch.sh │ ├── test.en-x.st.sh │ └── test.en-x.mt.sh ├── models │ ├── __init__.py │ └── hubert_transformer.py ├── tasks │ ├── __init__.py │ ├── speech_to_text_modified.py │ └── speech_and_text_translation.py ├── datasets │ ├── __init__.py │ ├── audio_utils.py │ ├── speech_and_text_translation_dataset.py │ └── speech_to_text_dataset.py ├── criterions │ ├── __init__.py │ ├── speech_and_text_translation_criterion.py │ ├── speech_and_text_translation_with_oracle_reg_criterion.py │ └── speech_and_text_translation_with_oracle_reg_adaptive_criterion.py └── train_scripts │ ├── train.en-x.postln.wmt_pretrain.sh │ ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.sh │ ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.sh │ ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress.sh │ └── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress_adaptive.sh └── README.md /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fairseq"] 2 | path = fairseq 3 | url = https://github.com/facebookresearch/fairseq 4 | -------------------------------------------------------------------------------- /cress/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterions import * 2 | from .models import * 3 | from .tasks import * 4 | from .datasets import * 5 | 6 | print("fairseq plugins loaded...") -------------------------------------------------------------------------------- /cress/test_scripts/avg_epoch.sh: -------------------------------------------------------------------------------- 1 | ckpt=$1 2 | python scripts/average_checkpoints.py \ 3 | --inputs checkpoints/$ckpt \ 4 | --num-epoch-checkpoints 10 \ 5 | --output checkpoints/$ckpt/avg_last_10_epoch.pt -------------------------------------------------------------------------------- /cress/models/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("cress.models." + file_name) 9 | -------------------------------------------------------------------------------- /cress/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("cress.tasks." + file_name) 9 | -------------------------------------------------------------------------------- /cress/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("cress.datasets." + file_name) 9 | -------------------------------------------------------------------------------- /cress/test_scripts/test.en-x.st.sh: -------------------------------------------------------------------------------- 1 | ckpt=$1 2 | lang=$2 3 | lenpen=$3 4 | fairseq-generate data/mustc/en-$lang \ 5 | --user-dir cress \ 6 | --config-yaml config.yaml --gen-subset tst-COMMON --task speech_to_text_modified \ 7 | --path $ckpt \ 8 | --max-source-positions 900000 \ 9 | --max-tokens 2000000 --beam 8 --lenpen $lenpen --scoring sacrebleu 10 | -------------------------------------------------------------------------------- /cress/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | 4 | # automatically import any Python files in the criterions/ directory 5 | for file in os.listdir(os.path.dirname(__file__)): 6 | if file.endswith(".py") and not file.startswith("_"): 7 | file_name = file[: file.find(".py")] 8 | importlib.import_module("cress.criterions." + file_name) 9 | -------------------------------------------------------------------------------- /cress/test_scripts/test.en-x.mt.sh: -------------------------------------------------------------------------------- 1 | ckpt=$1 2 | lang=$2 3 | lenpen=$3 4 | fairseq-generate data/mustc/en-$lang --text-data data/mustc/en-$lang/binary --tgt-lang $lang \ 5 | --user-dir cress \ 6 | --config-yaml config.yaml --gen-subset test --task speech_and_text_translation \ 7 | --path $ckpt \ 8 | --ext-mt-training \ 9 | --max-tokens 2000000 --max-tokens-text 4096 --beam 8 --lenpen $lenpen --scoring sacrebleu -------------------------------------------------------------------------------- /cress/train_scripts/train.en-x.postln.wmt_pretrain.sh: -------------------------------------------------------------------------------- 1 | tgt=$1 2 | exp=en-$tgt.postln.wmt_pretrain 3 | fairseq-train data/mustc/en-$tgt --text-data data/wmt/en-$tgt/mustc_wmt_en_$tgt/binary --tgt-lang $tgt \ 4 | --user-dir cress \ 5 | --config-yaml config.yaml --train-subset train --valid-subset dev \ 6 | --save-dir checkpoints/${exp} --num-workers 4 --max-tokens 1000000 --max-tokens-text 8192 --max-update 250000 \ 7 | --task speech_and_text_translation --criterion speech_and_text_translation --label-smoothing 0.1 \ 8 | --arch hubert_transformer_postln --optimizer adam --adam-betas '(0.9, 0.98)' --lr 7e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0 \ 9 | --no-progress-bar --log-format json --log-interval 100 \ 10 | --save-interval-updates 5000 \ 11 | --ddp-backend=legacy_ddp \ 12 | --warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 1 \ 13 | --layernorm-embedding \ 14 | --fp16 \ 15 | --ext-mt-training \ 16 | --hubert-model-path checkpoints/hubert_base_ls960.pt 17 | -------------------------------------------------------------------------------- /cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.sh: -------------------------------------------------------------------------------- 1 | tgt=$1 2 | exp=en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain 3 | fairseq-train data/mustc/en-$tgt --text-data data/mustc/en-$tgt/binary/ --tgt-lang $tgt \ 4 | --user-dir cress \ 5 | --config-yaml config.yaml --train-subset train --valid-subset dev \ 6 | --save-dir checkpoints/${exp} --num-workers 4 --max-tokens 1000000 --max-tokens-text 8192 --max-update 100000 \ 7 | --task speech_and_text_translation --criterion speech_and_text_translation --label-smoothing 0.1 \ 8 | --arch hubert_transformer_postln --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt \ 9 | --no-progress-bar --log-format json --log-interval 100 \ 10 | --ddp-backend=legacy_ddp \ 11 | --warmup-updates 8000 --clip-norm 10.0 --seed 1 --update-freq 1 \ 12 | --layernorm-embedding \ 13 | --patience 10 \ 14 | --fp16 \ 15 | --ext-mt-training \ 16 | --hubert-model-path checkpoints/hubert_base_ls960.pt \ 17 | --load-pretrained-mt-encoder-decoder-from checkpoints/en-$tgt.postln.wmt_pretrain/avg_last_5_epoch.pt -------------------------------------------------------------------------------- /cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.sh: -------------------------------------------------------------------------------- 1 | tgt=$1 2 | exp=en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt 3 | fairseq-train data/mustc/en-$tgt --text-data data/mustc/en-$tgt/binary/ --tgt-lang $tgt \ 4 | --user-dir cress \ 5 | --config-yaml config.yaml --train-subset train --valid-subset dev \ 6 | --save-dir checkpoints/${exp} --num-workers 4 --max-tokens 2000000 --batch-size 32 --max-tokens-text 4096 --max-update 100000 \ 7 | --task speech_and_text_translation --criterion speech_and_text_translation --label-smoothing 0.1 \ 8 | --arch hubert_transformer_postln --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0001 \ 9 | --no-progress-bar --log-format json --log-interval 100 \ 10 | --ddp-backend=legacy_ddp \ 11 | --warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 2 \ 12 | --layernorm-embedding \ 13 | --patience 10 \ 14 | --fp16 \ 15 | --st-training --mt-finetune \ 16 | --hubert-model-path checkpoints/hubert_base_ls960.pt \ 17 | --eval-bleu \ 18 | --eval-bleu-args '{"beam": 8}' \ 19 | --eval-bleu-print-samples \ 20 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 21 | --load-pretrained-mt-encoder-decoder-from checkpoints/en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain/avg_last_10_epoch.pt -------------------------------------------------------------------------------- /cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress.sh: -------------------------------------------------------------------------------- 1 | tgt=$1 2 | exp=en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress 3 | fairseq-train data/mustc/en-$tgt --text-data data/mustc/en-$tgt/binary/ --tgt-lang $tgt \ 4 | --user-dir cress \ 5 | --config-yaml config.yaml --train-subset train --valid-subset dev \ 6 | --save-dir checkpoints/${exp} --num-workers 4 --max-tokens 2000000 --batch-size 32 --max-tokens-text 4096 --max-update 100000 \ 7 | --task speech_and_text_translation --criterion speech_and_text_translation_with_oracle_reg --reg-loss-type jsd --reg-weight 1.0 --label-smoothing 0.1 \ 8 | --arch hubert_transformer_postln --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0001 \ 9 | --use-word-level-oracle --decay-k 15 --use-word-gumbel-noise --gumbel-temperature 1.0 \ 10 | --no-progress-bar --log-format json --log-interval 100 \ 11 | --ddp-backend=legacy_ddp \ 12 | --warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 2 \ 13 | --layernorm-embedding \ 14 | --max-epoch 20 \ 15 | --fp16 \ 16 | --st-training \ 17 | --hubert-model-path checkpoints/hubert_base_ls960.pt \ 18 | --eval-bleu \ 19 | --eval-bleu-args '{"beam": 8}' \ 20 | --eval-bleu-print-samples \ 21 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 22 | --load-pretrained-mt-encoder-decoder-from checkpoints/en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain/avg_last_10_epoch.pt -------------------------------------------------------------------------------- /cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress_adaptive.sh: -------------------------------------------------------------------------------- 1 | tgt=$1 2 | exp=en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress_adaptive 3 | fairseq-train data/mustc/en-$tgt --text-data data/mustc/en-$tgt/binary/ --tgt-lang $tgt \ 4 | --user-dir cress \ 5 | --config-yaml config.yaml --train-subset train --valid-subset dev \ 6 | --save-dir checkpoints/${exp} --num-workers 4 --max-tokens 2000000 --batch-size 32 --max-tokens-text 4096 --max-update 100000 \ 7 | --task speech_and_text_translation --criterion speech_and_text_translation_with_oracle_reg_adaptive --reg-loss-type jsd --reg-weight 1.0 --label-smoothing 0.1 \ 8 | --arch hubert_transformer_postln --optimizer adam --adam-betas '(0.9, 0.98)' --lr 1e-4 --lr-scheduler inverse_sqrt --weight-decay 0.0001 \ 9 | --use-word-level-oracle --decay-k 15 --use-word-gumbel-noise --gumbel-temperature 1.0 \ 10 | --adaptive-base 0.7 --adaptive-scale 0.05 --adaptive-func linear_cosine --adaptive-weight-drop 0.0 \ 11 | --adaptive-st-loss --adaptive-mt-loss --adaptive-reg-loss \ 12 | --no-progress-bar --log-format json --log-interval 100 \ 13 | --ddp-backend=legacy_ddp \ 14 | --warmup-updates 4000 --clip-norm 0.0 --seed 1 --update-freq 2 \ 15 | --layernorm-embedding \ 16 | --patience 10 \ 17 | --fp16 \ 18 | --st-training \ 19 | --hubert-model-path checkpoints/hubert_base_ls960.pt \ 20 | --eval-bleu \ 21 | --eval-bleu-args '{"beam": 8}' \ 22 | --eval-bleu-print-samples \ 23 | --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \ 24 | --restore-file checkpoints/en-$tgt.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress/checkpoint20.pt -------------------------------------------------------------------------------- /cress/criterions/speech_and_text_translation_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import random 8 | from dataclasses import dataclass, field 9 | 10 | import torch 11 | from fairseq import metrics, utils 12 | from fairseq.criterions import FairseqCriterion, register_criterion 13 | from fairseq.criterions.label_smoothed_cross_entropy import ( 14 | LabelSmoothedCrossEntropyCriterion, 15 | LabelSmoothedCrossEntropyCriterionConfig, 16 | ) 17 | from fairseq.dataclass import FairseqDataclass 18 | from omegaconf import II 19 | 20 | 21 | @dataclass 22 | class SpeechAndTextTranslationCriterionConfig(LabelSmoothedCrossEntropyCriterionConfig): 23 | mt_finetune: bool = field( 24 | default=False, 25 | metadata={"help": "st + mt multi-task finetune"}, 26 | ) 27 | 28 | @register_criterion( 29 | "speech_and_text_translation", dataclass=SpeechAndTextTranslationCriterionConfig 30 | ) 31 | class SpeechAndTextTranslationCriterion(LabelSmoothedCrossEntropyCriterion): 32 | def __init__( 33 | self, 34 | task, 35 | sentence_avg, 36 | label_smoothing, 37 | ignore_prefix_size=0, 38 | report_accuracy=False, 39 | mt_finetune=False, 40 | ): 41 | super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy) 42 | self.mt_finetune = mt_finetune 43 | 44 | def forward_st(self, model, sample, reduce): 45 | audio_input = { 46 | "src_tokens": sample["net_input"]["audio"], 47 | "src_lengths": sample["net_input"]["audio_lengths"], 48 | "mode": "st", 49 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 50 | } 51 | audio_output = model(**audio_input) 52 | loss, _ = self.compute_loss(model, audio_output, sample, reduce=reduce) 53 | return loss 54 | 55 | def forward_mt(self, model, sample, reduce): 56 | text_input = { 57 | "src_tokens": sample["net_input"]["source"], 58 | "src_lengths": sample["net_input"]["source_lengths"], 59 | "mode": "mt", 60 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 61 | } 62 | text_output = model(**text_input) 63 | loss, _ = self.compute_loss(model, text_output, sample, reduce=reduce) 64 | return loss 65 | 66 | def forward_ext_mt(self, model, sample, reduce): 67 | text_output = model(**sample["net_input"]) 68 | loss, _ = self.compute_loss(model, text_output, sample, reduce=reduce) 69 | return loss 70 | 71 | def forward(self, model, sample, reduce=True): 72 | """Compute the loss for the given sample. 73 | Returns a tuple with three elements: 74 | 1) the loss 75 | 2) the sample size, which is used as the denominator for the gradient 76 | 3) logging outputs to display while training 77 | """ 78 | st_loss, mt_loss, ext_mt_loss = torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda() 79 | st_size, mt_size, ext_mt_size = 0, 0, 0 80 | 81 | mode = sample["net_input"]["mode"] 82 | if mode == "st": 83 | if self.mt_finetune and self.training: 84 | st_loss = self.forward_st(model, sample, reduce) 85 | mt_loss = self.forward_mt(model, sample, reduce) 86 | loss = st_loss + mt_loss 87 | st_size = mt_size = sample_size = sample["ntokens"] 88 | else: 89 | loss = st_loss = self.forward_st(model, sample, reduce) 90 | st_size = sample_size = sample["ntokens"] 91 | elif mode == "ext_mt": 92 | loss = ext_mt_loss = self.forward_ext_mt(model, sample, reduce) 93 | ext_mt_size = sample_size = sample["ntokens"] 94 | 95 | logging_output = { 96 | "loss": loss.data, 97 | "st_loss": st_loss.data, 98 | "st_sample_size": st_size, 99 | "mt_loss": mt_loss.data, 100 | "mt_sample_size": mt_size, 101 | "ext_mt_loss": ext_mt_loss.data, 102 | "ext_mt_sample_size": ext_mt_size, 103 | "ntokens": sample["ntokens"], 104 | "nsentences": sample["target"].size(0), 105 | "sample_size": sample_size, 106 | } 107 | 108 | return loss, sample_size, logging_output 109 | 110 | @classmethod 111 | def reduce_metrics(cls, logging_outputs) -> None: 112 | """Aggregate logging outputs from data parallel training.""" 113 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 114 | st_loss_sum = sum(log.get("st_loss", 0) for log in logging_outputs) 115 | mt_loss_sum = sum(log.get("mt_loss", 0) for log in logging_outputs) 116 | ext_mt_loss_sum = sum(log.get("ext_mt_loss", 0) for log in logging_outputs) 117 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 118 | st_sample_size = sum(log.get("st_sample_size", 0) for log in logging_outputs) 119 | mt_sample_size = sum(log.get("mt_sample_size", 0) for log in logging_outputs) 120 | ext_mt_sample_size = sum(log.get("ext_mt_sample_size", 0) for log in logging_outputs) 121 | 122 | metrics.log_scalar( 123 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 124 | ) 125 | metrics.log_scalar( 126 | "st_loss", st_loss_sum / st_sample_size / math.log(2) if st_sample_size != 0 else 0, st_sample_size, round=3 127 | ) 128 | metrics.log_scalar( 129 | "mt_loss", mt_loss_sum / mt_sample_size / math.log(2) if mt_sample_size != 0 else 0, mt_sample_size, round=3 130 | ) 131 | metrics.log_scalar( 132 | "ext_mt_loss", ext_mt_loss_sum / ext_mt_sample_size / math.log(2) if ext_mt_sample_size != 0 else 0, ext_mt_sample_size, round=3 133 | ) 134 | 135 | @staticmethod 136 | def logging_outputs_can_be_summed() -> bool: 137 | """ 138 | Whether the logging outputs returned by `forward` can be summed 139 | across workers prior to calling `reduce_metrics`. Setting this 140 | to True will improves distributed training speed. 141 | """ 142 | return True -------------------------------------------------------------------------------- /cress/tasks/speech_to_text_modified.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import logging 7 | from pathlib import Path 8 | from argparse import Namespace 9 | 10 | from fairseq.data import Dictionary, encoders 11 | from cress.datasets.speech_to_text_dataset import ( 12 | S2TDataConfig, 13 | SpeechToTextDataset, 14 | SpeechToTextDatasetCreator, 15 | get_features_or_waveform, 16 | ) 17 | from fairseq.tasks import LegacyFairseqTask, register_task 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | @register_task("speech_to_text_modified") 24 | class SpeechToTextTaskModified(LegacyFairseqTask): 25 | @classmethod 26 | def add_args(cls, parser): 27 | parser.add_argument("data", help="manifest root path") 28 | parser.add_argument( 29 | "--config-yaml", 30 | type=str, 31 | default="config.yaml", 32 | help="Configuration YAML filename (under manifest root)", 33 | ) 34 | parser.add_argument( 35 | "--max-source-positions", 36 | default=6000, 37 | type=int, 38 | metavar="N", 39 | help="max number of tokens in the source sequence", 40 | ) 41 | parser.add_argument( 42 | "--max-target-positions", 43 | default=1024, 44 | type=int, 45 | metavar="N", 46 | help="max number of tokens in the target sequence", 47 | ) 48 | 49 | def __init__(self, args, tgt_dict): 50 | super().__init__(args) 51 | self.tgt_dict = tgt_dict 52 | self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) 53 | self.speaker_to_id = self._get_speaker_to_id() 54 | if ( 55 | self.data_cfg.prepend_tgt_lang_tag 56 | and self.data_cfg.prepend_bos_and_append_tgt_lang_tag 57 | ): 58 | raise ValueError( 59 | "Please set only one of the two options to avoid adding target token multiple times" 60 | ) 61 | 62 | def _get_speaker_to_id(self): 63 | speaker_to_id = None 64 | speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") 65 | if speaker_set_filename is not None: 66 | speaker_set_path = Path(self.args.data) / speaker_set_filename 67 | with open(speaker_set_path) as f: 68 | speaker_to_id = {r.strip(): i for i, r in enumerate(f)} 69 | return speaker_to_id 70 | 71 | @classmethod 72 | def setup_task(cls, args, **kwargs): 73 | data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) 74 | dict_path = Path(args.data) / data_cfg.vocab_filename 75 | if not dict_path.is_file(): 76 | raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") 77 | tgt_dict = Dictionary.load(dict_path.as_posix()) 78 | logger.info( 79 | f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" 80 | ) 81 | 82 | if getattr(args, "train_subset", None) is not None: 83 | if not all(s.startswith("train") for s in args.train_subset.split(",")): 84 | raise ValueError('Train splits should be named like "train*".') 85 | return cls(args, tgt_dict) 86 | 87 | def build_criterion(self, args): 88 | from fairseq import criterions 89 | 90 | if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: 91 | raise ValueError( 92 | 'Please set "--ignore-prefix-size 1" since ' 93 | "target language ID token is prepended as BOS." 94 | ) 95 | return criterions.build_criterion(args, self) 96 | 97 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 98 | is_train_split = split.startswith("train") 99 | pre_tokenizer = self.build_tokenizer(self.args) 100 | bpe_tokenizer = self.build_bpe(self.args) 101 | self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( 102 | self.args.data, 103 | self.data_cfg, 104 | split, 105 | self.tgt_dict, 106 | pre_tokenizer, 107 | bpe_tokenizer, 108 | is_train_split=is_train_split, 109 | epoch=epoch, 110 | seed=self.args.seed, 111 | speaker_to_id=self.speaker_to_id, 112 | ) 113 | 114 | @property 115 | def target_dictionary(self): 116 | return self.tgt_dict 117 | 118 | @property 119 | def source_dictionary(self): 120 | return None 121 | 122 | def max_positions(self): 123 | return self.args.max_source_positions, self.args.max_target_positions 124 | 125 | def build_model(self, args, from_checkpoint=False): 126 | args.input_feat_per_channel = self.data_cfg.input_feat_per_channel 127 | args.input_channels = self.data_cfg.input_channels 128 | args.speaker_to_id = self.speaker_to_id 129 | return super(SpeechToTextTaskModified, self).build_model(args, from_checkpoint) 130 | 131 | def build_generator( 132 | self, 133 | models, 134 | args, 135 | seq_gen_cls=None, 136 | extra_gen_cls_kwargs=None, 137 | ): 138 | if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: 139 | raise ValueError( 140 | 'Please set "--prefix-size 1" since ' 141 | "target language ID token is prepended as BOS." 142 | ) 143 | lang_token_ids = { 144 | i 145 | for s, i in self.tgt_dict.indices.items() 146 | if SpeechToTextDataset.is_lang_tag(s) 147 | } 148 | 149 | if extra_gen_cls_kwargs is None: 150 | extra_gen_cls_kwargs = {} 151 | extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids 152 | 153 | eos_token = ( 154 | args.eos_token 155 | if "eos_token" in args and args.eos_token is not None 156 | else self.data_cfg.config.get("eos_token", None) 157 | ) 158 | 159 | if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: 160 | raise Warning( 161 | "Please provide --eos_token to replace eos in sequence generator" 162 | ) 163 | 164 | eos_id = self.tgt_dict.index(eos_token) if eos_token else None 165 | extra_gen_cls_kwargs["eos"] = eos_id 166 | 167 | return super().build_generator( 168 | models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs 169 | ) 170 | 171 | def build_tokenizer(self, args): 172 | logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") 173 | return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) 174 | 175 | def build_bpe(self, args): 176 | logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") 177 | return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) 178 | 179 | def get_interactive_tokens_and_lengths(self, lines, encode_fn): 180 | n_frames = [get_features_or_waveform(p).shape[0] for p in lines] 181 | return lines, n_frames 182 | 183 | def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): 184 | return SpeechToTextDataset( 185 | "interactive", False, self.data_cfg, src_tokens, src_lengths 186 | ) 187 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRESS: Understanding and Bridging the Modality Gap for Speech Translation 2 | 3 | **Qingkai Fang, Yang Feng\* | Institute of Computing Technology, Chinese Academy of Sciences (ICT/CAS)** 4 | 5 | This is a PyTorch implementation of the **ACL 2023 main conference paper** [Understanding and Bridging the Modality Gap for Speech Translation](https://arxiv.org/abs/2305.08706). 6 | 7 | 🙌 We provide our **code**, **ST/MT model weights**, and **processed ST/MT data** in this repository. 8 | 9 | 👀 Also see our other works dedicated to **bridging the modality gap** for speech translation (ST): 10 | 11 | - [STEMM (ACL 2022)](https://aclanthology.org/2022.acl-long.486/) 12 | - [CMOT (ACL 2023)](https://arxiv.org/abs/2305.14635) 13 | 14 | 15 | ## Release 16 | 17 | We have released the following assets for **all 8 translation directions of MuST-C**: 18 | 19 | - Processed ST data in `.tsv` format 20 | - Processed external MT data in fairseq binary format 21 | - SentencePiece vocabulary 22 | - Pretrained MT models in both `base` and `expand` settings 23 | - Pretrained CRESS models in both `base` and `expand` settings 24 | 25 | | | Link | Password | 26 | | --------------------- | ----------------------------------------------- | -------- | 27 | | **Processed ST Data** | https://pan.baidu.com/s/1J7BgcbSNwma4SdJfHENRdg | 94wu | 28 | | **Processed MT Data** | https://pan.baidu.com/s/1gDMOU35_pug73y0kd-F3vw | 6tbk | 29 | | **Vocabulary** | https://pan.baidu.com/s/13ucCEVzAdxRu99bdZ2oIdw | nph3 | 30 | | **MT Model (base)** | https://pan.baidu.com/s/1xm6myQfY-wYS4D0_rMBT_g | tm6k | 31 | | **MT Model (expand)** | https://pan.baidu.com/s/1byufAhoYQmgA8DCf9WUZQg | 61g4 | 32 | | **CRESS Model (base)** | https://pan.baidu.com/s/1_KCS_-a_Ss4Bm40dTQc6Vw | ra8j | 33 | | **CRESS Model (expand)** | https://pan.baidu.com/s/1zGJKmJf8TEnwBLzpOmfGYQ | ctyf | 34 | 35 | 36 | 37 | ## Environment Configuration 38 | 39 | 1. Clone this repository: 40 | 41 | ``` 42 | git clone git@github.com:ictnlp/CRESS.git 43 | cd CRESS/ 44 | ``` 45 | 46 | 2. Install `fairseq`: 47 | 48 | ``` 49 | cd fairseq/ 50 | pip install --editable ./ 51 | python setup.py build develop 52 | ``` 53 | 54 | 3. We organize our implementation as fairseq plug-ins in the `cress` directory: 55 | 56 | ``` 57 | . 58 | ├── criterions 59 | │ ├── __init__.py 60 | │ ├── speech_and_text_translation_criterion.py 61 | │ ├── speech_and_text_translation_with_oracle_reg_adaptive_criterion.py 62 | │ └── speech_and_text_translation_with_oracle_reg_criterion.py 63 | ├── datasets 64 | │ ├── audio_utils.py 65 | │ ├── __init__.py 66 | │ ├── speech_and_text_translation_dataset.py 67 | │ └── speech_to_text_dataset.py 68 | ├── __init__.py 69 | ├── models 70 | │ ├── hubert_transformer.py 71 | │ └── __init__.py 72 | ├── tasks 73 | │ ├── __init__.py 74 | │ ├── speech_and_text_translation.py 75 | │ └── speech_to_text_modified.py 76 | ├── test_scripts 77 | │ ├── avg_epoch.sh 78 | │ ├── test.en-x.mt.sh 79 | │ └── test.en-x.st.sh 80 | └── train_scripts 81 | ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress_adaptive.sh 82 | ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress.sh 83 | ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.sh 84 | ├── train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.sh 85 | └── train.en-x.postln.wmt_pretrain.sh 86 | ``` 87 | 88 | You can import our implementation with `--user-dir cress` in fairseq. 89 | 90 | 91 | 92 | ## Data Preparation 93 | 94 | 1. Make directories to store ST (MuST-C) and MT (WMT) datasets. Please specify the target language via `$TGT_LANG`: 95 | 96 | ``` 97 | TGT_LANG=de 98 | MUSTC_ROOT=data/mustc 99 | WMT_ROOT=data/wmt 100 | mkdir -p $MUSTC_ROOT $WMT_ROOT 101 | ``` 102 | 103 | 2. Download the [MuST-C v1.0](https://ict.fbk.eu/must-c/) archive to the `$MUSTC_ROOT` directory and uncompress it: 104 | 105 | ``` 106 | cd $MUSTC_ROOT 107 | tar -xzvf MUSTC_v1.0_en-${TGT_LANG}.tar.gz 108 | ``` 109 | 110 | 3. We provide the processed ST data and the SentencePiece vocabulary files. You can download them via the Baidu Netdisk: 111 | 112 | | | Link | Password | 113 | | --------------------- | ----------------------------------------------- | -------- | 114 | | **Processed ST Data** | https://pan.baidu.com/s/1J7BgcbSNwma4SdJfHENRdg | 94wu | 115 | | **Vocabulary** | https://pan.baidu.com/s/13ucCEVzAdxRu99bdZ2oIdw | nph3 | 116 | 117 | Put the downloaded files in the `$MUSTC_ROOT/en-${TGT_LANG}/` directory. It should look like the this: 118 | 119 | ``` 120 | . 121 | ├── binary 122 | ├── config.yaml 123 | ├── data 124 | ├── dev.tsv 125 | ├── docs 126 | ├── spm_unigram10000.model 127 | ├── spm_unigram10000.txt 128 | ├── spm_unigram10000.vocab 129 | ├── train.tsv 130 | └── tst-COMMON.tsv 131 | ``` 132 | 133 | 4. For MT pretraining, we need additional MT datasets. We provide the processed MT data in the fairseq binary format. You can download them via the Baidu Netdisk: 134 | 135 | | | Link | Password | 136 | | --------------------- | ----------------------------------------------- | -------- | 137 | | **Processed MT Data** | https://pan.baidu.com/s/1gDMOU35_pug73y0kd-F3vw | 6tbk | 138 | 139 | Put the downloaded files in the `$WMT_ROOT/en-${TGT_LANG}` directory. 140 | 141 | 142 | 143 | ## Model Training 144 | 145 | The modal training contains two steps: MT pretraining and ST finetuning. 146 | 147 | - In the `base` setting, we pretrain the model with `` pairs from the MuST-C dataset. 148 | - In the `expand` setting, we first pretrain the model with external MT datasets, and then pretrain the model with `` pairs from MuST-C. 149 | 150 | All the training scripts below are configured to run using **4 GPUs**. You can adjust `--update-freq` depending on the number of your available GPUS. 151 | 152 | Before training, please download the [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) model and place it in the `checkpoints/hubert_base_ls960.pt` path. 153 | 154 | ### MT Pretraining 155 | 156 | 1. (Optional) Pretrain the model with the external MT dataset. Please run the script: 157 | 158 | ``` 159 | sh cress/train_scripts/train.en-x.postln.wmt_pretrain.sh $TGT_LANG 160 | ``` 161 | 162 | You should adjust the maximum training steps (`--max-update`) based on the size of the training data. 163 | 164 | After training, please average the last 5 checkpoints: 165 | 166 | ``` 167 | python scripts/average_checkpoints.py \ 168 | --inputs checkpoints/en-$tgt.postln.wmt_pretrain \ 169 | --num-epoch-checkpoints 5 \ 170 | --output checkpoints/$ckpt/avg_last_5_epoch.pt 171 | ``` 172 | 173 | 2. Pretrain the model with `` pairs from the MuST-C dataset. Please run the script: 174 | 175 | ``` 176 | sh cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.sh $TGT_LANG 177 | ``` 178 | 179 | After training, please average the last 10 checkpoints. You can use the script `cress/test_scripts/avg_epoch.sh`. The averaged checkpoint will be used to intialize the ST model. 180 | 181 | **To ensure consistent performance, we have released our checkpoints of pretrained MT models in both `base` and `expand` settings. You can download them via the Baidu Netdisk.** 182 | 183 | | | Link | Password | 184 | | --------------- | ----------------------------------------------- | -------- | 185 | | **MT (base)** | https://pan.baidu.com/s/1xm6myQfY-wYS4D0_rMBT_g | tm6k | 186 | | **MT (expand)** | https://pan.baidu.com/s/1byufAhoYQmgA8DCf9WUZQg | 61g4 | 187 | 188 | ### Multitask Learning 189 | 190 | 1. For multitask learning (the `MTL` baseline in the paper), please run the script: 191 | 192 | ``` 193 | sh cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.sh $TGT_LANG 194 | ``` 195 | 196 | ### Cross-modal Regularization with Scheduled Sampling (CRESS) 197 | 198 | 1. For the `CRESS` training, please first run the script below. Note that token-level adaptive training is not used for the first 20 epochs of training. 199 | 200 | ``` 201 | sh cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress.sh $TGT_LANG 202 | ``` 203 | 204 | 2. For the subsqeuent training epochs, token-level adaptive training will be used. Please run the script: 205 | 206 | ``` 207 | sh cress/train_scripts/train.en-x.postln.wmt_pretrain.mustc_mt_pretrain.mustc_st+mt.cress_adaptive.sh $TGT_LANG 208 | ``` 209 | 210 | We also released checkpoints of CRESS. You can download and evaluate them. 211 | 212 | | | Link | Password | 213 | | --------------- | ----------------------------------------------- | -------- | 214 | | **CRESS (base)** | https://pan.baidu.com/s/1_KCS_-a_Ss4Bm40dTQc6Vw | ra8j | 215 | | **CRESS (expand)** | https://pan.baidu.com/s/1zGJKmJf8TEnwBLzpOmfGYQ | ctyf | 216 | 217 | 218 | 219 | ## Evaluation 220 | 221 | For evaluation, please first average the last 10 checkpoints using the `cress/test_scripts/avg_epoch.sh` script. Next, please use the scripts below to evaluate the ST/MT performance of the averaged checkpoint. 222 | 223 | The values of `--lenpen` vary across different target languages as follows: 224 | 225 | | TGT_LANG | De | Fr | Es | Ro | Ru | It | Pt | Nl | 226 | | ---------- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 227 | | `--lenpen` | 1.2 | 1.8 | 0.6 | 1.4 | 0.8 | 1.0 | 1.4 | 1.0 | 228 | 229 | ### ST Evaluation 230 | 231 | To evaluation the ST performance of the model, please use the `cress/test_scripts/test.en-x.st.sh` script: 232 | 233 | ``` 234 | sh cress/test_scripts/test.en-x.st.sh $CKPT $TGT_LANG $LENPEN 235 | ``` 236 | 237 | ### MT Evaluation 238 | 239 | To evaluation the MT performance of the model, please use the `cress/test_scripts/test.en-x.mt.sh` script. 240 | 241 | ``` 242 | sh cress/test_scripts/test.en-x.mt.sh $CKPT $TGT_LANG $LENPEN 243 | ``` 244 | 245 | 246 | 247 | ## Citation 248 | 249 | If this repository is useful for you, please cite as: 250 | 251 | ``` 252 | @inproceedings{fang-and-feng-2023-understanding, 253 | title = {Understanding and Bridging the Modality Gap for Speech Translation}, 254 | author = {Fang, Qingkai and Feng, Yang}, 255 | booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics}, 256 | year = {2023}, 257 | } 258 | ``` 259 | -------------------------------------------------------------------------------- /cress/datasets/audio_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import mmap 8 | from pathlib import Path 9 | from typing import BinaryIO, List, Optional, Tuple, Union 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import torchaudio 15 | 16 | SF_AUDIO_FILE_EXTENSIONS = {".wav", ".flac", ".ogg"} 17 | FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS = {".npy", ".wav", ".flac", ".ogg"} 18 | 19 | 20 | def convert_waveform( 21 | waveform: Union[np.ndarray, torch.Tensor], 22 | sample_rate: int, 23 | normalize_volume: bool = False, 24 | to_mono: bool = False, 25 | to_sample_rate: Optional[int] = None, 26 | ) -> Tuple[Union[np.ndarray, torch.Tensor], int]: 27 | """convert a waveform: 28 | - to a target sample rate 29 | - from multi-channel to mono channel 30 | - volume normalization 31 | 32 | Args: 33 | waveform (numpy.ndarray or torch.Tensor): 2D original waveform 34 | (channels x length) 35 | sample_rate (int): original sample rate 36 | normalize_volume (bool): perform volume normalization 37 | to_mono (bool): convert to mono channel if having multiple channels 38 | to_sample_rate (Optional[int]): target sample rate 39 | Returns: 40 | waveform (numpy.ndarray): converted 2D waveform (channels x length) 41 | sample_rate (float): target sample rate 42 | """ 43 | try: 44 | import torchaudio.sox_effects as ta_sox 45 | except ImportError: 46 | raise ImportError("Please install torchaudio: pip install torchaudio") 47 | 48 | effects = [] 49 | if normalize_volume: 50 | effects.append(["gain", "-n"]) 51 | if to_sample_rate is not None and to_sample_rate != sample_rate: 52 | effects.append(["rate", f"{to_sample_rate}"]) 53 | if to_mono and waveform.shape[0] > 1: 54 | effects.append(["channels", "1"]) 55 | if len(effects) > 0: 56 | is_np_input = isinstance(waveform, np.ndarray) 57 | _waveform = torch.from_numpy(waveform) if is_np_input else waveform 58 | converted, converted_sample_rate = ta_sox.apply_effects_tensor( 59 | _waveform, sample_rate, effects 60 | ) 61 | if is_np_input: 62 | converted = converted.numpy() 63 | return converted, converted_sample_rate 64 | return waveform, sample_rate 65 | 66 | 67 | def get_waveform( 68 | path_or_fp: Union[str, BinaryIO], 69 | normalization: bool = True, 70 | mono: bool = True, 71 | frames: int = -1, 72 | start: int = 0, 73 | always_2d: bool = True, 74 | output_sample_rate: Optional[int] = None, 75 | normalize_volume: bool = False, 76 | ) -> Tuple[np.ndarray, int]: 77 | """Get the waveform and sample rate of a 16-bit WAV/FLAC/OGG Vorbis audio. 78 | 79 | Args: 80 | path_or_fp (str or BinaryIO): the path or file-like object 81 | normalization (bool): normalize values to [-1, 1] (Default: True) 82 | mono (bool): convert multi-channel audio to mono-channel one 83 | frames (int): the number of frames to read. (-1 for reading all) 84 | start (int): Where to start reading. A negative value counts from the end. 85 | always_2d (bool): always return 2D array even for mono-channel audios 86 | output_sample_rate (Optional[int]): output sample rate 87 | normalize_volume (bool): normalize volume 88 | Returns: 89 | waveform (numpy.ndarray): 1D or 2D waveform (channels x length) 90 | sample_rate (float): sample rate 91 | """ 92 | if isinstance(path_or_fp, str): 93 | ext = Path(path_or_fp).suffix 94 | if ext not in SF_AUDIO_FILE_EXTENSIONS: 95 | raise ValueError(f"Unsupported audio format: {ext}") 96 | 97 | try: 98 | import soundfile as sf 99 | except ImportError: 100 | raise ImportError("Please install soundfile: pip install soundfile") 101 | 102 | waveform, sample_rate = sf.read( 103 | path_or_fp, dtype="float32", always_2d=True, frames=frames, start=start 104 | ) 105 | waveform = waveform.T # T x C -> C x T 106 | waveform, sample_rate = convert_waveform( 107 | waveform, 108 | sample_rate, 109 | normalize_volume=normalize_volume, 110 | to_mono=mono, 111 | to_sample_rate=output_sample_rate, 112 | ) 113 | 114 | if not normalization: 115 | waveform *= 2**15 # denormalized to 16-bit signed integers 116 | if not always_2d: 117 | waveform = waveform.squeeze(axis=0) 118 | return waveform, sample_rate 119 | 120 | def get_segment_waveform( 121 | path_or_fp, offset, n_frames, normalization=True 122 | ): 123 | if isinstance(path_or_fp, str): 124 | ext = Path(path_or_fp).suffix 125 | if ext not in {".flac", ".wav", ".mp3"}: 126 | raise ValueError(f"Unsupported audio format: {ext}") 127 | 128 | waveform, sample_rate = torchaudio.load(path_or_fp, frame_offset=offset, num_frames=n_frames) 129 | if not normalization: 130 | waveform *= 2 ** 15 131 | return waveform, sample_rate 132 | 133 | def _get_kaldi_fbank( 134 | waveform: np.ndarray, sample_rate: int, n_bins=80 135 | ) -> Optional[np.ndarray]: 136 | """Get mel-filter bank features via PyKaldi.""" 137 | try: 138 | from kaldi.feat.fbank import Fbank, FbankOptions 139 | from kaldi.feat.mel import MelBanksOptions 140 | from kaldi.feat.window import FrameExtractionOptions 141 | from kaldi.matrix import Vector 142 | 143 | mel_opts = MelBanksOptions() 144 | mel_opts.num_bins = n_bins 145 | frame_opts = FrameExtractionOptions() 146 | frame_opts.samp_freq = sample_rate 147 | opts = FbankOptions() 148 | opts.mel_opts = mel_opts 149 | opts.frame_opts = frame_opts 150 | fbank = Fbank(opts=opts) 151 | features = fbank.compute(Vector(waveform.squeeze()), 1.0).numpy() 152 | return features 153 | except ImportError: 154 | return None 155 | 156 | 157 | def _get_torchaudio_fbank( 158 | waveform: np.ndarray, sample_rate, n_bins=80 159 | ) -> Optional[np.ndarray]: 160 | """Get mel-filter bank features via TorchAudio.""" 161 | try: 162 | import torchaudio.compliance.kaldi as ta_kaldi 163 | 164 | waveform = torch.from_numpy(waveform) 165 | features = ta_kaldi.fbank( 166 | waveform, num_mel_bins=n_bins, sample_frequency=sample_rate 167 | ) 168 | return features.numpy() 169 | except ImportError: 170 | return None 171 | 172 | 173 | def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray: 174 | """Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi 175 | (faster CPP implementation) to TorchAudio (Python implementation). Note that 176 | Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the 177 | waveform should not be normalized.""" 178 | waveform, sample_rate = get_waveform(path_or_fp, normalization=False) 179 | 180 | features = _get_kaldi_fbank(waveform, sample_rate, n_bins) 181 | if features is None: 182 | features = _get_torchaudio_fbank(waveform, sample_rate, n_bins) 183 | if features is None: 184 | raise ImportError( 185 | "Please install pyKaldi or torchaudio to enable " 186 | "online filterbank feature extraction" 187 | ) 188 | 189 | return features 190 | 191 | 192 | def is_npy_data(data: bytes) -> bool: 193 | return data[0] == 147 and data[1] == 78 194 | 195 | 196 | def is_sf_audio_data(data: bytes) -> bool: 197 | is_wav = data[0] == 82 and data[1] == 73 and data[2] == 70 198 | is_flac = data[0] == 102 and data[1] == 76 and data[2] == 97 199 | is_ogg = data[0] == 79 and data[1] == 103 and data[2] == 103 200 | return is_wav or is_flac or is_ogg 201 | 202 | 203 | def mmap_read(path: str, offset: int, length: int) -> bytes: 204 | with open(path, "rb") as f: 205 | with mmap.mmap(f.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_o: 206 | data = mmap_o[offset : offset + length] 207 | return data 208 | 209 | 210 | def read_from_stored_zip(zip_path: str, offset: int, length: int) -> bytes: 211 | return mmap_read(zip_path, offset, length) 212 | 213 | 214 | def parse_path(path: str) -> Tuple[str, List[int]]: 215 | """Parse data path which is either a path to 216 | 1. a .npy/.wav/.flac/.ogg file 217 | 2. a stored ZIP file with slicing info: "[zip_path]:[offset]:[length]" 218 | 219 | Args: 220 | path (str): the data path to parse 221 | 222 | Returns: 223 | file_path (str): the file path 224 | slice_ptr (list of int): empty in case 1; 225 | byte offset and length for the slice in case 2 226 | """ 227 | 228 | if Path(path).suffix in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: 229 | _path, slice_ptr = path, [] 230 | else: 231 | _path, *slice_ptr = path.split(":") 232 | if not Path(_path).is_file(): 233 | raise FileNotFoundError(f"File not found: {_path}") 234 | assert len(slice_ptr) in {0, 2}, f"Invalid path: {path}" 235 | slice_ptr = [int(i) for i in slice_ptr] 236 | return _path, slice_ptr 237 | 238 | 239 | def get_window(window_fn: callable, n_fft: int, win_length: int) -> torch.Tensor: 240 | padding = n_fft - win_length 241 | assert padding >= 0 242 | return F.pad(window_fn(win_length), (padding // 2, padding - padding // 2)) 243 | 244 | 245 | def get_fourier_basis(n_fft: int) -> torch.Tensor: 246 | basis = np.fft.fft(np.eye(n_fft)) 247 | basis = np.vstack( 248 | [np.real(basis[: n_fft // 2 + 1, :]), np.imag(basis[: n_fft // 2 + 1, :])] 249 | ) 250 | return torch.from_numpy(basis).float() 251 | 252 | 253 | def get_mel_filters( 254 | sample_rate: int, n_fft: int, n_mels: int, f_min: float, f_max: float 255 | ) -> torch.Tensor: 256 | try: 257 | import librosa 258 | except ImportError: 259 | raise ImportError("Please install librosa: pip install librosa") 260 | basis = librosa.filters.mel(sample_rate, n_fft, n_mels, f_min, f_max) 261 | return torch.from_numpy(basis).float() 262 | 263 | 264 | class TTSSpectrogram(torch.nn.Module): 265 | def __init__( 266 | self, 267 | n_fft: int, 268 | win_length: int, 269 | hop_length: int, 270 | window_fn: callable = torch.hann_window, 271 | return_phase: bool = False, 272 | ) -> None: 273 | super(TTSSpectrogram, self).__init__() 274 | self.n_fft = n_fft 275 | self.hop_length = hop_length 276 | self.return_phase = return_phase 277 | 278 | basis = get_fourier_basis(n_fft).unsqueeze(1) 279 | basis *= get_window(window_fn, n_fft, win_length) 280 | self.register_buffer("basis", basis) 281 | 282 | def forward( 283 | self, waveform: torch.Tensor 284 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: 285 | padding = (self.n_fft // 2, self.n_fft // 2) 286 | x = F.pad(waveform.unsqueeze(1), padding, mode="reflect") 287 | x = F.conv1d(x, self.basis, stride=self.hop_length) 288 | real_part = x[:, : self.n_fft // 2 + 1, :] 289 | imag_part = x[:, self.n_fft // 2 + 1 :, :] 290 | magnitude = torch.sqrt(real_part**2 + imag_part**2) 291 | if self.return_phase: 292 | phase = torch.atan2(imag_part, real_part) 293 | return magnitude, phase 294 | return magnitude 295 | 296 | 297 | class TTSMelScale(torch.nn.Module): 298 | def __init__( 299 | self, n_mels: int, sample_rate: int, f_min: float, f_max: float, n_stft: int 300 | ) -> None: 301 | super(TTSMelScale, self).__init__() 302 | basis = get_mel_filters(sample_rate, (n_stft - 1) * 2, n_mels, f_min, f_max) 303 | self.register_buffer("basis", basis) 304 | 305 | def forward(self, specgram: torch.Tensor) -> torch.Tensor: 306 | return torch.matmul(self.basis, specgram) 307 | -------------------------------------------------------------------------------- /cress/criterions/speech_and_text_translation_with_oracle_reg_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import numpy as np 8 | from dataclasses import dataclass, field 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from fairseq import metrics, utils 13 | from fairseq.criterions import FairseqCriterion, register_criterion 14 | from fairseq.criterions.label_smoothed_cross_entropy import ( 15 | LabelSmoothedCrossEntropyCriterion, 16 | LabelSmoothedCrossEntropyCriterionConfig, 17 | label_smoothed_nll_loss, 18 | ) 19 | from fairseq.dataclass import FairseqDataclass 20 | from omegaconf import II 21 | 22 | 23 | @dataclass 24 | class SpeechAndTextTranslationOracleRegCriterionConfig(LabelSmoothedCrossEntropyCriterionConfig): 25 | reg_weight: float = field( 26 | default=1.0, 27 | metadata={"help": "weight of regularization loss"}, 28 | ) 29 | reg_loss_type: str = field( 30 | default="jsd", 31 | metadata={"help": "loss type of regularization (e.g. jsd, l1)"}, 32 | ) 33 | use_word_level_oracle: bool = field( 34 | default=False, 35 | metadata={"help": "use word level oracles"}, 36 | ) 37 | decay_k: float = field( 38 | default=15, 39 | metadata={"help": "decay hyper-paramter k"}, 40 | ) 41 | use_word_gumbel_noise: bool = field( 42 | default=False, 43 | metadata={"help": "select word with gumbel noise"}, 44 | ) 45 | gumbel_temperature: float = field( 46 | default=1.0, 47 | metadata={"help": "temperature of gumbel max in word oracles"}, 48 | ) 49 | 50 | @register_criterion( 51 | "speech_and_text_translation_with_oracle_reg", dataclass=SpeechAndTextTranslationOracleRegCriterionConfig 52 | ) 53 | class SpeechAndTextTranslatioOracleRegCriterion(LabelSmoothedCrossEntropyCriterion): 54 | def __init__( 55 | self, 56 | task, 57 | sentence_avg, 58 | label_smoothing, 59 | ignore_prefix_size=0, 60 | report_accuracy=False, 61 | reg_weight=1.0, 62 | reg_loss_type="jsd", 63 | use_word_level_oracle=False, 64 | decay_k=15, 65 | use_word_gumbel_noise=False, 66 | gumbel_temperature=1.0, 67 | ): 68 | super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy) 69 | self.reg_weight = reg_weight 70 | self.padding_idx = task.target_dictionary.pad() 71 | self.tgt_dict = task.target_dictionary 72 | self.bpe_tokenizer = task.bpe_tokenizer 73 | self.reg_loss_type = reg_loss_type 74 | self.use_word_level_oracle = use_word_level_oracle 75 | self.decay_k = decay_k 76 | self.use_word_gumbel_noise = use_word_gumbel_noise 77 | self.gumbel_temperature = gumbel_temperature 78 | 79 | def decay_prob(self, epoch): 80 | k = self.decay_k 81 | return k / (k + np.exp(epoch / k)) 82 | 83 | def get_word_oracle_tokens(self, pred_logits, prev_output_tokens, epoch, epsilon=1e-6): 84 | bsz, _ = prev_output_tokens.size() 85 | if self.use_word_gumbel_noise: 86 | uniform = torch.Tensor(pred_logits.size()).to(pred_logits.device).float().uniform_(0, 1) 87 | gumbel = -torch.log(-torch.log(uniform + epsilon) + epsilon) 88 | pred_logits = (pred_logits + gumbel.to(pred_logits.device)) / self.gumbel_temperature 89 | pred_tokens = torch.max(pred_logits, dim=-1)[1] 90 | bos_idx = prev_output_tokens[0, 0].repeat(bsz, 1).to(pred_tokens) 91 | pred_tokens = torch.cat([bos_idx, pred_tokens], dim=1)[:, :-1] 92 | sample_gold_prob = self.decay_prob(epoch) 93 | sample_gold_prob = sample_gold_prob * torch.ones_like(prev_output_tokens, dtype=torch.float32) 94 | sample_gold_mask = torch.bernoulli(sample_gold_prob).long() 95 | 96 | return prev_output_tokens * sample_gold_mask + pred_tokens * (1 - sample_gold_mask) 97 | 98 | def forward_st(self, model, sample, reduce, word_oracle=False): 99 | audio_input = { 100 | "src_tokens": sample["net_input"]["audio"], 101 | "src_lengths": sample["net_input"]["audio_lengths"], 102 | "mode": "st", 103 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 104 | } 105 | audio_encoder_out = model.encoder( 106 | audio_input["src_tokens"], 107 | audio_input["src_lengths"], 108 | audio_input["mode"] 109 | ) 110 | prev_output_tokens = audio_input["prev_output_tokens"] 111 | with torch.no_grad(): 112 | if word_oracle: 113 | audio_output = model.decoder( 114 | prev_output_tokens, 115 | audio_encoder_out, 116 | ) 117 | prev_output_tokens = self.get_word_oracle_tokens( 118 | audio_output[0].detach(), 119 | prev_output_tokens, 120 | model.epoch, 121 | ) 122 | audio_output = model.decoder( 123 | prev_output_tokens, 124 | audio_encoder_out, 125 | ) 126 | loss, _, lprobs, target = self.compute_loss_with_lprobs(model, audio_output, sample, reduce=reduce) 127 | return loss, lprobs, target 128 | 129 | def forward_mt(self, model, sample, reduce, word_oracle=False): 130 | text_input = { 131 | "src_tokens": sample["net_input"]["source"], 132 | "src_lengths": sample["net_input"]["source_lengths"], 133 | "mode": "mt", 134 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 135 | } 136 | text_encoder_out = model.encoder( 137 | text_input["src_tokens"], 138 | text_input["src_lengths"], 139 | text_input["mode"] 140 | ) 141 | prev_output_tokens = text_input["prev_output_tokens"] 142 | with torch.no_grad(): 143 | if word_oracle: 144 | text_output = model.decoder( 145 | prev_output_tokens, 146 | text_encoder_out, 147 | ) 148 | prev_output_tokens = self.get_word_oracle_tokens( 149 | text_output[0].detach(), 150 | prev_output_tokens, 151 | model.epoch, 152 | ) 153 | text_output = model.decoder( 154 | prev_output_tokens, 155 | text_encoder_out, 156 | ) 157 | loss, _, lprobs, target = self.compute_loss_with_lprobs(model, text_output, sample, reduce=reduce) 158 | return loss, lprobs, target 159 | 160 | def forward_ext_mt(self, model, sample, reduce): 161 | text_output = model(**sample["net_input"]) 162 | loss, _ = self.compute_loss(model, text_output, sample, reduce=reduce) 163 | return loss 164 | 165 | def forward(self, model, sample, reduce=True): 166 | """Compute the loss for the given sample. 167 | Returns a tuple with three elements: 168 | 1) the loss 169 | 2) the sample size, which is used as the denominator for the gradient 170 | 3) logging outputs to display while training 171 | """ 172 | st_loss, mt_loss, ext_mt_loss, reg_loss = torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda() 173 | st_size, mt_size, ext_mt_size, reg_size = 0, 0, 0, 0 174 | 175 | mode = sample["net_input"]["mode"] 176 | if mode == "st": 177 | if self.training: 178 | word_oracle = self.use_word_level_oracle 179 | st_loss, st_lprobs, st_target = self.forward_st(model, sample, reduce, word_oracle) 180 | mt_loss, mt_lprobs, mt_target = self.forward_mt(model, sample, reduce, word_oracle) 181 | reg_loss = self.compute_reg_loss(st_lprobs, mt_lprobs, st_target, mt_target) 182 | loss = st_loss + mt_loss + self.reg_weight * reg_loss 183 | st_size = mt_size = sample_size = reg_size = sample["ntokens"] 184 | else: 185 | st_loss, _, _ = self.forward_st(model, sample, reduce) 186 | loss = st_loss 187 | st_size = sample_size = sample["ntokens"] 188 | elif mode == "ext_mt": 189 | loss = ext_mt_loss = self.forward_ext_mt(model, sample, reduce) 190 | ext_mt_size = sample_size = sample["ntokens"] 191 | 192 | logging_output = { 193 | "loss": loss.data, 194 | "st_loss": st_loss.data, 195 | "st_sample_size": st_size, 196 | "mt_loss": mt_loss.data, 197 | "mt_sample_size": mt_size, 198 | "ext_mt_loss": ext_mt_loss.data, 199 | "ext_mt_sample_size": ext_mt_size, 200 | "reg_loss": reg_loss.data, 201 | "reg_sample_size": reg_size, 202 | "ntokens": sample["ntokens"], 203 | "nsentences": sample["target"].size(0), 204 | "sample_size": sample_size, 205 | } 206 | 207 | return loss, sample_size, logging_output 208 | 209 | def compute_loss_with_lprobs(self, model, net_output, sample, reduce=True): 210 | lprobs, target = self.get_lprobs_and_target(model, net_output, sample) 211 | loss, nll_loss = label_smoothed_nll_loss( 212 | lprobs, 213 | target, 214 | self.eps, 215 | ignore_index=self.padding_idx, 216 | reduce=reduce, 217 | ) 218 | return loss, nll_loss, lprobs, target 219 | 220 | def compute_jsd_loss(self, st_lprobs, mt_lprobs, st_target, mt_target, ignore_index): 221 | kl_loss_st = F.kl_div(mt_lprobs, st_lprobs, log_target=True, reduction="none").sum(-1) 222 | kl_loss_mt = F.kl_div(st_lprobs, mt_lprobs, log_target=True, reduction="none").sum(-1) 223 | pad_mask = st_target.eq(ignore_index) 224 | kl_loss_st.masked_fill_(pad_mask, 0.0) 225 | pad_mask = mt_target.eq(ignore_index) 226 | kl_loss_mt.masked_fill_(pad_mask, 0.0) 227 | kl_loss_st = kl_loss_st.sum() 228 | kl_loss_mt = kl_loss_mt.sum() 229 | kl_loss = (kl_loss_st + kl_loss_mt) / 2.0 230 | return kl_loss 231 | 232 | def compute_reg_loss(self, st_lprobs, mt_lprobs, st_target, mt_target): 233 | if self.reg_loss_type == "jsd": 234 | return self.compute_jsd_loss(st_lprobs, mt_lprobs, st_target, mt_target, self.padding_idx) 235 | else: 236 | raise NotImplementedError 237 | 238 | @classmethod 239 | def reduce_metrics(cls, logging_outputs) -> None: 240 | """Aggregate logging outputs from data parallel training.""" 241 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 242 | st_loss_sum = sum(log.get("st_loss", 0) for log in logging_outputs) 243 | mt_loss_sum = sum(log.get("mt_loss", 0) for log in logging_outputs) 244 | ext_mt_loss_sum = sum(log.get("ext_mt_loss", 0) for log in logging_outputs) 245 | reg_loss_sum = sum(log.get("reg_loss", 0) for log in logging_outputs) 246 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 247 | st_sample_size = sum(log.get("st_sample_size", 0) for log in logging_outputs) 248 | mt_sample_size = sum(log.get("mt_sample_size", 0) for log in logging_outputs) 249 | ext_mt_sample_size = sum(log.get("ext_mt_sample_size", 0) for log in logging_outputs) 250 | reg_sample_size = sum(log.get("reg_sample_size", 0) for log in logging_outputs) 251 | 252 | metrics.log_scalar( 253 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 254 | ) 255 | metrics.log_scalar( 256 | "st_loss", st_loss_sum / st_sample_size / math.log(2) if st_sample_size != 0 else 0, st_sample_size, round=3 257 | ) 258 | metrics.log_scalar( 259 | "mt_loss", mt_loss_sum / mt_sample_size / math.log(2) if mt_sample_size != 0 else 0, mt_sample_size, round=3 260 | ) 261 | metrics.log_scalar( 262 | "ext_mt_loss", ext_mt_loss_sum / ext_mt_sample_size / math.log(2) if ext_mt_sample_size != 0 else 0, ext_mt_sample_size, round=3 263 | ) 264 | metrics.log_scalar( 265 | "reg_loss", reg_loss_sum / reg_sample_size / math.log(2) if reg_sample_size != 0 else 0, reg_sample_size, round=3 266 | ) 267 | 268 | @staticmethod 269 | def logging_outputs_can_be_summed() -> bool: 270 | """ 271 | Whether the logging outputs returned by `forward` can be summed 272 | across workers prior to calling `reduce_metrics`. Setting this 273 | to True will improves distributed training speed. 274 | """ 275 | return True -------------------------------------------------------------------------------- /cress/criterions/speech_and_text_translation_with_oracle_reg_adaptive_criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import numpy as np 8 | from dataclasses import dataclass, field 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from fairseq import metrics, utils 13 | from fairseq.criterions import FairseqCriterion, register_criterion 14 | from fairseq.criterions.label_smoothed_cross_entropy import ( 15 | LabelSmoothedCrossEntropyCriterion, 16 | LabelSmoothedCrossEntropyCriterionConfig, 17 | label_smoothed_nll_loss, 18 | ) 19 | from fairseq.dataclass import FairseqDataclass 20 | from omegaconf import II 21 | 22 | 23 | @dataclass 24 | class SpeechAndTextTranslationOracleRegAdaptiveCriterionConfig(LabelSmoothedCrossEntropyCriterionConfig): 25 | reg_weight: float = field( 26 | default=1.0, 27 | metadata={"help": "weight of regularization loss"}, 28 | ) 29 | reg_loss_type: str = field( 30 | default="jsd", 31 | metadata={"help": "loss type of regularization (e.g. jsd, l1)"}, 32 | ) 33 | use_word_level_oracle: bool = field( 34 | default=False, 35 | metadata={"help": "use word level oracles"}, 36 | ) 37 | decay_k: float = field( 38 | default=15, 39 | metadata={"help": "decay hyper-paramter k"}, 40 | ) 41 | use_word_gumbel_noise: bool = field( 42 | default=False, 43 | metadata={"help": "select word with gumbel noise"}, 44 | ) 45 | gumbel_temperature: float = field( 46 | default=1.0, 47 | metadata={"help": "temperature of gumbel max in word oracles"}, 48 | ) 49 | adaptive_base: float = field( 50 | default=1.0, 51 | metadata={"help": "adaptive weight: base + scale * F()"}, 52 | ) 53 | adaptive_scale: float = field( 54 | default=1.0, 55 | metadata={"help": "adaptive weight: base + scale * F()"}, 56 | ) 57 | adaptive_func: str = field( 58 | default="linear_cosine", 59 | metadata={"help": "adaptive weight: base + scale * F(), choice: \ 60 | linear_cosine: F() = 1 - cosine"}, 61 | ) 62 | adaptive_st_loss: bool = field( 63 | default=False, 64 | metadata={"help": "using adaptive weight for st loss"}, 65 | ) 66 | adaptive_mt_loss: bool = field( 67 | default=False, 68 | metadata={"help": "using adaptive weight for mt loss"}, 69 | ) 70 | adaptive_reg_loss: bool = field( 71 | default=False, 72 | metadata={"help": "using adaptive weight for regularization loss"}, 73 | ) 74 | adaptive_weight_drop: float = field( 75 | default=0.0, 76 | metadata={"help": "weight drop for adaptive training"}, 77 | ) 78 | 79 | @register_criterion( 80 | "speech_and_text_translation_with_oracle_reg_adaptive", dataclass=SpeechAndTextTranslationOracleRegAdaptiveCriterionConfig 81 | ) 82 | class SpeechAndTextTranslatioOracleRegAdaptiveCriterion(LabelSmoothedCrossEntropyCriterion): 83 | def __init__( 84 | self, 85 | task, 86 | sentence_avg, 87 | label_smoothing, 88 | ignore_prefix_size=0, 89 | report_accuracy=False, 90 | reg_weight=1.0, 91 | reg_loss_type="jsd", 92 | use_word_level_oracle=False, 93 | decay_k=15, 94 | use_word_gumbel_noise=False, 95 | gumbel_temperature=1.0, 96 | adaptive_base=1.0, 97 | adaptive_scale=1.0, 98 | adaptive_func="linear_cosine", 99 | adaptive_st_loss=False, 100 | adaptive_mt_loss=False, 101 | adaptive_reg_loss=False, 102 | adaptive_weight_drop=0.0, 103 | ): 104 | super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy) 105 | self.reg_weight = reg_weight 106 | self.padding_idx = task.target_dictionary.pad() 107 | self.tgt_dict = task.target_dictionary 108 | self.bpe_tokenizer = task.bpe_tokenizer 109 | self.reg_loss_type = reg_loss_type 110 | self.use_word_level_oracle = use_word_level_oracle 111 | self.decay_k = decay_k 112 | self.use_word_gumbel_noise = use_word_gumbel_noise 113 | self.gumbel_temperature = gumbel_temperature 114 | self.adaptive_base = adaptive_base 115 | self.adaptive_scale = adaptive_scale 116 | self.adaptive_func = adaptive_func 117 | self.adaptive_st_loss = adaptive_st_loss 118 | self.adaptive_mt_loss = adaptive_mt_loss 119 | self.adaptive_reg_loss = adaptive_reg_loss 120 | self.adaptive_weight_drop = adaptive_weight_drop 121 | 122 | def decay_prob(self, epoch): 123 | k = self.decay_k 124 | return k / (k + np.exp(epoch / k)) 125 | 126 | def get_word_oracle_tokens(self, pred_logits, prev_output_tokens, epoch, epsilon=1e-6): 127 | bsz, _ = prev_output_tokens.size() 128 | if self.use_word_gumbel_noise: 129 | uniform = torch.Tensor(pred_logits.size()).to(pred_logits.device).float().uniform_(0, 1) 130 | gumbel = -torch.log(-torch.log(uniform + epsilon) + epsilon) 131 | pred_logits = (pred_logits + gumbel.to(pred_logits.device)) / self.gumbel_temperature 132 | pred_tokens = torch.max(pred_logits, dim=-1)[1] 133 | bos_idx = prev_output_tokens[0, 0].repeat(bsz, 1).to(pred_tokens) 134 | pred_tokens = torch.cat([bos_idx, pred_tokens], dim=1)[:, :-1] 135 | sample_gold_prob = self.decay_prob(epoch) 136 | sample_gold_prob = sample_gold_prob * torch.ones_like(prev_output_tokens, dtype=torch.float32) 137 | sample_gold_mask = torch.bernoulli(sample_gold_prob).long() 138 | 139 | return prev_output_tokens * sample_gold_mask + pred_tokens * (1 - sample_gold_mask) 140 | 141 | def forward_st(self, model, sample, reduce, word_oracle=False): 142 | audio_input = { 143 | "src_tokens": sample["net_input"]["audio"], 144 | "src_lengths": sample["net_input"]["audio_lengths"], 145 | "mode": "st", 146 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 147 | } 148 | audio_encoder_out = model.encoder( 149 | audio_input["src_tokens"], 150 | audio_input["src_lengths"], 151 | audio_input["mode"] 152 | ) 153 | prev_output_tokens = audio_input["prev_output_tokens"] 154 | with torch.no_grad(): 155 | if word_oracle: 156 | audio_output = model.decoder( 157 | prev_output_tokens, 158 | audio_encoder_out, 159 | ) 160 | prev_output_tokens = self.get_word_oracle_tokens( 161 | audio_output[0].detach(), 162 | prev_output_tokens, 163 | model.epoch, 164 | ) 165 | x, extra = model.decoder.extract_features_scriptable( 166 | prev_output_tokens, 167 | audio_encoder_out, 168 | ) 169 | x = model.decoder.output_layer(x) 170 | return x, extra 171 | 172 | def forward_mt(self, model, sample, reduce, word_oracle=False): 173 | text_input = { 174 | "src_tokens": sample["net_input"]["source"], 175 | "src_lengths": sample["net_input"]["source_lengths"], 176 | "mode": "mt", 177 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 178 | } 179 | text_encoder_out = model.encoder( 180 | text_input["src_tokens"], 181 | text_input["src_lengths"], 182 | text_input["mode"] 183 | ) 184 | prev_output_tokens = text_input["prev_output_tokens"] 185 | with torch.no_grad(): 186 | if word_oracle: 187 | text_output = model.decoder( 188 | prev_output_tokens, 189 | text_encoder_out, 190 | ) 191 | prev_output_tokens = self.get_word_oracle_tokens( 192 | text_output[0].detach(), 193 | prev_output_tokens, 194 | model.epoch, 195 | ) 196 | x, extra = model.decoder.extract_features_scriptable( 197 | prev_output_tokens, 198 | text_encoder_out, 199 | ) 200 | x = model.decoder.output_layer(x) 201 | return x, extra 202 | 203 | def forward_ext_mt(self, model, sample, reduce): 204 | text_output = model(**sample["net_input"]) 205 | loss, _ = self.compute_loss(model, text_output, sample, reduce=reduce) 206 | return loss 207 | 208 | def forward(self, model, sample, reduce=True): 209 | """Compute the loss for the given sample. 210 | Returns a tuple with three elements: 211 | 1) the loss 212 | 2) the sample size, which is used as the denominator for the gradient 213 | 3) logging outputs to display while training 214 | """ 215 | st_loss, mt_loss, ext_mt_loss, reg_loss = torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda(), torch.Tensor([0]).cuda() 216 | st_size, mt_size, ext_mt_size, reg_size = 0, 0, 0, 0 217 | 218 | mode = sample["net_input"]["mode"] 219 | if mode == "st": 220 | if self.training: 221 | word_oracle = self.use_word_level_oracle 222 | st_output = self.forward_st(model, sample, reduce, word_oracle) 223 | mt_output = self.forward_mt(model, sample, reduce, word_oracle) 224 | st_loss, mt_loss, reg_loss = self.compute_adaptive_loss(model, sample, st_output, mt_output, self.padding_idx, self.eps) 225 | loss = st_loss + mt_loss + self.reg_weight * reg_loss 226 | st_size = mt_size = sample_size = reg_size = sample["ntokens"] 227 | else: 228 | st_output = self.forward_st(model, sample, reduce) 229 | st_loss, _ = self.compute_loss(model, st_output, sample, reduce=reduce) 230 | loss = st_loss 231 | st_size = sample_size = sample["ntokens"] 232 | elif mode == "ext_mt": 233 | loss = ext_mt_loss = self.forward_ext_mt(model, sample, reduce) 234 | ext_mt_size = sample_size = sample["ntokens"] 235 | 236 | logging_output = { 237 | "loss": loss.data, 238 | "st_loss": st_loss.data, 239 | "st_sample_size": st_size, 240 | "mt_loss": mt_loss.data, 241 | "mt_sample_size": mt_size, 242 | "ext_mt_loss": ext_mt_loss.data, 243 | "ext_mt_sample_size": ext_mt_size, 244 | "reg_loss": reg_loss.data, 245 | "reg_sample_size": reg_size, 246 | "ntokens": sample["ntokens"], 247 | "nsentences": sample["target"].size(0), 248 | "sample_size": sample_size, 249 | } 250 | 251 | return loss, sample_size, logging_output 252 | 253 | def get_adaptive_weight(self, st_output, mt_output): 254 | if self.adaptive_func == "linear_cosine": 255 | st_decoder_state = st_output[1]["inner_states"][-1].detach() 256 | mt_decoder_state = mt_output[1]["inner_states"][-1].detach() 257 | cosine = F.cosine_similarity(st_decoder_state, mt_decoder_state, dim=-1) 258 | weight = 1.0 - cosine 259 | else: 260 | raise NotImplementedError 261 | return self.adaptive_base + self.adaptive_scale * weight 262 | 263 | def compute_adaptive_loss(self, model, sample, st_output, mt_output, ignore_index, epsilon): 264 | st_lprobs, target = self.get_lprobs_and_target(model, st_output, sample) 265 | mt_lprobs, _ = self.get_lprobs_and_target(model, mt_output, sample) 266 | target = target.unsqueeze(-1) 267 | # get weight 268 | weight = self.get_adaptive_weight(st_output, mt_output).view(-1, 1) 269 | drop_p = self.adaptive_weight_drop * torch.ones_like(weight) 270 | drop_mask = torch.bernoulli(drop_p).bool() 271 | weight.masked_fill_(drop_mask, 1.0) 272 | # st loss 273 | st_nll_loss = -st_lprobs.gather(dim=-1, index=target) 274 | st_smooth_loss = -st_lprobs.sum(dim=-1, keepdim=True) 275 | if self.adaptive_st_loss: 276 | st_nll_loss *= weight 277 | pad_mask = target.eq(ignore_index) 278 | st_nll_loss.masked_fill_(pad_mask, 0.0) 279 | st_smooth_loss.masked_fill_(pad_mask, 0.0) 280 | st_nll_loss = st_nll_loss.sum() 281 | st_smooth_loss = st_smooth_loss.sum() 282 | eps_i = epsilon / (st_lprobs.size(-1) - 1) 283 | st_loss = (1.0 - epsilon - eps_i) * st_nll_loss + eps_i * st_smooth_loss 284 | # mt loss 285 | mt_nll_loss = -mt_lprobs.gather(dim=-1, index=target) 286 | mt_smooth_loss = -mt_lprobs.sum(dim=-1, keepdim=True) 287 | if self.adaptive_mt_loss: 288 | mt_nll_loss *= weight 289 | pad_mask = target.eq(ignore_index) 290 | mt_nll_loss.masked_fill_(pad_mask, 0.0) 291 | mt_smooth_loss.masked_fill_(pad_mask, 0.0) 292 | mt_nll_loss = mt_nll_loss.sum() 293 | mt_smooth_loss = mt_smooth_loss.sum() 294 | eps_i = epsilon / (mt_lprobs.size(-1) - 1) 295 | mt_loss = (1.0 - epsilon - eps_i) * mt_nll_loss + eps_i * mt_smooth_loss 296 | # reg loss 297 | if self.reg_loss_type == "jsd": 298 | kl_loss_st = F.kl_div(mt_lprobs, st_lprobs, log_target=True, reduction="none").sum(-1, keepdim=True) 299 | kl_loss_mt = F.kl_div(st_lprobs, mt_lprobs, log_target=True, reduction="none").sum(-1, keepdim=True) 300 | if self.adaptive_reg_loss: 301 | kl_loss_st *= weight 302 | kl_loss_mt *= weight 303 | pad_mask = target.eq(ignore_index) 304 | kl_loss_st.masked_fill_(pad_mask, 0.0) 305 | kl_loss_mt.masked_fill_(pad_mask, 0.0) 306 | kl_loss_st = kl_loss_st.sum() 307 | kl_loss_mt = kl_loss_mt.sum() 308 | reg_loss = (kl_loss_st + kl_loss_mt) / 2.0 309 | else: 310 | raise NotImplementedError 311 | return st_loss, mt_loss, reg_loss 312 | 313 | @classmethod 314 | def reduce_metrics(cls, logging_outputs) -> None: 315 | """Aggregate logging outputs from data parallel training.""" 316 | loss_sum = sum(log.get("loss", 0) for log in logging_outputs) 317 | st_loss_sum = sum(log.get("st_loss", 0) for log in logging_outputs) 318 | mt_loss_sum = sum(log.get("mt_loss", 0) for log in logging_outputs) 319 | ext_mt_loss_sum = sum(log.get("ext_mt_loss", 0) for log in logging_outputs) 320 | reg_loss_sum = sum(log.get("reg_loss", 0) for log in logging_outputs) 321 | sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) 322 | st_sample_size = sum(log.get("st_sample_size", 0) for log in logging_outputs) 323 | mt_sample_size = sum(log.get("mt_sample_size", 0) for log in logging_outputs) 324 | ext_mt_sample_size = sum(log.get("ext_mt_sample_size", 0) for log in logging_outputs) 325 | reg_sample_size = sum(log.get("reg_sample_size", 0) for log in logging_outputs) 326 | 327 | metrics.log_scalar( 328 | "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 329 | ) 330 | metrics.log_scalar( 331 | "st_loss", st_loss_sum / st_sample_size / math.log(2) if st_sample_size != 0 else 0, st_sample_size, round=3 332 | ) 333 | metrics.log_scalar( 334 | "mt_loss", mt_loss_sum / mt_sample_size / math.log(2) if mt_sample_size != 0 else 0, mt_sample_size, round=3 335 | ) 336 | metrics.log_scalar( 337 | "ext_mt_loss", ext_mt_loss_sum / ext_mt_sample_size / math.log(2) if ext_mt_sample_size != 0 else 0, ext_mt_sample_size, round=3 338 | ) 339 | metrics.log_scalar( 340 | "reg_loss", reg_loss_sum / reg_sample_size / math.log(2) if reg_sample_size != 0 else 0, reg_sample_size, round=3 341 | ) 342 | 343 | @staticmethod 344 | def logging_outputs_can_be_summed() -> bool: 345 | """ 346 | Whether the logging outputs returned by `forward` can be summed 347 | across workers prior to calling `reduce_metrics`. Setting this 348 | to True will improves distributed training speed. 349 | """ 350 | return True -------------------------------------------------------------------------------- /cress/datasets/speech_and_text_translation_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import csv 7 | import io 8 | import logging 9 | import re 10 | from collections import defaultdict 11 | from pathlib import Path 12 | from typing import Dict, List, Optional 13 | from dataclasses import dataclass 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from fairseq.data import ( 19 | ConcatDataset, 20 | Dictionary, 21 | FairseqDataset, 22 | ResamplingDataset, 23 | data_utils as fairseq_data_utils, 24 | ) 25 | from cress.datasets.speech_to_text_dataset import ( 26 | _collate_frames, 27 | get_features_or_waveform, 28 | ) 29 | from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform 30 | from fairseq.data.audio.data_cfg import S2TDataConfig 31 | 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | 36 | @dataclass 37 | class SpeechAndTextTranslationDatasetItem(object): 38 | index: int 39 | audio: torch.Tensor 40 | source: torch.Tensor 41 | target: torch.Tensor 42 | speaker_id: Optional[int] = None 43 | 44 | 45 | class SpeechAndTextTranslationDataset(FairseqDataset): 46 | LANG_TAG_TEMPLATE = "" 47 | 48 | def __init__( 49 | self, 50 | split: str, 51 | is_train_split: bool, 52 | cfg: S2TDataConfig, 53 | audio_paths: List[str], 54 | n_frames: List[int], 55 | src_texts: Optional[List[str]] = None, 56 | tgt_texts: Optional[List[str]] = None, 57 | speakers: Optional[List[str]] = None, 58 | src_langs: Optional[List[str]] = None, 59 | tgt_langs: Optional[List[str]] = None, 60 | ids: Optional[List[str]] = None, 61 | tgt_dict: Optional[Dictionary] = None, 62 | pre_tokenizer=None, 63 | bpe_tokenizer=None, 64 | n_frames_per_step=1, 65 | speaker_to_id=None, 66 | append_eos=True, 67 | ): 68 | self.split, self.is_train_split = split, is_train_split 69 | self.cfg = cfg 70 | self.audio_paths, self.n_frames = audio_paths, n_frames 71 | self.n_samples = len(audio_paths) 72 | assert len(n_frames) == self.n_samples > 0 73 | assert src_texts is None or len(src_texts) == self.n_samples 74 | assert tgt_texts is None or len(tgt_texts) == self.n_samples 75 | assert speakers is None or len(speakers) == self.n_samples 76 | assert src_langs is None or len(src_langs) == self.n_samples 77 | assert tgt_langs is None or len(tgt_langs) == self.n_samples 78 | assert ids is None or len(ids) == self.n_samples 79 | assert (tgt_dict is None and tgt_texts is None) or ( 80 | tgt_dict is not None and tgt_texts is not None 81 | ) 82 | self.src_texts, self.tgt_texts = src_texts, tgt_texts 83 | self.src_langs, self.tgt_langs = src_langs, tgt_langs 84 | self.speakers = speakers 85 | self.tgt_dict = tgt_dict 86 | self.check_tgt_lang_tag() 87 | self.ids = ids 88 | self.shuffle = cfg.shuffle if is_train_split else False 89 | 90 | self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( 91 | self.cfg.get_feature_transforms(split, is_train_split) 92 | ) 93 | 94 | self.pre_tokenizer = pre_tokenizer 95 | self.bpe_tokenizer = bpe_tokenizer 96 | self.n_frames_per_step = n_frames_per_step 97 | self.speaker_to_id = speaker_to_id 98 | 99 | self.src_lens = self.get_src_lens_and_check_oov() 100 | self.tgt_lens = self.get_tgt_lens_and_check_oov() 101 | self.append_eos = append_eos 102 | 103 | logger.info(self.__repr__()) 104 | 105 | def get_src_lens_and_check_oov(self): 106 | if self.src_texts is None: 107 | return [0 for _ in range(self.n_samples)] 108 | src_lens = [] 109 | n_tokens, n_oov_tokens = 0, 0 110 | for i in range(self.n_samples): 111 | tokenized = self.get_tokenized_src_text(i).split(" ") 112 | oov_tokens = [ 113 | t 114 | for t in tokenized 115 | if self.tgt_dict.index(t) == self.tgt_dict.unk_index 116 | ] 117 | n_tokens += len(tokenized) 118 | n_oov_tokens += len(oov_tokens) 119 | src_lens.append(len(tokenized)) 120 | logger.info(f"'{self.split}-src' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") 121 | return src_lens 122 | 123 | def get_tgt_lens_and_check_oov(self): 124 | if self.tgt_texts is None: 125 | return [0 for _ in range(self.n_samples)] 126 | tgt_lens = [] 127 | n_tokens, n_oov_tokens = 0, 0 128 | for i in range(self.n_samples): 129 | tokenized = self.get_tokenized_tgt_text(i).split(" ") 130 | oov_tokens = [ 131 | t 132 | for t in tokenized 133 | if self.tgt_dict.index(t) == self.tgt_dict.unk_index 134 | ] 135 | n_tokens += len(tokenized) 136 | n_oov_tokens += len(oov_tokens) 137 | tgt_lens.append(len(tokenized)) 138 | logger.info(f"'{self.split}-tgt' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") 139 | return tgt_lens 140 | 141 | def __repr__(self): 142 | return ( 143 | self.__class__.__name__ 144 | + f'(split="{self.split}", n_samples={self.n_samples:_}, ' 145 | f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " 146 | f"shuffle={self.shuffle}, transforms={self.feature_transforms}, " 147 | f"n_frames_per_step={self.n_frames_per_step}" 148 | ) 149 | 150 | @classmethod 151 | def is_lang_tag(cls, token): 152 | pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") 153 | return re.match(pattern, token) 154 | 155 | def check_tgt_lang_tag(self): 156 | if self.cfg.prepend_tgt_lang_tag: 157 | assert self.tgt_langs is not None and self.tgt_dict is not None 158 | tgt_lang_tags = [ 159 | self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) 160 | ] 161 | assert all(t in self.tgt_dict for t in tgt_lang_tags) 162 | 163 | @classmethod 164 | def tokenize(cls, tokenizer, text: str): 165 | return text if tokenizer is None else tokenizer.encode(text) 166 | 167 | def get_tokenized_src_text(self, index: int): 168 | text = self.tokenize(self.pre_tokenizer, self.src_texts[index]) 169 | text = self.tokenize(self.bpe_tokenizer, text) 170 | return text 171 | 172 | def get_tokenized_tgt_text(self, index: int): 173 | text = self.tokenize(self.pre_tokenizer, self.tgt_texts[index]) 174 | text = self.tokenize(self.bpe_tokenizer, text) 175 | return text 176 | 177 | def pack_frames(self, feature: torch.Tensor): 178 | if self.n_frames_per_step == 1: 179 | return feature 180 | n_packed_frames = feature.shape[0] // self.n_frames_per_step 181 | feature = feature[: self.n_frames_per_step * n_packed_frames] 182 | return feature.reshape(n_packed_frames, -1) 183 | 184 | @classmethod 185 | def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): 186 | lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) 187 | assert lang_tag_idx != dictionary.unk() 188 | return lang_tag_idx 189 | 190 | def _get_source_audio(self, index: int) -> torch.Tensor: 191 | source = get_features_or_waveform( 192 | self.audio_paths[index], 193 | need_waveform=self.cfg.use_audio_input, 194 | use_sample_rate=self.cfg.use_sample_rate, 195 | ) 196 | if self.cfg.use_audio_input: 197 | if self.cfg.standardize_audio: 198 | with torch.no_grad(): 199 | source = F.layer_norm(source, source.shape) 200 | else: 201 | if self.feature_transforms is not None: 202 | source = self.feature_transforms(source) 203 | source = torch.from_numpy(source).float() 204 | return source 205 | 206 | def __getitem__(self, index: int) -> SpeechAndTextTranslationDatasetItem: 207 | audio = self._get_source_audio(index) 208 | audio = self.pack_frames(audio) 209 | 210 | tokenized = self.get_tokenized_src_text(index) 211 | source = self.tgt_dict.encode_line( 212 | tokenized, add_if_not_exist=False, append_eos=False 213 | ).long() 214 | 215 | tokenized = self.get_tokenized_tgt_text(index) 216 | target = self.tgt_dict.encode_line( 217 | tokenized, add_if_not_exist=False, append_eos=self.append_eos 218 | ).long() 219 | if self.cfg.prepend_tgt_lang_tag: 220 | lang_tag_idx = self.get_lang_tag_idx( 221 | self.tgt_langs[index], self.tgt_dict 222 | ) 223 | target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) 224 | 225 | speaker_id = None 226 | if self.speaker_to_id is not None: 227 | speaker_id = self.speaker_to_id[self.speakers[index]] 228 | return SpeechAndTextTranslationDatasetItem( 229 | index=index, audio=audio, source=source, target=target, speaker_id=speaker_id 230 | ) 231 | 232 | def __len__(self): 233 | return self.n_samples 234 | 235 | def collater( 236 | self, samples: List[SpeechAndTextTranslationDatasetItem], return_order: bool = False 237 | ) -> Dict: 238 | if len(samples) == 0: 239 | return {} 240 | indices = torch.tensor([x.index for x in samples], dtype=torch.long) 241 | frames = _collate_frames([x.audio for x in samples], self.cfg.use_audio_input) 242 | # sort samples by descending number of frames 243 | n_frames = torch.tensor([x.audio.size(0) for x in samples], dtype=torch.long) 244 | n_frames, order = n_frames.sort(descending=True) 245 | indices = indices.index_select(0, order) 246 | frames = frames.index_select(0, order) 247 | 248 | source = fairseq_data_utils.collate_tokens( 249 | [x.source for x in samples], 250 | self.tgt_dict.pad(), 251 | self.tgt_dict.eos(), 252 | left_pad=False, 253 | move_eos_to_beginning=False, 254 | ) 255 | source = source.index_select(0, order) 256 | source_lengths = torch.tensor( 257 | [x.source.size(0) for x in samples], dtype=torch.long 258 | ).index_select(0, order) 259 | 260 | target = fairseq_data_utils.collate_tokens( 261 | [x.target for x in samples], 262 | self.tgt_dict.pad(), 263 | self.tgt_dict.eos(), 264 | left_pad=False, 265 | move_eos_to_beginning=False, 266 | ) 267 | target = target.index_select(0, order) 268 | target_lengths = torch.tensor( 269 | [x.target.size(0) for x in samples], dtype=torch.long 270 | ).index_select(0, order) 271 | prev_output_tokens = fairseq_data_utils.collate_tokens( 272 | [x.target for x in samples], 273 | self.tgt_dict.pad(), 274 | self.tgt_dict.eos(), 275 | left_pad=False, 276 | move_eos_to_beginning=True, 277 | ) 278 | prev_output_tokens = prev_output_tokens.index_select(0, order) 279 | ntokens = sum(x.target.size(0) for x in samples) 280 | 281 | speaker = None 282 | if self.speaker_to_id is not None: 283 | speaker = ( 284 | torch.tensor([s.speaker_id for s in samples], dtype=torch.long) 285 | .index_select(0, order) 286 | .view(-1, 1) 287 | ) 288 | 289 | net_input = { 290 | "audio": frames, 291 | "audio_lengths": n_frames, 292 | "source": source, 293 | "source_lengths": source_lengths, 294 | "prev_output_tokens": prev_output_tokens, 295 | } 296 | out = { 297 | "id": indices, 298 | "net_input": net_input, 299 | "speaker": speaker, 300 | "target": target, 301 | "target_lengths": target_lengths, 302 | "ntokens": ntokens, 303 | "nsentences": len(samples), 304 | } 305 | if return_order: 306 | out["order"] = order 307 | return out 308 | 309 | def num_tokens(self, index): 310 | return self.n_frames[index] 311 | 312 | def size(self, index): 313 | return self.n_frames[index], self.src_lens[index], self.tgt_lens[index] 314 | 315 | @property 316 | def sizes(self): 317 | return np.array(self.n_frames) 318 | 319 | @property 320 | def can_reuse_epoch_itr_across_epochs(self): 321 | return True 322 | 323 | def ordered_indices(self): 324 | if self.shuffle: 325 | order = [np.random.permutation(len(self))] 326 | else: 327 | order = [np.arange(len(self))] 328 | # first by descending order of # of frames then by original/random order 329 | order.append([-n for n in self.n_frames]) 330 | return np.lexsort(order) 331 | 332 | def prefetch(self, indices): 333 | raise False 334 | 335 | 336 | class SpeechAndTextTranslationDatasetCreator(object): 337 | # mandatory columns 338 | KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" 339 | KEY_TGT_TEXT = "tgt_text" 340 | # optional columns 341 | KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" 342 | KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" 343 | # default values 344 | DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" 345 | 346 | @classmethod 347 | def _from_list( 348 | cls, 349 | split_name: str, 350 | is_train_split, 351 | samples: List[Dict], 352 | cfg: S2TDataConfig, 353 | tgt_dict, 354 | pre_tokenizer, 355 | bpe_tokenizer, 356 | n_frames_per_step, 357 | speaker_to_id, 358 | ) -> SpeechAndTextTranslationDataset: 359 | audio_root = Path(cfg.audio_root) 360 | ids = [s[cls.KEY_ID] for s in samples] 361 | audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] 362 | n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] 363 | tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] 364 | src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] 365 | speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] 366 | src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] 367 | tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] 368 | return SpeechAndTextTranslationDataset( 369 | split_name, 370 | is_train_split, 371 | cfg, 372 | audio_paths, 373 | n_frames, 374 | src_texts=src_texts, 375 | tgt_texts=tgt_texts, 376 | speakers=speakers, 377 | src_langs=src_langs, 378 | tgt_langs=tgt_langs, 379 | ids=ids, 380 | tgt_dict=tgt_dict, 381 | pre_tokenizer=pre_tokenizer, 382 | bpe_tokenizer=bpe_tokenizer, 383 | n_frames_per_step=n_frames_per_step, 384 | speaker_to_id=speaker_to_id, 385 | ) 386 | 387 | @classmethod 388 | def get_size_ratios( 389 | cls, datasets: List[SpeechAndTextTranslationDataset], alpha: float = 1.0 390 | ) -> List[float]: 391 | """Size ratios for temperature-based sampling 392 | (https://arxiv.org/abs/1907.05019)""" 393 | 394 | id_to_lp, lp_to_sz = {}, defaultdict(int) 395 | for ds in datasets: 396 | lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} 397 | assert len(lang_pairs) == 1 398 | lang_pair = list(lang_pairs)[0] 399 | id_to_lp[ds.split] = lang_pair 400 | lp_to_sz[lang_pair] += sum(ds.n_frames) 401 | 402 | sz_sum = sum(v for v in lp_to_sz.values()) 403 | lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} 404 | lp_to_tgt_prob = {k: v ** alpha for k, v in lp_to_prob.items()} 405 | prob_sum = sum(v for v in lp_to_tgt_prob.values()) 406 | lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} 407 | lp_to_sz_ratio = { 408 | k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() 409 | } 410 | size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] 411 | 412 | p_formatted = { 413 | k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz 414 | } 415 | logger.info(f"sampling probability balancing: {p_formatted}") 416 | sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} 417 | logger.info(f"balanced sampling size ratio: {sr_formatted}") 418 | return size_ratio 419 | 420 | @classmethod 421 | def _load_samples_from_tsv(cls, root: str, split: str): 422 | tsv_path = Path(root) / f"{split}.tsv" 423 | if not tsv_path.is_file(): 424 | raise FileNotFoundError(f"Dataset not found: {tsv_path}") 425 | with open(tsv_path) as f: 426 | reader = csv.DictReader( 427 | f, 428 | delimiter="\t", 429 | quotechar=None, 430 | doublequote=False, 431 | lineterminator="\n", 432 | quoting=csv.QUOTE_NONE, 433 | ) 434 | samples = [dict(e) for e in reader] 435 | if len(samples) == 0: 436 | raise ValueError(f"Empty manifest: {tsv_path}") 437 | return samples 438 | 439 | @classmethod 440 | def _from_tsv( 441 | cls, 442 | root: str, 443 | cfg: S2TDataConfig, 444 | split: str, 445 | tgt_dict, 446 | is_train_split: bool, 447 | pre_tokenizer, 448 | bpe_tokenizer, 449 | n_frames_per_step, 450 | speaker_to_id, 451 | ) -> SpeechAndTextTranslationDataset: 452 | samples = cls._load_samples_from_tsv(root, split) 453 | return cls._from_list( 454 | split, 455 | is_train_split, 456 | samples, 457 | cfg, 458 | tgt_dict, 459 | pre_tokenizer, 460 | bpe_tokenizer, 461 | n_frames_per_step, 462 | speaker_to_id, 463 | ) 464 | 465 | @classmethod 466 | def from_tsv( 467 | cls, 468 | root: str, 469 | cfg: S2TDataConfig, 470 | splits: str, 471 | tgt_dict, 472 | pre_tokenizer, 473 | bpe_tokenizer, 474 | is_train_split: bool, 475 | epoch: int, 476 | seed: int, 477 | n_frames_per_step: int = 1, 478 | speaker_to_id=None, 479 | ) -> SpeechAndTextTranslationDataset: 480 | datasets = [ 481 | cls._from_tsv( 482 | root, 483 | cfg, 484 | split, 485 | tgt_dict, 486 | is_train_split, 487 | pre_tokenizer, 488 | bpe_tokenizer, 489 | n_frames_per_step, 490 | speaker_to_id, 491 | ) 492 | for split in splits.split(",") 493 | ] 494 | 495 | if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: 496 | # temperature-based sampling 497 | size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) 498 | datasets = [ 499 | ResamplingDataset( 500 | d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) 501 | ) 502 | for r, d in zip(size_ratios, datasets) 503 | ] 504 | 505 | return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] -------------------------------------------------------------------------------- /cress/models/hubert_transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import math 5 | from pathlib import Path 6 | from typing import Dict, List, Optional, OrderedDict, Tuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | from fairseq import checkpoint_utils, tasks, utils 13 | from fairseq.data.data_utils import lengths_to_padding_mask 14 | from fairseq.models import ( 15 | FairseqEncoder, 16 | FairseqEncoderDecoderModel, 17 | register_model, 18 | register_model_architecture, 19 | ) 20 | from fairseq.models.speech_to_text.hub_interface import S2THubInterface 21 | from fairseq.models.speech_to_text.s2t_transformer import Conv1dSubsampler, TransformerDecoderScriptable 22 | from fairseq.models.hubert import HubertModel 23 | from fairseq.models.transformer import Embedding 24 | from fairseq.modules import ( 25 | FairseqDropout, 26 | LayerNorm, 27 | PositionalEmbedding, 28 | TransformerEncoderLayer, 29 | ) 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | @register_model("hubert_transformer") 34 | class HubertTransformerModel(FairseqEncoderDecoderModel): 35 | """Adapted Transformer model (https://arxiv.org/abs/1706.03762) for 36 | speech-to-text tasks. The Transformer encoder/decoder remains the same. 37 | A trainable input subsampler is prepended to the Transformer encoder to 38 | project inputs into the encoder dimension as well as downsample input 39 | sequence for computational efficiency.""" 40 | 41 | @classmethod 42 | def hub_models(cls): 43 | base_url = "http://dl.fbaipublicfiles.com/fairseq/s2t" 44 | model_ids = [ 45 | "s2t_transformer_s-en-asr-librispeech", 46 | "s2t_transformer_m-en-asr-librispeech", 47 | "s2t_transformer_l-en-asr-librispeech", 48 | ] 49 | return {i: f"{base_url}/{i}.tar.gz" for i in model_ids} 50 | 51 | @classmethod 52 | def from_pretrained( 53 | cls, 54 | model_name_or_path, 55 | checkpoint_file="model.pt", 56 | data_name_or_path=".", 57 | config_yaml="config.yaml", 58 | **kwargs, 59 | ): 60 | from fairseq import hub_utils 61 | 62 | x = hub_utils.from_pretrained( 63 | model_name_or_path, 64 | checkpoint_file, 65 | data_name_or_path, 66 | archive_map=cls.hub_models(), 67 | config_yaml=config_yaml, 68 | **kwargs, 69 | ) 70 | return S2THubInterface(x["args"], x["task"], x["models"][0]) 71 | 72 | def __init__(self, encoder, decoder): 73 | super().__init__(encoder, decoder) 74 | self.epoch = 1 75 | 76 | def set_epoch(self, epoch): 77 | self.epoch = epoch 78 | 79 | @staticmethod 80 | def add_args(parser): 81 | """Add model-specific arguments to the parser.""" 82 | # Transformer 83 | parser.add_argument( 84 | "--activation-fn", 85 | type=str, 86 | default="relu", 87 | choices=utils.get_available_activation_fns(), 88 | help="activation function to use", 89 | ) 90 | parser.add_argument( 91 | "--dropout", type=float, metavar="D", help="dropout probability" 92 | ) 93 | parser.add_argument( 94 | "--attention-dropout", 95 | type=float, 96 | metavar="D", 97 | help="dropout probability for attention weights", 98 | ) 99 | parser.add_argument( 100 | "--activation-dropout", 101 | "--relu-dropout", 102 | type=float, 103 | metavar="D", 104 | help="dropout probability after activation in FFN.", 105 | ) 106 | parser.add_argument( 107 | "--encoder-embed-dim", 108 | type=int, 109 | metavar="N", 110 | help="encoder embedding dimension", 111 | ) 112 | parser.add_argument( 113 | "--encoder-ffn-embed-dim", 114 | type=int, 115 | metavar="N", 116 | help="encoder embedding dimension for FFN", 117 | ) 118 | parser.add_argument( 119 | "--encoder-layers", type=int, metavar="N", help="num encoder layers" 120 | ) 121 | parser.add_argument( 122 | "--encoder-attention-heads", 123 | type=int, 124 | metavar="N", 125 | help="num encoder attention heads", 126 | ) 127 | parser.add_argument( 128 | "--encoder-normalize-before", 129 | action="store_true", 130 | help="apply layernorm before each encoder block", 131 | ) 132 | parser.add_argument( 133 | "--decoder-embed-dim", 134 | type=int, 135 | metavar="N", 136 | help="decoder embedding dimension", 137 | ) 138 | parser.add_argument( 139 | "--decoder-ffn-embed-dim", 140 | type=int, 141 | metavar="N", 142 | help="decoder embedding dimension for FFN", 143 | ) 144 | parser.add_argument( 145 | "--decoder-layers", type=int, metavar="N", help="num decoder layers" 146 | ) 147 | parser.add_argument( 148 | "--decoder-attention-heads", 149 | type=int, 150 | metavar="N", 151 | help="num decoder attention heads", 152 | ) 153 | parser.add_argument( 154 | "--decoder-normalize-before", 155 | action="store_true", 156 | help="apply layernorm before each decoder block", 157 | ) 158 | parser.add_argument( 159 | "--share-decoder-input-output-embed", 160 | action="store_true", 161 | help="share decoder input and output embeddings", 162 | ) 163 | parser.add_argument( 164 | "--layernorm-embedding", 165 | action="store_true", 166 | help="add layernorm to embedding", 167 | ) 168 | parser.add_argument( 169 | "--no-scale-embedding", 170 | action="store_true", 171 | help="if True, dont scale embeddings", 172 | ) 173 | parser.add_argument( 174 | "--load-pretrained-encoder-from", 175 | type=str, 176 | metavar="STR", 177 | help="model to take encoder weights from (for initialization)", 178 | ) 179 | # hubert arguments 180 | parser.add_argument( 181 | "--hubert-model-path", 182 | type=str, 183 | metavar="STR", 184 | help="path/to/hubert/model" 185 | ) 186 | parser.add_argument( 187 | "--freeze-hubert", 188 | action="store_true", 189 | help="if we want to freeze the hubert features" 190 | ) 191 | # subsampler arguments 192 | parser.add_argument( 193 | "--conv-kernel-sizes", 194 | type=str, 195 | help="kernel sizes of Conv1d subsampling layers", 196 | ) 197 | parser.add_argument( 198 | "--conv-channels", 199 | type=int, 200 | help="# of channels in Conv1d subsampling layers", 201 | ) 202 | # pretrain 203 | parser.add_argument( 204 | "--load-pretrained-mt-encoder-decoder-from", 205 | type=str, 206 | help="model to take mt encoder/decoder weight from (for initialization)", 207 | ) 208 | 209 | @classmethod 210 | def build_encoder(cls, args, task=None, embed_tokens=None): 211 | return HubertTransformerEncoder(args, task.target_dictionary, embed_tokens) 212 | 213 | @classmethod 214 | def build_decoder(cls, args, task, embed_tokens): 215 | return TransformerDecoderScriptable(args, task.target_dictionary, embed_tokens) 216 | 217 | @classmethod 218 | def build_model(cls, args, task): 219 | """Build a new model instance.""" 220 | 221 | # make sure all arguments are present in older models 222 | base_architecture(args) 223 | 224 | def build_embedding(dictionary, embed_dim): 225 | num_embeddings = len(dictionary) 226 | padding_idx = dictionary.pad() 227 | return Embedding(num_embeddings, embed_dim, padding_idx) 228 | 229 | decoder_embed_tokens = build_embedding( 230 | task.target_dictionary, args.decoder_embed_dim 231 | ) 232 | encoder_embed_tokens = decoder_embed_tokens 233 | encoder = cls.build_encoder(args, task, encoder_embed_tokens) 234 | decoder = cls.build_decoder(args, task, decoder_embed_tokens) 235 | # load pretrained mt models 236 | mt_pretrained_path = getattr(args, "load_pretrained_mt_encoder_decoder_from", None) 237 | if mt_pretrained_path is not None and Path(mt_pretrained_path).exists(): 238 | state_dict = checkpoint_utils.load_checkpoint_to_cpu(mt_pretrained_path)["model"] 239 | mt_encoder_state_dict = OrderedDict() 240 | mt_decoder_state_dict = OrderedDict() 241 | for key in state_dict.keys(): 242 | if "hubert" in key or "subsampler" in key: 243 | continue 244 | if key.startswith("encoder"): 245 | subkey = key[len("encoder") + 1 :] 246 | mt_encoder_state_dict[subkey] = state_dict[key] 247 | if key.startswith("decoder"): 248 | subkey = key[len("decoder") + 1 :] 249 | mt_decoder_state_dict[subkey] = state_dict[key] 250 | encoder.load_state_dict(mt_encoder_state_dict, strict=False) 251 | decoder.load_state_dict(mt_decoder_state_dict, strict=False) 252 | 253 | return cls(encoder, decoder) 254 | 255 | def get_normalized_probs( 256 | self, 257 | net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], 258 | log_probs: bool, 259 | sample: Optional[Dict[str, Tensor]] = None, 260 | ): 261 | # net_output['encoder_out'] is a (B, T, D) tensor 262 | lprobs = self.get_normalized_probs_scriptable(net_output, log_probs, sample) 263 | lprobs.batch_first = True 264 | return lprobs 265 | 266 | def forward(self, src_tokens, src_lengths, mode, prev_output_tokens): 267 | """ 268 | The forward method inherited from the base class has a **kwargs 269 | argument in its input, which is not supported in torchscript. This 270 | method overwrites the forward method definition without **kwargs. 271 | """ 272 | encoder_out = self.encoder(src_tokens=src_tokens, src_lengths=src_lengths, mode=mode) 273 | decoder_out = self.decoder( 274 | prev_output_tokens=prev_output_tokens, encoder_out=encoder_out 275 | ) 276 | return decoder_out 277 | 278 | 279 | class HubertTransformerEncoder(FairseqEncoder): 280 | """Speech-to-text Transformer encoder that consists of input subsampler and 281 | Transformer encoder.""" 282 | 283 | def __init__(self, args, dictionary=None, embed_tokens=None): 284 | super().__init__(None) 285 | 286 | self.num_updates = 0 287 | 288 | self.dropout_module = FairseqDropout( 289 | p=args.dropout, module_name=self.__class__.__name__ 290 | ) 291 | self.embed_scale = math.sqrt(args.encoder_embed_dim) 292 | if args.no_scale_embedding: 293 | self.embed_scale = 1.0 294 | self.padding_idx = dictionary.pad() 295 | 296 | # load hubert 297 | self.hubert_model_path = getattr(args, "hubert_model_path", None) 298 | self.freeze_hubert = getattr(args, "freeze_hubert", False) 299 | assert self.hubert_model_path is not None 300 | ckpt = checkpoint_utils.load_checkpoint_to_cpu(self.hubert_model_path) 301 | hubert_args = ckpt["cfg"] 302 | task = tasks.setup_task(hubert_args.task) 303 | if "task_state" in ckpt: 304 | task.load_state_dict(ckpt["task_state"]) 305 | self.hubert_model = task.build_model(hubert_args.model) 306 | self.hubert_model.load_state_dict(ckpt["model"]) 307 | self.hubert_model.remove_pretraining_modules() 308 | if self.freeze_hubert: 309 | for param in self.hubert_model.parameters(): 310 | param.requires_grad = False 311 | 312 | # speech subsample 313 | if args.conv_kernel_sizes: 314 | self.subsampler = Conv1dSubsampler( 315 | hubert_args.model.encoder_embed_dim, 316 | args.conv_channels, 317 | args.encoder_embed_dim, 318 | [int(k) for k in args.conv_kernel_sizes.split(",")], 319 | ) 320 | else: 321 | self.subsampler = None 322 | self.dim_proj = nn.Linear(hubert_args.model.encoder_embed_dim, args.encoder_embed_dim) 323 | 324 | # embedding 325 | self.embed_tokens = embed_tokens 326 | export = getattr(args, "export", False) 327 | if getattr(args, "layernorm_embedding", False): 328 | self.layernorm_embedding = LayerNorm(embed_tokens.embedding_dim, export=export) 329 | else: 330 | self.layernorm_embedding = None 331 | self.embed_positions = PositionalEmbedding( 332 | args.max_source_positions, 333 | args.encoder_embed_dim, 334 | self.padding_idx, 335 | ) 336 | 337 | # transformer encoder 338 | self.transformer_layers = nn.ModuleList( 339 | [TransformerEncoderLayer(args) for _ in range(args.encoder_layers)] 340 | ) 341 | if args.encoder_normalize_before: 342 | self.layer_norm = LayerNorm(args.encoder_embed_dim) 343 | else: 344 | self.layer_norm = None 345 | 346 | def _get_hubert_features(self, src_tokens, src_lengths): 347 | padding_mask = lengths_to_padding_mask(src_lengths) 348 | hubert_args = { 349 | "source": src_tokens, 350 | "padding_mask": padding_mask, 351 | "mask": False, 352 | } 353 | x, padding_mask = self.hubert_model.extract_features(**hubert_args) 354 | output_length = (1 - padding_mask.int()).sum(dim=1) 355 | return x, padding_mask, output_length 356 | 357 | def forward_embedding( 358 | self, src_tokens, token_embedding: Optional[torch.Tensor] = None 359 | ): 360 | # embed tokens and positions 361 | if token_embedding is None: 362 | token_embedding = self.embed_tokens(src_tokens) 363 | x = embed = self.embed_scale * token_embedding 364 | if self.embed_positions is not None: 365 | x = embed + self.embed_positions(src_tokens) 366 | if self.layernorm_embedding is not None: 367 | x = self.layernorm_embedding(x) 368 | x = self.dropout_module(x) 369 | return x, embed 370 | 371 | def _forward(self, src_tokens, src_lengths, mode, return_all_hiddens=False): 372 | if mode == "st": 373 | x, encoder_padding_mask, input_lengths = self._get_hubert_features(src_tokens, src_lengths) 374 | if self.subsampler is not None: 375 | x, input_lengths = self.subsampler(x, input_lengths) 376 | encoder_padding_mask = lengths_to_padding_mask(input_lengths) 377 | x = x.transpose(0, 1) # T x B x C -> B x T x C 378 | else: 379 | x = self.dim_proj(x) 380 | if self.layernorm_embedding is not None: 381 | x = self.layernorm_embedding(x) 382 | x = self.dropout_module(x) 383 | else: 384 | encoder_padding_mask = src_tokens.eq(self.padding_idx) 385 | has_pads = src_tokens.device.type == "xla" or encoder_padding_mask.any() 386 | x, _ = self.forward_embedding(src_tokens) 387 | if has_pads: 388 | x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) 389 | 390 | encoder_embedding = x 391 | x = x.transpose(0, 1) # B x T x C -> T x B x C 392 | 393 | encoder_states = [] 394 | if return_all_hiddens: 395 | encoder_states.append(x) 396 | 397 | for layer in self.transformer_layers: 398 | x = layer(x, encoder_padding_mask) 399 | if return_all_hiddens: 400 | encoder_states.append(x) 401 | 402 | if self.layer_norm is not None: 403 | x = self.layer_norm(x) 404 | 405 | return { 406 | "encoder_out": [x], # T x B x C 407 | "encoder_padding_mask": [encoder_padding_mask], # B x T 408 | "encoder_embedding": [encoder_embedding], # B x T x C 409 | "encoder_states": encoder_states, # List[T x B x C] 410 | "src_tokens": [], 411 | "src_lengths": [], 412 | } 413 | 414 | def forward(self, src_tokens, src_lengths, mode, return_all_hiddens=False): 415 | x = self._forward( 416 | src_tokens, src_lengths, mode, return_all_hiddens=return_all_hiddens 417 | ) 418 | return x 419 | 420 | def reorder_encoder_out(self, encoder_out, new_order): 421 | new_encoder_out = ( 422 | [] 423 | if len(encoder_out["encoder_out"]) == 0 424 | else [x.index_select(1, new_order) for x in encoder_out["encoder_out"]] 425 | ) 426 | 427 | new_encoder_padding_mask = ( 428 | [] 429 | if len(encoder_out["encoder_padding_mask"]) == 0 430 | else [ 431 | x.index_select(0, new_order) 432 | for x in encoder_out["encoder_padding_mask"] 433 | ] 434 | ) 435 | 436 | new_encoder_embedding = ( 437 | [] 438 | if len(encoder_out["encoder_embedding"]) == 0 439 | else [ 440 | x.index_select(0, new_order) for x in encoder_out["encoder_embedding"] 441 | ] 442 | ) 443 | 444 | encoder_states = encoder_out["encoder_states"] 445 | if len(encoder_states) > 0: 446 | for idx, state in enumerate(encoder_states): 447 | encoder_states[idx] = state.index_select(1, new_order) 448 | 449 | return { 450 | "encoder_out": new_encoder_out, # T x B x C 451 | "encoder_padding_mask": new_encoder_padding_mask, # B x T 452 | "encoder_embedding": new_encoder_embedding, # B x T x C 453 | "encoder_states": encoder_states, # List[T x B x C] 454 | "src_tokens": [], # B x T 455 | "src_lengths": [], # B x 1 456 | } 457 | 458 | def set_num_updates(self, num_updates): 459 | super().set_num_updates(num_updates) 460 | self.num_updates = num_updates 461 | 462 | @register_model_architecture(model_name="hubert_transformer", arch_name="hubert_transformer") 463 | def base_architecture(args): 464 | # subsampler 465 | args.conv_kernel_sizes = getattr(args, "conv_kernel_sizes", "5,5") 466 | args.conv_channels = getattr(args, "conv_channels", 1024) 467 | # Transformer 468 | args.encoder_layers = getattr(args, "encoder_layers", 6) 469 | args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) 470 | args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) 471 | args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8) 472 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True) 473 | args.decoder_layers = getattr(args, "decoder_layers", 6) 474 | args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512) 475 | args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 2048) 476 | args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8) 477 | args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True) 478 | args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) 479 | args.dropout = getattr(args, "dropout", 0.1) 480 | args.attention_dropout = getattr(args, "attention_dropout", args.dropout) 481 | args.activation_dropout = getattr(args, "activation_dropout", args.dropout) 482 | args.activation_fn = getattr(args, "activation_fn", "relu") 483 | args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None) 484 | args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0) 485 | args.share_decoder_input_output_embed = getattr( 486 | args, "share_decoder_input_output_embed", False 487 | ) 488 | args.no_token_positional_embeddings = getattr( 489 | args, "no_token_positional_embeddings", False 490 | ) 491 | args.adaptive_input = getattr(args, "adaptive_input", False) 492 | args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0.0) 493 | args.decoder_output_dim = getattr( 494 | args, "decoder_output_dim", args.decoder_embed_dim 495 | ) 496 | args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim) 497 | args.no_scale_embedding = getattr(args, "no_scale_embedding", False) 498 | args.quant_noise_pq = getattr(args, "quant_noise_pq", 0) 499 | 500 | @register_model_architecture(model_name="hubert_transformer", arch_name="hubert_transformer_postln") 501 | def hubert_transformer_postln(args): 502 | args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False) 503 | args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) 504 | base_architecture(args) -------------------------------------------------------------------------------- /cress/datasets/speech_to_text_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import csv 7 | import io 8 | import logging 9 | import re 10 | from collections import defaultdict 11 | from pathlib import Path 12 | from typing import Dict, List, Optional 13 | from dataclasses import dataclass 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | from fairseq.data import ( 19 | ConcatDataset, 20 | Dictionary, 21 | FairseqDataset, 22 | ResamplingDataset, 23 | data_utils as fairseq_data_utils, 24 | ) 25 | from cress.datasets.audio_utils import ( 26 | get_fbank, 27 | get_waveform, 28 | get_segment_waveform, 29 | read_from_stored_zip, 30 | is_npy_data, 31 | is_sf_audio_data, 32 | parse_path, 33 | FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS, 34 | ) 35 | from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform 36 | from fairseq.data.audio.data_cfg import S2TDataConfig 37 | 38 | 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | def get_features_from_npy_or_audio(path): 43 | ext = Path(path).suffix 44 | if ext not in FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS: 45 | raise ValueError(f'Unsupported file format for "{path}"') 46 | return np.load(path) if ext == ".npy" else get_fbank(path) 47 | 48 | 49 | def get_features_or_waveform_from_stored_zip( 50 | path, 51 | byte_offset, 52 | byte_size, 53 | need_waveform=False, 54 | use_sample_rate=None, 55 | ): 56 | assert path.endswith(".zip") 57 | data = read_from_stored_zip(path, byte_offset, byte_size) 58 | f = io.BytesIO(data) 59 | if is_npy_data(data): 60 | features_or_waveform = np.load(f) 61 | elif is_sf_audio_data(data): 62 | features_or_waveform = ( 63 | get_waveform(f, always_2d=False, output_sample_rate=use_sample_rate)[0] 64 | if need_waveform 65 | else get_fbank(f) 66 | ) 67 | else: 68 | raise ValueError(f'Unknown file format for "{path}"') 69 | return features_or_waveform 70 | 71 | def get_raw_waveform_from_audio( 72 | path, byte_offset, byte_size): 73 | return get_segment_waveform(path, byte_offset, byte_size)[0].squeeze(0) 74 | 75 | def get_features_or_waveform(path: str, need_waveform=False, use_sample_rate=None): 76 | """Get speech features from .npy file or waveform from .wav/.flac file. 77 | The file may be inside an uncompressed ZIP file and is accessed via byte 78 | offset and length. 79 | Args: 80 | path (str): File path in the format of "<.npy/.wav/.flac path>" or 81 | "::". 82 | need_waveform (bool): return waveform instead of features. 83 | use_sample_rate (int): change sample rate for the input wave file 84 | Returns: 85 | features_or_waveform (numpy.ndarray): speech features or waveform. 86 | """ 87 | _path, slice_ptr = parse_path(path) 88 | if len(slice_ptr) == 0: 89 | if need_waveform: 90 | return get_waveform( 91 | _path, always_2d=False, output_sample_rate=use_sample_rate 92 | )[0] 93 | return get_features_from_npy_or_audio(_path) 94 | elif len(slice_ptr) == 2: 95 | if _path.endswith(".zip"): 96 | features_or_waveform = get_features_or_waveform_from_stored_zip( 97 | _path, slice_ptr[0], slice_ptr[1], need_waveform=need_waveform 98 | ) 99 | else: 100 | features_or_waveform = get_raw_waveform_from_audio( 101 | _path, slice_ptr[0], slice_ptr[1] 102 | ) 103 | else: 104 | raise ValueError(f"Invalid path: {path}") 105 | 106 | return features_or_waveform 107 | 108 | 109 | def _collate_frames( 110 | frames: List[torch.Tensor], is_audio_input: bool = False 111 | ) -> torch.Tensor: 112 | """ 113 | Convert a list of 2D frames into a padded 3D tensor 114 | Args: 115 | frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is 116 | length of i-th frame and f_dim is static dimension of features 117 | Returns: 118 | 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] 119 | """ 120 | max_len = max(frame.size(0) for frame in frames) 121 | if is_audio_input: 122 | out = frames[0].new_zeros((len(frames), max_len)) 123 | else: 124 | out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) 125 | for i, v in enumerate(frames): 126 | out[i, : v.size(0)] = v 127 | return out 128 | 129 | 130 | @dataclass 131 | class SpeechToTextDatasetItem(object): 132 | index: int 133 | source: torch.Tensor 134 | target: Optional[torch.Tensor] = None 135 | speaker_id: Optional[int] = None 136 | 137 | 138 | class SpeechToTextDataset(FairseqDataset): 139 | LANG_TAG_TEMPLATE = "" 140 | 141 | def __init__( 142 | self, 143 | split: str, 144 | is_train_split: bool, 145 | cfg: S2TDataConfig, 146 | audio_paths: List[str], 147 | n_frames: List[int], 148 | src_texts: Optional[List[str]] = None, 149 | tgt_texts: Optional[List[str]] = None, 150 | speakers: Optional[List[str]] = None, 151 | src_langs: Optional[List[str]] = None, 152 | tgt_langs: Optional[List[str]] = None, 153 | ids: Optional[List[str]] = None, 154 | tgt_dict: Optional[Dictionary] = None, 155 | pre_tokenizer=None, 156 | bpe_tokenizer=None, 157 | n_frames_per_step=1, 158 | speaker_to_id=None, 159 | append_eos=True, 160 | ): 161 | self.split, self.is_train_split = split, is_train_split 162 | self.cfg = cfg 163 | self.audio_paths, self.n_frames = audio_paths, n_frames 164 | self.n_samples = len(audio_paths) 165 | assert len(n_frames) == self.n_samples > 0 166 | assert src_texts is None or len(src_texts) == self.n_samples 167 | assert tgt_texts is None or len(tgt_texts) == self.n_samples 168 | assert speakers is None or len(speakers) == self.n_samples 169 | assert src_langs is None or len(src_langs) == self.n_samples 170 | assert tgt_langs is None or len(tgt_langs) == self.n_samples 171 | assert ids is None or len(ids) == self.n_samples 172 | assert (tgt_dict is None and tgt_texts is None) or ( 173 | tgt_dict is not None and tgt_texts is not None 174 | ) 175 | self.src_texts, self.tgt_texts = src_texts, tgt_texts 176 | self.src_langs, self.tgt_langs = src_langs, tgt_langs 177 | self.speakers = speakers 178 | self.tgt_dict = tgt_dict 179 | self.check_tgt_lang_tag() 180 | self.ids = ids 181 | self.shuffle = cfg.shuffle if is_train_split else False 182 | 183 | self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( 184 | self.cfg.get_feature_transforms(split, is_train_split) 185 | ) 186 | 187 | self.pre_tokenizer = pre_tokenizer 188 | self.bpe_tokenizer = bpe_tokenizer 189 | self.n_frames_per_step = n_frames_per_step 190 | self.speaker_to_id = speaker_to_id 191 | 192 | self.tgt_lens = self.get_tgt_lens_and_check_oov() 193 | self.append_eos = append_eos 194 | 195 | logger.info(self.__repr__()) 196 | 197 | def get_tgt_lens_and_check_oov(self): 198 | if self.tgt_texts is None: 199 | return [0 for _ in range(self.n_samples)] 200 | tgt_lens = [] 201 | n_tokens, n_oov_tokens = 0, 0 202 | for i in range(self.n_samples): 203 | tokenized = self.get_tokenized_tgt_text(i).split(" ") 204 | oov_tokens = [ 205 | t 206 | for t in tokenized 207 | if self.tgt_dict.index(t) == self.tgt_dict.unk_index 208 | ] 209 | n_tokens += len(tokenized) 210 | n_oov_tokens += len(oov_tokens) 211 | tgt_lens.append(len(tokenized)) 212 | logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") 213 | return tgt_lens 214 | 215 | def __repr__(self): 216 | return ( 217 | self.__class__.__name__ 218 | + f'(split="{self.split}", n_samples={self.n_samples:_}, ' 219 | f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " 220 | f"shuffle={self.shuffle}, transforms={self.feature_transforms}, " 221 | f"n_frames_per_step={self.n_frames_per_step}" 222 | ) 223 | 224 | @classmethod 225 | def is_lang_tag(cls, token): 226 | pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") 227 | return re.match(pattern, token) 228 | 229 | def check_tgt_lang_tag(self): 230 | if self.cfg.prepend_tgt_lang_tag: 231 | assert self.tgt_langs is not None and self.tgt_dict is not None 232 | tgt_lang_tags = [ 233 | self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) 234 | ] 235 | assert all(t in self.tgt_dict for t in tgt_lang_tags) 236 | 237 | @classmethod 238 | def tokenize(cls, tokenizer, text: str): 239 | return text if tokenizer is None else tokenizer.encode(text) 240 | 241 | def get_tokenized_tgt_text(self, index: int): 242 | text = self.tokenize(self.pre_tokenizer, self.tgt_texts[index]) 243 | text = self.tokenize(self.bpe_tokenizer, text) 244 | return text 245 | 246 | def pack_frames(self, feature: torch.Tensor): 247 | if self.n_frames_per_step == 1: 248 | return feature 249 | n_packed_frames = feature.shape[0] // self.n_frames_per_step 250 | feature = feature[: self.n_frames_per_step * n_packed_frames] 251 | return feature.reshape(n_packed_frames, -1) 252 | 253 | @classmethod 254 | def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): 255 | lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) 256 | assert lang_tag_idx != dictionary.unk() 257 | return lang_tag_idx 258 | 259 | def _get_source_audio(self, index: int) -> torch.Tensor: 260 | source = get_features_or_waveform( 261 | self.audio_paths[index], 262 | need_waveform=self.cfg.use_audio_input, 263 | use_sample_rate=self.cfg.use_sample_rate, 264 | ) 265 | if self.cfg.use_audio_input: 266 | if self.cfg.standardize_audio: 267 | with torch.no_grad(): 268 | source = F.layer_norm(source, source.shape) 269 | else: 270 | if self.feature_transforms is not None: 271 | source = self.feature_transforms(source) 272 | source = torch.from_numpy(source).float() 273 | return source 274 | 275 | def __getitem__(self, index: int) -> SpeechToTextDatasetItem: 276 | source = self._get_source_audio(index) 277 | source = self.pack_frames(source) 278 | 279 | target = None 280 | if self.tgt_texts is not None: 281 | tokenized = self.get_tokenized_tgt_text(index) 282 | target = self.tgt_dict.encode_line( 283 | tokenized, add_if_not_exist=False, append_eos=self.append_eos 284 | ).long() 285 | if self.cfg.prepend_tgt_lang_tag: 286 | lang_tag_idx = self.get_lang_tag_idx( 287 | self.tgt_langs[index], self.tgt_dict 288 | ) 289 | target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) 290 | 291 | speaker_id = None 292 | if self.speaker_to_id is not None: 293 | speaker_id = self.speaker_to_id[self.speakers[index]] 294 | return SpeechToTextDatasetItem( 295 | index=index, source=source, target=target, speaker_id=speaker_id 296 | ) 297 | 298 | def __len__(self): 299 | return self.n_samples 300 | 301 | def collater( 302 | self, samples: List[SpeechToTextDatasetItem], return_order: bool = False 303 | ) -> Dict: 304 | if len(samples) == 0: 305 | return {} 306 | indices = torch.tensor([x.index for x in samples], dtype=torch.long) 307 | frames = _collate_frames([x.source for x in samples], self.cfg.use_audio_input) 308 | # sort samples by descending number of frames 309 | n_frames = torch.tensor([x.source.size(0) for x in samples], dtype=torch.long) 310 | n_frames, order = n_frames.sort(descending=True) 311 | indices = indices.index_select(0, order) 312 | frames = frames.index_select(0, order) 313 | 314 | target, target_lengths = None, None 315 | prev_output_tokens = None 316 | ntokens = None 317 | if self.tgt_texts is not None: 318 | target = fairseq_data_utils.collate_tokens( 319 | [x.target for x in samples], 320 | self.tgt_dict.pad(), 321 | self.tgt_dict.eos(), 322 | left_pad=False, 323 | move_eos_to_beginning=False, 324 | ) 325 | target = target.index_select(0, order) 326 | target_lengths = torch.tensor( 327 | [x.target.size(0) for x in samples], dtype=torch.long 328 | ).index_select(0, order) 329 | prev_output_tokens = fairseq_data_utils.collate_tokens( 330 | [x.target for x in samples], 331 | self.tgt_dict.pad(), 332 | self.tgt_dict.eos(), 333 | left_pad=False, 334 | move_eos_to_beginning=True, 335 | ) 336 | prev_output_tokens = prev_output_tokens.index_select(0, order) 337 | ntokens = sum(x.target.size(0) for x in samples) 338 | 339 | speaker = None 340 | if self.speaker_to_id is not None: 341 | speaker = ( 342 | torch.tensor([s.speaker_id for s in samples], dtype=torch.long) 343 | .index_select(0, order) 344 | .view(-1, 1) 345 | ) 346 | 347 | net_input = { 348 | "src_tokens": frames, 349 | "src_lengths": n_frames, 350 | "mode": "st", 351 | "prev_output_tokens": prev_output_tokens, 352 | } 353 | out = { 354 | "id": indices, 355 | "net_input": net_input, 356 | "speaker": speaker, 357 | "target": target, 358 | "target_lengths": target_lengths, 359 | "ntokens": ntokens, 360 | "nsentences": len(samples), 361 | } 362 | if return_order: 363 | out["order"] = order 364 | return out 365 | 366 | def num_tokens(self, index): 367 | return self.n_frames[index] 368 | 369 | def size(self, index): 370 | return self.n_frames[index], self.tgt_lens[index] 371 | 372 | @property 373 | def sizes(self): 374 | return np.array(self.n_frames) 375 | 376 | @property 377 | def can_reuse_epoch_itr_across_epochs(self): 378 | return True 379 | 380 | def ordered_indices(self): 381 | if self.shuffle: 382 | order = [np.random.permutation(len(self))] 383 | else: 384 | order = [np.arange(len(self))] 385 | # first by descending order of # of frames then by original/random order 386 | order.append([-n for n in self.n_frames]) 387 | return np.lexsort(order) 388 | 389 | def prefetch(self, indices): 390 | raise False 391 | 392 | 393 | class SpeechToTextDatasetCreator(object): 394 | # mandatory columns 395 | KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" 396 | KEY_TGT_TEXT = "tgt_text" 397 | # optional columns 398 | KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" 399 | KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" 400 | # default values 401 | DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" 402 | 403 | @classmethod 404 | def _from_list( 405 | cls, 406 | split_name: str, 407 | is_train_split, 408 | samples: List[Dict], 409 | cfg: S2TDataConfig, 410 | tgt_dict, 411 | pre_tokenizer, 412 | bpe_tokenizer, 413 | n_frames_per_step, 414 | speaker_to_id, 415 | ) -> SpeechToTextDataset: 416 | audio_root = Path(cfg.audio_root) 417 | ids = [s[cls.KEY_ID] for s in samples] 418 | audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] 419 | n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] 420 | tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] 421 | src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] 422 | speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] 423 | src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] 424 | tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] 425 | return SpeechToTextDataset( 426 | split_name, 427 | is_train_split, 428 | cfg, 429 | audio_paths, 430 | n_frames, 431 | src_texts=src_texts, 432 | tgt_texts=tgt_texts, 433 | speakers=speakers, 434 | src_langs=src_langs, 435 | tgt_langs=tgt_langs, 436 | ids=ids, 437 | tgt_dict=tgt_dict, 438 | pre_tokenizer=pre_tokenizer, 439 | bpe_tokenizer=bpe_tokenizer, 440 | n_frames_per_step=n_frames_per_step, 441 | speaker_to_id=speaker_to_id, 442 | ) 443 | 444 | @classmethod 445 | def get_size_ratios( 446 | cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 447 | ) -> List[float]: 448 | """Size ratios for temperature-based sampling 449 | (https://arxiv.org/abs/1907.05019)""" 450 | 451 | id_to_lp, lp_to_sz = {}, defaultdict(int) 452 | for ds in datasets: 453 | lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} 454 | assert len(lang_pairs) == 1 455 | lang_pair = list(lang_pairs)[0] 456 | id_to_lp[ds.split] = lang_pair 457 | lp_to_sz[lang_pair] += sum(ds.n_frames) 458 | 459 | sz_sum = sum(v for v in lp_to_sz.values()) 460 | lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} 461 | lp_to_tgt_prob = {k: v ** alpha for k, v in lp_to_prob.items()} 462 | prob_sum = sum(v for v in lp_to_tgt_prob.values()) 463 | lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} 464 | lp_to_sz_ratio = { 465 | k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() 466 | } 467 | size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] 468 | 469 | p_formatted = { 470 | k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz 471 | } 472 | logger.info(f"sampling probability balancing: {p_formatted}") 473 | sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} 474 | logger.info(f"balanced sampling size ratio: {sr_formatted}") 475 | return size_ratio 476 | 477 | @classmethod 478 | def _load_samples_from_tsv(cls, root: str, split: str): 479 | tsv_path = Path(root) / f"{split}.tsv" 480 | if not tsv_path.is_file(): 481 | raise FileNotFoundError(f"Dataset not found: {tsv_path}") 482 | with open(tsv_path) as f: 483 | reader = csv.DictReader( 484 | f, 485 | delimiter="\t", 486 | quotechar=None, 487 | doublequote=False, 488 | lineterminator="\n", 489 | quoting=csv.QUOTE_NONE, 490 | ) 491 | samples = [dict(e) for e in reader] 492 | if len(samples) == 0: 493 | raise ValueError(f"Empty manifest: {tsv_path}") 494 | return samples 495 | 496 | @classmethod 497 | def _from_tsv( 498 | cls, 499 | root: str, 500 | cfg: S2TDataConfig, 501 | split: str, 502 | tgt_dict, 503 | is_train_split: bool, 504 | pre_tokenizer, 505 | bpe_tokenizer, 506 | n_frames_per_step, 507 | speaker_to_id, 508 | ) -> SpeechToTextDataset: 509 | samples = cls._load_samples_from_tsv(root, split) 510 | return cls._from_list( 511 | split, 512 | is_train_split, 513 | samples, 514 | cfg, 515 | tgt_dict, 516 | pre_tokenizer, 517 | bpe_tokenizer, 518 | n_frames_per_step, 519 | speaker_to_id, 520 | ) 521 | 522 | @classmethod 523 | def from_tsv( 524 | cls, 525 | root: str, 526 | cfg: S2TDataConfig, 527 | splits: str, 528 | tgt_dict, 529 | pre_tokenizer, 530 | bpe_tokenizer, 531 | is_train_split: bool, 532 | epoch: int, 533 | seed: int, 534 | n_frames_per_step: int = 1, 535 | speaker_to_id=None, 536 | ) -> SpeechToTextDataset: 537 | datasets = [ 538 | cls._from_tsv( 539 | root, 540 | cfg, 541 | split, 542 | tgt_dict, 543 | is_train_split, 544 | pre_tokenizer, 545 | bpe_tokenizer, 546 | n_frames_per_step, 547 | speaker_to_id, 548 | ) 549 | for split in splits.split(",") 550 | ] 551 | 552 | if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: 553 | # temperature-based sampling 554 | size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) 555 | datasets = [ 556 | ResamplingDataset( 557 | d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) 558 | ) 559 | for r, d in zip(size_ratios, datasets) 560 | ] 561 | 562 | return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] -------------------------------------------------------------------------------- /cress/tasks/speech_and_text_translation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from email.policy import default 7 | import torch 8 | import json 9 | import os 10 | import itertools 11 | import logging 12 | import numpy as np 13 | from pathlib import Path 14 | from argparse import Namespace 15 | from typing_extensions import Concatenate 16 | 17 | from fairseq import utils, metrics 18 | from fairseq.data import Dictionary, encoders 19 | from fairseq.data.iterators import GroupedEpochBatchIterator 20 | from fairseq.data import ( 21 | AppendTokenDataset, 22 | ConcatDataset, 23 | LanguagePairDataset, 24 | PrependTokenDataset, 25 | StripTokenDataset, 26 | TruncateDataset, 27 | data_utils, 28 | encoders, 29 | indexed_dataset, 30 | ) 31 | from fairseq.data.audio.multi_modality_dataset import ( 32 | MultiModalityDataset, 33 | ModalityDatasetItem, 34 | ) 35 | from fairseq.data.audio.speech_to_text_dataset import ( 36 | S2TDataConfig, 37 | SpeechToTextDataset, 38 | get_features_or_waveform, 39 | ) 40 | from cress.datasets.speech_and_text_translation_dataset import ( 41 | SpeechAndTextTranslationDataset, 42 | SpeechAndTextTranslationDatasetCreator, 43 | ) 44 | from fairseq.tasks import LegacyFairseqTask, register_task 45 | from fairseq import search 46 | 47 | 48 | logger = logging.getLogger(__name__) 49 | 50 | EVAL_BLEU_ORDER = 4 51 | 52 | @register_task("speech_and_text_translation") 53 | class SpeechAndTextTranslationTask(LegacyFairseqTask): 54 | @classmethod 55 | def add_args(cls, parser): 56 | parser.add_argument("data", help="manifest root path") 57 | parser.add_argument("--text-data", help="manifest root path for text data") 58 | parser.add_argument( 59 | "--config-yaml", 60 | type=str, 61 | default="config.yaml", 62 | help="Configuration YAML filename (under manifest root)", 63 | ) 64 | parser.add_argument( 65 | "--max-audio-positions", 66 | default=900000, 67 | type=int, 68 | metavar="N", 69 | help="max number of tokens in the source sequence", 70 | ) 71 | parser.add_argument( 72 | "--max-source-positions", 73 | default=512, 74 | type=int, 75 | metavar="N", 76 | help="max number of tokens in the source sequence", 77 | ) 78 | parser.add_argument( 79 | "--max-target-positions", 80 | default=512, 81 | type=int, 82 | metavar="N", 83 | help="max number of tokens in the target sequence", 84 | ) 85 | parser.add_argument( 86 | "--max-tokens-text", 87 | type=int, 88 | metavar="N", 89 | help="maximum tokens for encoder text input ", 90 | ) 91 | parser.add_argument( 92 | "--batch-size-text", 93 | type=int, 94 | metavar="N", 95 | help="batch size for encoder text input ", 96 | ) 97 | parser.add_argument( 98 | "--st-training", 99 | action="store_true", 100 | help="speech translation training" 101 | ) 102 | parser.add_argument( 103 | "--ext-mt-training", 104 | action="store_true", 105 | help="external machine transaltion training", 106 | ) 107 | parser.add_argument( 108 | "--tgt-lang", 109 | type=str, 110 | help="target language" 111 | ) 112 | parser.add_argument( 113 | "--st-sample-ratio", 114 | type=float, 115 | default=1.0, 116 | help="sample ratio of st dataset" 117 | ) 118 | parser.add_argument( 119 | "--mt-sample-ratio", 120 | type=float, 121 | default=1.0, 122 | help="sample ratio of ext mt dataset" 123 | ) 124 | parser.add_argument( 125 | "--update-mix-data", 126 | action="store_true", 127 | help="use mixed data in one update when update-freq > 1", 128 | ) 129 | # options for reporting BLEU during validation 130 | parser.add_argument( 131 | "--eval-bleu", 132 | action="store_true", 133 | help="evaluation with BLEU scores", 134 | ) 135 | parser.add_argument( 136 | "--eval-bleu-detok", 137 | type=str, 138 | default="space", 139 | help="detokenize before computing BLEU (e.g., 'moses'); " 140 | "required if using --eval-bleu; use 'space' to " 141 | "disable detokenization; see fairseq.data.encoders " 142 | "for other options", 143 | ) 144 | parser.add_argument( 145 | "--eval-bleu-detok-args", 146 | type=str, 147 | metavar="JSON", 148 | help="args for building the tokenizer, if needed", 149 | ) 150 | parser.add_argument( 151 | "--eval-tokenized-bleu", 152 | action="store_true", 153 | default=False, 154 | help="compute tokenized BLEU instead of sacrebleu", 155 | ) 156 | parser.add_argument( 157 | "--eval-bleu-remove-bpe", 158 | nargs="?", 159 | const="@@ ", 160 | default=None, 161 | help="remove BPE before computing BLEU", 162 | ) 163 | parser.add_argument( 164 | "--eval-bleu-args", 165 | type=str, 166 | metavar="JSON", 167 | help="generation args for BLUE scoring, " 168 | "e.g., '{\"beam\": 4, \"lenpen\": 0.6}'", 169 | ) 170 | parser.add_argument( 171 | "--eval-bleu-print-samples", 172 | action="store_true", 173 | help="print sample generations during validation", 174 | ) 175 | parser.add_argument( 176 | "--eval-bleu-bpe", 177 | type=str, 178 | metavar="BPE", 179 | default=None, 180 | help="args for building the bpe, if needed", 181 | ) 182 | parser.add_argument( 183 | "--eval-bleu-bpe-path", 184 | type=str, 185 | metavar='BPE', 186 | help="args for building the bpe, if needed", 187 | ) 188 | 189 | def __init__(self, args, src_dict, tgt_dict): 190 | super().__init__(args) 191 | self.src_dict = src_dict 192 | self.tgt_dict = tgt_dict 193 | self.data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) 194 | self.speaker_to_id = self._get_speaker_to_id() 195 | 196 | self.pre_tokenizer = self.build_tokenizer(self.args) 197 | self.bpe_tokenizer = self.build_bpe(self.args) 198 | 199 | def _get_speaker_to_id(self): 200 | speaker_to_id = None 201 | speaker_set_filename = self.data_cfg.config.get("speaker_set_filename") 202 | if speaker_set_filename is not None: 203 | speaker_set_path = Path(self.args.data) / speaker_set_filename 204 | with open(speaker_set_path) as f: 205 | speaker_to_id = {r.strip(): i for i, r in enumerate(f)} 206 | return speaker_to_id 207 | 208 | @classmethod 209 | def setup_task(cls, args, **kwargs): 210 | data_cfg = S2TDataConfig(Path(args.data) / args.config_yaml) 211 | dict_path = Path(args.data) / data_cfg.vocab_filename 212 | if not dict_path.is_file(): 213 | raise FileNotFoundError(f"Dict not found: {dict_path.as_posix()}") 214 | src_dict = tgt_dict = Dictionary.load(dict_path.as_posix()) 215 | logger.info( 216 | f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" 217 | ) 218 | 219 | if getattr(args, "train_subset", None) is not None: 220 | if not all(s.startswith("train") for s in args.train_subset.split(",")): 221 | raise ValueError('Train splits should be named like "train*".') 222 | return cls(args, src_dict, tgt_dict) 223 | 224 | def build_criterion(self, args): 225 | from fairseq import criterions 226 | 227 | if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: 228 | raise ValueError( 229 | 'Please set "--ignore-prefix-size 1" since ' 230 | "target language ID token is prepended as BOS." 231 | ) 232 | return criterions.build_criterion(args, self) 233 | 234 | def load_dataset(self, split, epoch=1, combine=False, **kwargs): 235 | is_train_split = split.startswith("train") 236 | concat_dataset = [] 237 | assert self.args.st_training or self.args.ext_mt_training 238 | if self.args.st_training: 239 | st_dataset = self.load_st_dataset(split, epoch) 240 | concat_dataset.append(ModalityDatasetItem( 241 | "st", 242 | st_dataset, 243 | [self.args.max_audio_positions, self.args.max_source_positions, self.args.max_target_positions], 244 | self.args.max_tokens, 245 | self.args.batch_size, 246 | )) 247 | if self.args.ext_mt_training and (is_train_split or not self.args.st_training): 248 | mt_dataset = self.load_mt_dataset(split) 249 | concat_dataset.append(ModalityDatasetItem( 250 | "ext_mt", 251 | mt_dataset, 252 | [self.args.max_source_positions, self.args.max_target_positions], 253 | self.args.max_tokens_text, 254 | self.args.batch_size_text, 255 | )) 256 | 257 | self.datasets[split] = MultiModalityDataset(concat_dataset) 258 | 259 | def load_st_dataset(self, split, epoch): 260 | is_train_split = split.startswith("train") 261 | return SpeechAndTextTranslationDatasetCreator.from_tsv( 262 | self.args.data, 263 | self.data_cfg, 264 | split, 265 | self.tgt_dict, 266 | self.pre_tokenizer, 267 | self.bpe_tokenizer, 268 | is_train_split=is_train_split, 269 | epoch=epoch, 270 | seed=self.args.seed, 271 | speaker_to_id=self.speaker_to_id, 272 | ) 273 | 274 | def load_mt_dataset(self, split): 275 | if split == "dev": 276 | split = "valid" 277 | return load_langpair_dataset( 278 | self.args.text_data, 279 | split, 280 | "en", 281 | self.src_dict, 282 | self.args.tgt_lang, 283 | self.tgt_dict, 284 | combine=True, 285 | dataset_impl=None, 286 | upsample_primary=1, 287 | left_pad_source=False, 288 | left_pad_target=False, 289 | remove_eos_from_source=True, 290 | max_source_positions=self.args.max_source_positions, 291 | max_target_positions=self.args.max_target_positions, 292 | load_alignments=False, 293 | truncate_source=False, 294 | ) 295 | 296 | @property 297 | def target_dictionary(self): 298 | return self.tgt_dict 299 | 300 | @property 301 | def source_dictionary(self): 302 | return self.src_dict 303 | 304 | def max_positions(self): 305 | return self.args.max_audio_positions, self.args.max_source_positions, self.args.max_target_positions 306 | 307 | def build_model(self, args): 308 | args.input_feat_per_channel = self.data_cfg.input_feat_per_channel 309 | args.input_channels = self.data_cfg.input_channels 310 | args.speaker_to_id = self.speaker_to_id 311 | model = super(SpeechAndTextTranslationTask, self).build_model(args) 312 | 313 | if getattr(args, "eval_bleu", False): 314 | assert getattr(args, "eval_bleu_detok", None) is not None, ( 315 | "--eval-bleu-detok is required if using --eval-bleu; " 316 | "try --eval-bleu-detok=moses (or --eval-bleu-detok=space " 317 | "to disable detokenization, e.g., when using sentencepiece)" 318 | ) 319 | gen_args = json.loads(getattr(args, "eval_bleu_args", "{}") or "{}") 320 | self.sequence_generator = self.build_generator([model], Namespace(**gen_args)) 321 | 322 | return model 323 | 324 | def build_tokenizer(self, args): 325 | logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") 326 | return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) 327 | 328 | def build_bpe(self, args): 329 | logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") 330 | return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) 331 | 332 | def get_interactive_tokens_and_lengths(self, lines, encode_fn): 333 | n_frames = [get_features_or_waveform(p).shape[0] for p in lines] 334 | return lines, n_frames 335 | 336 | def build_dataset_for_inference(self, src_tokens, src_lengths, **kwargs): 337 | return SpeechAndTextTranslationDataset( 338 | "interactive", False, self.data_cfg, src_tokens, src_lengths 339 | ) 340 | 341 | def begin_epoch(self, epoch, model): 342 | model.set_epoch(epoch) 343 | 344 | def get_batch_iterator( 345 | self, 346 | dataset, 347 | max_tokens=None, 348 | max_sentences=None, 349 | max_positions=None, 350 | ignore_invalid_inputs=False, 351 | required_batch_size_multiple=1, 352 | seed=1, 353 | num_shards=1, 354 | shard_id=0, 355 | num_workers=0, 356 | epoch=0, 357 | data_buffer_size=0, 358 | disable_iterator_cache=False, 359 | skip_remainder_batch=False, 360 | grouped_shuffling=False, 361 | update_epoch_batch_itr=False, 362 | ): 363 | num_dataset = len(dataset.datasets) 364 | if num_dataset == 1: 365 | mult_ratio = [1.0] 366 | elif num_dataset == 2: 367 | mult_ratio = [ 368 | self.args.st_sample_ratio, 369 | self.args.mt_sample_ratio, 370 | ] 371 | 372 | # initialize the dataset with the correct starting epoch 373 | dataset.set_epoch(epoch) 374 | 375 | batch_samplers = dataset.get_batch_samplers( 376 | mult_ratio, required_batch_size_multiple, seed 377 | ) 378 | 379 | # return a reusable, sharded iterator 380 | epoch_iter = GroupedEpochBatchIterator( 381 | dataset=dataset, 382 | collate_fn=dataset.collater, 383 | batch_samplers=batch_samplers, 384 | seed=seed, 385 | num_shards=num_shards, 386 | shard_id=shard_id, 387 | num_workers=num_workers, 388 | epoch=epoch, 389 | # mult_rate=1 if self.args.update_mix_data else max(self.args.update_freq), 390 | mult_rate=1, 391 | buffer_size=data_buffer_size, 392 | skip_remainder_batch=skip_remainder_batch, 393 | ) 394 | self.dataset_to_epoch_iter[dataset] = {} # refresh it every epoch 395 | return epoch_iter 396 | 397 | def build_generator( 398 | self, 399 | models, 400 | args, 401 | seq_gen_cls=None, 402 | extra_gen_cls_kwargs=None, 403 | debug=False, 404 | ): 405 | if getattr(self.args, "debug_task", "st") == "st": 406 | if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: 407 | raise ValueError( 408 | 'Please set "--prefix-size 1" since ' 409 | "target language ID token is prepended as BOS." 410 | ) 411 | lang_token_ids = { 412 | i 413 | for s, i in self.tgt_dict.indices.items() 414 | if SpeechToTextDataset.is_lang_tag(s) 415 | } 416 | 417 | if extra_gen_cls_kwargs is None: 418 | extra_gen_cls_kwargs = {} 419 | extra_gen_cls_kwargs["symbols_to_strip_from_output"] = lang_token_ids 420 | 421 | eos_token = ( 422 | args.eos_token 423 | if "eos_token" in args and args.eos_token is not None 424 | else self.data_cfg.config.get("eos_token", None) 425 | ) 426 | 427 | if self.data_cfg.prepend_bos_and_append_tgt_lang_tag and not eos_token: 428 | raise Warning( 429 | "Please provide --eos_token to replace eos in sequence generator" 430 | ) 431 | 432 | eos_id = self.tgt_dict.index(eos_token) if eos_token else None 433 | extra_gen_cls_kwargs["eos"] = eos_id 434 | 435 | if debug: 436 | return self.build_generator_debug( 437 | models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs 438 | ) 439 | else: 440 | return super().build_generator( 441 | models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs 442 | ) 443 | 444 | def build_generator_debug( 445 | self, 446 | models, 447 | args, 448 | seq_gen_cls=None, 449 | extra_gen_cls_kwargs=None, 450 | prefix_allowed_tokens_fn=None, 451 | ): 452 | """ 453 | Build a :class:`~fairseq.SequenceGenerator` instance for this 454 | task. 455 | 456 | Args: 457 | models (List[~fairseq.models.FairseqModel]): ensemble of models 458 | args (fairseq.dataclass.configs.GenerationConfig): 459 | configuration object (dataclass) for generation 460 | extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass 461 | through to SequenceGenerator 462 | prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): 463 | If provided, this function constrains the beam search to 464 | allowed tokens only at each step. The provided function 465 | should take 2 arguments: the batch ID (`batch_id: int`) 466 | and a unidimensional tensor of token ids (`inputs_ids: 467 | torch.Tensor`). It has to return a `List[int]` with the 468 | allowed tokens for the next generation step conditioned 469 | on the previously generated tokens (`inputs_ids`) and 470 | the batch ID (`batch_id`). This argument is useful for 471 | constrained generation conditioned on the prefix, as 472 | described in "Autoregressive Entity Retrieval" 473 | (https://arxiv.org/abs/2010.00904) and 474 | https://github.com/facebookresearch/GENRE. 475 | """ 476 | if getattr(args, "score_reference", False): 477 | from fairseq.sequence_scorer import SequenceScorer 478 | 479 | return SequenceScorer( 480 | self.target_dictionary, 481 | compute_alignment=getattr(args, "print_alignment", False), 482 | ) 483 | 484 | from fairseq.sequence_generator_debug import ( 485 | SequenceGenerator, 486 | SequenceGeneratorWithAlignment, 487 | ) 488 | 489 | # Choose search strategy. Defaults to Beam Search. 490 | sampling = getattr(args, "sampling", False) 491 | sampling_topk = getattr(args, "sampling_topk", -1) 492 | sampling_topp = getattr(args, "sampling_topp", -1.0) 493 | diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) 494 | diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) 495 | match_source_len = getattr(args, "match_source_len", False) 496 | diversity_rate = getattr(args, "diversity_rate", -1) 497 | constrained = getattr(args, "constraints", False) 498 | if prefix_allowed_tokens_fn is None: 499 | prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) 500 | if ( 501 | sum( 502 | int(cond) 503 | for cond in [ 504 | sampling, 505 | diverse_beam_groups > 0, 506 | match_source_len, 507 | diversity_rate > 0, 508 | ] 509 | ) 510 | > 1 511 | ): 512 | raise ValueError("Provided Search parameters are mutually exclusive.") 513 | assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" 514 | assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" 515 | 516 | if sampling: 517 | search_strategy = search.Sampling( 518 | self.target_dictionary, sampling_topk, sampling_topp 519 | ) 520 | elif diverse_beam_groups > 0: 521 | search_strategy = search.DiverseBeamSearch( 522 | self.target_dictionary, diverse_beam_groups, diverse_beam_strength 523 | ) 524 | elif match_source_len: 525 | # this is useful for tagging applications where the output 526 | # length should match the input length, so we hardcode the 527 | # length constraints for simplicity 528 | search_strategy = search.LengthConstrainedBeamSearch( 529 | self.target_dictionary, 530 | min_len_a=1, 531 | min_len_b=0, 532 | max_len_a=1, 533 | max_len_b=0, 534 | ) 535 | elif diversity_rate > -1: 536 | search_strategy = search.DiverseSiblingsSearch( 537 | self.target_dictionary, diversity_rate 538 | ) 539 | elif constrained: 540 | search_strategy = search.LexicallyConstrainedBeamSearch( 541 | self.target_dictionary, args.constraints 542 | ) 543 | elif prefix_allowed_tokens_fn: 544 | search_strategy = search.PrefixConstrainedBeamSearch( 545 | self.target_dictionary, prefix_allowed_tokens_fn 546 | ) 547 | else: 548 | search_strategy = search.BeamSearch(self.target_dictionary) 549 | 550 | extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} 551 | if seq_gen_cls is None: 552 | if getattr(args, "print_alignment", False): 553 | seq_gen_cls = SequenceGeneratorWithAlignment 554 | extra_gen_cls_kwargs["print_alignment"] = args.print_alignment 555 | else: 556 | seq_gen_cls = SequenceGenerator 557 | 558 | return seq_gen_cls( 559 | models, 560 | self.target_dictionary, 561 | beam_size=getattr(args, "beam", 5), 562 | max_len_a=getattr(args, "max_len_a", 0), 563 | max_len_b=getattr(args, "max_len_b", 200), 564 | min_len=getattr(args, "min_len", 1), 565 | normalize_scores=(not getattr(args, "unnormalized", False)), 566 | len_penalty=getattr(args, "lenpen", 1), 567 | unk_penalty=getattr(args, "unkpen", 0), 568 | temperature=getattr(args, "temperature", 1.0), 569 | match_source_len=getattr(args, "match_source_len", False), 570 | no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), 571 | search_strategy=search_strategy, 572 | **extra_gen_cls_kwargs, 573 | ) 574 | 575 | def inference_step( 576 | self, generator, models, sample, prefix_tokens=None, constraints=None 577 | ): 578 | if getattr(self.args, "debug_task", "st") == "st" and "audio" in sample["net_input"]: 579 | net_input = { 580 | "src_tokens": sample["net_input"]["audio"], 581 | "src_lengths": sample["net_input"]["audio_lengths"], 582 | "mode": "st", 583 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 584 | } 585 | sample["net_input"] = net_input 586 | elif getattr(self.args, "debug_task", "st") == "mt": 587 | net_input = { 588 | "src_tokens": sample["net_input"]["source"], 589 | "src_lengths": sample["net_input"]["source_lengths"], 590 | "mode": "mt", 591 | "prev_output_tokens": sample["net_input"]["prev_output_tokens"], 592 | } 593 | sample["net_input"] = net_input 594 | with torch.no_grad(): 595 | return generator.generate( 596 | models, sample, prefix_tokens=prefix_tokens, constraints=constraints 597 | ) 598 | 599 | def valid_step(self, sample, model, criterion): 600 | loss, sample_size, logging_output = super().valid_step(sample, model, criterion) 601 | if self.args.eval_bleu: 602 | bleu = self._inference_with_bleu(self.sequence_generator, sample, model) 603 | logging_output["_bleu_sys_len"] = bleu.sys_len 604 | logging_output["_bleu_ref_len"] = bleu.ref_len 605 | # we split counts into separate entries so that they can be 606 | # summed efficiently across workers using fast-stat-sync 607 | assert len(bleu.counts) == EVAL_BLEU_ORDER 608 | for i in range(EVAL_BLEU_ORDER): 609 | logging_output["_bleu_counts_" + str(i)] = bleu.counts[i] 610 | logging_output["_bleu_totals_" + str(i)] = bleu.totals[i] 611 | return loss, sample_size, logging_output 612 | 613 | def reduce_metrics(self, logging_outputs, criterion): 614 | super().reduce_metrics(logging_outputs, criterion) 615 | if self.args.eval_bleu: 616 | def sum_logs(key): 617 | if key in logging_outputs[0]: 618 | return sum(log[key].cpu().numpy() for log in logging_outputs) 619 | return sum(log.get(key, 0) for log in logging_outputs) 620 | 621 | counts, totals = [], [] 622 | for i in range(EVAL_BLEU_ORDER): 623 | counts.append(sum_logs("_bleu_counts_" + str(i))) 624 | totals.append(sum_logs("_bleu_totals_" + str(i))) 625 | 626 | if max(totals) > 0: 627 | # log counts as numpy arrays -- log_scalar will sum them correctly 628 | metrics.log_scalar("_bleu_counts", np.array(counts)) 629 | metrics.log_scalar("_bleu_totals", np.array(totals)) 630 | metrics.log_scalar("_bleu_sys_len", sum_logs("_bleu_sys_len")) 631 | metrics.log_scalar("_bleu_ref_len", sum_logs("_bleu_ref_len")) 632 | 633 | def compute_bleu(meters): 634 | import inspect 635 | import sacrebleu 636 | 637 | fn_sig = inspect.getfullargspec(sacrebleu.compute_bleu)[0] 638 | if "smooth_method" in fn_sig: 639 | smooth = {"smooth_method": "exp"} 640 | else: 641 | smooth = {"smooth": "exp"} 642 | bleu = sacrebleu.compute_bleu( 643 | correct=meters["_bleu_counts"].sum, 644 | total=meters["_bleu_totals"].sum, 645 | sys_len=meters["_bleu_sys_len"].sum, 646 | ref_len=meters["_bleu_ref_len"].sum, 647 | **smooth 648 | ) 649 | return round(bleu.score, 2) 650 | 651 | metrics.log_derived("bleu", compute_bleu) 652 | 653 | def _inference_with_bleu(self, generator, sample, model): 654 | import sacrebleu 655 | 656 | def decode(toks, escape_unk=False): 657 | s = self.tgt_dict.string( 658 | toks.int().cpu(), 659 | self.args.eval_bleu_remove_bpe, 660 | unk_string=("UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"), 661 | ) 662 | if self.bpe_tokenizer is not None: 663 | s = self.bpe_tokenizer.decode(s) 664 | if self.pre_tokenizer is not None: 665 | s = self.pre_tokenizer.decode(s) 666 | return s 667 | 668 | gen_out = self.inference_step(generator, [model], sample, prefix_tokens=None) 669 | hyps, refs = [], [] 670 | for i in range(len(gen_out)): 671 | hyp = decode(gen_out[i][0]["tokens"]) 672 | ref = decode( 673 | utils.strip_pad(sample["target"][i], self.tgt_dict.pad()), 674 | escape_unk=True, # don't count as matches to the hypo 675 | ) 676 | hyps.append(hyp) 677 | refs.append(ref) 678 | 679 | if self.args.eval_bleu_print_samples: 680 | logger.info("example hypothesis: " + hyps[0]) 681 | logger.info("example reference: " + refs[0]) 682 | if self.args.eval_tokenized_bleu: 683 | return sacrebleu.corpus_bleu(hyps, [refs], tokenize="none") 684 | else: 685 | return sacrebleu.corpus_bleu(hyps, [refs]) 686 | 687 | 688 | def load_langpair_dataset( 689 | data_path, 690 | split, 691 | src, 692 | src_dict, 693 | tgt, 694 | tgt_dict, 695 | combine, 696 | dataset_impl, 697 | upsample_primary, 698 | left_pad_source, 699 | left_pad_target, 700 | remove_eos_from_source, 701 | max_source_positions, 702 | max_target_positions, 703 | prepend_bos=False, 704 | load_alignments=False, 705 | truncate_source=False, 706 | append_source_id=False, 707 | num_buckets=0, 708 | shuffle=True, 709 | pad_to_multiple=1, 710 | prepend_bos_src=None, 711 | ): 712 | def split_exists(split, src, tgt, lang, data_path): 713 | filename = os.path.join(data_path, "{}.{}-{}.{}".format(split, src, tgt, lang)) 714 | return indexed_dataset.dataset_exists(filename, impl=dataset_impl) 715 | 716 | src_datasets = [] 717 | tgt_datasets = [] 718 | 719 | for k in itertools.count(): 720 | split_k = split + (str(k) if k > 0 else "") 721 | 722 | # infer langcode 723 | if split_exists(split_k, src, tgt, src, data_path): 724 | prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, src, tgt)) 725 | elif split_exists(split_k, tgt, src, src, data_path): 726 | prefix = os.path.join(data_path, "{}.{}-{}.".format(split_k, tgt, src)) 727 | else: 728 | if k > 0: 729 | break 730 | else: 731 | raise FileNotFoundError( 732 | "Dataset not found: {} ({})".format(split, data_path) 733 | ) 734 | 735 | src_dataset = data_utils.load_indexed_dataset( 736 | prefix + src, src_dict, dataset_impl 737 | ) 738 | if truncate_source: 739 | src_dataset = AppendTokenDataset( 740 | TruncateDataset( 741 | StripTokenDataset(src_dataset, src_dict.eos()), 742 | max_source_positions - 1, 743 | ), 744 | src_dict.eos(), 745 | ) 746 | src_datasets.append(src_dataset) 747 | 748 | tgt_dataset = data_utils.load_indexed_dataset( 749 | prefix + tgt, tgt_dict, dataset_impl 750 | ) 751 | if tgt_dataset is not None: 752 | tgt_datasets.append(tgt_dataset) 753 | 754 | logger.info( 755 | "{} {} {}-{} {} examples".format( 756 | data_path, split_k, src, tgt, len(src_datasets[-1]) 757 | ) 758 | ) 759 | 760 | if not combine: 761 | break 762 | 763 | assert len(src_datasets) == len(tgt_datasets) or len(tgt_datasets) == 0 764 | 765 | if len(src_datasets) == 1: 766 | src_dataset = src_datasets[0] 767 | tgt_dataset = tgt_datasets[0] if len(tgt_datasets) > 0 else None 768 | else: 769 | sample_ratios = [1] * len(src_datasets) 770 | sample_ratios[0] = upsample_primary 771 | src_dataset = ConcatDataset(src_datasets, sample_ratios) 772 | if len(tgt_datasets) > 0: 773 | tgt_dataset = ConcatDataset(tgt_datasets, sample_ratios) 774 | else: 775 | tgt_dataset = None 776 | 777 | if prepend_bos: 778 | assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index") 779 | src_dataset = PrependTokenDataset(src_dataset, src_dict.bos()) 780 | if tgt_dataset is not None: 781 | tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos()) 782 | elif prepend_bos_src is not None: 783 | logger.info(f"prepending src bos: {prepend_bos_src}") 784 | src_dataset = PrependTokenDataset(src_dataset, prepend_bos_src) 785 | 786 | eos = None 787 | if append_source_id: 788 | src_dataset = AppendTokenDataset( 789 | src_dataset, src_dict.index("[{}]".format(src)) 790 | ) 791 | if tgt_dataset is not None: 792 | tgt_dataset = AppendTokenDataset( 793 | tgt_dataset, tgt_dict.index("[{}]".format(tgt)) 794 | ) 795 | eos = tgt_dict.index("[{}]".format(tgt)) 796 | 797 | align_dataset = None 798 | if load_alignments: 799 | align_path = os.path.join(data_path, "{}.align.{}-{}".format(split, src, tgt)) 800 | if indexed_dataset.dataset_exists(align_path, impl=dataset_impl): 801 | align_dataset = data_utils.load_indexed_dataset( 802 | align_path, None, dataset_impl 803 | ) 804 | 805 | tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None 806 | return LanguagePairDataset( 807 | src_dataset, 808 | src_dataset.sizes, 809 | src_dict, 810 | tgt_dataset, 811 | tgt_dataset_sizes, 812 | tgt_dict, 813 | left_pad_source=left_pad_source, 814 | left_pad_target=left_pad_target, 815 | remove_eos_from_source=remove_eos_from_source, 816 | align_dataset=align_dataset, 817 | eos=eos, 818 | num_buckets=num_buckets, 819 | shuffle=shuffle, 820 | pad_to_multiple=pad_to_multiple, 821 | ) 822 | --------------------------------------------------------------------------------