├── .editorconfig ├── .flake8 ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README-NeurST.md ├── README.md ├── configs ├── inference.yml ├── pretrain.yml └── tune-on-en2de.yml ├── data ├── DataProcess.md ├── clean_data.py ├── data source.md ├── dataset │ ├── iwslt.sh │ └── opus-100.sh ├── downsample.py ├── get_parallel.py └── sentencePiece_train.py ├── examples ├── README.md ├── ctnmt │ ├── README.md │ └── example_configs │ │ ├── asy_distillation.yaml │ │ ├── dynamic_switch.yaml │ │ └── rate_schedule.yaml ├── iwslt21 │ ├── OFFLINE.md │ ├── SIMUL_TRANS.md │ └── scripts │ │ ├── evaluate_cascade.sh │ │ ├── evaluate_e2e.sh │ │ └── evaluate_mt.sh ├── prune_tune │ ├── README.md │ ├── scripts │ │ ├── prediction_args.yml │ │ ├── prepare-target-dataset-wp.sh │ │ ├── training_args.yml │ │ └── validation_args.yml │ └── src │ │ ├── mask_sequence_generator.py │ │ ├── partial_trainer.py │ │ └── partial_tuning_optimizer.py ├── quantization │ ├── README.md │ └── example_view_quant_weight.py ├── simultaneous_translation │ └── README.md ├── speech_transformer │ ├── augmented_librispeech │ │ ├── 01-download.sh │ │ ├── 02-audio_feature_extraction.sh │ │ ├── 03-preprocess.sh │ │ ├── README.md │ │ ├── RESULTS.md │ │ ├── asr_prediction_args.yml │ │ ├── asr_training_args.yml │ │ ├── asr_validation_args.yml │ │ ├── mt_prediction_args.yml │ │ ├── mt_training_args.yml │ │ ├── mt_validation_args.yml │ │ ├── st_prediction_args.yml │ │ ├── st_training_args.yml │ │ └── st_validation_args.yml │ └── must-c │ │ ├── 01-download.sh │ │ ├── 02-audio_feature_extraction.sh │ │ ├── 03-preprocess.sh │ │ ├── 03-preprocess_alone.sh │ │ ├── README.md │ │ ├── RESULTS.md │ │ ├── asr_prediction_args.yml │ │ ├── asr_training_args.yml │ │ ├── asr_validation_args.yml │ │ ├── mt_prediction_args.yml │ │ ├── mt_training_args.yml │ │ ├── mt_validation_args.yml │ │ ├── st_prediction_args.yml │ │ ├── st_training_args.yml │ │ └── st_validation_args.yml ├── translation │ ├── README.md │ ├── download_wmt14en2de.py │ ├── prediction_args.yml │ ├── prepare-wmt14en2de-bpe.sh │ ├── prepare-wmt14en2de-wp.sh │ ├── training_args.yml │ └── validation_args.yml └── weight_pruning │ └── README.md ├── neurst ├── __init__.py ├── __version__.py ├── cli │ ├── README.md │ ├── __init__.py │ ├── analysis │ │ ├── __init__.py │ │ ├── audio_tfrecord_analysis.py │ │ └── audio_transcript_length_ratio_analysis.py │ ├── avg_checkpoint.py │ ├── cascade_st.py │ ├── convert_checkpoint.py │ ├── create_tfrecords.py │ ├── extract_audio_transcripts.py │ ├── generate_vocab.py │ ├── inspect_checkpoint.py │ ├── process_text.py │ ├── run_exp.py │ ├── simuleval_cli.py │ ├── text_metric.py │ ├── view_registry.py │ └── view_tfrecord.py ├── criterions │ ├── __init__.py │ ├── criterion.py │ ├── joint_criterion.py │ ├── label_smoothed_cross_entropy.py │ └── label_smoothed_cross_entropy_with_kd.py ├── data │ ├── __init__.py │ ├── audio │ │ ├── __init__.py │ │ ├── feature_extractor.py │ │ ├── float_identity.py │ │ └── log_mel_fbank.py │ ├── data_pipelines │ │ ├── __init__.py │ │ ├── bert_data_pipeline.py │ │ ├── data_pipeline.py │ │ ├── gpt2_data_pipeline.py │ │ ├── multilingual_text_data_pipeline.py │ │ └── text_data_pipeline.py │ ├── dataset_utils.py │ ├── datasets │ │ ├── __init__.py │ │ ├── audio │ │ │ ├── __init__.py │ │ │ ├── audio_dataset.py │ │ │ ├── aug_librispeech.py │ │ │ ├── common_voice.py │ │ │ ├── iwslt.py │ │ │ ├── iwslt_tst.py │ │ │ ├── librispeech.py │ │ │ ├── mustc.py │ │ │ └── tedlium.py │ │ ├── data_sampler │ │ │ ├── __init__.py │ │ │ ├── data_sampler.py │ │ │ └── temperature_sampler.py │ │ ├── dataset.py │ │ ├── mixed_speech_text_dataset.py │ │ ├── mixed_train_dataset.py │ │ ├── mono_text_dataset.py │ │ ├── multilingual_translation_dataset.py │ │ ├── multiple_dataset.py │ │ ├── parallel_text_dataset.py │ │ └── text_gen_dataset.py │ └── text │ │ ├── __init__.py │ │ ├── bpe.py │ │ ├── character.py │ │ ├── huggingface_tokenizer.py │ │ ├── jieba_segment.py │ │ ├── moses_tokenizer.py │ │ ├── spm.py │ │ ├── subtokenizer.py │ │ ├── thai_tokenizer.py │ │ ├── tokenizer.py │ │ └── vocab.py ├── exps │ ├── __init__.py │ ├── base_experiment.py │ ├── evaluator.py │ ├── sequence_evaluator.py │ ├── sequence_generator.py │ ├── sequence_generator_savedmodel.py │ ├── trainer.py │ └── validation.py ├── layers │ ├── __init__.py │ ├── adapters │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── adapterEmb.py │ │ ├── adapterLayer.py │ │ └── adapterSerial.py │ ├── attentions │ │ ├── __init__.py │ │ ├── light_convolution_layer.py │ │ └── multi_head_attention.py │ ├── auto_pretrained_layer.py │ ├── common_layers.py │ ├── decoders │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── light_convolution_decoder.py │ │ ├── transformer_decoder.py │ │ └── transformer_decoder_CIAT.py │ ├── encoders │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── light_convolution_encoder.py │ │ ├── transformer_encoder.py │ │ └── transformer_encoder_CIAT.py │ ├── layer_utils.py │ ├── metric_layers │ │ ├── __init__.py │ │ ├── metric_layer.py │ │ └── token_metric_layers.py │ ├── modalities │ │ ├── __init__.py │ │ ├── audio_modalities.py │ │ └── text_modalities.py │ ├── quantization │ │ ├── __init__.py │ │ ├── quant_dense_layer.py │ │ └── quant_layers.py │ ├── search │ │ ├── __init__.py │ │ ├── beam_search.py │ │ ├── sampling.py │ │ └── sequence_search.py │ └── transformer_layers.py ├── metrics │ ├── __init__.py │ ├── bleu.py │ ├── compound_split_bleu.py │ ├── metric.py │ └── wer.py ├── models │ ├── __init__.py │ ├── bert.py │ ├── ctnmt_transformer.py │ ├── encoder_decoder_ensemble_model.py │ ├── encoder_decoder_model.py │ ├── gpt2.py │ ├── light_convolution_model.py │ ├── model.py │ ├── model_utils.py │ ├── speech_transformer.py │ ├── transformer.py │ ├── transformer_CIAT.py │ ├── waitk_transformer.py │ └── wav2vec2.py ├── optimizers │ ├── __init__.py │ ├── rate_schedule_optimizer.py │ └── schedules │ │ ├── __init__.py │ │ ├── inverse_sqrt_schedule.py │ │ ├── noam_schedule.py │ │ └── piecewise_schedule.py ├── sparsity │ ├── __init__.py │ ├── pruning_optimizer.py │ └── pruning_schedule.py ├── tasks │ ├── __init__.py │ ├── language_model.py │ ├── multilingual_translation.py │ ├── seq2seq.py │ ├── speech2text.py │ ├── task.py │ ├── translation.py │ └── waitk_translation.py ├── training │ ├── __init__.py │ ├── callbacks.py │ ├── criterion_validator.py │ ├── distribution_utils.py │ ├── gradaccum_keras_model.py │ ├── hvd_utils.py │ ├── revised_dynamic_loss_scale.py │ ├── seq_generation_validator.py │ ├── training_utils.py │ └── validator.py └── utils │ ├── __init__.py │ ├── activations.py │ ├── audio_lib.py │ ├── checkpoints.py │ ├── compat.py │ ├── configurable.py │ ├── converters │ ├── __init__.py │ ├── converter.py │ ├── fairseq_transformer.py │ ├── fairseq_transformer2.py │ ├── fairseq_wav2vec2.py │ ├── google_bert.py │ └── openai_gpt2.py │ ├── flags_core.py │ ├── hparams_sets.py │ ├── misc.py │ ├── registry.py │ ├── simuleval_agents │ ├── __init__.py │ └── simul_trans_text_agent.py │ └── userdef │ └── __init__.py ├── neurst_pt ├── __init__.py ├── layers │ ├── __init__.py │ ├── attentions │ │ ├── __init__.py │ │ └── multi_head_attention.py │ ├── common_layers.py │ ├── decoders │ │ ├── __init__.py │ │ ├── decoder.py │ │ └── transformer_decoder.py │ ├── encoders │ │ ├── __init__.py │ │ ├── encoder.py │ │ └── transformer_encoder.py │ ├── layer_utils.py │ └── modalities │ │ ├── __init__.py │ │ ├── audio_modalities.py │ │ └── text_modalities.py ├── models │ ├── __init__.py │ ├── encoder_decoder_model.py │ ├── model.py │ ├── model_utils.py │ ├── speech_transformer.py │ └── transformer.py └── utils │ ├── __init__.py │ └── activations.py ├── requirements.apt.txt ├── requirements.txt ├── run.sh ├── run_cli.sh ├── setup.py ├── tests ├── __init__.py ├── examples │ ├── codes.bpe4k.en │ ├── codes.bpe4k.zh │ ├── dev.example.en.txt │ ├── dev.example.zh.txt │ ├── example_create_seq2seq_tfrecrods.yml │ ├── example_eval_seq2seq.yml │ ├── example_predict_seq2seq.yml │ ├── example_train_gpt2.yml │ ├── example_train_seq2seq.yml │ ├── example_validator_gpt2.yml │ ├── example_validator_seq2seq.yml │ ├── train.example.en.tok.bpe.txt │ ├── train.example.zh.jieba.bpe.txt │ ├── train.tfrecords-00000-of-00004 │ ├── train.tfrecords-00001-of-00004 │ ├── train.tfrecords-00002-of-00004 │ ├── train.tfrecords-00003-of-00004 │ ├── vocab.en │ └── vocab.zh ├── neurst │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── text │ │ │ ├── __init__.py │ │ │ ├── bpe_test.py │ │ │ ├── jieba_segment_test.py │ │ │ ├── moses_tokenizer_test.py │ │ │ └── vocab_test.py │ │ └── text_data_pipeline_test.py │ ├── layers │ │ ├── __init__.py │ │ ├── attentions │ │ │ ├── __init__.py │ │ │ ├── multi_head_attention_pt_test.py │ │ │ └── multi_head_attention_test.py │ │ ├── common_layers_test.py │ │ ├── decoders │ │ │ ├── __init__.py │ │ │ └── transformer_decoder_test.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── transformer_encoder_test.py │ │ ├── modalities_test.py │ │ └── search │ │ │ ├── __init__.py │ │ │ └── beam_search_test.py │ ├── models │ │ ├── __init__.py │ │ ├── gpt2_test.py │ │ └── transformer_test.py │ └── utils │ │ └── __init__.py └── neurst_pt │ ├── __init__.py │ ├── decoders │ ├── __init__.py │ └── transformer_decoder_test.py │ ├── encoders │ ├── __init__.py │ └── transformer_encoder_test.py │ ├── layers │ ├── __init__.py │ ├── common_layers_test.py │ └── layer_utils_test.py │ ├── modalities │ ├── __init__.py │ ├── audio_modalities_test.py │ └── text_modalities_test.py │ └── models │ ├── __init__.py │ ├── speech_transformer_test.py │ └── transformer_test.py └── tox.ini /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | end_of_line = lf 5 | insert_final_newline = true 6 | max_line_length = 120 7 | 8 | [Makefile] 9 | indent_style = tab 10 | 11 | [*.py] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | [*.{js,ts,html}] 16 | indent_style = space 17 | indent_size = 2 18 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length=120 3 | exclude = 4 | venv/, 5 | venv_py/, 6 | .eggs, 7 | .tox, 8 | ignore = D400,D300,D205,D200,D105,D100,D101,D103,D107,W503,E129 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.swp 3 | *.pyc 4 | *.egg-info/ 5 | .eggs/ 6 | .idea/ 7 | .tox/ 8 | .pytest_cache/* 9 | venv/ 10 | venv_py/ 11 | .mypy_cache/ 12 | __pycache__/ 13 | /build 14 | /dist 15 | .debug/* 16 | .my_config/ -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | All notable changes to this project will be documented in this file. 3 | 4 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 5 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 6 | 7 | ## [Unreleased] 8 | ### Added 9 | - Instruction for 10 | - CTNMT (Yang et al., 2020) training 11 | - Prune-Tune (Liang et al., 2021) 12 | - dataset for IWSLT offline ST task 13 | - language model task and GPT-2 pretraining 14 | 15 | 16 | ### Changed 17 | 18 | 19 | ### Fixed 20 | 21 | 22 | ## [0.1.1] - 28th March, 2021 23 | ### Added 24 | - PyTorch version Transformer & SpeechTransformer model. 25 | - Audio extraction for CommonVoice/IWSLT. 26 | - Data sampler and dataset for multilingual machine translation 27 | - Mixed training dataset with data sampler. 28 | - Multilingual Translation task 29 | - Instruction for 30 | - training transformer models on WMT14 EN->DE 31 | - weight pruning 32 | - quantization aware training for transformer model 33 | 34 | ### Fixed 35 | - Compat with TensorFlow v2.4 36 | 37 | 38 | 39 | ## [0.1.0] - 25th Dec., 2020 40 | ### Added 41 | - Basic code structure for Encoder, Decoder, Model, DataPipeline, Tokenizer, Experiment, Metric, and Dataset. 42 | - (Model) Adds implementation of pre-norm/post-norm Transformer, Speech Transformer, BERT, GPT-2, and Wav2Vec2.0. 43 | - (Task) Adds implementation of sequence to sequence task and speech to text task (ASR, ST). 44 | - (DataPipeline, Tokenizer) Adds wrappers for commonly used tokenizers: moses, bpe, jieba, character, sentencepiece, etc. 45 | - (Dataset) Adds support for reading parallel corpus, speech corpora (libri-trans, MuST-C, and LibriSpeech), and TFRecords. 46 | - (Experiment) Adds implementation of common training procedure with mixed precision training and various distributed strategies (`MirroredStrategy`, `Horovod`, `Byteps`). 47 | - (Metric) Adds implementation of BLEU and WER metrics. 48 | - (Converter) Adds implementation of converting checkpoints from google BERT, OpenAI GPT-2, fairseq Transformer, and fairseq Wav2Vec2.0. 49 | - Add support for converting checkpoints from publicly 50 | - Beam search decoding and top-k/p sampling. 51 | - Supports averaging checkpoints, TFRecord generation, model restoring (see [cli/README.md](/neurst/cli/README.md)). 52 | - Step-by-step recipes for training an end-to-end speech translation model (see [examples/speech_to_text](/examples/speech_transformer)). 53 | 54 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL = bash 2 | 3 | all: install_dev isort isort_check lint test 4 | check: isort_check lint test 5 | 6 | install_dev: 7 | @pip install -e .[dev] >/dev/null 2>&1 8 | 9 | isort: 10 | @isort -s venv -s venv_py -s .tox -rc --atomic . 11 | 12 | isort_check: 13 | @isort -rc -s venv -s venv_py -s .tox -c . 14 | 15 | lint: 16 | @flake8 17 | 18 | test: 19 | @tox 20 | 21 | clean: 22 | @rm -rf .pytest_cache .tox bytedmypackage.egg-info 23 | @rm -rf tests/*.pyc tests/__pycache__ 24 | 25 | .IGNORE: install_dev 26 | .PHONY: all check install_dev isort isort_check lint test 27 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Counter-Interference Adapter for Multilingual Machine Translation 2 | 3 | ### Overview 4 | 5 | This is the (early-access) implementation of EMNLP'21 Finding paper [Counter-Interference Adapter for Multilingual Machine Translation](https://arxiv.org/abs/2104.08154) 6 | 7 | The code is based on [NeurST Toolkits](https://github.com/bytedance/neurst.git) 8 | 9 | #### Data prepare 10 | 11 | Please follow `data/DataProcess.md` to download and preprocess data. 12 | 13 | #### Training and Inference 14 | 15 | The major class for the paper is TransformerCIAT in `neurst/models/transformer_CIAT.py` 16 | 17 | The example config files are in `configs/`, please modify the file path in those yml files. 18 | 19 | Pretaining: 20 | ```bash 21 | python3 -m neurst.cli.run_exp.py --config configs/pretrain.yml 22 | ``` 23 | 24 | Tune on En-De with adapter: 25 | ```bash 26 | python3 -m neurst.cli.run_exp.py --config configs/tune-on-en2de.yml 27 | ``` 28 | Inference: 29 | 30 | ```bash 31 | python3 -m neurst.cli.run_exp.py --config configs/inference.yml 32 | ``` 33 | -------------------------------------------------------------------------------- /configs/inference.yml: -------------------------------------------------------------------------------- 1 | entry.class: SequenceGenerator 2 | entry.params: 3 | output_file: ${output path of hypo file} 4 | metric.class: BLEU 5 | search_method.class: beam_search 6 | search_method.params: 7 | beam_size: 4 8 | length_penalty: -1 9 | extra_decode_length: 20 10 | maximum_decode_length: 50 11 | 12 | dataset.class: parallel_text 13 | dataset.params: 14 | src_file: ${your source eval file} 15 | trg_file: ${your target eval file, optional} 16 | 17 | task.params: 18 | batch_size: 32 19 | -------------------------------------------------------------------------------- /configs/pretrain.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 10000000 4 | initial_global_step: 0 5 | save_checkpoint_steps: 500 6 | summary_steps: 500 7 | criterion.class: LabelSmoothedCrossEntropy 8 | criterion.params: 9 | label_smoothing: 0.1 10 | 11 | model.class: CIAT 12 | model.params: 13 | is_pretrain: yes 14 | hparams_set: CIAT_big 15 | 16 | task.class: Translation 17 | task.params: 18 | batch_size: 32000 19 | batch_by_tokens: true 20 | max_src_len: 120 21 | max_trg_len: 120 22 | src_data_pipeline.class: TextDataPipeline 23 | src_data_pipeline.params: 24 | vocab_path: {your spm.vocab} 25 | subtokenizer: spm 26 | subtokenizer_codes: {your spm.model} 27 | trg_data_pipeline.class: TextDataPipeline 28 | trg_data_pipeline.params: 29 | vocab_path: {your spm.vocab} 30 | subtokenizer: spm 31 | subtokenizer_codes: {your spm.model} 32 | 33 | dataset.class: ParallelTextDataset 34 | dataset.params: 35 | src_file: ${your source file of entire corpus} 36 | trg_file: ${your target file of entire corpus} 37 | data_is_processed: false 38 | 39 | 40 | model_dir: ${your model dir} 41 | 42 | 43 | validator.class: SeqGenerationValidator 44 | validator.params: 45 | eval_dataset: parallel_text 46 | eval_dataset.params: 47 | src_file: ${your source eval file} 48 | trg_file: ${your target eval file} 49 | eval_steps: 500 50 | eval_start_at: 5000 51 | eval_criterion.class: label_smoothed_cross_entropy 52 | eval_search_method.class: beam_search 53 | eval_search_method.params: 54 | beam_size: 8 55 | length_penalty: 0.6 56 | extra_decode_length: 50 57 | maximum_decode_length: 200 58 | eval_metric.class: tok_bleu 59 | eval_estop_patience: 10 60 | 61 | -------------------------------------------------------------------------------- /configs/tune-on-en2de.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 10000000 4 | initial_global_step: 0 5 | save_checkpoint_steps: 500 6 | summary_steps: 500 7 | criterion.class: LabelSmoothedCrossEntropy 8 | criterion.params: 9 | label_smoothing: 0.1 10 | 11 | model.class: CIAT 12 | model.params: 13 | is_pretrain: no 14 | hparams_set: CIAT_big 15 | 16 | task.class: Translation 17 | task.params: 18 | batch_size: 32000 19 | batch_by_tokens: true 20 | max_src_len: 120 21 | max_trg_len: 120 22 | src_data_pipeline.class: TextDataPipeline 23 | src_data_pipeline.params: 24 | vocab_path: {your spm.vocab} 25 | subtokenizer: spm 26 | subtokenizer_codes: {your spm.model} 27 | trg_data_pipeline.class: TextDataPipeline 28 | trg_data_pipeline.params: 29 | vocab_path: {your spm.vocab} 30 | subtokenizer: spm 31 | subtokenizer_codes: {your spm.model} 32 | 33 | dataset.class: ParallelTextDataset 34 | dataset.params: 35 | src_file: ${your source file of en2de data} 36 | trg_file: ${your target file of en2de data} 37 | data_is_processed: false 38 | 39 | pretrain_model: ${your pretrained model dir} 40 | model_dir: ${your en-de dir} 41 | 42 | 43 | validator.class: SeqGenerationValidator 44 | validator.params: 45 | eval_dataset: parallel_text 46 | eval_dataset.params: 47 | src_file: ${your source eval file} 48 | trg_file: ${your target eval file} 49 | eval_steps: 500 50 | eval_start_at: 5000 51 | eval_criterion.class: label_smoothed_cross_entropy 52 | eval_search_method.class: beam_search 53 | eval_search_method.params: 54 | beam_size: 8 55 | length_penalty: 0.6 56 | extra_decode_length: 50 57 | maximum_decode_length: 200 58 | eval_metric.class: tok_bleu 59 | eval_estop_patience: 10 60 | 61 | -------------------------------------------------------------------------------- /data/clean_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def main(args): 6 | if args.output_path is None: 7 | args.output_path = args.input_path+"_clean" 8 | if not os.path.isdir(args.output_path): 9 | print("mkdir -p %s"%args.output_path) 10 | file_list = os.listdir(args.input_path) 11 | if args.suffix is not None: 12 | file_list = list(filter(lambda x: x.endswith(args.suffix), file_list)) 13 | 14 | normalize = os.path.join(args.moses_path, "scripts/tokenizer/normalize-punctuation.perl") 15 | rmnprint = os.path.join(args.moses_path, "scripts/tokenizer/remove-non-printing-char.perl") 16 | unescape = os.path.join(args.moses_path, "scripts/tokenizer/deescape-special-chars.perl") 17 | 18 | for file in file_list: 19 | mode, language, inout = file.split(".") 20 | if inout == "in": 21 | language=language.split("2")[0] 22 | elif inout == "out": 23 | language = language.split("2")[-1] 24 | input_file = os.path.join(args.input_path, file) 25 | output_file = os.path.join(args.output_path, file) 26 | print("perl %s -l %s < %s | "%(normalize, language, input_file), 27 | "perl %s -l %s | "%(rmnprint, language), 28 | "perl %s -l %s > %s" % (unescape, language, output_file), 29 | ) 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--moses_path", type=str, 34 | default="", 35 | help="path for Moses") 36 | parser.add_argument("--input_path", type=str, 37 | default="", 38 | help="input file path") 39 | parser.add_argument("--output_path", type=str, 40 | default=None, 41 | help="output file path") 42 | parser.add_argument("--suffix", type=str, 43 | default=None, 44 | help="filter files with given suffix (for test set process)") 45 | args = parser.parse_args() 46 | main(args) 47 | -------------------------------------------------------------------------------- /data/data source.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | ### IWSLT source 4 | We follow the preprocessing scripts of [Multilingual neural machine translation with knowledge distillation](https://openreview.net/forum?id=S1gUsoR9YX), (https://github.com/RayeRen/multilingual-kd-pytorch/blob/master/data/iwslt/raw/prepare-iwslt14.sh),except that we disable their lowercase option. 5 | 6 | ### OPUS-100 source 7 | 8 | Download on [official website]([http://opus.nlpl.eu/opus-100.php](http://opus.nlpl.eu/opus-100.php)), alternatively, you can use `opus-100.sh`. 9 | 10 | ### WMT source 11 | 12 | 13 | Download on [official website](http://www.statmt.org/wmt21/translation-task.html). 14 | -------------------------------------------------------------------------------- /data/downsample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | 5 | def get_file_len(fname): 6 | with open(fname) as f: 7 | size = len([0 for _ in f]) 8 | return size 9 | 10 | 11 | def normalize_length(lang_length_dict, total_samples=10000000, power=0.2): 12 | sum_len = 0 13 | sum_len_power = 0 14 | min_len = float("inf") 15 | for language in lang_length_dict.keys(): 16 | sum_len += lang_length_dict[language] 17 | sum_len_power += lang_length_dict[language] ** power 18 | min_len = min(min_len, lang_length_dict[language]) 19 | 20 | for language in lang_length_dict.keys(): 21 | lang_length_dict[language] = int((lang_length_dict[language] ** power / sum_len_power) * total_samples) 22 | return lang_length_dict 23 | 24 | 25 | def main(args): 26 | file_list = os.listdir(args.input_path) 27 | lang_length_dict = {} 28 | lang_file_dict = {} 29 | for file_name in file_list: 30 | mode, language, inout = file_name.split(".") 31 | if inout == "in": 32 | language = language.split("2")[0] 33 | elif inout == "out": 34 | language = language.split("2")[-1] 35 | input_file = os.path.join(args.input_path, file_name) 36 | file_len = get_file_len(input_file) 37 | if language not in lang_length_dict.keys(): 38 | lang_length_dict[language] = file_len 39 | lang_file_dict[language] = [input_file] 40 | else: 41 | lang_length_dict[language] += file_len 42 | lang_file_dict[language] += [input_file] 43 | lang_length_dict = normalize_length(lang_length_dict, total_samples=args.total_samples) 44 | for language in lang_length_dict.keys(): 45 | print("cat %s | shuf -n %s > tmp_%s" % (" ".join(lang_file_dict[language]), lang_length_dict[language], language)) 46 | print("cat tmp_* > samples") 47 | for language in lang_length_dict.keys(): 48 | print("rm tmp_%s"%language) 49 | 50 | 51 | if __name__ == '__main__': 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--input_path", type=str, 54 | default="", 55 | help="input file path") 56 | parser.add_argument("--total_samples", type=int, 57 | default=10000000, 58 | ) 59 | 60 | args = parser.parse_args() 61 | main(args) 62 | -------------------------------------------------------------------------------- /data/sentencePiece_train.py: -------------------------------------------------------------------------------- 1 | # https://github.com/google/sentencepiece/blob/master/python/README.md 2 | # https://github.com/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb 3 | import sentencepiece as spm 4 | import argparse 5 | import os 6 | 7 | def main(args): 8 | india_lang_set = ",,,,,,,,,,,,,,,," 9 | all_lang_set = ",,,,,,,,,,,,,,,,,,," \ 10 | ",,,,,,,,,,,,,,,
,,,," \ 11 | ",,,,,,,,,,,,,,,,,,," \ 12 | ",,,,," 13 | spm.SentencePieceTrainer.train( 14 | input=args.input_file, 15 | model_prefix=os.path.join(args.output_path, "bpe"), 16 | vocab_size=args.vocab_size, 17 | model_type="bpe", 18 | input_sentence_size=args.input_sentence_size, 19 | user_defined_symbols=all_lang_set, 20 | ) 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--input_file", type=str, required=True, 26 | help="input file") 27 | parser.add_argument("--output_path", type=str, required=True, 28 | help="output src file") 29 | parser.add_argument("--input_sentence_size", type=int, required=False, default=1000000, 30 | help="input_sentence_size") 31 | parser.add_argument("--vocab_size", type=int, required=False, default=32000, 32 | help="vocab_size") 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Official NeurST Examples and Modules 2 | 3 | The folder contains example (re-)implementations of selected research papers and benchmarks, with released model checkpoints. 4 | 5 | ## Speech-to-Text Translation 6 | 7 | ### 2021 8 | 9 | - [ACL demo] Zhao et al. NeurST: Neural Speech Translation Toolkit. [[paper](https://aclanthology.org/2021.acl-demo.7/)] 10 | - E2E ST benchmark: [libri-trans](/examples/speech_transformer/augmented_librispeech), [must-c](/examples/speech_transformer/must-c) 11 | 12 | - [IWLST 2021 System] Zhao et al. The Volctrans Neural Speech Translation System for IWSLT 2021. [[paper](https://aclanthology.org/2021.iwslt-1.6/)] 13 | - [Offline ST](/examples/iwslt21/OFFLINE.md) 14 | - [Simultaneous Translation](/examples/iwslt21/SIMUL_TRANS.md) 15 | 16 | 17 | ## Neural Machine Translation 18 | 19 | ### 2021 20 | 21 | - [AAAI] Liang et al. Finding Sparse Structures for Domain Specific Neural Machine Translation. [[paper](https://arxiv.org/abs/2012.10586)][[example](/examples/prune_tune)] 22 | 23 | 24 | ### 2020 25 | 26 | - [AAAI] Yang et al. Towards Making the Most of BERT in Neural Machine Translation. [[paper](https://arxiv.org/abs/1908.05672)][[example](/examples/ctnmt)] 27 | 28 | ### 2019 29 | 30 | - [ICLR] Wu et al. Pay Less Attention With Lightweight and Dynamic Convolutions. [[paper](https://arxiv.org/pdf/1901.10430.pdf)][code only] 31 | 32 | 33 | ### 2017 34 | 35 | - [NIPS] Vaswani et al. Attention Is All You Need. [[paper](https://arxiv.org/pdf/1706.03762.pdf)] 36 | - MT benchmark: [WMT14 EN->DE](/examples/translation) 37 | 38 | ## Neural Network Techniques 39 | 40 | - Weight Pruning 41 | - unstructured pruning [[example](/examples/weight_pruning)] 42 | - Quantization 43 | - quantization aware training [[ref](https://arxiv.org/abs/1712.05877)][[example](/examples/quantization)] 44 | -------------------------------------------------------------------------------- /examples/ctnmt/example_configs/asy_distillation.yaml: -------------------------------------------------------------------------------- 1 | model_dir: /tmp/asy_distillation_model 2 | 3 | entry.class: trainer 4 | entry.params: 5 | train_steps: 1000000 6 | save_checkpoint_steps: 1000 7 | summary_steps: 200 8 | criterion.class: LabelSmoothedCrossEntropyWithKd 9 | criterion.params: 10 | kd_weight: 0.01 11 | optimizer.class: adam 12 | optimizer.params: 13 | epsilon: 1.e-9 14 | beta_1: 0.9 15 | beta_2: 0.98 16 | lr_schedule.class: noam 17 | lr_schedule.params: 18 | initial_factor: 1.0 19 | dmodel: 1024 20 | warmup_steps: 4000 21 | pretrain_model: 22 | - path: https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip 23 | model_name: google_bert 24 | from_prefix: bert 25 | to_prefix: ctnmt/bert 26 | dataset.class: ParallelTextDataset 27 | dataset.params: 28 | src_file: /tmp/data/train.en2de.in 29 | trg_file: /tmp/data/train.en2de.out 30 | data_is_processed: false 31 | 32 | task.class: Translation 33 | task.params: 34 | src_data_pipeline.class: BertDataPipeline 35 | src_data_pipeline.params: 36 | language: de 37 | name: bert-large-cased 38 | vocab_path: /tmp/vocab.txt 39 | trg_data_pipeline.class: TextDataPipeline 40 | trg_data_pipeline.params: 41 | subtokenizer: spm 42 | subtokenizer_codes: /tmp/spm.model 43 | vocab_path: /tmp/spm.vocab 44 | batch_size_per_gpu: 8000 45 | batch_by_tokens: true 46 | max_src_len: 100 47 | max_trg_len: 100 48 | 49 | hparams_set: ctnmt_big 50 | model.class: CtnmtTransformer 51 | model.params: 52 | bert_mode: bert_distillation 53 | 54 | validator.class: SeqGenerationValidator 55 | validator.params: 56 | eval_dataset: parallel_text 57 | eval_dataset.params: 58 | src_file: /tmp/data/test.en2de.in 59 | trg_file: /tmp/data/test.en2de.out 60 | eval_batch_size: 32 61 | eval_start_at: 1000 62 | eval_steps: 1000 63 | eval_criterion.class: label_smoothed_cross_entropy 64 | eval_search_method: beam_search 65 | eval_search_method.params: 66 | beam_size: 8 67 | length_penalty: 0.6 68 | extra_decode_length: 50 69 | maximum_decode_length: 200 70 | eval_metric: CompoundSplitBleu 71 | eval_top_checkpoints_to_keep: 10 72 | eval_auto_average_checkpoints: true 73 | eval_estop_patience: 30 74 | -------------------------------------------------------------------------------- /examples/ctnmt/example_configs/dynamic_switch.yaml: -------------------------------------------------------------------------------- 1 | model_dir: /tmp/dynamic_switch 2 | 3 | entry.class: trainer 4 | entry.params: 5 | train_steps: 1000000 6 | save_checkpoint_steps: 1000 7 | summary_steps: 200 8 | criterion.class: LabelSmoothedCrossEntropy 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 1.0 17 | dmodel: 1024 18 | warmup_steps: 4000 19 | pretrain_model: 20 | - path: https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip 21 | model_name: google_bert 22 | from_prefix: bert 23 | to_prefix: ctnmt/bert 24 | 25 | hparams_set: ctnmt_big 26 | model.class: CtnmtTransformer 27 | model.params: 28 | bert_mode: dynamic_switch 29 | 30 | validator.class: SeqGenerationValidator 31 | validator.params: 32 | eval_dataset: parallel_text 33 | eval_dataset.params: 34 | src_file: /tmp/data/test.en2de.in 35 | trg_file: /tmp/data/test.en2de.out 36 | eval_batch_size: 32 37 | eval_start_at: 1000 38 | eval_steps: 1000 39 | eval_criterion.class: label_smoothed_cross_entropy 40 | eval_search_method: beam_search 41 | eval_search_method.params: 42 | beam_size: 8 43 | length_penalty: 0.6 44 | extra_decode_length: 50 45 | maximum_decode_length: 200 46 | eval_metric: CompoundSplitBleu 47 | eval_top_checkpoints_to_keep: 10 48 | eval_auto_average_checkpoints: true 49 | eval_estop_patience: 30 50 | 51 | dataset.class: ParallelTextDataset 52 | dataset.params: 53 | src_file: /tmp/data/train.en2de.in 54 | trg_file: /tmp/data/train.en2de.out 55 | data_is_processed: false 56 | 57 | task.class: Translation 58 | task.params: 59 | src_data_pipeline.class: BertDataPipeline 60 | src_data_pipeline.params: 61 | language: de 62 | name: bert-large-cased 63 | vocab_path: /tmp/vocab.txt 64 | trg_data_pipeline.class: TextDataPipeline 65 | trg_data_pipeline.params: 66 | subtokenizer: spm 67 | subtokenizer_codes: /tmp/spm.model 68 | vocab_path: /tmp/spm.vocab 69 | batch_size_per_gpu: 8000 70 | batch_by_tokens: true 71 | max_src_len: 120 72 | max_trg_len: 120 73 | -------------------------------------------------------------------------------- /examples/ctnmt/example_configs/rate_schedule.yaml: -------------------------------------------------------------------------------- 1 | model_dir: /tmp/rate_schedule 2 | entry.class: trainer 3 | entry.params: 4 | train_steps: 100000000 5 | save_checkpoint_steps: 1000 6 | summary_steps: 200 7 | criterion.class: LabelSmoothedCrossEntropy 8 | optimizer.class: adam 9 | optimizer_controller: RateScheduledOptimizer 10 | optimizer_controller_args: 11 | warm_steps: 10000 12 | freeze_steps: 20000 13 | controlled_varname_pattern: bert 14 | optimizer.params: 15 | epsilon: 1.e-9 16 | beta_1: 0.9 17 | beta_2: 0.98 18 | lr_schedule.class: noam 19 | lr_schedule.params: 20 | initial_factor: 1.0 21 | dmodel: 1024 22 | warmup_steps: 4000 23 | pretrain_model: 24 | - path: https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip 25 | model_name: google_bert 26 | from_prefix: bert 27 | to_prefix: ctnmt/bert 28 | 29 | dataset.class: ParallelTextDataset 30 | dataset.params: 31 | src_file: /tmp/data/train.en2de.in 32 | trg_file: /tmp/data/train.en2de.out 33 | data_is_processed: false 34 | 35 | task.class: Translation 36 | task.params: 37 | src_data_pipeline.class: BertDataPipeline 38 | src_data_pipeline.params: 39 | language: de 40 | name: bert-large-cased 41 | vocab_path: /tmp/vocab.txt 42 | trg_data_pipeline.class: TextDataPipeline 43 | trg_data_pipeline.params: 44 | subtokenizer: spm 45 | subtokenizer_codes: /tmp/spm.model 46 | vocab_path: /tmp/spm.vocab 47 | batch_size_per_gpu: 8000 48 | batch_by_tokens: true 49 | max_src_len: 120 50 | max_trg_len: 120 51 | 52 | hparams_set: ctnmt_big 53 | model.class: CtnmtTransformer 54 | # use BERT as encoder if you use rate scheduler only 55 | model.params: 56 | bert_mode: bert_as_encoder 57 | validator.class: SeqGenerationValidator 58 | validator.params: 59 | eval_dataset: parallel_text 60 | eval_dataset.params: 61 | src_file: /tmp/data/test.en2de.in 62 | trg_file: /tmp/data/test.en2de.out 63 | eval_batch_size: 32 64 | eval_start_at: 1000 65 | eval_steps: 1000 66 | eval_criterion.class: label_smoothed_cross_entropy 67 | eval_search_method: beam_search 68 | eval_search_method.params: 69 | beam_size: 8 70 | length_penalty: 0.6 71 | extra_decode_length: 50 72 | maximum_decode_length: 200 73 | eval_metric: CompoundSplitBleu 74 | eval_top_checkpoints_to_keep: 10 75 | eval_auto_average_checkpoints: true 76 | eval_estop_patience: 30 77 | -------------------------------------------------------------------------------- /examples/iwslt21/scripts/evaluate_e2e.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | if [[ ! -n "$3" ]] ;then 6 | echo "Usage: ./evaluate_cascade.sh TEST_SET MODEL_DIR OUTPUT_PATH" 7 | exit 1; 8 | fi 9 | 10 | TEST_SET=$1 11 | MODEL_DIR=$2 12 | OUTPUT_PATH=$3 13 | 14 | URL_PREFIX="http://sf3-ttcdn-tos.pstatp.com/obj/nlp-opensource/neurst/iwslt21/offline" 15 | CFG_URL="${URL_PREFIX}/cfgs/st_prediction_args.yml" 16 | DATA_URL_PREFIX="${URL_PREFIX}/devtests" 17 | 18 | case $TEST_SET in 19 | "mustc-v1-dev") 20 | DATA_FILE="mustc_v1.0_en-de.dev.tfrecords-00000-of-00001" 21 | ;; 22 | 23 | "mustc-v1-tst") 24 | DATA_FILE="mustc_v1.0_en-de.tst-COMMON.tfrecords-00000-of-00001" 25 | ;; 26 | 27 | "mustc-v2-dev") 28 | DATA_FILE="mustc_v2.0_en-de.dev.tfrecords-00000-of-00001" 29 | ;; 30 | 31 | "mustc-v2-tst") 32 | DATA_FILE="mustc_v2.0_en-de.tst-COMMON.tfrecords-00000-of-00001" 33 | ;; 34 | 35 | "tst2020") 36 | DATA_FILE="iwslt-slt.tst2020.tfrecords-00000-of-00001" 37 | ;; 38 | 39 | "tst2021") 40 | DATA_FILE="iwslt-slt.tst2021.tfrecords-00000-of-00001" 41 | ;; 42 | 43 | *) echo "Unknown ${TEST_SET}" 44 | ;; 45 | esac 46 | 47 | TST_FILE_PREFIX=${DATA_FILE%.*} 48 | OUTPUT_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.hypo.txt 49 | OUTPUT_NOTAG_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.hypo.notag.txt 50 | LOCAL_REF_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.ref.txt 51 | LOCAL_REF_NOTAG_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.ref.notag.txt 52 | 53 | # run e2e st 54 | python3 -m neurst.cli.run_exp --config_paths $CFG_URL --data_path $DATA_URL_PREFIX/$DATA_FILE --output_file $OUTPUT_FILE --model_dir $MODEL_DIR 1>/dev/null 2>&1 55 | 56 | perl -pe 's/\([^\)]+\)//g;' $OUTPUT_FILE | tr -s " " > $OUTPUT_NOTAG_FILE 57 | 58 | if [[ $TEST_SET == tst20* ]]; 59 | then 60 | # clean hypothesis 61 | echo >/dev/null 2>&1 62 | else 63 | python3 -c """ 64 | from neurst.data.datasets import build_dataset 65 | 66 | ds = build_dataset({ 67 | 'class': 'AudioTripleTFRecordDataset', 68 | 'params': { 69 | 'data_path': '$DATA_URL_PREFIX/$DATA_FILE' 70 | } 71 | }) 72 | 73 | with open('$LOCAL_REF_FILE', 'w') as fw_mt: 74 | for x in ds.build_iterator()(): 75 | fw_mt.write(x['translation']+'\n') 76 | """ 1>/dev/null 2>&1 77 | 78 | # clean hypothesis 79 | perl -pe 's/\([^\)]+\)//g;' $LOCAL_REF_FILE | tr -s " " > $LOCAL_REF_NOTAG_FILE 80 | 81 | echo "============================= Evaluation (including tags) =============================" > $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 82 | cat $OUTPUT_FILE | sacrebleu -l en-de $LOCAL_REF_FILE >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 83 | echo "" >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 84 | 85 | echo "============================= Evaluation (no tags) =============================" >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 86 | cat $OUTPUT_NOTAG_FILE | sacrebleu -l en-de $LOCAL_REF_NOTAG_FILE >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 87 | 88 | fi 89 | -------------------------------------------------------------------------------- /examples/iwslt21/scripts/evaluate_mt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | if [[ ! -n "$3" ]] ;then 6 | echo "Usage: ./evaluate_mt.sh TEST_SET MODEL_DIR OUTPUT_PATH" 7 | exit 1; 8 | fi 9 | 10 | TEST_SET=$1 11 | MODEL_DIR=$2 12 | OUTPUT_PATH=$3 13 | 14 | URL_PREFIX="http://sf3-ttcdn-tos.pstatp.com/obj/nlp-opensource/neurst/iwslt21/offline" 15 | CFG_URL="${URL_PREFIX}/cfgs/mt_prediction_args.yml" 16 | DATA_URL_PREFIX="${URL_PREFIX}/devtests" 17 | 18 | case $TEST_SET in 19 | "mustc-v1-dev") 20 | SRC_FILE="mustc_v1.0_en-de.dev.tagen.txt" 21 | TRG_FILE="mustc_v1.0_en-de.dev.de.txt" 22 | ;; 23 | 24 | "mustc-v1-tst") 25 | SRC_FILE="mustc_v1.0_en-de.tst-COMMON.tagen.txt" 26 | TRG_FILE="mustc_v1.0_en-de.tst-COMMON.de.txt" 27 | ;; 28 | 29 | "mustc-v2-dev") 30 | SRC_FILE="mustc_v2.0_en-de.dev.tagen.txt" 31 | TRG_FILE="mustc_v2.0_en-de.dev.de.txt" 32 | ;; 33 | 34 | "mustc-v2-tst") 35 | SRC_FILE="mustc_v2.0_en-de.tst-COMMON.tagen.txt" 36 | TRG_FILE="mustc_v2.0_en-de.tst-COMMON.de.txt" 37 | ;; 38 | 39 | "mustc-v1-dev-tc") 40 | SRC_FILE="mustc_v1.0_en-de.dev.en.txt" 41 | TRG_FILE="mustc_v1.0_en-de.dev.de.txt" 42 | ;; 43 | 44 | "mustc-v1-tst-tc") 45 | SRC_FILE="mustc_v1.0_en-de.tst-COMMON.en.txt" 46 | TRG_FILE="mustc_v1.0_en-de.tst-COMMON.de.txt" 47 | ;; 48 | 49 | "mustc-v2-dev-tc") 50 | SRC_FILE="mustc_v2.0_en-de.dev.en.txt" 51 | TRG_FILE="mustc_v2.0_en-de.dev.de.txt" 52 | ;; 53 | 54 | "mustc-v2-tst-tc") 55 | SRC_FILE="mustc_v2.0_en-de.tst-COMMON.en.txt" 56 | TRG_FILE="mustc_v2.0_en-de.tst-COMMON.de.txt" 57 | ;; 58 | 59 | *) echo "Unknown ${TEST_SET}" 60 | ;; 61 | esac 62 | 63 | TST_FILE_PREFIX=${SRC_FILE%.*} 64 | TST_FILE_PREFIX=${TST_FILE_PREFIX%.*} 65 | OUTPUT_MT_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.hypo.txt 66 | OUTPUT_MT_NOTAG_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.hypo.notag.txt 67 | LOCAL_REF_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.ref.txt 68 | LOCAL_REF_NOTAG_FILE=$OUTPUT_PATH/$TST_FILE_PREFIX.de.ref.notag.txt 69 | 70 | curl $DATA_URL_PREFIX/$TRG_FILE -o $LOCAL_REF_FILE 71 | 72 | # inference 73 | python3 -m neurst.cli.run_exp --config_paths $CFG_URL --src_file $DATA_URL_PREFIX/$SRC_FILE --output_file $OUTPUT_MT_FILE --model_dir $MODEL_DIR 1>/dev/null 2>&1 74 | 75 | # clean hypothesis 76 | perl -pe 's/\([^\)]+\)//g;' $LOCAL_REF_FILE | tr -s " " > $LOCAL_REF_NOTAG_FILE 77 | perl -pe 's/\([^\)]+\)//g;' $OUTPUT_MT_FILE | tr -s " " > $OUTPUT_MT_NOTAG_FILE 78 | 79 | 80 | echo "============================= Evaluation (including tags) =============================" > $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 81 | cat $OUTPUT_MT_FILE | sacrebleu -l en-de $LOCAL_REF_FILE >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 82 | echo "" >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 83 | 84 | echo "============================= Evaluation (no tags) =============================" >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 85 | cat $OUTPUT_MT_NOTAG_FILE | sacrebleu -l en-de $LOCAL_REF_NOTAG_FILE >> $OUTPUT_PATH/$TST_FILE_PREFIX.bleu.txt 86 | -------------------------------------------------------------------------------- /examples/prune_tune/scripts/prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 32 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: 0.6 7 | maximum_decode_length: 160 8 | metric: bleu 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: ParallelTextDataset 15 | dataset.params: 16 | src_file: DEV_SRC 17 | trg_file: DEV_TRG 18 | test: 19 | dataset.class: ParallelTextDataset 20 | dataset.params: 21 | src_file: TEST_SRC 22 | trg_file: TEST_TRG 23 | -------------------------------------------------------------------------------- /examples/prune_tune/scripts/training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 1000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | -------------------------------------------------------------------------------- /examples/prune_tune/scripts/validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: ParallelTextDataset 4 | eval_dataset.params: 5 | src_file: DEV_SRC 6 | trg_file: DEV_TRG 7 | eval_batch_size: 64 8 | eval_start_at: STR_EVL 9 | eval_steps: EVL_STEP 10 | eval_criterion: label_smoothed_cross_entropy 11 | eval_search_method: beam_search 12 | eval_search_method.params: 13 | beam_size: 4 14 | length_penalty: 0.6 15 | maximum_decode_length: 160 16 | extra_decode_length: 50 17 | eval_metric: bleu 18 | eval_top_checkpoints_to_keep: 5 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/quantization/example_view_quant_weight.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | import tensorflow as tf 17 | 18 | from neurst.layers.quantization import QuantLayer 19 | from neurst.models.transformer import Transformer 20 | from neurst.tasks import build_task 21 | from neurst.utils.checkpoints import restore_checkpoint_if_possible 22 | from neurst.utils.configurable import ModelConfigs 23 | 24 | model_dir = sys.argv[1] 25 | model_configs = ModelConfigs.load(model_dir) 26 | QuantLayer.global_init(model_configs["enable_quant"], **model_configs["quant_params"]) 27 | task = build_task(model_configs) 28 | model: Transformer = task.build_model(model_configs) 29 | restore_checkpoint_if_possible(model, model_dir) 30 | 31 | clip_max = model._encoder._stacking_layers[0][1]._layer._conv1.traced["kernel"].clip_max 32 | 33 | weight_clip_max = tf.maximum(clip_max, 0.0) 34 | weight_clip_max = tf.cast(weight_clip_max, tf.float32) 35 | bits_tmp = float(2 ** (QuantLayer.quant_bits - 1)) 36 | weight_clip_min = -weight_clip_max * bits_tmp / (bits_tmp - 1) 37 | 38 | print("The quantized weight of encoder layer0's first ffn") 39 | print(tf.quantization.quantize(model._encoder._stacking_layers[0][1]._layer._conv1.kernel, 40 | weight_clip_min, clip_max, tf.qint8)) 41 | -------------------------------------------------------------------------------- /examples/simultaneous_translation/README.md: -------------------------------------------------------------------------------- 1 | # Simultaneous Translation 2 | 3 | This README contains instructions for training and evaluating a wait-k based simultaneous translation system with [SimulEval](https://github.com/facebookresearch/SimulEval). For more example models, see [iwslt21-simul-trans](/examples/iwslt21/SIMUL_TRANS.md). 4 | 5 | ## Requirements 6 | **SimulEval** 7 | ```bash 8 | git clone https://github.com/facebookresearch/SimulEval.git 9 | cd SimulEval/ 10 | pip install -e . 11 | ``` 12 | It is worth noting that there are some conflicts between python multiprocessing and CUDA initialization of tensorflow, so we make some changes to `SimulEval/simuleval/cli.py` and actually use [`neurst/cli/simuleval_cli.py`](/neurst/cli/simuleval_cli.py) instead. 13 | 14 | The changes are as follow: 15 | ```python 16 | # add init method to import tensorflow and restrict memory usage 17 | def init(): 18 | global tf 19 | import tensorflow as tf 20 | tf.config.experimental.set_memory_growth( 21 | tf.config.experimental.list_physical_devices('GPU')[0], 22 | True 23 | ) 24 | 25 | # set init method as the initializer of multiprocessing.Pool 26 | # with Pool(args.num_processes) as p: 27 | with Pool(args.num_processes, initializer=init) as p: 28 | ``` 29 | 30 | ## Wait-k Training 31 | Following [examples/translation](/examples/translation/README.md), we can train a wait-k based transformer model with extra options: 32 | ```bash 33 | python3 -m neurst.cli.run_exp \ 34 | --config_paths wmt14_en_de/training_args.yml,wmt14_en_de/translation_bpe.yml \ 35 | --hparams_set waitk_transformer_base \ 36 | --model_dir /wmt14_en_de/waitk_benchmark_base \ 37 | --task WaitkTranslation \ # overwrite the task.class in wmt14_en_de/translation_bpe.yml 38 | --wait_k 3 # the wait k lagging 39 | ``` 40 | 41 | - The self-attention in the encoder is monotonic. 42 | - To enable multi-path training (different `k` for different training batch), we can set `--wait_k '[3,5,7]'`. 43 | - For validation inner the training process, it will only pick the first `k` for evaluation. For a single validation process, one can overwrite the `k` for evaluation with `--eval_task_args "{'wait_k':5}"` 44 | 45 | For more details, see [waitk_transformer.py](/neurst/models/waitk_transformer.py) and [waitk_translation.py](/neurst/tasks/waitk_translation.py). 46 | 47 | ## Evaluating Latency with SimulEval 48 | As mentioned above, the original SimulEval has conflict on multiprocessing with TensorFlow. So here we use an upgraded version. 49 | ```bash 50 | python3 -m neurst.cli.simuleval_cli \ 51 | --agent simul_trans_text_agent \ 52 | --data-type text \ 53 | --source path_to_src_file \ 54 | --target path_to_trg_file \ 55 | --model-dir path_to_model_dir \ 56 | --num-processes 12 \ 57 | --wait-k 7 \ 58 | --output temp 59 | ``` 60 | 61 | - Here [simul_trans_text_agent](/neurst/utils/simuleval_agents/simul_trans_text_agent.py) is the standard implementation for wait-k transformer. 62 | - We can increase `--num-processes` to speed up decoding, which will result in an exponential increase in GPU memory. 63 | - We can also change `--wait-k` to balance BLEU and Average Latency. 64 | 65 | 66 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/01-download.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | #!/usr/bin/env bash 15 | 16 | set -e 17 | 18 | if [[ ! -n "$1" ]] ;then 19 | echo "Usage: ./01-download.sh SAVE_PATH" 20 | exit 1 21 | else 22 | DATA_PATH="$1" 23 | fi 24 | 25 | DATA_PATH=$DATA_PATH/raw/ 26 | 27 | mkdir -p $DATA_PATH 28 | 29 | # Download from 30 | # https://github.com/alicank/Translation-Augmented-LibriSpeech-Corpus 31 | # and acquire following zip files: 32 | # - train_100h.zip 33 | # - dev.zip 34 | # - test.zip 35 | 36 | echo "Downloading Augmented LibriSpeech dataset..." 37 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/RESULTS.md: -------------------------------------------------------------------------------- 1 | # Results on Argumented LibriSpeech 2 | 3 | 4 | ### Comparison with counterparts (speech_transformer_s) 5 | test, case-insensitive 6 | 7 | |Model|tok|detok| 8 | |---|---|---| 9 | |Transformer ST + ASR PT (1)| - |15.5| 10 | |Transformer ST + ASR/MT PT (1)| - |16.2| 11 | |Transformer ST + ASR/MT PT + SpecAug (1) | - |16.7| 12 | |Transformer ST ensemble 3 models (1) | - | 17.4| 13 | |Transformer ST + ASR/MT PT (2)| 14.3 | - | 14 | |Transformer ST + ASR/MT PT + KD (2) | 17.0 | - | 15 | |Transformer ST + ASR PT + SpecAug (3) | 16.9 | - | 16 | |Transformer ST + ASR PT + curriculum pre-training + SpecAug (3) | 18.0 | - | 17 | |Transformer ST + ASR PT (4) | 15.3 | - | 18 | |Transformer ST + triple supervision (TED) (4) | 18.3 | - | 19 | |**NeurST** Transformer ST + ASR PT | 17.9 | 16.5 | 20 | |**NeurST** Transformer ST + ASR PT + SpecAug | 18.7 | 17.2 | 21 | |**NeurST** Transformer ST ensemble 2 models | **19.2** | **17.7**| 22 | 23 | (1) Espnet-ST (Inaguma et al., 2020) with additional techniques: speed perturbation, pre-trained MT decoder and CTC loss for ASR pretrain; 24 | 25 | (2) Liu et al. (2019) with the proposed knowledge distillation; 26 | 27 | (3) Wang et al. (2020) with additional ASR corpora and curriculum pre-training; 28 | 29 | (4) Dong et al. (2020) with CTC loss and a pre-trained BERT encoder as supervision with external ASR data; 30 | 31 | 32 | ### ASR (dmodel=256, WER) 33 | 34 | |Framework|Model|Dev|Test| | 35 | |---|---|---|---|---| 36 | |NeurST|Transformer ASR |8.8|8.8| pure end-to-end, beam=4, no length penalty | 37 | |Espnet (Inaguma et al., 2020)| Transformer ASR + ctc | 6.5 | 6.4 | multi-task training with ctc loss | 38 | 39 | 40 | ### MT and ST (dmodel=256, case-sensitive, tokenized BLEU/detokenized BLEU) 41 | 42 | |Framework|Model|Dev|Test| 43 | |---|---|---|---| 44 | |NeurST|Transformer MT |20.8 / 19.3 | 19.3 / 17.6 | 45 | |NeurST|cascade ST (Transformer ASR -> Transformer MT) | 18.3 / 17.0| 17.4 / 16.0 | 46 | |NeurST|end2end Transformer ST + ASR pretrain | 18.3 / 16.9 | 16.9 / 15.5 | 47 | |NeurST|end2end Transformer ST + ASR pretrain + SpecAug | 19.3 / 17.8 | 17.8 / 16.3 | 48 | |NeurST|end2end Transformer ST ensemble above 2 models | 19.3 / 18.0 | 18.3 / 16.8 | 49 | 50 | ### MT and Cascade ST (dmodel=256, case-insensitive, tokenized BLEU/detokenized BLEU) 51 | 52 | |Framework|Model|Dev|Test| 53 | |---|---|---|---| 54 | |NeurST|Transformer MT | 21.7 / 20.2 | 20.2 / 18.5 | 55 | |Espnet (Inaguma et al., 2020)| Transformer MT| ---- / 19.6 | ---- / 18.1 | 56 | |NeurST|cascade ST (Transformer ASR -> Transformer MT) | 19.2 / 17.8 | 18.2 / 16.8 | 57 | |Espnet (Inaguma et al., 2020)| cascade ST (Transformer ASR + ctc -> Transformer MT) | ---- / ---- | ---- / 17.0 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/asr_prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: -1 7 | maximum_decode_length: 150 8 | metric: wer 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: AudioTFRecordDataset 15 | dataset.params: 16 | data_path: DATA_PATH/devtest/dev.tfrecords-00000-of-00001 17 | feature_key: audio 18 | transcript_key: transcript 19 | test: 20 | dataset.class: AudioTFRecordDataset 21 | dataset.params: 22 | data_path: DATA_PATH/devtest/test.tfrecords-00000-of-00001 23 | feature_key: audio 24 | transcript_key: transcript 25 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/asr_training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 2000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 3.5 17 | end_factor: 2.0 18 | dmodel: 256 19 | warmup_steps: 25000 20 | start_decay_at: 50000 21 | decay_steps: 50000 22 | 23 | dataset.class: AudioTFRecordDataset 24 | dataset.params: 25 | data_path: DATA_PATH/asr_st/train/ 26 | shuffle_dataset: True 27 | feature_key: audio 28 | transcript_key: transcript 29 | 30 | task.class: SpeechToText 31 | task.params: 32 | audio_feature_dim: 80 33 | audio_feature_channels: 1 34 | transcript_data_pipeline.class: TranscriptDataPipeline 35 | transcript_data_pipeline.params: 36 | remove_punctuation: True 37 | lowercase: True 38 | language: en 39 | tokenizer: moses 40 | subtokenizer: bpe 41 | subtokenizer_codes: DATA_PATH/asr_st/codes.bpe 42 | vocab_path: DATA_PATH/asr_st/vocab.en 43 | batch_by_frames: True 44 | batch_size: 120000 45 | max_src_len: 3000 46 | max_trg_len: 120 47 | truncate_src: True 48 | experimental_frame_transcript_ratio: 21 49 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/asr_validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: AudioTFRecordDataset 4 | eval_dataset.params: 5 | data_path: DATA_PATH/devtest/dev.tfrecords-00000-of-00001 6 | feature_key: audio 7 | transcript_key: transcript 8 | eval_batch_size: 64 9 | eval_start_at: 6000 10 | eval_steps: 2000 11 | eval_criterion: label_smoothed_cross_entropy 12 | eval_search_method: beam_search 13 | eval_search_method.params: 14 | beam_size: 4 15 | length_penalty: -1 16 | maximum_decode_length: 150 17 | eval_metric: wer 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/mt_prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: -1 7 | maximum_decode_length: 180 8 | metric: bleu 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: ParallelTextDataset 15 | dataset.params: 16 | src_file: DATA_PATH/transcripts/dev.en.txt 17 | trg_file: DATA_PATH/transcripts/dev.fr.txt 18 | test: 19 | dataset.class: ParallelTextDataset 20 | dataset.params: 21 | src_file: DATA_PATH/transcripts/test.en.txt 22 | trg_file: DATA_PATH/transcripts/test.fr.txt 23 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/mt_training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 1000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 1.0 17 | dmodel: 512 18 | warmup_steps: 4000 19 | 20 | dataset.class: ParallelTextDataset 21 | dataset.params: 22 | src_file: DATA_PATH/mt/train/train.en.bpe.txt 23 | trg_file: DATA_PATH/mt/train/train.fr.tok.bpe.txt 24 | data_is_processed: True 25 | 26 | task.class: seq2seq 27 | task.params: 28 | batch_by_tokens: True 29 | batch_size: 25000 30 | max_src_len: 120 31 | max_trg_len: 150 32 | src_data_pipeline.class: TranscriptDataPipeline 33 | src_data_pipeline.params: 34 | remove_punctuation: True 35 | lowercase: True 36 | language: en 37 | tokenizer: moses 38 | subtokenizer: bpe 39 | subtokenizer_codes: DATA_PATH/mt/codes.bpe 40 | vocab_path: DATA_PATH/mt/vocab.en 41 | trg_data_pipeline.class: TextDataPipeline 42 | trg_data_pipeline.params: 43 | language: fr 44 | tokenizer: moses 45 | subtokenizer: bpe 46 | subtokenizer_codes: DATA_PATH/mt/codes.bpe 47 | vocab_path: DATA_PATH/mt/vocab.fr 48 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/mt_validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: ParallelTextDataset 4 | eval_dataset.params: 5 | src_file: DATA_PATH/transcripts/dev.en.txt 6 | trg_file: DATA_PATH/transcripts/dev.fr.txt 7 | eval_batch_size: 64 8 | eval_start_at: 5000 9 | eval_steps: 1000 10 | eval_criterion: label_smoothed_cross_entropy 11 | eval_search_method: beam_search 12 | eval_search_method.params: 13 | beam_size: 4 14 | length_penalty: -1 15 | maximum_decode_length: 180 16 | extra_decode_length: 50 17 | eval_metric: tok_bleu 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/st_prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: -1 7 | maximum_decode_length: 180 8 | metric: bleu 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: AudioTFRecordDataset 15 | dataset.params: 16 | data_path: DATA_PATH/devtest/dev.tfrecords-00000-of-00001 17 | feature_key: audio 18 | transcript_key: translation 19 | test: 20 | dataset.class: AudioTFRecordDataset 21 | dataset.params: 22 | data_path: DATA_PATH/devtest/test.tfrecords-00000-of-00001 23 | feature_key: audio 24 | transcript_key: translation 25 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/st_training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 2000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 3.5 17 | end_factor: 1.5 18 | dmodel: 256 19 | warmup_steps: 25000 20 | start_decay_at: 50000 21 | decay_steps: 50000 22 | 23 | dataset.class: AudioTFRecordDataset 24 | dataset.params: 25 | data_path: DATA_PATH/asr_st/train/ 26 | shuffle_dataset: True 27 | feature_key: audio 28 | transcript_key: translation 29 | 30 | task.class: SpeechToText 31 | task.params: 32 | audio_feature_dim: 80 33 | transcript_data_pipeline.class: TranscriptDataPipeline 34 | transcript_data_pipeline.params: 35 | remove_punctuation: False 36 | lowercase: False 37 | language: fr 38 | tokenizer: moses 39 | subtokenizer: bpe 40 | subtokenizer_codes: DATA_PATH/asr_st/codes.bpe 41 | vocab_path: DATA_PATH/asr_st/vocab.fr 42 | batch_by_frames: True 43 | batch_size: 80000 44 | max_src_len: 3000 45 | max_trg_len: 150 46 | truncate_src: True 47 | experimental_frame_transcript_ratio: 12 48 | -------------------------------------------------------------------------------- /examples/speech_transformer/augmented_librispeech/st_validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: AudioTFRecordDataset 4 | eval_dataset.params: 5 | data_path: DATA_PATH/devtest/dev.tfrecords-00000-of-00001 6 | feature_key: audio 7 | transcript_key: translation 8 | eval_batch_size: 64 9 | eval_start_at: 6000 10 | eval_steps: 2000 11 | eval_criterion: label_smoothed_cross_entropy 12 | eval_search_method: beam_search 13 | eval_search_method.params: 14 | beam_size: 4 15 | length_penalty: -1 16 | maximum_decode_length: 180 17 | eval_metric: bleu 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/01-download.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | #!/usr/bin/env bash 15 | 16 | set -e 17 | 18 | if [[ ! -n "$1" ]] ;then 19 | echo "Usage: ./01-download.sh SAVE_PATH" 20 | exit 1 21 | else 22 | DATA_PATH="$1" 23 | fi 24 | 25 | DATA_PATH=$DATA_PATH/raw/ 26 | 27 | mkdir -p $DATA_PATH 28 | 29 | # Download from 30 | # https://ict.fbk.eu/must-c/ 31 | # and get following tgz files: 32 | # - MUSTC_v1.0_en-de.tar.gz 33 | # - MUSTC_v1.0_en-es.tar.gz 34 | # - MUSTC_v1.0_en-fr.tar.gz 35 | # - MUSTC_v1.0_en-it.tar.gz 36 | # - MUSTC_v1.0_en-nl.tar.gz 37 | # - MUSTC_v1.0_en-pt.tar.gz 38 | # - MUSTC_v1.0_en-ro.tar.gz 39 | # - MUSTC_v1.0_en-ru.tar.gz 40 | 41 | echo "Downloading MuST-C dataset..." 42 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/RESULTS.md: -------------------------------------------------------------------------------- 1 | # Results on MuST-C 2 | 3 | 4 | ### Comparison with counterparts (speech_transformer_s) 5 | test-COMMON, case-sensitive, detokenized BLEU 6 | 7 | |Model|DE|ES|FR|IT|NL|PT|RO|RU|avg.| 8 | |---|---|---|---|---|---|---|---|---|---| 9 | |Transformer ST + ASR PT (1) | 21.8 | 26.4 | 31.6 | 21.5 | 25.2 | 26.8 | 20.5 | 14.3 | 23.5 | 10 | | Transformer ST + ASR/MT PT (1) | 22.3 | 27.8 | 31.5 | 22.8 | 26.9 | 27.3 | 20.9 | 15.3 | 24.4| \\ 11 | | Transformer ST + ASR/MT PT + SpecAug (1) | **22.9** | **28.0** | 32.8 | **23.8** | **27.4** | 28.0 | 21.9 | **15.8** | **25.1** | 12 | | Transformer ST + ASR PT + SpecAug (2) | 22.7 | 27.2 | 32.9 | 22.7 | 27.3 | 28.1 | 21.9 | 15.3 | 24.8| 13 | | Transformer ST + adaptive feature selection (3) | 22.4 | 26.9 | 31.6 | 23.0 | 24.9 | 26.3 | 21.0 | 14.7 | 23.9| 14 | |**NeurST** Transformer ST + ASR PT | 21.9 | 26.8 | 32.3 | 22.2 | 26.4 | 27.6 | 20.9 | 15.2 | 24.2| 15 | |**NeurST** Transformer ST + ASR PT + SpecAug | 22.8 | 27.4 | **33.3** | 22.9 | 27.2 | **28.7** | **22.2** | 15.1 | 24.9| 16 | 17 | 18 | (1) Espnet-ST (Inaguma et al., 2020) with additional techniques: speed perturbation, pre-trained MT decoder and CTC loss for ASR pretrain; 19 | 20 | (2) fairseq-ST (Wang et al., 2020) with the same setting as NeurST; 21 | 22 | (3) Zhang et al. (2020) with the proposed adaptive feature selection method 23 | 24 | ### ASR (dmodel=256, WER) 25 | test-COMMON 26 | 27 | |Framework|Model|DE|ES|FR|IT|NL|PT|RO|RU| 28 | |---|---|---|---|---|---|---|---|---|---| 29 | |NeurST|Transformer ASR |13.6|13|12.9|13.5|13.8|14.4|13.7|13.4| 30 | |Espnet (Inaguma et al., 2020)| Transformer ASR + ctc |12.7|12.1|12|12.4|12.1|13.4|12.6|12.3| 31 | |fairseq-ST (Wang et al., 2020)| Transformer ASR|18.2|17.7|17.2|17.9|17.6|19.1|18.1|17.7| 32 | 33 | 34 | ### MT and ST (dmodel=256, case-sensitive, tokenized BLEU/detokenized BLEU) 35 | test-COMMON 36 | 37 | |Framework|Model|DE|ES|FR|IT|NL|PT|RO|RU| 38 | |---|---|---|---|---|---|---|---|---|---| 39 | |NeurST|Transformer MT |27.9/27.8|32.9/32.8|42.2/40.2|29.0/28.5|32.9/32.7|34.4/34.0|27.5/26.4|19.3/19.1| 40 | |NeurST|cascade ST (Transformer ASR -> Transformer MT) |23.5/23.4|28.1/28.0|35.8/33.9|24.3/23.8|27.3/27.1|28.6/28.3|23.3/22.2|16.2/16.0| 41 | |NeurST|end2end Transformer ST + ASR pretrain |21.9/21.9|26.9/26.8|34.2/32.3|22.6/22.2|26.5/26.4|27.8/27.6|21.9/20.9|15.0/15.2| 42 | |NeurST|end2end Transformer ST + ASR pretrain + SpecAug |22.8/22.8|27.5/27.4|35.2/33.3|23.4/22.9|27.4/27.2|29.0/28.7|23.2/22.2|15.2/15.1| 43 | 44 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/asr_prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: -1 7 | maximum_decode_length: 150 8 | metric: wer 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: AudioTFRecordDataset 15 | dataset.params: 16 | data_path: DATA_PATH/devtest/dev.en-TRG_LANG.tfrecords-00000-of-00001 17 | feature_key: audio 18 | transcript_key: transcript 19 | tst-COM: 20 | dataset.class: AudioTFRecordDataset 21 | dataset.params: 22 | data_path: DATA_PATH/devtest/tst-COMMON.en-TRG_LANG.tfrecords-00000-of-00001 23 | feature_key: audio 24 | transcript_key: transcript 25 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/asr_training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 2000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 3.5 17 | end_factor: 2.0 18 | dmodel: 256 19 | warmup_steps: 25000 20 | start_decay_at: 50000 21 | decay_steps: 50000 22 | 23 | dataset.class: AudioTFRecordDataset 24 | dataset.params: 25 | data_path: DATA_PATH/asr_st/TRG_LANG/train/ 26 | shuffle_dataset: True 27 | feature_key: audio 28 | transcript_key: transcript 29 | 30 | task.class: SpeechToText 31 | task.params: 32 | audio_feature_dim: 80 33 | audio_feature_channels: 1 34 | transcript_data_pipeline.class: TranscriptDataPipeline 35 | transcript_data_pipeline.params: 36 | remove_punctuation: True 37 | lowercase: True 38 | language: en 39 | tokenizer: moses 40 | subtokenizer: bpe 41 | subtokenizer_codes: DATA_PATH/asr_st/TRG_LANG/codes.bpe 42 | vocab_path: DATA_PATH/asr_st/TRG_LANG/vocab.en 43 | batch_by_frames: True 44 | batch_size: 120000 45 | max_src_len: 3000 46 | max_trg_len: 120 47 | truncate_src: True 48 | experimental_frame_transcript_ratio: 21 49 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/asr_validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: AudioTFRecordDataset 4 | eval_dataset.params: 5 | data_path: DATA_PATH/devtest/dev.en-TRG_LANG.tfrecords-00000-of-00001 6 | feature_key: audio 7 | transcript_key: transcript 8 | eval_batch_size: 64 9 | eval_start_at: 6000 10 | eval_steps: 2000 11 | eval_criterion: label_smoothed_cross_entropy 12 | eval_search_method: beam_search 13 | eval_search_method.params: 14 | beam_size: 4 15 | length_penalty: -1 16 | maximum_decode_length: 150 17 | eval_metric: wer 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/mt_prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: -1 7 | maximum_decode_length: 180 8 | metric: bleu 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: ParallelTextDataset 15 | dataset.params: 16 | src_file: DATA_PATH/transcripts/TRG_LANG/dev.en.txt 17 | trg_file: DATA_PATH/transcripts/TRG_LANG/dev.TRG_LANG.txt 18 | tst-COM: 19 | dataset.class: ParallelTextDataset 20 | dataset.params: 21 | src_file: DATA_PATH/transcripts/TRG_LANG/tst-COMMON.en.txt 22 | trg_file: DATA_PATH/transcripts/TRG_LANG/tst-COMMON.TRG_LANG.txt 23 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/mt_training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 120000 4 | summary_steps: 200 5 | save_checkpoint_steps: 1000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 1.0 17 | dmodel: 512 18 | warmup_steps: 4000 19 | 20 | dataset.class: ParallelTextDataset 21 | dataset.params: 22 | src_file: DATA_PATH/mt/TRG_LANG/train/train.en.clean.tok.bpe.txt 23 | trg_file: DATA_PATH/mt/TRG_LANG/train/train.TRG_LANG.tok.bpe.txt 24 | data_is_processed: True 25 | 26 | task.class: seq2seq 27 | task.params: 28 | batch_by_tokens: True 29 | batch_size: 25000 30 | max_src_len: 120 31 | max_trg_len: 150 32 | src_data_pipeline.class: TranscriptDataPipeline 33 | src_data_pipeline.params: 34 | remove_punctuation: True 35 | lowercase: True 36 | language: en 37 | tokenizer: moses 38 | subtokenizer: bpe 39 | subtokenizer_codes: DATA_PATH/mt/TRG_LANG/codes.bpe 40 | vocab_path: DATA_PATH/mt/TRG_LANG/vocab.en 41 | trg_data_pipeline.class: TextDataPipeline 42 | trg_data_pipeline.params: 43 | language: TRG_LANG 44 | tokenizer: moses 45 | subtokenizer: bpe 46 | subtokenizer_codes: DATA_PATH/mt/TRG_LANG/codes.bpe 47 | vocab_path: DATA_PATH/mt/TRG_LANG/vocab.TRG_LANG 48 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/mt_validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: ParallelTextDataset 4 | eval_dataset.params: 5 | src_file: DATA_PATH/transcripts/TRG_LANG/dev.en.txt 6 | trg_file: DATA_PATH/transcripts/TRG_LANG/dev.TRG_LANG.txt 7 | eval_batch_size: 64 8 | eval_start_at: 5000 9 | eval_steps: 1000 10 | eval_criterion: label_smoothed_cross_entropy 11 | eval_search_method: beam_search 12 | eval_search_method.params: 13 | beam_size: 4 14 | length_penalty: -1 15 | maximum_decode_length: 180 16 | extra_decode_length: 50 17 | eval_metric: tok_bleu 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/st_prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: -1 7 | maximum_decode_length: 180 8 | metric: bleu 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: AudioTFRecordDataset 15 | dataset.params: 16 | data_path: DATA_PATH/devtest/dev.en-TRG_LANG.tfrecords-00000-of-00001 17 | feature_key: audio 18 | transcript_key: translation 19 | tst-COM: 20 | dataset.class: AudioTFRecordDataset 21 | dataset.params: 22 | data_path: DATA_PATH/devtest/tst-COMMON.en-TRG_LANG.tfrecords-00000-of-00001 23 | feature_key: audio 24 | transcript_key: translation 25 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/st_training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 2000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | optimizer.class: adam 10 | optimizer.params: 11 | epsilon: 1.e-9 12 | beta_1: 0.9 13 | beta_2: 0.98 14 | lr_schedule.class: noam 15 | lr_schedule.params: 16 | initial_factor: 3.5 17 | end_factor: 1.5 18 | dmodel: 256 19 | warmup_steps: 25000 20 | start_decay_at: 50000 21 | decay_steps: 50000 22 | 23 | dataset.class: AudioTFRecordDataset 24 | dataset.params: 25 | data_path: DATA_PATH/asr_st/TRG_LANG/train/ 26 | shuffle_dataset: True 27 | feature_key: audio 28 | transcript_key: translation 29 | 30 | task.class: SpeechToText 31 | task.params: 32 | audio_feature_dim: 80 33 | transcript_data_pipeline.class: TranscriptDataPipeline 34 | transcript_data_pipeline.params: 35 | remove_punctuation: False 36 | lowercase: False 37 | language: TRG_LANG 38 | tokenizer: moses 39 | subtokenizer: bpe 40 | subtokenizer_codes: DATA_PATH/asr_st/TRG_LANG/codes.bpe 41 | vocab_path: DATA_PATH/asr_st/TRG_LANG/vocab.TRG_LANG 42 | batch_by_frames: True 43 | batch_size: 80000 44 | max_src_len: 3000 45 | max_trg_len: 150 46 | truncate_src: True 47 | experimental_frame_transcript_ratio: 12 48 | -------------------------------------------------------------------------------- /examples/speech_transformer/must-c/st_validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: AudioTFRecordDataset 4 | eval_dataset.params: 5 | data_path: DATA_PATH/devtest/dev.en-TRG_LANG.tfrecords-00000-of-00001 6 | feature_key: audio 7 | transcript_key: translation 8 | eval_batch_size: 64 9 | eval_start_at: 6000 10 | eval_steps: 2000 11 | eval_criterion: label_smoothed_cross_entropy 12 | eval_search_method: beam_search 13 | eval_search_method.params: 14 | beam_size: 4 15 | length_penalty: -1 16 | maximum_decode_length: 180 17 | eval_metric: bleu 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /examples/translation/prediction_args.yml: -------------------------------------------------------------------------------- 1 | entry: predict 2 | batch_size: 64 3 | search_method: beam_search 4 | search_method.params: 5 | beam_size: 4 6 | length_penalty: 0.6 7 | maximum_decode_length: 160 8 | metric: bleu 9 | 10 | dataset.class: MultipleDataset 11 | dataset.params: 12 | multiple_datasets: 13 | dev: 14 | dataset.class: ParallelTextDataset 15 | dataset.params: 16 | src_file: DEV_SRC 17 | trg_file: DEV_TRG 18 | test: 19 | dataset.class: ParallelTextDataset 20 | dataset.params: 21 | src_file: TEST_SRC 22 | trg_file: TEST_TRG 23 | -------------------------------------------------------------------------------- /examples/translation/training_args.yml: -------------------------------------------------------------------------------- 1 | entry.class: trainer 2 | entry.params: 3 | train_steps: 200000 4 | summary_steps: 200 5 | save_checkpoint_steps: 1000 6 | criterion.class: label_smoothed_cross_entropy 7 | criterion.params: 8 | label_smoothing: 0.1 9 | -------------------------------------------------------------------------------- /examples/translation/validation_args.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset: ParallelTextDataset 4 | eval_dataset.params: 5 | src_file: DEV_SRC 6 | trg_file: DEV_TRG 7 | eval_batch_size: 64 8 | eval_start_at: 5000 9 | eval_steps: 1000 10 | eval_criterion: label_smoothed_cross_entropy 11 | eval_search_method: beam_search 12 | eval_search_method.params: 13 | beam_size: 4 14 | length_penalty: 0.6 15 | maximum_decode_length: 160 16 | extra_decode_length: 50 17 | eval_metric: bleu 18 | eval_top_checkpoints_to_keep: 10 19 | eval_auto_average_checkpoints: True 20 | -------------------------------------------------------------------------------- /neurst/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | 4 | import importlib 5 | 6 | from .__version__ import __version__ # NOQA 7 | 8 | __author__ = "ZhaoChengqi " 9 | 10 | __all__ = [ 11 | "cli", 12 | "data", 13 | "criterions", 14 | "exps", 15 | "layers", 16 | "metrics", 17 | "models", 18 | "tasks", 19 | "utils", 20 | "training", 21 | "optimizers", 22 | "sparsity" 23 | ] 24 | 25 | importlib.import_module("neurst.criterions") 26 | importlib.import_module("neurst.data") 27 | importlib.import_module("neurst.data.audio") 28 | importlib.import_module("neurst.data.data_pipelines") 29 | importlib.import_module("neurst.data.datasets") 30 | importlib.import_module("neurst.data.datasets.audio") 31 | importlib.import_module("neurst.data.text") 32 | importlib.import_module("neurst.exps") 33 | importlib.import_module("neurst.layers") 34 | importlib.import_module("neurst.layers.attentions") 35 | importlib.import_module("neurst.layers.decoders") 36 | importlib.import_module("neurst.layers.encoders") 37 | importlib.import_module("neurst.layers.metric_layers") 38 | importlib.import_module("neurst.layers.quantization") 39 | importlib.import_module("neurst.layers.search") 40 | importlib.import_module("neurst.metrics") 41 | importlib.import_module("neurst.models") 42 | importlib.import_module("neurst.optimizers") 43 | importlib.import_module("neurst.optimizers.schedules") 44 | importlib.import_module("neurst.sparsity") 45 | importlib.import_module("neurst.tasks") 46 | importlib.import_module("neurst.training") 47 | importlib.import_module("neurst.utils") 48 | importlib.import_module("neurst.utils.converters") 49 | -------------------------------------------------------------------------------- /neurst/__version__.py: -------------------------------------------------------------------------------- 1 | VERSION = (0, 1, 0) 2 | 3 | __version__ = '.'.join(map(str, VERSION)) 4 | -------------------------------------------------------------------------------- /neurst/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst/cli/__init__.py -------------------------------------------------------------------------------- /neurst/cli/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst/cli/analysis/__init__.py -------------------------------------------------------------------------------- /neurst/cli/convert_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from absl import app, logging 15 | 16 | import neurst.utils.flags_core as flags_core 17 | from neurst.utils.converters import Converter, build_converter 18 | 19 | FLAG_LIST = [ 20 | flags_core.Flag("from", dtype=flags_core.Flag.TYPE.STRING, default=None, 21 | required=True, help="The path to pretrained model directory " 22 | "or a key indicating the publicly available model name."), 23 | flags_core.Flag("to", dtype=flags_core.Flag.TYPE.STRING, default=None, 24 | required=True, help="The path to save the converted checkpoint."), 25 | flags_core.Flag("model_name", dtype=flags_core.Flag.TYPE.STRING, default=None, 26 | required=True, help="The name of pretrained model, e.g. google_bert."), 27 | ] 28 | 29 | 30 | def convert(converter: Converter, from_path, to_path): 31 | assert converter is not None 32 | assert from_path 33 | assert to_path 34 | converter.convert(from_path, to_path) 35 | 36 | 37 | def _main(_): 38 | arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=False) 39 | args, remaining_argv = flags_core.parse_flags(FLAG_LIST, arg_parser) 40 | flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) 41 | converter = build_converter(args["model_name"]) 42 | convert(converter, args["from"], args["to"]) 43 | 44 | 45 | if __name__ == "__main__": 46 | logging.set_verbosity(logging.INFO) 47 | app.run(_main, argv=["pseudo.py"]) 48 | -------------------------------------------------------------------------------- /neurst/cli/extract_audio_transcripts.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | from absl import app, logging 16 | 17 | import neurst.utils.flags_core as flags_core 18 | from neurst.data.datasets import Dataset, build_dataset 19 | from neurst.data.datasets.audio.audio_dataset import RawAudioDataset 20 | 21 | FLAG_LIST = [ 22 | flags_core.Flag("output_transcript_file", dtype=flags_core.Flag.TYPE.STRING, 23 | required=True, help="The path to save transcriptions."), 24 | flags_core.Flag("output_translation_file", dtype=flags_core.Flag.TYPE.STRING, 25 | default=None, help="The path to save transcriptions."), 26 | flags_core.ModuleFlag(Dataset.REGISTRY_NAME, help="The raw dataset."), 27 | ] 28 | 29 | 30 | def main(dataset, output_transcript_file, output_translation_file=None): 31 | assert isinstance(dataset, RawAudioDataset) 32 | transcripts = dataset.transcripts 33 | translations = dataset.translations 34 | assert transcripts, "Fail to extract transcripts." 35 | with tf.io.gfile.GFile(output_transcript_file, "w") as fw: 36 | fw.write("\n".join(transcripts) + "\n") 37 | if translations and output_translation_file: 38 | with tf.io.gfile.GFile(output_translation_file, "w") as fw: 39 | fw.write("\n".join(translations) + "\n") 40 | 41 | 42 | def _main(_): 43 | # define and parse program flags 44 | arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=True) 45 | args, remaining_argv = flags_core.intelligent_parse_flags(FLAG_LIST, arg_parser) 46 | flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) 47 | dataset = build_dataset(args) 48 | if dataset is None: 49 | raise ValueError("dataset must be provided.") 50 | main(dataset=dataset, 51 | output_transcript_file=args["output_transcript_file"], 52 | output_translation_file=args["output_translation_file"]) 53 | 54 | 55 | if __name__ == "__main__": 56 | logging.set_verbosity(logging.INFO) 57 | app.run(_main, argv=["pseudo.py"]) 58 | -------------------------------------------------------------------------------- /neurst/cli/process_text.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | from absl import app, logging 16 | 17 | import neurst.utils.flags_core as flags_core 18 | from neurst.data.data_pipelines.data_pipeline import lowercase_and_remove_punctuations 19 | from neurst.data.text import Tokenizer, build_tokenizer 20 | 21 | FLAG_LIST = [ 22 | flags_core.Flag("input", dtype=flags_core.Flag.TYPE.STRING, default=None, 23 | help="The path to the input text file."), 24 | flags_core.Flag("output", dtype=flags_core.Flag.TYPE.STRING, default=None, 25 | help="The path to the output text file."), 26 | flags_core.Flag("lowercase", dtype=flags_core.Flag.TYPE.BOOLEAN, default=None, 27 | help="Whether to lowercase."), 28 | flags_core.Flag("remove_punctuation", dtype=flags_core.Flag.TYPE.BOOLEAN, default=None, 29 | help="Whether to remove the punctuations."), 30 | flags_core.ModuleFlag(Tokenizer.REGISTRY_NAME, help="The tokenizer."), 31 | ] 32 | 33 | 34 | def _main(_): 35 | arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=False) 36 | args, remaining_argv = flags_core.intelligent_parse_flags(FLAG_LIST, arg_parser) 37 | flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) 38 | 39 | tokenizer = build_tokenizer(args) 40 | with tf.io.gfile.GFile(args["input"]) as fp: 41 | with tf.io.gfile.GFile(args["output"], "w") as fw: 42 | for line in fp: 43 | line = lowercase_and_remove_punctuations(tokenizer.language, line.strip(), 44 | args["lowercase"], args["remove_punctuation"]) 45 | fw.write(tokenizer.tokenize(line, return_str=True) + "\n") 46 | 47 | 48 | if __name__ == "__main__": 49 | logging.set_verbosity(logging.INFO) 50 | app.run(_main, argv=["pseudo.py"]) 51 | -------------------------------------------------------------------------------- /neurst/cli/text_metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | from absl import app, logging 16 | 17 | import neurst.utils.flags_core as flags_core 18 | from neurst.metrics import Metric, build_metric 19 | from neurst.utils.misc import flatten_string_list 20 | 21 | FLAG_LIST = [ 22 | flags_core.Flag("hypo_file", dtype=flags_core.Flag.TYPE.STRING, default=None, 23 | help="The path to hypothesis file."), 24 | flags_core.Flag("ref_file", dtype=flags_core.Flag.TYPE.STRING, default=None, multiple=True, 25 | help="The path to reference file. "), 26 | flags_core.ModuleFlag(Metric.REGISTRY_NAME, help="The metric for evaluation."), 27 | ] 28 | 29 | 30 | def evaluate(metric, hypo_file, ref_file): 31 | assert metric is not None 32 | assert hypo_file 33 | assert ref_file 34 | with tf.io.gfile.GFile(hypo_file) as fp: 35 | hypo = [line.strip() for line in fp] 36 | 37 | ref_list = [] 38 | for one_ref_file in flatten_string_list(ref_file): 39 | with tf.io.gfile.GFile(one_ref_file) as fp: 40 | ref = [line.strip() for line in fp] 41 | ref_list.append(ref) 42 | 43 | metric_result = (metric(hypo, ref_list) if len(ref_list) > 1 44 | else metric(hypo, ref_list[0])) 45 | for k, v in metric_result.items(): 46 | logging.info("Evaluation result: %s=%.2f", k, v) 47 | 48 | 49 | def _main(_): 50 | arg_parser = flags_core.define_flags(FLAG_LIST, with_config_file=False) 51 | args, remaining_argv = flags_core.intelligent_parse_flags(FLAG_LIST, arg_parser) 52 | flags_core.verbose_flags(FLAG_LIST, args, remaining_argv) 53 | metric = build_metric(args) 54 | evaluate(metric, args["hypo_file"], args["ref_file"]) 55 | 56 | 57 | if __name__ == "__main__": 58 | logging.set_verbosity(logging.INFO) 59 | app.run(_main, argv=["pseudo.py"]) 60 | -------------------------------------------------------------------------------- /neurst/cli/view_tfrecord.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import sys 15 | 16 | import tensorflow as tf 17 | 18 | from neurst.data.dataset_utils import glob_tfrecords 19 | 20 | 21 | def cli_main(): 22 | if len(sys.argv) == 1 or (len(sys.argv) == 2 and (sys.argv[1] in ["help", "--help", "-h"])): 23 | print("Usage: ") 24 | print(" >> python3 -m neurst.cli.view_tfrecord path") 25 | print(" Show examples and types of TF Record elements.") 26 | exit() 27 | 28 | print("===================== Examine elements =====================") 29 | for x in tf.data.TFRecordDataset(glob_tfrecords(sys.argv[1])).take(1): 30 | example = tf.train.Example() 31 | example.ParseFromString(x.numpy()) 32 | print(example) 33 | print("elements: {") 34 | for name in example.features.feature: 35 | if len(example.features.feature[name].bytes_list.value) > 0: 36 | print(f" \"{name}\": bytes (str)") 37 | elif len(example.features.feature[name].int64_list.value) > 0: 38 | print(f" \"{name}\": int64") 39 | elif len(example.features.feature[name].float_list.value) > 0: 40 | print(f" \"{name}\": float32") 41 | print("}") 42 | 43 | 44 | if __name__ == "__main__": 45 | cli_main() 46 | -------------------------------------------------------------------------------- /neurst/criterions/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.criterions.criterion import Criterion 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_criterion, register_criterion = setup_registry(Criterion.REGISTRY_NAME, base_class=Criterion, 8 | verbose_creation=True) 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.criterions.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst/criterions/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | 18 | 19 | @six.add_metaclass(ABCMeta) 20 | class Criterion(object): 21 | REGISTRY_NAME = "criterion" 22 | 23 | def __init__(self): 24 | self._model = None 25 | 26 | @staticmethod 27 | def class_or_method_args(): 28 | """ Returns a list of args for flag definition. """ 29 | return [] 30 | 31 | @abstractmethod 32 | def __call__(self, model_inp, model_out): 33 | """ Calculates according to model inputs and model outputs. 34 | 35 | Returns a list of tensors. 36 | """ 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def reduce_loss(self, model_inp, model_out): 41 | """ Reduces loss tensor for training according to the model inputs 42 | and outputs. 43 | 44 | Returns: A float tensor. 45 | """ 46 | raise NotImplementedError 47 | 48 | @abstractmethod 49 | def reduce_metrics(self, eval_res_list): 50 | """ Reduces the metrics according to a list of returned value from `eval`. 51 | 52 | Args: 53 | eval_res_list: A list of tuples of numpy.ndarray generated by `self.__call__` 54 | and model.__call__. 55 | 56 | Returns: 57 | A dict of reduced metrics for evaluation. 58 | """ 59 | raise NotImplementedError 60 | 61 | def reduce_sample_metrics(self, eval_res): 62 | """ Reduces the metrics at sample level. 63 | 64 | Args: 65 | eval_res: A tuple of numpy.ndarray or tensors generated by `self.__call__`. 66 | 67 | Returns: 68 | A list of dict of reduced metrics for evaluation. 69 | """ 70 | raise NotImplementedError 71 | 72 | @abstractmethod 73 | def as_metric(self): 74 | """ Returns a wrapper class of Metric. """ 75 | raise NotImplementedError 76 | 77 | def set_model(self, model): 78 | self._model = model 79 | 80 | @property 81 | def model(self): 82 | return self._model 83 | -------------------------------------------------------------------------------- /neurst/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | importlib.import_module("neurst.data.audio") 4 | importlib.import_module("neurst.data.data_pipelines") 5 | importlib.import_module("neurst.data.datasets") 6 | importlib.import_module("neurst.data.text") 7 | -------------------------------------------------------------------------------- /neurst/data/audio/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.data.audio.feature_extractor import FeatureExtractor 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_feature_extractor, register_feature_extractor = setup_registry( 8 | "feature_extractor", base_class=FeatureExtractor, verbose_creation=True) 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.data.audio.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst/data/audio/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | 18 | 19 | @six.add_metaclass(ABCMeta) 20 | class FeatureExtractor(object): 21 | """ Abstract feature extractor for extracting audio features. """ 22 | REGISTRY_NAME = "feature_extractor" 23 | 24 | @property 25 | @abstractmethod 26 | def feature_dim(self): 27 | """ Returns the dimension of the feature. """ 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def seconds(self, feature): 32 | """ Returns the time seconds of this sample. """ 33 | raise NotImplementedError 34 | 35 | @abstractmethod 36 | def __call__(self, signal, rate): 37 | raise NotImplementedError 38 | 39 | @staticmethod 40 | def class_or_method_args(): 41 | return [] 42 | -------------------------------------------------------------------------------- /neurst/data/audio/float_identity.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy 15 | 16 | from neurst.data.audio import FeatureExtractor, register_feature_extractor 17 | 18 | 19 | @register_feature_extractor 20 | class FloatIdentity(FeatureExtractor): 21 | 22 | def __init__(self, args): 23 | _ = args 24 | 25 | def seconds(self, feature): 26 | # by default: sample rate=16000 27 | return len(feature) / 16000. 28 | 29 | @property 30 | def feature_dim(self): 31 | return 1 32 | 33 | def __call__(self, signal, rate): 34 | if isinstance(signal[0], (float, numpy.float32, numpy.float64)): 35 | return numpy.array(signal) 36 | return numpy.array(signal) / 32768. 37 | -------------------------------------------------------------------------------- /neurst/data/audio/log_mel_fbank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy 15 | 16 | from neurst.data.audio import FeatureExtractor, register_feature_extractor 17 | from neurst.utils.flags_core import Flag 18 | 19 | try: 20 | from python_speech_features import logfbank 21 | except ImportError: 22 | pass 23 | 24 | 25 | @register_feature_extractor("fbank") 26 | class LogMelFbank(FeatureExtractor): 27 | 28 | def __init__(self, args): 29 | self._nfilt = args["nfilt"] 30 | self._winlen = args["winlen"] 31 | self._winstep = args["winstep"] 32 | try: 33 | from python_speech_features import logfbank 34 | _ = logfbank 35 | except ImportError: 36 | raise ImportError('Please install python_speech_features with: pip3 install python_speech_features') 37 | 38 | @staticmethod 39 | def class_or_method_args(): 40 | return [ 41 | Flag("nfilt", dtype=Flag.TYPE.INTEGER, default=80, 42 | help="The number of frames in the filterbank."), 43 | Flag("winlen", dtype=Flag.TYPE.FLOAT, default=0.025, 44 | help="The length of the analysis window in seconds. Default is 0.025s."), 45 | Flag("winstep", dtype=Flag.TYPE.FLOAT, default=0.01, 46 | help="The step between successive windows in seconds. Default is 0.01s.") 47 | ] 48 | 49 | @property 50 | def feature_dim(self): 51 | return self._nfilt 52 | 53 | def seconds(self, feature): 54 | return (numpy.shape(feature)[0] - 1.) * self._winstep + self._winlen 55 | 56 | def __call__(self, signal, rate): 57 | inp = logfbank(signal, samplerate=rate, nfilt=self._nfilt, 58 | winlen=self._winlen, winstep=self._winstep).astype(numpy.float32) 59 | inp = (inp - numpy.mean(inp)) / numpy.std(inp) 60 | return inp 61 | -------------------------------------------------------------------------------- /neurst/data/data_pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.data.data_pipelines.data_pipeline import DataPipeline 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_data_pipeline, register_data_pipeline = setup_registry(DataPipeline.REGISTRY_NAME, base_class=DataPipeline) 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.data.data_pipelines.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/data/data_pipelines/data_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | from abc import ABCMeta, abstractmethod 16 | 17 | import six 18 | from sacremoses import MosesPunctNormalizer 19 | 20 | from neurst.utils.configurable import extract_constructor_params 21 | 22 | PUNC_PATTERN = re.compile(r"[,\.\!\(\);:、\?\-\+=\"/><《》\[\],。:;「」【】{}`@#\$%\^&\*]") 23 | PUNC_NORMERS = dict() 24 | 25 | 26 | def lowercase_and_remove_punctuations(language, text, lowercase=True, remove_punctuation=True): 27 | if lowercase: 28 | text = text.lower() 29 | if language not in ["zh", "ja"]: 30 | if language not in PUNC_NORMERS: 31 | PUNC_NORMERS[language] = MosesPunctNormalizer(lang=language) 32 | text = PUNC_NORMERS[language].normalize(text) 33 | text = text.replace("' s ", "'s ").replace( 34 | "' ve ", "'ve ").replace("' m ", "'m ").replace("' t ", "'t ").replace("' re ", "'re ") 35 | if remove_punctuation: 36 | text = PUNC_PATTERN.sub(" ", text) 37 | text = " ".join(text.strip().split()) 38 | return text 39 | 40 | 41 | @six.add_metaclass(ABCMeta) 42 | class DataPipeline(object): 43 | REGISTRY_NAME = "data_pipeline" 44 | PUNC_NORMERS = dict() 45 | 46 | def __init__(self, **kwargs): 47 | self._params = extract_constructor_params(locals(), verbose=False) 48 | 49 | def get_config(self) -> dict: 50 | return self._params 51 | 52 | @property 53 | @abstractmethod 54 | def meta(self) -> dict: 55 | """ The meta data. """ 56 | return {} 57 | 58 | @abstractmethod 59 | def recover(self, input): 60 | """ Recovers one data sample. """ 61 | raise NotImplementedError 62 | 63 | @abstractmethod 64 | def process(self, input, is_processed=False): 65 | """ Processes one data sample. """ 66 | raise NotImplementedError 67 | 68 | def text_pre_normalize(self, language, input, is_processed=False): 69 | if is_processed: 70 | return input 71 | output = lowercase_and_remove_punctuations(language, input, 72 | self._params.get("lowercase", False), 73 | self._params.get("remove_punctuation", False)) 74 | return output 75 | -------------------------------------------------------------------------------- /neurst/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.data.datasets.dataset import Dataset, TFRecordDataset 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_dataset, register_dataset = setup_registry(Dataset.REGISTRY_NAME, base_class=Dataset, 8 | verbose_creation=True) 9 | _ = TFRecordDataset 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.data.datasets.' + model_name) 16 | 17 | importlib.import_module("neurst.data.datasets.audio") 18 | importlib.import_module("neurst.data.datasets.data_sampler") 19 | -------------------------------------------------------------------------------- /neurst/data/datasets/audio/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | models_dir = os.path.dirname(__file__) 5 | for file in os.listdir(models_dir): 6 | path = os.path.join(models_dir, file) 7 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 8 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 9 | module = importlib.import_module('neurst.data.datasets.audio.' + model_name) 10 | -------------------------------------------------------------------------------- /neurst/data/datasets/data_sampler/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.data.datasets.data_sampler.data_sampler import DataSampler 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_data_sampler, register_data_sampler = setup_registry(DataSampler.REGISTRY_NAME, base_class=DataSampler, 8 | verbose_creation=True) 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.data.datasets.data_sampler.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/data/datasets/data_sampler/data_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import random 15 | from abc import ABCMeta, abstractmethod 16 | 17 | import numpy 18 | import six 19 | import yaml 20 | 21 | from neurst.utils.flags_core import Flag 22 | 23 | 24 | @six.add_metaclass(ABCMeta) 25 | class DataSampler(object): 26 | REGISTRY_NAME = "data_sampler" 27 | 28 | def __init__(self, args): 29 | if isinstance(args["sample_sizes"], str): 30 | args["sample_sizes"] = yaml.load(args["sample_sizes"], Loader=yaml.FullLoader) 31 | assert isinstance(args["sample_sizes"], dict) and len(args["sample_sizes"]) > 0, ( 32 | "Unknown `sample_sizes`={} with type {}".format(args["sample_sizes"], type(args["sample_sizes"]))) 33 | self._sample_ratios = self.get_sample_ratios(args["sample_sizes"]) 34 | total = sum(self._sample_ratios.values()) 35 | self._normalized_sample_weights = {k: float(v) / total for k, v in self._sample_ratios.items()} 36 | self._sample_items = [] 37 | self._sample_boundaries = [] 38 | for k, v in self._sample_ratios.items(): 39 | self._sample_items.append(k) 40 | if len(self._sample_boundaries) == 0: 41 | self._sample_boundaries.append(float(v) / total) 42 | else: 43 | self._sample_boundaries.append(self._sample_boundaries[-1] + float(v) / total) 44 | self._sample_boundaries = numpy.array(self._sample_boundaries) 45 | 46 | @staticmethod 47 | def class_or_method_args(): 48 | return [ 49 | Flag("sample_sizes", dtype=Flag.TYPE.STRING, 50 | help="A dict. The key is the item name to be sampled, " 51 | "while the value is the corresponding proportion.") 52 | ] 53 | 54 | @property 55 | def normalized_sample_weights(self): 56 | return self._normalized_sample_weights 57 | 58 | @abstractmethod 59 | def get_sample_ratios(self, sample_sizes) -> dict: 60 | raise NotImplementedError 61 | 62 | def __call__(self): 63 | ratio = random.random() 64 | for idx in range(len(self._sample_boundaries) - 1, -1, -1): 65 | if ratio > self._sample_boundaries[idx]: 66 | return self._sample_items[idx + 1] 67 | return self._sample_items[0] 68 | -------------------------------------------------------------------------------- /neurst/data/datasets/data_sampler/temperature_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from neurst.data.datasets.data_sampler import DataSampler, register_data_sampler 15 | from neurst.utils.flags_core import Flag 16 | 17 | 18 | @register_data_sampler("temperature") 19 | class TemperatureSampler(DataSampler): 20 | 21 | def __init__(self, args): 22 | self._temperature = args["temperature"] 23 | super(TemperatureSampler, self).__init__(args) 24 | 25 | @staticmethod 26 | def class_or_method_args(): 27 | this_flags = super(TemperatureSampler, TemperatureSampler).class_or_method_args() 28 | this_flags.append( 29 | Flag("temperature", dtype=Flag.TYPE.FLOAT, default=5, 30 | help="The temperature for sampling.")) 31 | return this_flags 32 | 33 | def get_sample_ratios(self, sample_sizes) -> dict: 34 | total_size = sum(sample_sizes.values()) 35 | return {k: (v / total_size) ** (1. / self._temperature) 36 | for k, v in sample_sizes.items()} 37 | -------------------------------------------------------------------------------- /neurst/data/datasets/multiple_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from neurst.data.datasets import build_dataset, register_dataset 15 | from neurst.data.datasets.dataset import Dataset 16 | from neurst.utils.flags_core import Flag 17 | 18 | 19 | @register_dataset("multi_dataset") 20 | class MultipleDataset(Dataset): 21 | 22 | def __init__(self, args): 23 | """ Initializes the multiple dataset. 24 | 25 | Args: 26 | args: containing `multiple_dataset`, which is like 27 | { 28 | "data0": { "dataset.class": "", "dataset.params": ""}, 29 | "data1": { "dataset.class": "", "dataset.params": ""}, 30 | ...... 31 | ] 32 | """ 33 | super(MultipleDataset, self).__init__() 34 | self._datasets = {name: build_dataset(dsargs) 35 | for name, dsargs in args["multiple_datasets"].items()} 36 | self._sample_weights = dict() 37 | if args["sample_weights"]: 38 | assert isinstance(args["sample_weights"], dict) 39 | else: 40 | args["sample_weights"] = {} 41 | sum = 0. 42 | for name in self._datasets: 43 | self._sample_weights[name] = args["sample_weights"].get(name, 1.) 44 | sum += self._sample_weights[name] 45 | for name in self._datasets: 46 | self._sample_weights[name] /= sum 47 | 48 | @staticmethod 49 | def class_or_method_args(): 50 | return [ 51 | Flag("multiple_datasets", dtype=Flag.TYPE.STRING, 52 | help="A dict of dataset class and parameters, " 53 | "where the key is the dataset name and " 54 | "the value is a dict of arguments for one dataset."), 55 | Flag("sample_weights", dtype=Flag.TYPE.FLOAT, 56 | help="A dict of weights for averaging metrics, where the key " 57 | "is the dataset name. 1.0 for each by default.") 58 | ] 59 | 60 | @property 61 | def status(self): 62 | raise NotImplementedError 63 | 64 | @property 65 | def sample_weights(self): 66 | return self._sample_weights 67 | 68 | @property 69 | def datasets(self): 70 | return self._datasets 71 | 72 | def build(self, *args, **kwargs): 73 | raise NotImplementedError("Call each dataset's build function instead.") 74 | 75 | def build_iterator(self, *args, **kwargs): 76 | raise NotImplementedError("Call each dataset's build function instead.") 77 | -------------------------------------------------------------------------------- /neurst/data/datasets/text_gen_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | 18 | from neurst.data.datasets.dataset import Dataset 19 | 20 | 21 | @six.add_metaclass(ABCMeta) 22 | class TextGenDataset(Dataset): 23 | """ The abstract dataset for text generation, which must implement `get_targets` function. """ 24 | 25 | def __init__(self, trg_lang=None): 26 | self._targets = None 27 | self._trg_lang = trg_lang 28 | super(TextGenDataset, self).__init__() 29 | 30 | @property 31 | def trg_lang(self): 32 | return self._trg_lang 33 | 34 | @property 35 | @abstractmethod 36 | def status(self) -> str: 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def build_iterator(self, map_func=None, shard_id=0, total_shards=1): 41 | """ Returns the iterator of the dataset. 42 | 43 | Args: 44 | map_func: A function mapping a dataset element to another dataset element. 45 | shard_id: Generator yields on the `shard_id`-th shard of the whole dataset. 46 | total_shards: The number of total shards. 47 | """ 48 | raise NotImplementedError 49 | 50 | @property 51 | def targets(self): 52 | """ Returns a list of target texts. """ 53 | return self._targets 54 | -------------------------------------------------------------------------------- /neurst/data/text/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.data.text.tokenizer import Tokenizer 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_tokenizer, register_tokenizer = setup_registry(Tokenizer.REGISTRY_NAME, base_class=Tokenizer) 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.data.text.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/data/text/character.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | from neurst.data.text import register_tokenizer 17 | from neurst.data.text.tokenizer import Tokenizer 18 | 19 | 20 | @register_tokenizer("char") 21 | class Character(Tokenizer): 22 | CHAR_COMPILER = re.compile(r"([\u2E80-\u9FFF\uA000-\uA4FF\uAC00-\uD7FF\uF900-\uFAFF])") 23 | 24 | def __init__(self, language, glossaries=None): 25 | super(Character, self).__init__( 26 | language=language, glossaries=glossaries) 27 | 28 | @staticmethod 29 | def is_cjk(language): 30 | return language in ["zh", "ja", "ko"] 31 | 32 | @staticmethod 33 | def to_character(text, language=None): 34 | text = Character._convert_to_str(text) 35 | if Character.is_cjk(language): 36 | return Character.cjk_to_character(text) 37 | return " ".join(text) 38 | 39 | @staticmethod 40 | def cjk_to_character(text): 41 | """ CJK sentence to character-level. 42 | 43 | Args: 44 | text: A list of string tokens or a string. 45 | 46 | Returns: A string. 47 | """ 48 | text = Character._convert_to_str(text) 49 | res = Character.CHAR_COMPILER.sub(r" \1 ", text) 50 | # tokenize period and comma unless preceded by a digit 51 | res = re.sub(r'([^0-9])([\.,])', r'\1 \2 ', res) 52 | 53 | # tokenize period and comma unless followed by a digit 54 | res = re.sub(r'([\.,])([^0-9])', r' \1 \2', res) 55 | 56 | # tokenize dash when preceded by a digit 57 | res = re.sub(r'([0-9])(-)', r'\1 \2 ', res) 58 | 59 | # one space only between words 60 | res = re.sub(r'\s+', r' ', res) 61 | 62 | # no leading space 63 | res = re.sub(r'^\s+', r'', res) 64 | 65 | # no trailing space 66 | res = re.sub(r'\s+$', r'', res) 67 | return res 68 | 69 | def tokenize(self, text, return_str=False): 70 | return self._output_wrapper( 71 | self.to_character(text, language=self.language), 72 | return_str=return_str) 73 | 74 | def detokenize(self, text, return_str=True): 75 | if not self.is_cjk(self.language): 76 | raise NotImplementedError(f"detokenize fn in Character " 77 | f"is not implemented for language={self.language}") 78 | return self._output_wrapper( 79 | self.cjk_deseg(self._convert_to_str(text)), 80 | return_str=return_str) 81 | -------------------------------------------------------------------------------- /neurst/data/text/huggingface_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import traceback 15 | 16 | from absl import logging 17 | 18 | from neurst.data.text import register_tokenizer 19 | from neurst.data.text.tokenizer import Tokenizer 20 | 21 | try: 22 | from transformers import AutoTokenizer 23 | except ImportError: 24 | pass 25 | 26 | 27 | @register_tokenizer("huggingface") 28 | class HuggingFaceTokenizer(Tokenizer): 29 | 30 | def __init__(self, language, glossaries=None, subtokenizer_codes=None, **kwargs): 31 | super(HuggingFaceTokenizer, self).__init__( 32 | language=language, glossaries=glossaries, **kwargs) 33 | try: 34 | from transformers import AutoTokenizer 35 | _ = AutoTokenizer 36 | except ImportError: 37 | raise ImportError('Please install transformers with: pip3 install transformers') 38 | self._built = False 39 | self._codes = subtokenizer_codes 40 | 41 | def init_subtokenizer(self, codes): 42 | """ Lazily initializes huggingface tokenizer. """ 43 | self._codes = codes 44 | 45 | def _lazy_init(self): 46 | codes = self._codes 47 | success = False 48 | fail_times = 0 49 | while not success: 50 | try: 51 | self._tokenizer = AutoTokenizer.from_pretrained(codes) 52 | success = True 53 | except Exception as e: 54 | fail_times += 1 55 | logging.info("AutoTokenizer.from_pretrained fails for {0} times".format(fail_times)) 56 | if fail_times >= 5: 57 | logging.info(traceback.format_exc()) 58 | raise e 59 | 60 | self._built = True 61 | 62 | def tokenize(self, text, return_str=False): 63 | if not self._built: 64 | self._lazy_init() 65 | if not self._built: 66 | raise ValueError("call `init_subtokenizer` at first to initialize the tokenizer.") 67 | return self._output_wrapper( 68 | self._tokenizer.tokenize(self._convert_to_str(text)), return_str=return_str) 69 | 70 | def detokenize(self, text, return_str=True): 71 | if not self._built: 72 | self._lazy_init() 73 | if not self._built: 74 | raise ValueError("call `init_subtokenizer` at first to initialize the tokenizer.") 75 | return self._output_wrapper( 76 | self._tokenizer.convert_tokens_to_string(self._convert_to_list(text)), return_str=return_str) 77 | -------------------------------------------------------------------------------- /neurst/data/text/jieba_segment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from absl import logging 15 | 16 | from neurst.data.text import register_tokenizer 17 | from neurst.data.text.tokenizer import Tokenizer 18 | 19 | 20 | @register_tokenizer 21 | class Jieba(Tokenizer): 22 | 23 | def __init__(self, language="zh", glossaries=None): 24 | super(Jieba, self).__init__( 25 | language=language, glossaries=glossaries) 26 | if self._glossaries and len(self._glossaries) > 0: 27 | logging.info("WARNING: now `glossaries` has no effect on Jieba.") 28 | try: 29 | import jieba 30 | self._cut_fn = jieba.lcut 31 | except ImportError: 32 | raise ImportError('Please install jieba with: pip3 install jieba') 33 | 34 | def tokenize(self, text, return_str=False): 35 | return self._output_wrapper(self._cut_fn(self._convert_to_str(text)), 36 | return_str=return_str) 37 | 38 | def detokenize(self, text, return_str=True): 39 | return self._output_wrapper( 40 | self.cjk_deseg(self._convert_to_str(text)), return_str=return_str) 41 | -------------------------------------------------------------------------------- /neurst/data/text/moses_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from neurst.data.text import register_tokenizer 15 | from neurst.data.text.tokenizer import Tokenizer 16 | 17 | 18 | @register_tokenizer("moses") 19 | class MosesTokenizer(Tokenizer): 20 | 21 | def __init__(self, language, glossaries=None, 22 | aggressive_dash_splits=True, escape=False): 23 | super(MosesTokenizer, self).__init__( 24 | language=language, glossaries=glossaries) 25 | self._aggressive_dash_splits = aggressive_dash_splits 26 | self._escape = escape 27 | try: 28 | from sacremoses import MosesDetokenizer as MDetok 29 | from sacremoses import MosesTokenizer as MTok 30 | self._tok = MTok(lang=self.language) 31 | self._detok = MDetok(lang=self.language) 32 | except ImportError: 33 | raise ImportError('Please install Moses tokenizer with: pip3 install sacremoses') 34 | 35 | def tokenize(self, text, return_str=False): 36 | return self._tok.tokenize(self._convert_to_str(text), 37 | aggressive_dash_splits=self._aggressive_dash_splits, 38 | return_str=return_str, 39 | escape=self._escape, 40 | protected_patterns=self._glossaries) 41 | 42 | def detokenize(self, text, return_str=True): 43 | return self._detok.detokenize(self._convert_to_list(text), 44 | return_str=return_str, 45 | unescape=True) 46 | -------------------------------------------------------------------------------- /neurst/data/text/thai_tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ Defines thai tokenizer.""" 15 | from neurst.data.text import Tokenizer, register_tokenizer 16 | 17 | 18 | @register_tokenizer 19 | class ThaiTokenizer(Tokenizer): 20 | 21 | def __init__(self, language="th", glossaries=None): 22 | """ Initializes. """ 23 | _ = language 24 | language = "th" 25 | try: 26 | from thai_segmenter import tokenize as thai_tokenize 27 | self._thai_tokenize = thai_tokenize 28 | except ImportError: 29 | raise ImportError('Please install Thai tokenizer with: pip install thai-segmenter') 30 | self._thai_tokenize("") 31 | super(ThaiTokenizer, self).__init__(language=language, glossaries=glossaries) 32 | 33 | def tokenize(self, text, return_str=False): 34 | """ Tokenize a text. """ 35 | res = self._thai_tokenize(self._convert_to_str(text)) 36 | if return_str: 37 | res = " ".join(res) 38 | return res 39 | 40 | def detokenize(self, words, return_str=True): 41 | """ Recovers the result of `tokenize(words)`. 42 | 43 | Args: 44 | words: A list of strings, i.e. tokenized text. 45 | return_str: returns a string if True, a list of tokens otherwise. 46 | 47 | Returns: The recovered sentence string. 48 | """ 49 | if isinstance(words, str): 50 | words = words.strip().split() 51 | if return_str: 52 | words = "".join(words) 53 | return words 54 | -------------------------------------------------------------------------------- /neurst/exps/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.exps.base_experiment import BaseExperiment 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_exp, register_exp = setup_registry(BaseExperiment.REGISTRY_NAME, base_class=BaseExperiment, 8 | verbose_creation=True) 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.exps.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst/exps/base_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | 18 | 19 | @six.add_metaclass(ABCMeta) 20 | class BaseExperiment(object): 21 | REGISTRY_NAME = "entry" 22 | 23 | def __init__(self, strategy, model, task, custom_dataset, model_dir): 24 | """ Initializes the basic experiment for training, evaluation, etc. """ 25 | self._strategy = strategy 26 | self._model = model 27 | self._model_dir = model_dir 28 | self._task = task 29 | self._custom_dataset = custom_dataset 30 | 31 | @property 32 | def strategy(self): 33 | return self._strategy 34 | 35 | @property 36 | def model(self): 37 | return self._model 38 | 39 | @property 40 | def task(self): 41 | return self._task 42 | 43 | @property 44 | def custom_dataset(self): 45 | return self._custom_dataset 46 | 47 | @property 48 | def model_dir(self): 49 | return self._model_dir 50 | 51 | @abstractmethod 52 | def run(self): 53 | """ Running the method. """ 54 | raise NotImplementedError 55 | -------------------------------------------------------------------------------- /neurst/layers/__init__.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from neurst.layers.attentions.light_convolution_layer import LightConvolutionLayer 4 | from neurst.layers.attentions.multi_head_attention import MultiHeadAttention, MultiHeadSelfAttention 5 | from neurst.layers.common_layers import PrePostProcessingWrapper, TransformerFFN, PrePostProcessingWrapperWithadapter 6 | from neurst.utils.registry import setup_registry 7 | from neurst.layers.adapters import build_adapter 8 | 9 | build_base_layer, register_base_layer = setup_registry("base_layer", base_class=tf.keras.layers.Layer, 10 | verbose_creation=False) 11 | 12 | register_base_layer(MultiHeadSelfAttention) 13 | register_base_layer(MultiHeadAttention) 14 | register_base_layer(TransformerFFN) 15 | register_base_layer(LightConvolutionLayer) 16 | 17 | 18 | def build_transformer_component(layer_args, 19 | dropout_rate, 20 | pre_norm=True, 21 | epsilon=1e-6, 22 | res_conn_factor=1., 23 | name_postfix=None): 24 | base_layer = build_base_layer(layer_args) 25 | return PrePostProcessingWrapper( 26 | layer=base_layer, 27 | dropout_rate=dropout_rate, 28 | epsilon=epsilon, 29 | pre_norm=pre_norm, 30 | res_conn_factor=res_conn_factor, 31 | name=base_layer.name + (name_postfix or "_prepost_wrapper")) 32 | 33 | def build_transformer_component_base(layer_args, dropout_rate): 34 | base_layer = build_base_layer(layer_args) 35 | return PrePostProcessingWrapper( 36 | layer=base_layer, 37 | dropout_rate=dropout_rate, 38 | name=base_layer.name + "_prepost_wrapper") 39 | 40 | def build_transformer_component_with_adapter(layer_args, dropout_rate, adapter_args, is_pretrain, USEADAPTER=True): 41 | base_layer = build_base_layer(layer_args) 42 | adapter_args["adapter.params"]["use_norm"] = False 43 | adapter_layer = build_adapter(adapter_args) 44 | adapter_layer.trainable = not is_pretrain 45 | return PrePostProcessingWrapperWithadapter( 46 | layer=base_layer, 47 | adapter=adapter_layer, 48 | dropout_rate=dropout_rate, 49 | name=base_layer.name + "_" + adapter_layer.name + "_prepost_wrapper_with_adapter", 50 | is_pretrain=is_pretrain, 51 | use_adapter=USEADAPTER, 52 | ) 53 | -------------------------------------------------------------------------------- /neurst/layers/adapters/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | from neurst.layers.adapters.adapter import Adapter 4 | from neurst.utils.registry import setup_registry 5 | 6 | build_adapter, register_adapter = setup_registry("adapter", base_class=Adapter) 7 | 8 | models_dir = os.path.dirname(__file__) 9 | for file in os.listdir(models_dir): 10 | path = os.path.join(models_dir, file) 11 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 12 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 13 | module = importlib.import_module('neurst.layers.adapters.' + model_name) 14 | -------------------------------------------------------------------------------- /neurst/layers/adapters/adapter.py: -------------------------------------------------------------------------------- 1 | """ Base Adapter class. """ 2 | from __future__ import absolute_import, division, print_function 3 | 4 | from abc import ABCMeta, abstractmethod 5 | 6 | import six 7 | import tensorflow as tf 8 | from neurst.utils.configurable import extract_constructor_params 9 | 10 | 11 | @six.add_metaclass(ABCMeta) 12 | class Adapter(tf.keras.layers.Layer): 13 | """Base class for Adapter """ 14 | 15 | def __init__(self, name=None, **kwargs): 16 | """ Initializes the parameters of the decoder. """ 17 | self._params = extract_constructor_params(locals(), verbose=False) 18 | super(Adapter, self).__init__(name=name) 19 | 20 | def build(self, input_shape): 21 | super(Adapter, self).build(input_shape) 22 | 23 | def get_config(self): 24 | return self._params 25 | 26 | @abstractmethod 27 | def call(self, inputs, is_training=True): 28 | raise NotImplementedError 29 | -------------------------------------------------------------------------------- /neurst/layers/adapters/adapterEmb.py: -------------------------------------------------------------------------------- 1 | """ Base Adapter class. """ 2 | from __future__ import absolute_import, division, print_function 3 | import tensorflow as tf 4 | 5 | tf.random.Generator = None # Patch for a bug, https://stackoverflow.com/questions/62696815/tensorflow-core-api-v2-random-has-no-attribute-generator 6 | 7 | from neurst.layers.adapters import register_adapter 8 | from neurst.layers.adapters.adapter import Adapter 9 | 10 | 11 | @register_adapter 12 | class AdapterEmb(Adapter): 13 | """Embedding Adapter """ 14 | 15 | def __init__(self, 16 | hidden_size_inner, 17 | hidden_size_outter, 18 | dropout_rate=.3, 19 | use_norm=True, 20 | name="AdapterEmb", ): 21 | """ Initializes the parameters of the Embedding Adapter. 22 | """ 23 | super(AdapterEmb, self).__init__( 24 | hidden_size_inner=hidden_size_inner, 25 | hidden_size_outter=hidden_size_outter, 26 | dropout_rate=dropout_rate, 27 | name=name, 28 | ) 29 | self.inner_layer = None 30 | self.outter_layer = None 31 | self._norm_layer = None 32 | self._use_norm = use_norm 33 | if self._use_norm: 34 | self._norm_layer = tf.keras.layers.LayerNormalization( 35 | epsilon=1e-6, dtype="float32", name="output_ln") 36 | 37 | def build(self, input_shape): 38 | params = self.get_config() 39 | self.inner_layer = tf.keras.layers.Dense(units=params["hidden_size_inner"], activation=tf.nn.relu) 40 | self.outter_layer = tf.keras.layers.Dense(units=params["hidden_size_outter"], activation=None) 41 | super(Adapter, self).build(input_shape) 42 | 43 | def call(self, inputs, is_training=True): 44 | params = self.get_config() 45 | if self._use_norm: 46 | z = self._norm_layer(inputs) 47 | else: 48 | z = inputs 49 | 50 | z = self.inner_layer(z) 51 | z = tf.cast(z, inputs.dtype) 52 | if is_training: 53 | z = tf.nn.dropout(z, params["dropout_rate"]) 54 | h_out = tf.cast(self.outter_layer(z), inputs.dtype) 55 | return h_out 56 | -------------------------------------------------------------------------------- /neurst/layers/adapters/adapterLayer.py: -------------------------------------------------------------------------------- 1 | """ Base Adapter class. """ 2 | from __future__ import absolute_import, division, print_function 3 | import tensorflow as tf 4 | 5 | # tf.random.Generator = None # Patch for a bug, https://stackoverflow.com/questions/62696815/tensorflow-core-api-v2-random-has-no-attribute-generator 6 | # import tensorflow_addons as tfa 7 | 8 | from neurst.layers.adapters import register_adapter 9 | from neurst.layers.adapters.adapter import Adapter 10 | 11 | 12 | @register_adapter 13 | class AdapterLayer(Adapter): 14 | """Layer Adapter (Parallel) """ 15 | 16 | def __init__(self, 17 | hidden_size_inner, 18 | hidden_size_outter, 19 | dropout_rate=.3, 20 | use_norm=True, 21 | name="AdapterLayer", ): 22 | """ Initializes the parameters of the Layer Adapter. 23 | """ 24 | super(AdapterLayer, self).__init__( 25 | hidden_size_inner=hidden_size_inner, 26 | hidden_size_outter=hidden_size_outter, 27 | dropout_rate=dropout_rate, 28 | name=name, 29 | ) 30 | self.inner_layer = None 31 | self.outter_layer = None 32 | self._norm_layer = None 33 | self._use_norm = use_norm 34 | if self._use_norm: 35 | self._norm_layer = tf.keras.layers.LayerNormalization( 36 | epsilon=1e-6, dtype="float32", name="output_ln") 37 | 38 | def build(self, input_shape): 39 | params = self.get_config() 40 | self.inner_layer = tf.keras.layers.Dense(units=params["hidden_size_inner"], activation=tf.nn.relu) 41 | self.outter_layer = tf.keras.layers.Dense(units=params["hidden_size_outter"], activation=None) 42 | super(Adapter, self).build(input_shape) 43 | 44 | def call(self, inputs, is_training=True): 45 | params = self.get_config() 46 | if self._use_norm: 47 | z = self._norm_layer(inputs) 48 | else: 49 | z = inputs 50 | z = self.inner_layer(z) 51 | z = tf.cast(z, inputs.dtype) 52 | if is_training: 53 | z = tf.nn.dropout(z, params["dropout_rate"]) 54 | h_out = tf.cast(self.outter_layer(z), inputs.dtype) 55 | return h_out 56 | -------------------------------------------------------------------------------- /neurst/layers/adapters/adapterSerial.py: -------------------------------------------------------------------------------- 1 | """ Base Adapter class. """ 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import tensorflow as tf 5 | 6 | # tf.random.Generator = None # Patch for a bug, https://stackoverflow.com/questions/62696815/tensorflow-core-api-v2-random-has-no-attribute-generator 7 | 8 | from neurst.layers.adapters import register_adapter 9 | from neurst.layers.adapters.adapter import Adapter 10 | 11 | 12 | @register_adapter 13 | class AdapterSerial(Adapter): 14 | """Serial Adapter 15 | reimplement of Simple, scalable adaptation for neural machine translation (EMNLP-IJCNLP 2019 ) 16 | """ 17 | 18 | def __init__(self, 19 | hidden_size_inner, 20 | hidden_size_outter, 21 | dropout_rate=.3, 22 | name="SerialAdapter", ): 23 | """ Initializes the parameters of the Serial Adapter. 24 | """ 25 | super(AdapterSerial, self).__init__( 26 | hidden_size_inner=hidden_size_inner, 27 | hidden_size_outter=hidden_size_outter, 28 | dropout_rate=dropout_rate, 29 | name=name, 30 | ) 31 | self.inner_layer = None 32 | self.outter_layer = None 33 | self._norm_layer = tf.keras.layers.LayerNormalization( 34 | epsilon=1e-6, dtype="float32", name="output_ln") 35 | 36 | def build(self, input_shape): 37 | params = self.get_config() 38 | self.inner_layer = tf.keras.layers.Dense(units=params["hidden_size_inner"], activation=tf.nn.relu) 39 | self.outter_layer = tf.keras.layers.Dense(units=params["hidden_size_outter"], activation=None) 40 | super(Adapter, self).build(input_shape) 41 | 42 | def call(self, inputs, is_training=True): 43 | params = self.get_config() 44 | z = self._norm_layer(inputs) 45 | z = self.inner_layer(z) 46 | z = tf.cast(z, inputs.dtype) 47 | if is_training: 48 | z = tf.nn.dropout(z, params["dropout_rate"]) 49 | h_out = tf.cast(self.outter_layer(z), inputs.dtype) + inputs 50 | return h_out 51 | -------------------------------------------------------------------------------- /neurst/layers/attentions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst/layers/attentions/__init__.py -------------------------------------------------------------------------------- /neurst/layers/auto_pretrained_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | 16 | 17 | class AutoPretrainedLayer(tf.keras.layers.Layer): 18 | 19 | def __init__(self, pretrained_model_name_or_path, name=None): 20 | """ 21 | 22 | Args: 23 | pretrained_model_name_or_path: 24 | (1) a string with the `shortcut name` of a pre-trained model to load from cache 25 | or download, e.g.: ``bert-base-uncased``. 26 | (2) a string with the `identifier name` of a pre-trained model that was user-uploaded 27 | to our S3, e.g.: ``dbmdz/bert-base-german-cased``. 28 | (3) a path to a `directory` containing model weights saved using 29 | :func:`~transformers.TFPreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. 30 | (4) a path or url to a `TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). 31 | In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration 32 | object should be provided as ``config`` argument. 33 | name: 34 | """ 35 | super(AutoPretrainedLayer, self).__init__(name=name) 36 | self._pretrained_model_name_or_path = pretrained_model_name_or_path 37 | 38 | def get_config(self): 39 | return dict( 40 | pretrained_model_name_or_path=self._pretrained_model_name_or_path, 41 | name=self.name) 42 | 43 | def build(self, input_shape): 44 | _ = input_shape 45 | try: 46 | from transformers import TFAutoModel 47 | _ = TFAutoModel 48 | except ImportError: 49 | raise ImportError('Please install transformers with: pip3 install transformers') 50 | self._pretrained_model = TFAutoModel.from_pretrained(self._pretrained_model_name_or_path) 51 | super(AutoPretrainedLayer, self).build(input_shape) 52 | 53 | def call(self, inputs, is_training=False): 54 | return self._pretrained_model(inputs, training=is_training) 55 | -------------------------------------------------------------------------------- /neurst/layers/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.layers.decoders.decoder import Decoder 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_decoder, register_decoder = setup_registry(Decoder.REGISTRY_NAME, base_class=Decoder) 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.layers.decoders.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/layers/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.layers.encoders.encoder import Encoder 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_encoder, register_encoder = setup_registry(Encoder.REGISTRY_NAME, base_class=Encoder) 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.layers.encoders.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/layers/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ Base Encoder class. """ 15 | from __future__ import absolute_import, division, print_function 16 | 17 | from abc import ABCMeta, abstractmethod 18 | 19 | import six 20 | 21 | from neurst.layers.quantization.quant_layers import QuantLayer 22 | from neurst.utils.configurable import extract_constructor_params 23 | 24 | 25 | @six.add_metaclass(ABCMeta) 26 | class Encoder(QuantLayer): 27 | """ Base class for encoders. """ 28 | REGISTRY_NAME = "encoder" 29 | 30 | def __init__(self, name=None, **kwargs): 31 | """ Initializes the parameters of the encoders. """ 32 | self._params = extract_constructor_params(locals(), verbose=False) 33 | super(Encoder, self).__init__(name=name) 34 | 35 | def build(self, input_shape): 36 | super(Encoder, self).build(input_shape) 37 | 38 | def get_config(self): 39 | return self._params 40 | 41 | @abstractmethod 42 | def call(self, inputs, inputs_padding, is_training=True): 43 | """ Encodes the inputs. 44 | 45 | Args: 46 | inputs: The embedded input, a float tensor with shape 47 | [batch_size, max_length, embedding_dim]. 48 | inputs_padding: A float tensor with shape [batch_size, max_length], 49 | indicating the padding positions, where 1.0 for padding and 50 | 0.0 for non-padding. 51 | is_training: A bool, whether in training mode or not. 52 | 53 | Returns: 54 | The encoded output with shape [batch_size, max_length, hidden_size] 55 | """ 56 | raise NotImplementedError 57 | -------------------------------------------------------------------------------- /neurst/layers/metric_layers/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | METRIC_REDUCTION = namedtuple( 4 | "metric_reduction", "SUM MEAN")(0, 1) 5 | 6 | REGISTERED_METRICS = dict() 7 | 8 | 9 | def register_metric(name, redution): 10 | if name in REGISTERED_METRICS: 11 | raise ValueError(f"Metric {name} already registered.") 12 | REGISTERED_METRICS[name] = redution 13 | 14 | 15 | def get_metric_reduction(name, default=METRIC_REDUCTION.MEAN): 16 | return REGISTERED_METRICS.get(name, default) 17 | -------------------------------------------------------------------------------- /neurst/layers/metric_layers/metric_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import abstractmethod 15 | 16 | import tensorflow as tf 17 | 18 | from neurst.layers.metric_layers import METRIC_REDUCTION, register_metric 19 | 20 | 21 | class MetricLayer(tf.keras.layers.Layer): 22 | """ The base class of metric layer for verbose and """ 23 | 24 | def __init__(self): 25 | super(MetricLayer, self).__init__() 26 | self._layer_metrics = {} 27 | 28 | def build_metric_reduction(self, name, reduction): 29 | register_metric(name, reduction) 30 | if reduction == METRIC_REDUCTION.SUM: 31 | self._layer_metrics[name] = tf.keras.metrics.Sum(name) 32 | elif reduction == METRIC_REDUCTION.MEAN: 33 | self._layer_metrics[name] = tf.keras.metrics.Mean(name) 34 | else: 35 | raise NotImplementedError(f"Unknown reduction name: {reduction}.") 36 | 37 | @abstractmethod 38 | def calculate(self, input, output): 39 | """ Calculates metric values according to model input and output. """ 40 | raise NotImplementedError 41 | 42 | def call(self, inputs): 43 | """ Registers metrics by calling `self.add_metric()` 44 | 45 | Args: 46 | inputs: A list of [model inputs, model outputs] 47 | 48 | Returns: 49 | The model outputs. 50 | """ 51 | model_inp, model_out = inputs 52 | ms = self.calculate(model_inp, model_out) 53 | if not isinstance(ms, dict): 54 | assert len(self._layer_metrics) == 1, "The number of metrics mismatch." 55 | for k, v in self._layer_metrics.items(): 56 | ms = {k: ms} 57 | for name, aggr in self._layer_metrics.items(): 58 | m = aggr(ms[name]) 59 | self.add_metric(m) 60 | return model_out 61 | -------------------------------------------------------------------------------- /neurst/layers/modalities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst/layers/modalities/__init__.py -------------------------------------------------------------------------------- /neurst/layers/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from neurst.layers.quantization.quant_dense_layer import QuantDense 2 | from neurst.layers.quantization.quant_layers import QuantLayer 3 | 4 | _ = QuantDense 5 | _ = QuantLayer 6 | -------------------------------------------------------------------------------- /neurst/layers/quantization/quant_dense_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | 16 | from neurst.layers.quantization.quant_layers import QuantLayer 17 | 18 | 19 | class QuantDense(tf.keras.layers.Dense, QuantLayer): 20 | """ `tf.keras.layers.Dense` with quantization. """ 21 | 22 | def __init__(self, activation_quantizer=None, *args, **kwargs): 23 | tf.keras.layers.Dense.__init__(self, *args, **kwargs) 24 | QuantLayer.__init__(self, name=self.name) 25 | self._quant_op = None 26 | if activation_quantizer is not None: 27 | self._quant_op = self.add_activation_quantizer(self.name + "_activ", activation_quantizer) 28 | 29 | def build(self, input_shape): 30 | tf.keras.layers.Dense.build(self, input_shape) 31 | self.add_weight_quantizer(self.kernel) 32 | self.v = self.kernel 33 | self.built = True 34 | 35 | def call(self, inputs): 36 | self.kernel = tf.cast(self.quant_weight(self.v), inputs.dtype) 37 | return tf.keras.layers.Dense.call(self, inputs) 38 | 39 | def __call__(self, *args, **kwargs): 40 | output = tf.keras.layers.Dense.__call__(self, *args, **kwargs) 41 | if self._quant_op is None: 42 | return output 43 | return self._quant_op(output) 44 | -------------------------------------------------------------------------------- /neurst/layers/search/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.layers.search.sequence_search import SequenceSearch 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_search_layer, register_search_layer = setup_registry( 8 | SequenceSearch.REGISTRY_NAME, base_class=SequenceSearch, verbose_creation=True) 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.layers.search.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst/layers/search/sequence_search.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | import tensorflow as tf 18 | 19 | 20 | @six.add_metaclass(ABCMeta) 21 | class SequenceSearch(tf.keras.layers.Layer): 22 | REGISTRY_NAME = "search_method" 23 | 24 | def __init__(self): 25 | """ Initializes. 26 | 27 | Args: 28 | model: The model for generation. 29 | """ 30 | self._model = None 31 | super(SequenceSearch, self).__init__() 32 | 33 | def set_model(self, model): 34 | self._model = model 35 | 36 | @staticmethod 37 | def class_or_method_args(): 38 | return [] 39 | 40 | def build(self, input_shape): 41 | super(SequenceSearch, self).build(input_shape) 42 | 43 | @abstractmethod 44 | def call(self, parsed_inp, **kwargs) -> dict: 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /neurst/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.metrics.metric import Metric 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_metric, register_metric = setup_registry(Metric.REGISTRY_NAME, base_class=Metric) 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.metrics.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/metrics/compound_split_bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import re 15 | 16 | import numpy as np 17 | 18 | from neurst.metrics import register_metric 19 | from neurst.metrics.bleu import BLEU 20 | 21 | 22 | @register_metric 23 | class CompoundSplitBleu(BLEU): 24 | 25 | def __init__(self, *args, **kwargs): 26 | """ Initializes. 27 | 28 | Args: 29 | language: The language. 30 | """ 31 | _ = args 32 | _ = kwargs 33 | super(CompoundSplitBleu, self).__init__(*args, **kwargs) 34 | 35 | @staticmethod 36 | def _tokenize(ss, tok_fn, lc=False): 37 | res = super(CompoundSplitBleu, CompoundSplitBleu)._tokenize(ss, tok_fn, lc) 38 | if isinstance(res[0], str): 39 | return [re.sub(r"(\S)-(\S)", r"\1 ##AT##-##AT## \2", x) for x in res] 40 | return [[re.sub(r"(\S)-(\S)", r"\1 ##AT##-##AT## \2", x) for x in xx] for xx in res] 41 | 42 | def get_value(self, result): 43 | if isinstance(result, (float, np.float32, np.float64)): 44 | return result 45 | if self._flag in result: 46 | return result[self._flag] 47 | if self._flag.lower() in result: 48 | return result[self._flag.lower()] 49 | return result["compound_split_bleu"] 50 | 51 | def call(self, hypothesis, groundtruth=None): 52 | """ Returns the BLEU result dict. """ 53 | return { 54 | "compound_split_bleu": self.tok_bleu(hypothesis, groundtruth), 55 | "uncased_compound_split_bleu": self.tok_bleu(hypothesis, groundtruth, lc=True)} 56 | -------------------------------------------------------------------------------- /neurst/metrics/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | 18 | 19 | @six.add_metaclass(ABCMeta) 20 | class Metric(object): 21 | REGISTRY_NAME = "metric" 22 | 23 | def __init__(self, *args, **kwargs): 24 | self._flag = self.__class__.__name__ 25 | 26 | @property 27 | def flag(self): 28 | return self._flag 29 | 30 | @flag.setter 31 | def flag(self, flag_name): 32 | """ Sets the flag metric name if the result of `__call__` is a dict. """ 33 | self._flag = flag_name 34 | 35 | def set_groundtruth(self, groundtruth): 36 | raise NotImplementedError 37 | 38 | def greater_or_eq(self, result1, result2): 39 | """ Compare the two metric value result and return True if v1>=v2. """ 40 | return self.get_value(result1) >= self.get_value(result2) 41 | 42 | def get_value(self, result): 43 | """ Gets a float value from the metric result (if is a dict). """ 44 | if isinstance(result, dict) and self._flag in result: 45 | return result[self._flag] 46 | return result 47 | 48 | def __call__(self, hypothesis, groundtruth=None) -> dict: 49 | """ Returns a dict of metric values. """ 50 | res = self.call(hypothesis, groundtruth=groundtruth) 51 | if not isinstance(res, dict): 52 | res = {self.flag: res} 53 | return res 54 | 55 | @abstractmethod 56 | def call(self, hypothesis, groundtruth=None): 57 | """ Returns the metric value (float) or a dict of metric values. """ 58 | raise NotImplementedError 59 | 60 | 61 | class MetricWrapper(Metric): 62 | """ A wrapper class for easy-use of metric. """ 63 | 64 | def __init__(self, flag, greater_is_better=True): 65 | super(MetricWrapper, self).__init__() 66 | self._flag = flag 67 | self._greater_is_better = greater_is_better 68 | 69 | def call(self, hypothesis, groundtruth=None): 70 | raise NotImplementedError("No need to call `__call__` in MetricWrapper.") 71 | 72 | def greater_or_eq(self, result1, result2): 73 | if self._greater_is_better: 74 | return super(MetricWrapper, self).greater_or_eq(result1, result2) 75 | return not super(MetricWrapper, self).greater_or_eq(result1, result2) 76 | -------------------------------------------------------------------------------- /neurst/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.models.model import BaseModel 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_model, register_model = setup_registry(BaseModel.REGISTRY_NAME, base_class=BaseModel, create_fn="new", 8 | verbose_creation=True) 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.models.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | import tensorflow as tf 18 | 19 | 20 | @six.add_metaclass(ABCMeta) 21 | class BaseModel(tf.keras.Model): 22 | REGISTRY_NAME = "model" 23 | 24 | def __init__(self, args: dict, name=None): 25 | self._args = args 26 | super(BaseModel, self).__init__(name=name) 27 | 28 | @property 29 | def args(self): 30 | return self._args 31 | 32 | @staticmethod 33 | def class_or_method_args(): 34 | return [] 35 | 36 | @classmethod 37 | def new(cls, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def call(self, inputs, is_training=True): 42 | """ Forward pass of the model. 43 | 44 | Args: 45 | inputs: A dict of model inputs. 46 | is_training: A bool, whether in training mode or not. 47 | 48 | Returns: 49 | The model output. 50 | """ 51 | raise NotImplementedError 52 | -------------------------------------------------------------------------------- /neurst/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | import tensorflow as tf 5 | import yaml 6 | from absl import logging 7 | 8 | from neurst.utils.registry import get_registered_class, setup_registry 9 | 10 | OPTIMIZER_REGISTRY_NAME = "optimizer" 11 | build_optimizer, register_optimizer = setup_registry(OPTIMIZER_REGISTRY_NAME, base_class=tf.keras.optimizers.Optimizer, 12 | verbose_creation=True) 13 | 14 | models_dir = os.path.dirname(__file__) 15 | for file in os.listdir(models_dir): 16 | path = os.path.join(models_dir, file) 17 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 18 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 19 | module = importlib.import_module('neurst.optimizers.' + model_name) 20 | 21 | Adam = tf.keras.optimizers.Adam 22 | Adagrad = tf.keras.optimizers.Adagrad 23 | Adadelta = tf.keras.optimizers.Adadelta 24 | SGD = tf.keras.optimizers.SGD 25 | register_optimizer(Adam) 26 | register_optimizer(Adagrad) 27 | register_optimizer(Adadelta) 28 | register_optimizer(SGD) 29 | 30 | 31 | def controlling_optimizer(optimizer, controller, controller_args): 32 | """ Wrap the optimizer with controller. """ 33 | controller_cls = get_registered_class(controller, OPTIMIZER_REGISTRY_NAME) 34 | if controller_cls is None: 35 | return optimizer 36 | logging.info(f"Wrapper optimizer with controller {controller_cls}") 37 | new_cls = type(optimizer.__class__.__name__, (optimizer.__class__,), 38 | dict(controller_cls.__dict__)) 39 | new_optimizer = new_cls.from_config(optimizer.get_config()) 40 | new_optimizer._HAS_AGGREGATE_GRAD = optimizer._HAS_AGGREGATE_GRAD 41 | if controller_args is None: 42 | controller_args = {} 43 | elif isinstance(controller_args, str): 44 | controller_args = yaml.load(controller_args, Loader=yaml.FullLoader) 45 | assert isinstance(controller_args, dict) 46 | new_optimizer.reset_hparams(controller_args) 47 | return new_optimizer 48 | -------------------------------------------------------------------------------- /neurst/optimizers/schedules/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | import tensorflow as tf 5 | 6 | from neurst.utils.registry import setup_registry 7 | 8 | LR_SCHEDULE_REGISTRY_NAME = "lr_schedule" 9 | build_lr_schedule, register_lr_schedule = setup_registry( 10 | LR_SCHEDULE_REGISTRY_NAME, base_class=tf.keras.optimizers.schedules.LearningRateSchedule, 11 | verbose_creation=True) 12 | 13 | models_dir = os.path.dirname(__file__) 14 | for file in os.listdir(models_dir): 15 | path = os.path.join(models_dir, file) 16 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 17 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 18 | module = importlib.import_module('neurst.optimizers.schedules.' + model_name) 19 | -------------------------------------------------------------------------------- /neurst/sparsity/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | importlib.import_module("neurst.sparsity.pruning_schedule") 4 | -------------------------------------------------------------------------------- /neurst/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.tasks.task import Task 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_task, register_task = setup_registry(Task.REGISTRY_NAME, base_class=Task, verbose_creation=True) 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst.tasks.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst/tasks/waitk_translation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import yaml 15 | 16 | from neurst.tasks import register_task 17 | from neurst.tasks.translation import Translation 18 | from neurst.utils.flags_core import Flag 19 | 20 | 21 | @register_task 22 | class WaitkTranslation(Translation): 23 | """ Defines the translation task. """ 24 | 25 | def __init__(self, args): 26 | super(WaitkTranslation, self).__init__(args) 27 | self._wait_k = args["wait_k"] 28 | if isinstance(self._wait_k, str): 29 | self._wait_k = yaml.load(self._wait_k, Loader=yaml.FullLoader) 30 | assert self._wait_k, "Must provide wait_k as the decode lagging." 31 | assert isinstance(self._wait_k, list) or isinstance(self._wait_k, int), ( 32 | f"Value error: {self._wait_k}") 33 | 34 | def get_config(self): 35 | cfg = super(WaitkTranslation, self).get_config() 36 | cfg["wait_k"] = self._wait_k 37 | return cfg 38 | 39 | @staticmethod 40 | def class_or_method_args(): 41 | this_args = super(WaitkTranslation, WaitkTranslation).class_or_method_args() 42 | this_args.extend([ 43 | Flag("wait_k", dtype=Flag.TYPE.STRING, default=None, 44 | help="The lagging k.") 45 | ]) 46 | return this_args 47 | 48 | def build_model(self, args, name=None, **kwargs): 49 | return super(WaitkTranslation, self).build_model(args, name=name, 50 | waitk_lagging=self._wait_k, **kwargs) 51 | -------------------------------------------------------------------------------- /neurst/training/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.training.callbacks import (CentralizedCallback, CustomCheckpointCallback, LearningRateScheduler, 5 | MetricReductionCallback) 6 | from neurst.training.validator import Validator 7 | from neurst.utils.registry import setup_registry 8 | 9 | build_validator, register_validator = setup_registry(Validator.REGISTRY_NAME, base_class=Validator, 10 | verbose_creation=True) 11 | 12 | __all__ = [ 13 | "CentralizedCallback", 14 | "CustomCheckpointCallback", 15 | "LearningRateScheduler", 16 | "MetricReductionCallback", 17 | 18 | "Validator", 19 | "register_validator", 20 | "build_validator" 21 | ] 22 | 23 | models_dir = os.path.dirname(__file__) 24 | for file in os.listdir(models_dir): 25 | path = os.path.join(models_dir, file) 26 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 27 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 28 | module = importlib.import_module('neurst.training.' + model_name) 29 | -------------------------------------------------------------------------------- /neurst/training/validator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | 18 | from neurst.training.callbacks import CentralizedCallback 19 | from neurst.utils.flags_core import Flag 20 | 21 | 22 | @six.add_metaclass(ABCMeta) 23 | class Validator(CentralizedCallback): 24 | REGISTRY_NAME = "validator" 25 | 26 | def __init__(self, args): 27 | super(Validator, self).__init__() 28 | self._eval_steps = args["eval_steps"] 29 | self._eval_start_at = args["eval_start_at"] 30 | 31 | @staticmethod 32 | def class_or_method_args(): 33 | return [ 34 | Flag("eval_steps", dtype=Flag.TYPE.INTEGER, default=1000, 35 | help="The steps between two validation steps."), 36 | Flag("eval_start_at", dtype=Flag.TYPE.INTEGER, default=0, 37 | help="The step to start validation process."), 38 | ] 39 | 40 | @abstractmethod 41 | def build(self, strategy, task, model): 42 | """ Builds the validator and returns self. """ 43 | return self 44 | 45 | @abstractmethod 46 | def validate(self, step): 47 | """ Validation process. """ 48 | raise NotImplementedError 49 | 50 | def custom_on_train_batch_end(self, step, logs=None): 51 | _ = logs 52 | if step >= self._eval_start_at and step % self._eval_steps == 0: 53 | self.validate(step) 54 | -------------------------------------------------------------------------------- /neurst/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst/utils/__init__.py -------------------------------------------------------------------------------- /neurst/utils/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy 15 | import tensorflow as tf 16 | 17 | 18 | def gelu(x, non_approximate=False): 19 | """Gaussian Error Linear Unit. 20 | This is a smoother version of the RELU. 21 | Original paper: https://arxiv.org/abs/1606.08415 22 | Args: 23 | x: float Tensor to perform activation. 24 | non_approximate: use tanh approximation 25 | Returns: 26 | `x` with the GELU activation applied. 27 | """ 28 | if non_approximate: 29 | # TODO: check fp16 30 | # https://github.com/tensorflow/tensorflow/issues/25052 31 | if x.dtype.base_dtype.name == "float16": 32 | fp32_x = tf.cast(x, tf.float32) 33 | else: 34 | fp32_x = x 35 | cdf = 0.5 * (1.0 + tf.math.erf(fp32_x / numpy.sqrt(2.0))) 36 | 37 | if x.dtype.base_dtype.name == "float16": 38 | return x * tf.saturate_cast(cdf, tf.float16) 39 | 40 | return x * cdf 41 | cdf = 0.5 * (1.0 + tf.tanh( 42 | (numpy.sqrt(2 / numpy.pi) * (x + 0.044715 * tf.pow(x, 3))))) 43 | return x * cdf 44 | 45 | 46 | def glu(x): 47 | """ Gated linear unit. """ 48 | a, b = tf.split(x, axis=-1, num_or_size_splits=2) 49 | return a * tf.nn.sigmoid(b) 50 | 51 | 52 | def get_activation(activ): 53 | if callable(activ): 54 | return activ 55 | if activ is None: 56 | return None 57 | if activ == "tanh": 58 | return tf.nn.tanh 59 | elif activ == "relu": 60 | return tf.nn.relu 61 | elif activ == "gelu" or activ == "gelu_approx": 62 | return lambda x: gelu(x, non_approximate=False) 63 | elif activ == "gelu_nonapprox": 64 | return lambda x: gelu(x, non_approximate=True) 65 | else: 66 | raise ValueError("Unknown activation: {}".format(activ)) 67 | -------------------------------------------------------------------------------- /neurst/utils/converters/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.utils.converters.converter import Converter 5 | from neurst.utils.registry import setup_registry 6 | 7 | build_converter, register_converter = setup_registry(Converter.REGISTRY_NAME, base_class=Converter, 8 | verbose_creation=False, create_fn="new") 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst.utils.converters.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst/utils/converters/converter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | import tensorflow as tf 18 | 19 | from neurst.utils.configurable import ModelConfigs 20 | 21 | 22 | @six.add_metaclass(ABCMeta) 23 | class Converter(object): 24 | """ Abstract class for converting models and tasks. """ 25 | REGISTRY_NAME = "converter" 26 | 27 | @classmethod 28 | def new(cls, *args, **kwargs): 29 | _ = args 30 | _ = kwargs 31 | return cls 32 | 33 | @staticmethod 34 | @abstractmethod 35 | def convert_model_config(path): 36 | raise NotImplementedError 37 | 38 | @staticmethod 39 | @abstractmethod 40 | def convert_task_config(path): 41 | raise NotImplementedError 42 | 43 | @staticmethod 44 | def download(key): 45 | _ = key 46 | return None 47 | 48 | @staticmethod 49 | @abstractmethod 50 | def convert_checkpoint(path, save_path): 51 | raise NotImplementedError 52 | 53 | @classmethod 54 | def convert(cls, from_path, to_path): 55 | if (from_path.startswith("http://") or from_path.startswith("https://") 56 | or (not tf.io.gfile.exists(from_path))): 57 | path = cls.download(from_path) 58 | if path is None: 59 | raise ValueError(f"Fail to find model to download: {from_path}") 60 | from_path = path 61 | try: 62 | cfgs = cls.convert_model_config(from_path) 63 | except NotImplementedError: 64 | cfgs = {} 65 | try: 66 | cfgs.update(cls.convert_task_config(from_path)) 67 | except NotImplementedError: 68 | pass 69 | ModelConfigs.dump(cfgs, to_path) 70 | cls.convert_checkpoint(from_path, to_path) 71 | -------------------------------------------------------------------------------- /neurst/utils/hparams_sets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from absl import logging 15 | 16 | from neurst.utils.registry import REGISTRIES 17 | 18 | 19 | def register_hparams_set(name, backend="tf"): 20 | registry_name = "hparams_set" 21 | if registry_name not in REGISTRIES[backend]: 22 | REGISTRIES[backend][registry_name] = {} 23 | 24 | def register_x_fn(fn_, short_name=None): 25 | names = set() 26 | if short_name: 27 | for n in short_name: 28 | names.add(n.lower()) 29 | names.add(fn_.__name__) 30 | for n in names: 31 | if n in REGISTRIES[backend][registry_name]: 32 | if REGISTRIES[backend][registry_name][n] != fn_: 33 | raise ValueError('Cannot register duplicate {} (under {})'.format(n, registry_name)) 34 | else: 35 | REGISTRIES[backend][registry_name][n] = fn_ 36 | 37 | if isinstance(name, str): 38 | return lambda fn: register_x_fn(fn, [name]) 39 | elif isinstance(name, list): 40 | return lambda c: register_x_fn(c, name) 41 | else: 42 | raise ValueError("Not supported type: {}".format(type(name))) 43 | 44 | 45 | def get_hyper_parameters(name, backend="tf"): 46 | registry_name = "hparams_set" 47 | if name is None: 48 | return {} 49 | if registry_name in REGISTRIES[backend] and name in REGISTRIES[backend][registry_name]: 50 | logging.info("matched the pre-defined hyper-parameters set: {}".format(name)) 51 | return REGISTRIES[backend][registry_name][name]() 52 | for m, mc in REGISTRIES[backend]["model"].items(): 53 | if hasattr(mc, "build_model_args_by_name"): 54 | p = mc.build_model_args_by_name(name) 55 | if p is not None: 56 | return p 57 | return {} 58 | -------------------------------------------------------------------------------- /neurst/utils/simuleval_agents/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.utils.registry import setup_registry 5 | 6 | _, register_agent = setup_registry("simuleval_agent", verbose_creation=False) 7 | 8 | models_dir = os.path.dirname(__file__) 9 | for file in os.listdir(models_dir): 10 | path = os.path.join(models_dir, file) 11 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 12 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 13 | module = importlib.import_module('neurst.utils.simuleval_agents.' + model_name) 14 | -------------------------------------------------------------------------------- /neurst/utils/userdef/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst/utils/userdef/__init__.py -------------------------------------------------------------------------------- /neurst_pt/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | 4 | import importlib 5 | 6 | __author__ = "ZhaoChengqi " 7 | 8 | __all__ = [ 9 | "layers", 10 | "models", 11 | "utils", 12 | ] 13 | 14 | importlib.import_module("neurst_pt.layers") 15 | importlib.import_module("neurst_pt.layers.attentions") 16 | importlib.import_module("neurst_pt.layers.decoders") 17 | importlib.import_module("neurst_pt.layers.encoders") 18 | importlib.import_module("neurst_pt.models") 19 | importlib.import_module("neurst_pt.utils") 20 | -------------------------------------------------------------------------------- /neurst_pt/layers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from neurst.utils.registry import setup_registry 4 | from neurst_pt.layers.attentions.multi_head_attention import MultiHeadAttention, MultiHeadSelfAttention 5 | from neurst_pt.layers.common_layers import PrePostProcessingWrapper, TransformerFFN 6 | 7 | build_base_layer, register_base_layer = setup_registry("base_layer", base_class=nn.Module, 8 | verbose_creation=False, backend="pt") 9 | 10 | register_base_layer(MultiHeadSelfAttention) 11 | register_base_layer(MultiHeadAttention) 12 | register_base_layer(TransformerFFN) 13 | 14 | 15 | def build_transformer_component(layer_args, 16 | norm_shape, 17 | dropout_rate, 18 | pre_norm=True, 19 | epsilon=1e-6): 20 | base_layer = build_base_layer(layer_args) 21 | return PrePostProcessingWrapper( 22 | layer=base_layer, 23 | norm_shape=norm_shape, 24 | dropout_rate=dropout_rate, 25 | epsilon=epsilon, 26 | pre_norm=pre_norm) 27 | -------------------------------------------------------------------------------- /neurst_pt/layers/attentions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst_pt/layers/attentions/__init__.py -------------------------------------------------------------------------------- /neurst_pt/layers/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.utils.registry import setup_registry 5 | from neurst_pt.layers.decoders.decoder import Decoder 6 | 7 | build_decoder, register_decoder = setup_registry(Decoder.REGISTRY_NAME, base_class=Decoder, backend="pt") 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst_pt.layers.decoders.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst_pt/layers/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.utils.registry import setup_registry 5 | from neurst_pt.layers.encoders.encoder import Encoder 6 | 7 | build_encoder, register_encoder = setup_registry(Encoder.REGISTRY_NAME, base_class=Encoder, backend="pt") 8 | 9 | models_dir = os.path.dirname(__file__) 10 | for file in os.listdir(models_dir): 11 | path = os.path.join(models_dir, file) 12 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 13 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 14 | module = importlib.import_module('neurst_pt.layers.encoders.' + model_name) 15 | -------------------------------------------------------------------------------- /neurst_pt/layers/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ Base Encoder class. """ 15 | from __future__ import absolute_import, division, print_function 16 | 17 | from abc import ABCMeta, abstractmethod 18 | 19 | import six 20 | import torch.nn as nn 21 | 22 | from neurst.utils.configurable import extract_constructor_params 23 | 24 | 25 | @six.add_metaclass(ABCMeta) 26 | class Encoder(nn.Module): 27 | """ Base class for encoders. """ 28 | REGISTRY_NAME = "encoder" 29 | 30 | def __init__(self, **kwargs): 31 | """ Initializes the parameters of the encoders. """ 32 | self._params = extract_constructor_params(locals(), verbose=False) 33 | super(Encoder, self).__init__() 34 | 35 | @abstractmethod 36 | def forward(self, inputs, inputs_padding, is_training=True): 37 | """ Encodes the inputs. 38 | 39 | Args: 40 | inputs: The embedded input, a float tensor with shape 41 | [batch_size, max_length, embedding_dim]. 42 | inputs_padding: A float tensor with shape [batch_size, max_length], 43 | indicating the padding positions, where 1.0 for padding and 44 | 0.0 for non-padding. 45 | is_training: A bool, whether in training mode or not. 46 | 47 | Returns: 48 | The encoded output with shape [batch_size, max_length, hidden_size] 49 | """ 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /neurst_pt/layers/modalities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst_pt/layers/modalities/__init__.py -------------------------------------------------------------------------------- /neurst_pt/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | 4 | from neurst.utils.registry import setup_registry 5 | from neurst_pt.models.model import BaseModel 6 | 7 | build_model, register_model = setup_registry(BaseModel.REGISTRY_NAME, base_class=BaseModel, create_fn="new", 8 | verbose_creation=True, backend="pt") 9 | 10 | models_dir = os.path.dirname(__file__) 11 | for file in os.listdir(models_dir): 12 | path = os.path.join(models_dir, file) 13 | if not file.startswith('_') and not file.startswith('.') and file.endswith('.py'): 14 | model_name = file[:file.find('.py')] if file.endswith('.py') else file 15 | module = importlib.import_module('neurst_pt.models.' + model_name) 16 | -------------------------------------------------------------------------------- /neurst_pt/models/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from abc import ABCMeta, abstractmethod 15 | 16 | import six 17 | import torch.nn as nn 18 | 19 | 20 | @six.add_metaclass(ABCMeta) 21 | class BaseModel(nn.Module): 22 | REGISTRY_NAME = "model" 23 | 24 | def __init__(self, args): 25 | self._args = args 26 | super(BaseModel, self).__init__() 27 | 28 | @property 29 | def args(self): 30 | return self._args 31 | 32 | @staticmethod 33 | def class_or_method_args(): 34 | return [] 35 | 36 | @classmethod 37 | def new(cls, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | @abstractmethod 41 | def forward(self, inputs, is_training=True): 42 | """ Forward pass of the model. 43 | 44 | Args: 45 | inputs: A dict of model inputs. 46 | is_training: A bool, whether in training mode or not. 47 | 48 | Returns: 49 | The model output. 50 | """ 51 | raise NotImplementedError 52 | -------------------------------------------------------------------------------- /neurst_pt/models/model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | 16 | 17 | def input_length_to_nonpadding(lengths, max_len, dtype=None): 18 | """ Creates a bias tensor according to the non-padding tensor for cross entropy. 19 | 20 | Args: 21 | length: A Tensor with shape [batch_size, ], indicating the true length. 22 | max_len: A scalar tensor indicating the maximum length. 23 | 24 | Returns: 25 | A float tensor with shape [batch_size, max_len], 26 | indicating the padding positions, where 0.0 for padding and 27 | 1.0 for non-padding. 28 | """ 29 | row_vector = torch.arange(0, max_len) 30 | matrix = torch.unsqueeze(lengths, dim=-1) 31 | mask = (row_vector < matrix).to(dtype or torch.float) 32 | return mask # 1.0 for non-padding 33 | 34 | 35 | def input_length_to_padding(lengths, max_len, dtype=None): 36 | """ Creates a bias tensor according to the padding tensor for attention. 37 | 38 | Args: 39 | length: A Tensor with shape [batch_size, ], indicating the true length. 40 | max_len: A scalar tensor indicating the maximum length. 41 | 42 | Returns: 43 | A float tensor with shape [batch_size, max_len], 44 | indicating the padding positions, where 1.0 for padding and 45 | 0.0 for non-padding. 46 | """ 47 | return 1. - input_length_to_nonpadding(lengths, max_len, dtype) 48 | -------------------------------------------------------------------------------- /neurst_pt/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/neurst_pt/utils/__init__.py -------------------------------------------------------------------------------- /neurst_pt/utils/activations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch.nn.functional as F 15 | 16 | 17 | def get_activation(activ): 18 | if callable(activ): 19 | return activ 20 | if activ is None: 21 | return lambda x: x 22 | if activ == "tanh": 23 | return F.tanh 24 | elif activ == "relu": 25 | return F.relu 26 | elif activ == "gelu": 27 | return F.gelu 28 | elif activ == "glu": 29 | return lambda x: F.glu(x, -1) 30 | else: 31 | raise ValueError("Unknown activation: {}".format(activ)) 32 | -------------------------------------------------------------------------------- /requirements.apt.txt: -------------------------------------------------------------------------------- 1 | libsndfile1 2 | ffmpeg 3 | libavcodec-extra 4 | sox 5 | 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | tensorflow>=2.3 3 | pyyaml 4 | sacrebleu 5 | sacremoses 6 | jieba 7 | regex 8 | sentencepiece 9 | thai-segmenter 10 | soundfile 11 | python_speech_features 12 | subword-nmt 13 | mecab-python3 14 | ipadic 15 | tensorflow_addons 16 | torch 17 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | THIS_DIR="$( cd "$( dirname "$0" )" && pwd )" 5 | 6 | 7 | if [[ -z ${NEURST_LIB} ]] 8 | then 9 | NEURST_LIB=$THIS_DIR 10 | echo "using default --lib=${NEURST_LIB}" >&2 11 | fi 12 | 13 | if [[ $@ =~ "--enable_xla" ]] 14 | then 15 | export TF_XLA_FLAGS=--tf_xla_cpu_global_jit 16 | fi 17 | 18 | neurst-run "$@" 19 | 20 | -------------------------------------------------------------------------------- /run_cli.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | 4 | THIS_DIR="$( cd "$( dirname "$0" )" && pwd )" 5 | 6 | 7 | if [[ -z ${NEURST_LIB} ]] 8 | then 9 | NEURST_LIB=$THIS_DIR 10 | echo "using default --lib=${NEURST_LIB}" >&2 11 | fi 12 | 13 | pip3 install -e ${NEURST_LIB} --no-deps 14 | 15 | if [[ $@ =~ "--enable_xla" ]] 16 | then 17 | echo "enable XLA" 18 | export TF_XLA_FLAGS=--tf_xla_cpu_global_jit 19 | fi 20 | 21 | python3 -m $@ 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import io 3 | import os 4 | 5 | from setuptools import find_packages, setup 6 | 7 | NAME = "neurst" 8 | DESCRIPTION = "Neural Speech Translation Toolkit" 9 | URL = "https://github.com/bytedance/neurst" 10 | EMAIL = "zhaochengqi.d@bytedance.com" 11 | AUTHOR = "ZhaoChengqi" 12 | 13 | # TODO: one must manually install following packages if needed 14 | ALTERNATIVE_REQUIRES = [ 15 | "jieba>=0.42.1", # unnecessary for all 16 | "subword-nmt>=0.3.7", # unnecessary for all 17 | "thai-segmenter>=0.4.1", # unnecessary for all 18 | "soundfile>=0.10", # for speech processing 19 | "python_speech_features>=0.6", # for speech processing 20 | "transformers>=3.4.0", # Not necessary for all 21 | "sentencepiece>=0.1.7", 22 | "mecab-python3>=1.0.3", # for sacrebleu[ja] 23 | "ipadic>=1.0.0", # for sacrebleu[ja] 24 | "torch>=1.7.0", # for converting models from fairseq 25 | "fairseq>=0.10.1", # for converting models from fairseq 26 | "tensorflow_addons>=0.11.2", # for group normalization 27 | "pydub>=0.24.1", # for audio processing 28 | "sox>=1.4.1", # for audio processing 29 | ] 30 | 31 | REQUIRES = ["six>=1.11.0,<2.0.0", 32 | "pyyaml>=3.13", 33 | "sacrebleu>=1.4.0", 34 | "regex>=2019.1.24", 35 | "sacremoses>=0.0.38", 36 | # "tensorflow>=2.4.0", # ONE must manually install tensorflow 37 | "tqdm>=0.46", 38 | ] 39 | 40 | DEV_REQUIRES = ["flake8>=3.5.0,<4.0.0", 41 | "mypy>=0.620; python_version>='3.6'", 42 | "tox>=3.0.0,<4.0.0", 43 | "isort>=4.0.0,<5.0.0", 44 | "pytest>=4.0.0,<5.0.0"] + REQUIRES + ALTERNATIVE_REQUIRES 45 | 46 | here = os.path.abspath(os.path.dirname(__file__)) 47 | 48 | try: 49 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 50 | long_description = "\n" + f.read() 51 | except IOError: 52 | long_description = DESCRIPTION 53 | 54 | about = {} 55 | with io.open(os.path.join(here, NAME, "__version__.py")) as f: 56 | exec(f.read(), about) 57 | 58 | setup( 59 | name=NAME, 60 | version=about["__version__"], 61 | description=DESCRIPTION, 62 | long_description=long_description, 63 | long_description_content_type="text/markdown", 64 | author=AUTHOR, 65 | author_email=EMAIL, 66 | url=URL, 67 | classifiers=[ 68 | # Trove classifiers 69 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 70 | "Intended Audience :: Developers", 71 | "Programming Language :: Python :: 3.6", 72 | "Programming Language :: Python :: 3.7", 73 | "Programming Language :: Python :: 3.8" 74 | ], 75 | keywords="neurst", 76 | packages=find_packages(exclude=["docs", "tests"]), 77 | install_requires=REQUIRES, 78 | tests_require=[ 79 | "pytest>=4.0.0,<5.0.0" 80 | ], 81 | python_requires=">=3.6", 82 | extras_require={ 83 | "dev": DEV_REQUIRES, 84 | }, 85 | package_data={ 86 | # for PEP484 & PEP561 87 | NAME: ["py.typed", "*.pyi"], 88 | }, 89 | entry_points={ 90 | "console_scripts": [ 91 | "neurst-run = neurst.cli.run_exp:cli_main", 92 | "neurst-view = neurst.cli.view_registry:cli_main", 93 | "neurst-vocab = neurst.cli.generate_vocab:cli_main" 94 | ], 95 | }, 96 | ) 97 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/__init__.py -------------------------------------------------------------------------------- /tests/examples/dev.example.zh.txt: -------------------------------------------------------------------------------- 1 | 你好,妈。—你回来干嘛? 2 | 这么多年了,现在回来干嘛? —这么说,你是玛克辛的女儿… … 3 | 将近, 10年来,我只能和那三个人接触。 4 | 所以,现在真的很高兴— — —能和新人接触,我明白。 5 | 有时我希望我的生活能简单一点。 6 | 我想你回家,阿布。希望咱们能够和好如初。 7 | 这件事已经结束了。 8 | 也有可能是;另一个阴谋的开始。 9 | 老板,你爸因为这家伙死在监狱里。 10 | 我说过,等我准备好,咱们就去会会米奇·麦克迪尔。 11 | 时候到了。—对不起各位,打烊了。 12 | 米奇·麦克迪尔,我自15岁开始就幻想着这一刻。 13 | 知道我来的目的么? —动手吧,你这杂种。 14 | 了结了吧。—我不杀你。 15 | 我要雇你。 16 | 我15岁的时候我就发誓要找到你。 17 | 我是一个说到做到的人。 18 | 我知道阿比和克莱尔在肯塔基州。 19 | 知道塔米在田纳西州办离婚。 20 | 你想怎样? 21 | 我说过了。要雇你。 22 | 我的一个朋友在特区被指控谋杀。 23 | 他叫帕特里克·沃克。我希望你能当他的代理律师。 24 | 你开什么玩笑。 25 | 你从哪儿听出来我是在开玩笑。 26 | 他是我的好朋友。 27 | 他被陷害他没有犯下的罪,我要他无罪释放。 28 | 我不接手涉黑案。特别是集团犯罪。 29 | 你需要律师— — —我很清楚我需要什么! 30 | 他不是我们家族的人。他是一个经济学家。 31 | 读的是麻省理工学院。 32 | 干的是智库的活; 33 | 整天想的是提供就业机会和贸易赤字。 34 | 他杀了谁? —谁也没杀。 35 | 听说,是个女的;不是说她不是人,我只是想说,他是被陷害的。 36 | 乔伊,我就是一普通人。我甚至不认识你父亲。 37 | 我只不过在孟斐斯找到一份工作事情就变了样子。 38 | 我都没作证… … —你帮了条子! 39 | 你扳倒了那家律所,也就扳倒了我父亲。 40 | 你不能把责任推到我身上… … 41 | 抱歉,咱们好像不是来开辩论会的。 42 | 咱俩谁说的算,你还是我? 43 | 我的条件。你来接手这个案子。 44 | 这是你要为我做的。 45 | 只有这样,麦克迪尔和莫罗托家族才算扯平。 46 | 这是我能为你做的。 47 | 为什么找我? —这是你的强项啊: 48 | 揭开真相。 49 | 在孟斐斯… …在诺贝尔保险公司… …你拥有其他律师不具备的慧眼。 50 | 还有,额… … 51 | 大多数的律师都想赢。 52 | 你也得赢。 53 | 好老板知道怎么激励下属。 54 | 如果我拒绝呢? —你有两个选择,米奇。 55 | 要么为我做这件事,要么等我们走了,给路易斯·科尔曼打电话开始你的逃亡生活。 56 | 也许我找不着你。但我保证… … 57 | 我肯定不让你消停。 58 | 明早我再来。 59 | 如果你在,那就是接受了。 60 | 要是你不在,那就是拒绝了。 61 | 我真不敢相信他就那么径直走进了办公室。—他们盯着咱好几个月了。 62 | 那为何不早下手?为何要等到现在? —案子。 63 | 他留着我们的命就是要保他朋友出来。 64 | 怎么才能阻止他日后的追杀? —没办法。他说过的。 65 | 咱们可以跑啊。不用警察帮忙,自己想办法。 66 | 咱们总有后备计划的。—前提是大伙在一起。 67 | 咱们现在就得做决定。 68 | 给不给路易斯打电话? —你还想怎样,替他们办事? 69 | 你不是真的在考虑吧? —对啊,我当然在考虑,雷。 70 | 否则就要给我老婆打电话,她刚经历过地狱般生活,却被告知又要再一次逃亡了。 71 | 所以相反,你要告诉她你现在要为乔伊·莫罗托卖命? 72 | 看,如果能让莫罗托从我们生活中消失那么她也许会回来然后我们一家就团聚了。 73 | 但是即使你帮了他,我们也不可能全身而退。 74 | 咱们就先和当事人;见一面。 75 | 莫罗托说他不是黑手党,是一个经济学家。—这主意不好。 76 | 事情一旦有变咱撒腿就跑呗。 77 | 好,你打吧。 78 | 给路易斯打电话吧。 79 | 就见一面。 80 | 很抱歉,我还是得收取你的费用。 81 | 真的么? 82 | 我知道你希望在阳台那尴尬的小插曲能有个折扣啥的,但是… 83 | 额,这样才公平嘛。 84 | 你现在在帮我算打折钱么,我的治疗师? 85 | 你懂的,就个人意见,我觉得你应该换个治疗师。 86 | 我只是说说, 87 | 我— —我不知道如果道德委员会知道你吻了病人他们会做何感想呢。 88 | 我是说,那是你不是吗?那是你吧?你懂的,因为当时挺黑的。 89 | 没错,没错,是我。—所以呢… … ? 90 | 所以… …所以我觉得的你应该马上预约下一次见面。 91 | 明天2点怎么样? 92 | 可以。 93 | 嘿,额我有电话进来了。 94 | 我得挂了。—嗯,挂吧。 95 | 到时见。 96 | 嘿! —嘿。 97 | 真高兴等到你了。—是啊。 98 | 嘿这样,我爸今天带着克莱尔去银行上班。 99 | 我妈邀我吃午饭。 100 | 那么,还顺利么? 101 | 这… …这不太顺。—但是看看今天的情况吧。—我想你了。 102 | 我知道我应该给你些空间,但我就是想… … 103 | 打给你告诉你我想你了。—我也想你。 104 | 看,有客户来了。我再打给你好么? 105 | 好的。—塔米,我挂了。 106 | 再打给你。—可以吃一个么? 107 | 你知道,我不喜欢这些东西可我听说里根喜欢吃。 108 | 他从来都是这么随心所遇,是吧? 109 | 你没跑-就说明接受了? —不全然。 110 | 我们要见当事人。—那你可以省段周折了。 111 | 他是无辜的。 112 | 你怎么知道? —他告诉我的。 113 | 大多数的人都不会对我撒谎。—知道我是怎么想的? 114 | 我想你要是想杀我们早就动手了,但你没有。 115 | 你在等,知道为什么吗?因为这件案子 116 | 这个人,他对你很重要。 117 | 你不是想让我们接这案子,是需要我们接手。 118 | 我要见面。 119 | 安东尼奥? —是,老板? 120 | 去开车。—遵命。 121 | 就是他? 122 | 米奇·麦克迪尔,这是帕特里克·沃克。 123 | 有什么不对么? —我在想。 124 | 在想哪儿不对么? 125 | 穿三个扣子西装,说明你是在跟随潮流。 126 | 我需要一个跟风的律师么? 127 | 帕特里克… … —深灰色让人感觉可靠又好猜。 128 | 难道我需要这样一个律师么? —嘿!你收敛点。 129 | -------------------------------------------------------------------------------- /tests/examples/example_create_seq2seq_tfrecrods.yml: -------------------------------------------------------------------------------- 1 | dataset.class: ParallelTextDataset 2 | dataset.params: 3 | src_file: ./tests/examples/train.example.zh.jieba.bpe.txt 4 | trg_file: ./tests/examples/train.example.en.tok.bpe.txt 5 | data_is_processed: True 6 | 7 | task.class: Seq2Seq 8 | task.params: 9 | src_data_pipeline.class: TextDataPipeline 10 | src_data_pipeline.params: 11 | tokenizer: jieba 12 | subtokenizer: bpe 13 | subtokenizer_codes: ./tests/examples/codes.bpe4k.zh 14 | vocab_path: ./tests/examples/vocab.zh 15 | trg_data_pipeline.class: TextDataPipeline 16 | trg_data_pipeline.params: 17 | tokenizer: moses 18 | subtokenizer: bpe 19 | subtokenizer_codes: ./tests/examples/codes.bpe4k.en 20 | vocab_path: ./tests/examples/vocab.en 21 | 22 | processor_id: 0 23 | num_processors: 1 24 | num_output_shards: 4 25 | output_range_begin: 0 26 | output_range_end: 4 27 | output_template: ./tests/examples/train.tfrecords-%5.5d-of-%5.5d 28 | -------------------------------------------------------------------------------- /tests/examples/example_eval_seq2seq.yml: -------------------------------------------------------------------------------- 1 | model_dir: ./test_models 2 | 3 | entry.class: Evaluator 4 | entry.params: 5 | criterion.class: LabelSmoothedCrossEntropy 6 | 7 | dataset.class: parallel_text 8 | dataset.params: 9 | src_file: ./tests/examples/dev.example.zh.txt 10 | trg_file: ./tests/examples/dev.example.en.txt 11 | 12 | task.params: 13 | batch_size: 32 14 | -------------------------------------------------------------------------------- /tests/examples/example_predict_seq2seq.yml: -------------------------------------------------------------------------------- 1 | model_dir: ./test_models 2 | 3 | entry.class: SequenceGenerator 4 | entry.params: 5 | output_file: ./hypothesis.txt 6 | save_metric: ./metric.json 7 | metric.class: BLEU 8 | search_method.class: beam_search 9 | search_method.params: 10 | beam_size: 4 11 | length_penalty: -1 12 | extra_decode_length: 20 13 | maximum_decode_length: 50 14 | 15 | dataset.class: parallel_text 16 | dataset.params: 17 | src_file: ./tests/examples/dev.example.zh.txt 18 | trg_file: ./tests/examples/dev.example.en.txt 19 | 20 | task.params: 21 | batch_size: 32 22 | -------------------------------------------------------------------------------- /tests/examples/example_train_gpt2.yml: -------------------------------------------------------------------------------- 1 | model_dir: ./test_models 2 | 3 | entry.class: trainer 4 | entry.params: 5 | train_steps: 100 6 | save_checkpoint_steps: 50 7 | summary_steps: 10 8 | criterion.class: label_smoothed_cross_entropy 9 | criterion.params: 10 | label_smoothing: 0.1 11 | optimizer.class: adam 12 | optimizer.params: 13 | epsilon: 1.e-9 14 | beta_1: 0.9 15 | beta_2: 0.98 16 | lr_schedule.class: noam 17 | lr_schedule.params: 18 | initial_factor: 1.0 19 | dmodel: 8 20 | warmup_steps: 4000 21 | 22 | dataset.class: mono_text 23 | dataset.params: 24 | data_file: ./tests/examples/train.example.en.tok.bpe.txt 25 | data_is_processed: True 26 | 27 | 28 | task.class: lm 29 | task.params: 30 | data_pipeline.class: TextDataPipeline 31 | data_pipeline.params: 32 | language: en 33 | tokenizer: moses 34 | subtokenizer: bpe 35 | subtokenizer_codes: ./tests/examples/codes.bpe4k.en 36 | vocab_path: ./tests/examples/vocab.en 37 | batch_size: 500 38 | batch_by_tokens: true 39 | max_len: 50 40 | 41 | hparams_set: gpt2_toy 42 | -------------------------------------------------------------------------------- /tests/examples/example_train_seq2seq.yml: -------------------------------------------------------------------------------- 1 | model_dir: ./test_models 2 | 3 | entry.class: trainer 4 | entry.params: 5 | train_steps: 100 6 | save_checkpoint_steps: 50 7 | summary_steps: 10 8 | criterion.class: label_smoothed_cross_entropy 9 | criterion.params: 10 | label_smoothing: 0.1 11 | optimizer.class: adam 12 | optimizer.params: 13 | epsilon: 1.e-9 14 | beta_1: 0.9 15 | beta_2: 0.98 16 | lr_schedule.class: noam 17 | lr_schedule.params: 18 | initial_factor: 1.0 19 | dmodel: 8 20 | warmup_steps: 4000 21 | 22 | dataset.class: ParallelTextDataset 23 | dataset.params: 24 | src_file: ./tests/examples/train.example.zh.jieba.bpe.txt 25 | trg_file: ./tests/examples/train.example.en.tok.bpe.txt 26 | data_is_processed: True 27 | 28 | 29 | task.class: Seq2Seq 30 | task.params: 31 | src_data_pipeline.class: TextDataPipeline 32 | src_data_pipeline.params: 33 | language: zh 34 | tokenizer: jieba 35 | subtokenizer: bpe 36 | subtokenizer_codes: ./tests/examples/codes.bpe4k.zh 37 | vocab_path: ./tests/examples/vocab.zh 38 | trg_data_pipeline.class: TextDataPipeline 39 | trg_data_pipeline.params: 40 | language: en 41 | tokenizer: moses 42 | subtokenizer: bpe 43 | subtokenizer_codes: ./tests/examples/codes.bpe4k.en 44 | vocab_path: ./tests/examples/vocab.en 45 | batch_size: 1000 46 | batch_by_tokens: true 47 | max_src_len: 50 48 | max_trg_len: 50 49 | 50 | model.class: Transformer 51 | model.params: 52 | modality.share_source_target_embedding: false 53 | modality.share_embedding_and_softmax_weights: true 54 | modality.dim: 8 55 | modality.timing: sinusoids 56 | encoder.num_layers: 2 57 | encoder.hidden_size: 8 58 | encoder.num_attention_heads: 2 59 | encoder.filter_size: 32 60 | encoder.attention_dropout_rate: 0.1 61 | encoder.attention_type: dot_product 62 | encoder.ffn_activation: relu 63 | encoder.ffn_dropout_rate: 0.1 64 | encoder.layer_postprocess_dropout_rate: 0.1 65 | decoder.num_layers: 2 66 | decoder.hidden_size: 8 67 | decoder.num_attention_heads: 2 68 | decoder.filter_size: 32 69 | decoder.attention_dropout_rate: 0.1 70 | decoder.attention_type: dot_product 71 | decoder.ffn_activation: relu 72 | decoder.ffn_dropout_rate: 0.1 73 | decoder.layer_postprocess_dropout_rate: 0.1 74 | 75 | 76 | -------------------------------------------------------------------------------- /tests/examples/example_validator_gpt2.yml: -------------------------------------------------------------------------------- 1 | validator.class: CriterionValidator 2 | validator.params: 3 | eval_dataset.class: mono_text 4 | eval_dataset.params: 5 | data_file: ./tests/examples/dev.example.en.txt 6 | eval_start_at: 0 7 | eval_steps: 50 8 | eval_criterion.class: label_smoothed_cross_entropy 9 | eval_batch_size: 32 10 | -------------------------------------------------------------------------------- /tests/examples/example_validator_seq2seq.yml: -------------------------------------------------------------------------------- 1 | validator.class: SeqGenerationValidator 2 | validator.params: 3 | eval_dataset.class: ParallelTextDataset 4 | eval_dataset.params: 5 | src_file: ./tests/examples/dev.example.zh.txt 6 | trg_file: ./tests/examples/dev.example.en.txt 7 | eval_start_at: 0 8 | eval_steps: 50 9 | eval_criterion.class: label_smoothed_cross_entropy 10 | eval_search_method.class: beam_search 11 | eval_search_method.params: 12 | beam_size: 4 13 | length_penalty: 0.6 14 | extra_decode_length: 20 15 | maximum_decode_length: 50 16 | eval_metric.class: bleu 17 | eval_top_checkpoints_to_keep: 5 18 | -------------------------------------------------------------------------------- /tests/examples/train.tfrecords-00000-of-00004: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/examples/train.tfrecords-00000-of-00004 -------------------------------------------------------------------------------- /tests/examples/train.tfrecords-00001-of-00004: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/examples/train.tfrecords-00001-of-00004 -------------------------------------------------------------------------------- /tests/examples/train.tfrecords-00002-of-00004: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/examples/train.tfrecords-00002-of-00004 -------------------------------------------------------------------------------- /tests/examples/train.tfrecords-00003-of-00004: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/examples/train.tfrecords-00003-of-00004 -------------------------------------------------------------------------------- /tests/neurst/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/__init__.py -------------------------------------------------------------------------------- /tests/neurst/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/data/__init__.py -------------------------------------------------------------------------------- /tests/neurst/data/text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/data/text/__init__.py -------------------------------------------------------------------------------- /tests/neurst/data/text/bpe_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import os 4 | import tempfile 5 | 6 | import tensorflow as tf 7 | 8 | from neurst.data.text.bpe import BPE 9 | 10 | 11 | def test(): 12 | codes = ["技 术", "发 展"] 13 | tmp_file = tempfile.NamedTemporaryFile(delete=False) 14 | with tf.io.gfile.GFile(tmp_file.name, "w") as fw: 15 | fw.write("version\n") 16 | fw.write("\n".join(codes) + "\n") 17 | bpe = BPE(lang="zh", 18 | glossaries=["迅速", "<-neplhd-hehe>"]) 19 | bpe.init_subtokenizer(tmp_file.name) 20 | 21 | tokens = bpe.tokenize("技术 发展 迅猛", return_str=True) 22 | assert tokens == "技术 发@@ 展 迅@@ 猛" 23 | assert bpe.detokenize(tokens) == "技术 发展 迅猛" 24 | tokens = bpe.tokenize("技术发展迅猛", return_str=True) 25 | assert tokens == "技@@ 术@@ 发@@ 展@@ 迅@@ 猛" 26 | assert bpe.detokenize(tokens) == "技术发展迅猛" 27 | tokens = bpe.tokenize("技术迅速发展迅速 迅速 <-neplhd-hehe>", return_str=True) 28 | assert tokens == "技术@@ 迅速@@ 发@@ 展@@ 迅速 迅速 <-neplhd-hehe>" 29 | assert bpe.detokenize(tokens) == "技术迅速发展迅速 迅速 <-neplhd-hehe>" 30 | 31 | os.remove(tmp_file.name) 32 | 33 | 34 | if __name__ == "__main__": 35 | test() 36 | -------------------------------------------------------------------------------- /tests/neurst/data/text/jieba_segment_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from neurst.data.text.jieba_segment import Jieba 5 | 6 | 7 | def test(): 8 | tok = Jieba() 9 | assert tok.tokenize("他来到了网易杭研大厦", return_str=True) == "他 来到 了 网易 杭研 大厦" 10 | assert tok.detokenize("他 来到 了 网易 杭研 大厦", return_str=True) == "他来到了网易杭研大厦" 11 | 12 | 13 | if __name__ == "__main__": 14 | test() 15 | -------------------------------------------------------------------------------- /tests/neurst/data/text/moses_tokenizer_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | from neurst.data.text.moses_tokenizer import MosesTokenizer 4 | 5 | 6 | def test(): 7 | # lang, origin, tokenized, detokenized 8 | samples = [ 9 | ("zh", False, [], "`啊你 好~! ", "`啊你 好 ~ !", "`啊你好 ~ !"), 10 | ("en", False, [], "Hello p.m. 10, [.", "Hello p.m. 10 , [ .", None), 11 | ("en", True, [], "Hello p.m. 10, [.", "Hello p.m. 10 , [ .", None), 12 | ("en", False, [''], 'Hello p.m. 10, [.', 13 | 'Hello p.m. 10 , [ .', 'Hello p.m. 10, [.') 14 | 15 | ] 16 | 17 | for lang, escape, gloss, ori, tok, detok in samples: 18 | if not detok: 19 | detok = ori 20 | tokenizer = MosesTokenizer(language=lang, glossaries=gloss) 21 | assert tok == tokenizer.tokenize(ori, return_str=True) 22 | assert detok == tokenizer.detokenize(tok, return_str=True) 23 | 24 | 25 | if __name__ == '__main__': 26 | test() 27 | -------------------------------------------------------------------------------- /tests/neurst/data/text/vocab_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from neurst.data.text.vocab import Vocab 5 | 6 | word_tokens = ["Hello", "World", "yes", "i", "I"] 7 | 8 | 9 | def test(): 10 | vocab = Vocab(word_tokens, 11 | extra_tokens=["UNK", "EOS"]) 12 | assert vocab._token_list == ["Hello", "World", "yes", "i", "I", "UNK", "EOS"] 13 | assert vocab.vocab_size == 7 14 | assert vocab.map_token_to_id(["Hello", "world", "man"], 15 | unknown_default=100) == [0, 100, 100] 16 | assert vocab.map_id_to_token([1, 0, 3]) == ["World", "Hello", "i"] 17 | 18 | vocab = Vocab(word_tokens, 19 | extra_tokens=["UNK", "EOS"], lowercase=True) 20 | assert vocab._token_list == ["hello", "world", "yes", "i", "UNK", "EOS"] 21 | assert vocab.vocab_size == 6 22 | assert vocab.map_token_to_id(["Hello", "world", "man"], 23 | unknown_default=100) == [0, 1, 100] 24 | assert vocab.map_id_to_token([1, 0, 3]) == ["world", "hello", "i"] 25 | 26 | 27 | def test_file(): 28 | vocab_file = tempfile.NamedTemporaryFile(delete=False) 29 | with open(vocab_file.name, "w") as fw: 30 | for t in word_tokens: 31 | fw.write(t + "\t100\n") 32 | vocab = Vocab.load_from_file(vocab_file.name, 33 | extra_tokens=["UNK", "EOS"]) 34 | assert vocab._token_list == ["Hello", "World", "yes", "i", "I", "UNK", "EOS"] 35 | assert vocab.vocab_size == 7 36 | assert vocab.map_token_to_id(["Hello", "world", "man"], 37 | unknown_default=100) == [0, 100, 100] 38 | assert vocab.map_id_to_token([1, 0, 3]) == ["World", "Hello", "i"] 39 | 40 | vocab = Vocab.load_from_file(vocab_file.name, 41 | extra_tokens=["UNK", "EOS"], lowercase=True) 42 | assert vocab._token_list == ["hello", "world", "yes", "i", "UNK", "EOS"] 43 | assert vocab.vocab_size == 6 44 | assert vocab.map_token_to_id(["Hello", "world", "man", "EOS"], 45 | unknown_default=100) == [0, 1, 100, 5] 46 | assert vocab.map_id_to_token([1, 0, 3]) == ["world", "hello", "i"] 47 | os.remove(vocab_file.name) 48 | 49 | 50 | if __name__ == "__main__": 51 | test() 52 | test_file() 53 | -------------------------------------------------------------------------------- /tests/neurst/data/text_data_pipeline_test.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import tensorflow as tf 4 | 5 | from neurst.data.data_pipelines.text_data_pipeline import TextDataPipeline 6 | 7 | 8 | def test(): 9 | vocab_file = tempfile.NamedTemporaryFile(delete=False) 10 | with tf.io.gfile.GFile(vocab_file.name, "w") as fw: 11 | for t in ["技术", "迅@@", "猛"]: 12 | fw.write(t + "\t100\n") 13 | bpe_codes_file = tempfile.NamedTemporaryFile(delete=False) 14 | with tf.io.gfile.GFile(bpe_codes_file.name, "w") as fw: 15 | fw.write("version\n") 16 | fw.write("\n".join(["技 术", "发 展"]) + "\n") 17 | text_dp = TextDataPipeline( 18 | vocab_path=vocab_file.name, 19 | language="zh", 20 | subtokenizer="bpe", 21 | subtokenizer_codes=bpe_codes_file.name) 22 | 23 | assert text_dp.process("技术 发展 迅猛") == [0, 3, 3, 1, 2, 5] 24 | assert text_dp.recover([0, 3, 1, 2, 5]) == "技术 迅猛" 25 | 26 | 27 | if __name__ == "__main__": 28 | test() 29 | -------------------------------------------------------------------------------- /tests/neurst/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/layers/__init__.py -------------------------------------------------------------------------------- /tests/neurst/layers/attentions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/layers/attentions/__init__.py -------------------------------------------------------------------------------- /tests/neurst/layers/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/layers/decoders/__init__.py -------------------------------------------------------------------------------- /tests/neurst/layers/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/layers/encoders/__init__.py -------------------------------------------------------------------------------- /tests/neurst/layers/modalities_test.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from neurst.layers.modalities.text_modalities import WordEmbeddingSharedWeights 4 | 5 | 6 | def test_word_embedding(): 7 | embedding_layer = WordEmbeddingSharedWeights( 8 | embedding_dim=5, vocab_size=10, 9 | share_softmax_weights=False) 10 | inputs2d = numpy.random.randint(0, 9, size=(3, 4)) 11 | inputs1d = numpy.random.randint(0, 9, size=(3,)) 12 | emb_for_2d = embedding_layer(inputs2d) 13 | emb_for_1d = embedding_layer(inputs1d) 14 | assert len(embedding_layer.get_weights()) == 1 15 | assert numpy.sum( 16 | (embedding_layer.get_weights()[0][inputs1d] - emb_for_1d.numpy()) ** 2) < 1e-9 17 | assert numpy.sum( 18 | (embedding_layer.get_weights()[0][inputs2d] - emb_for_2d.numpy()) ** 2) < 1e-9 19 | assert "emb/weights" in embedding_layer.trainable_weights[0].name 20 | 21 | emb_shared_layer = WordEmbeddingSharedWeights( 22 | embedding_dim=5, vocab_size=10, 23 | share_softmax_weights=True) 24 | emb_for_2d = emb_shared_layer(inputs2d) 25 | emb_for_1d = emb_shared_layer(inputs1d) 26 | logits_for_2d = emb_shared_layer(emb_for_2d, mode="linear") 27 | logits_for_1d = emb_shared_layer(emb_for_1d, mode="linear") 28 | assert len(emb_shared_layer.get_weights()) == 2 29 | for w in emb_shared_layer.trainable_weights: 30 | if "bias" in w.name: 31 | bias = w 32 | else: 33 | weights = w 34 | assert numpy.sum( 35 | (weights.numpy()[inputs1d] - emb_for_1d.numpy()) ** 2) < 1e-9 36 | assert numpy.sum( 37 | (weights.numpy()[inputs2d] - emb_for_2d.numpy()) ** 2) < 1e-9 38 | assert numpy.sum( 39 | (numpy.dot(emb_for_2d.numpy(), numpy.transpose(weights.numpy()) 40 | ) + bias.numpy() - logits_for_2d.numpy()) ** 2) < 1e-9 41 | assert numpy.sum( 42 | (numpy.dot(emb_for_1d.numpy(), numpy.transpose(weights.numpy()) 43 | ) + bias.numpy() - logits_for_1d.numpy()) ** 2) < 1e-9 44 | assert "shared/weights" in weights.name 45 | assert "shared/bias" in bias.name 46 | 47 | 48 | if __name__ == "__main__": 49 | test_word_embedding() 50 | -------------------------------------------------------------------------------- /tests/neurst/layers/search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/layers/search/__init__.py -------------------------------------------------------------------------------- /tests/neurst/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/models/__init__.py -------------------------------------------------------------------------------- /tests/neurst/models/gpt2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import tensorflow as tf 15 | 16 | from neurst.tasks import build_task 17 | from neurst.utils.checkpoints import restore_checkpoint_if_possible_v2 18 | from neurst.utils.hparams_sets import get_hyper_parameters 19 | from neurst.utils.misc import assert_equal_numpy 20 | 21 | 22 | def test_openai_gpt2(): 23 | from transformers import GPT2Model, GPT2Tokenizer 24 | 25 | input_text = "Here is some text to encode" 26 | pt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 27 | pt_model = GPT2Model.from_pretrained("gpt2", return_dict=True) 28 | pt_outputs = pt_model(**pt_tokenizer([input_text], return_tensors="pt")) 29 | 30 | task = build_task({ 31 | "class": "lm", 32 | "params": { 33 | "data_pipeline.class": "GPT2DataPipeline", 34 | "max_len": 50, 35 | "begin_of_sentence": "eos" 36 | } 37 | }) 38 | 39 | model_cfgs = get_hyper_parameters("gpt2_117m") 40 | model = task.build_model(model_cfgs) 41 | restore_checkpoint_if_possible_v2(model, "117M", model_name="OpenAIGPT2") 42 | input_ids = task._data_pipeline.process(input_text) 43 | tf_inputs = { 44 | "trg_input": tf.convert_to_tensor([input_ids], tf.int64), 45 | "trg_length": tf.convert_to_tensor([len(input_ids)], tf.int64) 46 | } 47 | _, gen_init = model.get_symbols_to_logits_fn(tf_inputs, is_training=False, is_inference=False) 48 | tf_outputs = model.get_decoder_output(gen_init["decoder_input"], 49 | cache=gen_init["decoder_internal_cache"], 50 | is_training=False) 51 | assert_equal_numpy(pt_outputs.last_hidden_state.detach().numpy(), tf_outputs[:, :-1].numpy(), 5e-4) 52 | 53 | 54 | if __name__ == "__main__": 55 | test_openai_gpt2() 56 | -------------------------------------------------------------------------------- /tests/neurst/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst/utils/__init__.py -------------------------------------------------------------------------------- /tests/neurst_pt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst_pt/__init__.py -------------------------------------------------------------------------------- /tests/neurst_pt/decoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst_pt/decoders/__init__.py -------------------------------------------------------------------------------- /tests/neurst_pt/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst_pt/encoders/__init__.py -------------------------------------------------------------------------------- /tests/neurst_pt/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst_pt/layers/__init__.py -------------------------------------------------------------------------------- /tests/neurst_pt/layers/layer_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from neurst.layers.layer_utils import lower_triangle_attention_bias 15 | from neurst.utils.misc import assert_equal_numpy 16 | from neurst_pt.layers.layer_utils import lower_triangle_attention_bias as pt_lower_triangle_attention_bias 17 | 18 | 19 | def test_lower_triangle_attention_bias(): 20 | assert_equal_numpy(lower_triangle_attention_bias(5).numpy(), 21 | pt_lower_triangle_attention_bias(5).detach().numpy()) 22 | 23 | 24 | if __name__ == "__main__": 25 | test_lower_triangle_attention_bias() 26 | -------------------------------------------------------------------------------- /tests/neurst_pt/modalities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst_pt/modalities/__init__.py -------------------------------------------------------------------------------- /tests/neurst_pt/modalities/text_modalities_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 ByteDance Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy 15 | import tensorflow as tf 16 | import torch 17 | 18 | from neurst.layers.modalities.text_modalities import WordEmbeddingSharedWeights 19 | from neurst.utils.misc import assert_equal_numpy 20 | from neurst_pt.layers.modalities.text_modalities import WordEmbeddingSharedWeights as PTWordEmbeddingSharedWeights 21 | 22 | 23 | def test_emb(): 24 | emb_dim = 5 25 | vocab_size = 10 26 | tf_emb = WordEmbeddingSharedWeights(emb_dim, vocab_size, True) 27 | pt_emb = PTWordEmbeddingSharedWeights(emb_dim, vocab_size, True) 28 | inp_2d = numpy.random.randint(0, 9, [2, 5]) 29 | inp_1d = numpy.random.randint(0, 9, [3, ]) 30 | logits_2d = numpy.random.rand(2, 5) 31 | logits_3d = numpy.random.rand(2, 4, 5) 32 | tf_inp_2d = tf.convert_to_tensor(inp_2d, tf.int32) 33 | tf_inp_1d = tf.convert_to_tensor(inp_1d, tf.int32) 34 | tf_logits_2d = tf.convert_to_tensor(logits_2d, tf.float32) 35 | tf_logits_3d = tf.convert_to_tensor(logits_3d, tf.float32) 36 | pt_inp_2d = torch.IntTensor(inp_2d) 37 | pt_inp_1d = torch.IntTensor(inp_1d) 38 | pt_logits_2d = torch.FloatTensor(logits_2d) 39 | pt_logits_3d = torch.FloatTensor(logits_3d) 40 | _ = tf_emb(tf_logits_2d, mode="linear") 41 | _ = pt_emb(pt_logits_2d, mode="linear") 42 | pt_emb._shared_weights.data = torch.Tensor(tf_emb._shared_weights.numpy()) 43 | pt_emb._bias.data = torch.Tensor(tf_emb._bias.numpy()) 44 | assert_equal_numpy(tf_emb(tf_logits_2d, mode="linear").numpy(), 45 | pt_emb(pt_logits_2d, mode="linear").detach().numpy()) 46 | assert_equal_numpy(tf_emb(tf_logits_3d, mode="linear").numpy(), 47 | pt_emb(pt_logits_3d, mode="linear").detach().numpy()) 48 | assert_equal_numpy(tf_emb(tf_inp_2d).numpy(), pt_emb(pt_inp_2d).detach().numpy()) 49 | assert_equal_numpy(tf_emb(tf_inp_1d).numpy(), pt_emb(pt_inp_1d).detach().numpy()) 50 | 51 | 52 | if __name__ == "__main__": 53 | test_emb() 54 | -------------------------------------------------------------------------------- /tests/neurst_pt/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Yaoming95/CIAT/b07e4673f6c584c7c17212134d941c25b826a790/tests/neurst_pt/models/__init__.py -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py3 3 | 4 | [testenv] 5 | deps = .[dev] 6 | commands = 7 | pytest -sv 8 | install_command = pip install {opts} {packages} 9 | --------------------------------------------------------------------------------