├── .flake8 ├── .github └── workflows │ └── torchseq.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── Makefile ├── README.md ├── codecov.yml ├── configs ├── mbart.json ├── paraphrasing_vae.json ├── patches │ ├── beamsearch_16x8.json │ ├── beamsearch_4x2.json │ ├── beamsearch_4x2_length1.json │ ├── beamsearch_8x4.json │ ├── beamsearch_8x4_length1.json │ ├── dbs.json │ ├── eval_bsz_2.json │ ├── eval_bsz_8.json │ ├── multi_sample.json │ ├── name_vae.json │ ├── newsqa.json │ ├── nq.json │ ├── nucleus05.json │ ├── nucleus075.json │ ├── nucleus09.json │ ├── rerank_backtranslate.json │ ├── rerank_combo.json │ ├── rerank_ngram.json │ ├── rerank_qa.json │ ├── squad.json │ ├── tag_ppeval.json │ ├── udeptoq_squad.json │ └── unfreeze_encdec.json ├── qg_bart.json ├── qg_bert.json └── qg_transformer.json ├── data └── pretrained-vocabs │ ├── ._bart-large-cnn-merges.txt │ ├── ._bart-large-cnn-vocab.json │ ├── ._bart-large-merges.txt │ ├── ._bart-large-vocab.json │ ├── ._bert-base-cased-vocab.txt │ ├── ._bert-base-uncased-vocab.txt │ ├── ._roberta-base-merges.txt │ ├── ._roberta-base-vocab.json │ ├── bart-large-cnn-merges.txt │ ├── bart-large-cnn-vocab.json │ ├── bart-large-merges.txt │ ├── bart-large-vocab.json │ ├── bert-base-cased-vocab.txt │ ├── bert-base-uncased-vocab.txt │ ├── deberta-v3-base-spm.model │ ├── roberta-base-merges.txt │ └── roberta-base-vocab.json ├── docs ├── Makefile ├── _source │ ├── modules.rst │ ├── quickstart.rst │ ├── torchseq.agents.rst │ ├── torchseq.datasets.rst │ ├── torchseq.metric_hooks.rst │ ├── torchseq.models.rerankers.rst │ ├── torchseq.models.rst │ ├── torchseq.models.samplers.rst │ ├── torchseq.pretrained.rst │ ├── torchseq.rst │ └── torchseq.utils.rst ├── conf.py ├── index.rst └── make.bat ├── examples ├── Paraphrasing.ipynb ├── QG_BART_Eval.ipynb └── mBART.ipynb ├── mypy.ini ├── requirements.txt ├── scripts ├── download_data.sh ├── download_models.py ├── generate_3way_wikianswers.py ├── kaggle_prepro.py ├── paranmt_prepro.py ├── submit_job.sh └── train_vq_code_predictor.py ├── setup.py ├── tests ├── __init__.py ├── test_eval.py ├── test_misc.py ├── test_pretrained.py ├── test_regression.py ├── test_tokenisation.py └── utils.py └── torchseq ├── __init__.py ├── agents ├── __init__.py ├── aq_agent.py ├── base.py ├── lm_agent.py ├── model_agent.py ├── retrieval_agent.py └── seq2seq_agent.py ├── args.py ├── datasets ├── __init__.py ├── builder.py ├── json_dataset.py ├── json_loader.py ├── lm_dataset.py ├── lm_loader.py ├── loaders.py ├── paraphrase_dataset.py ├── paraphrase_loader.py ├── paraphrase_pair.py ├── qa_dataset.py ├── qa_loader.py └── qa_triple.py ├── demo ├── para_app.py ├── qg_app.py └── static │ ├── para_demo.htm │ ├── qgen.css │ ├── qgen.js │ ├── qq_demo.htm │ ├── separator.js │ └── spinner.css ├── eval ├── __init__.py ├── args.py ├── cli.py └── recipes │ ├── __init__.py │ ├── base.py │ └── opagg │ ├── extractive_summaries.py │ ├── hiro_post.py │ └── hiro_pre.py ├── main.py ├── metric_hooks ├── __init__.py ├── base.py ├── default.py ├── hrq_agg.py ├── opsumm_cluster_aug.py ├── prevalence_metric.py ├── prevalence_metric_old.py ├── qg_metric.py ├── rouge.py ├── semparse.py ├── sep_ae.py └── textual.py ├── models ├── __init__.py ├── activations.py ├── aq_transformer.py ├── bottleneck.py ├── bottleneck_autoencoder.py ├── contrastive_hrq_loss.py ├── contrastive_loss.py ├── contrastive_triplet_loss.py ├── ctxtans_encoder.py ├── decoder.py ├── encoder.py ├── exemplar_guided_autoencoder.py ├── hrq_vae.py ├── hyperbolic.py ├── hyperbolic_utils.py ├── lm_transformer.py ├── lr_schedule.py ├── modular_bottleneck.py ├── multihead_output.py ├── parallel_wrapper.py ├── pooling.py ├── positional_embeddings.py ├── pythae_vq.py ├── ranger.py ├── rerankers │ ├── __init__.py │ ├── backtranslate_reranker.py │ ├── combo.py │ ├── ngram_reranker.py │ ├── qa_reranker.py │ └── topk.py ├── retrieval_model.py ├── samplers │ ├── __init__.py │ ├── beam_search.py │ ├── diverse_beam.py │ ├── greedy.py │ ├── parallel_nucleus.py │ └── teacher_force.py ├── suppression_loss.py ├── transformer.py ├── vmf.py ├── vq_code_predictor.py ├── vq_vae.py └── vq_vae_legacy.py ├── pretrained ├── __init__.py ├── lexical_paraphraser_bert.py ├── lm.py ├── nli.py └── qa.py └── utils ├── __init__.py ├── ari.py ├── bart_score.py ├── cache.py ├── config.py ├── config_migration.py ├── easy_generate.py ├── fleiss.py ├── functions.py ├── logging.py ├── loss_dropper.py ├── mckenzie.py ├── metrics.py ├── model_loader.py ├── optimizer_group.py ├── perplexity.py ├── rouge.py ├── sari.py ├── seed.py ├── singleton.py ├── timer.py ├── tokenizer.py ├── tokenizer_wordlevel.py └── wandb.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203,E266,E501,W503,F403,F401,C901,E402 3 | max-line-length = 127 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 -------------------------------------------------------------------------------- /.github/workflows/torchseq.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: TorchSeq 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.9 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -r requirements.txt 27 | python -m nltk.downloader wordnet 28 | python -m nltk.downloader omw-1.4 29 | pip install . 30 | - name: Lint with flake8 31 | run: | 32 | pip install flake8 black 33 | # stop the build if there are Python syntax errors or undefined names 34 | # flake8 ./torchseq --count --select=E9,F63,F7,F82 --show-source --statistics 35 | make syntax 36 | make check 37 | - name: Test with pytest 38 | run: | 39 | pip install pytest pytest-cov codecov 40 | make test 41 | - name: Check types 42 | run: | 43 | pip install mypy 44 | mypy ./torchseq --install-types --non-interactive 45 | # - name: Upload coverage 46 | # run: | 47 | # make coverage 48 | 49 | 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .mypy_cache 4 | 5 | /runs 6 | /runs_bak 7 | /models 8 | /data 9 | /data_bak 10 | /configs 11 | /scripts 12 | /baselines 13 | /plots 14 | /wandb 15 | /cache 16 | 17 | .ipynb_checkpoints/ 18 | *.DS_Store 19 | *.ipynb 20 | 21 | *.sh 22 | 23 | /*env/ 24 | /torchseq/external 25 | .vscode 26 | /notebooks 27 | 28 | .coverage 29 | coverage.xml 30 | 31 | *.egg-info/ 32 | 33 | # Sphinx documentation 34 | docs/_build/ 35 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | # python: 28 | # install: 29 | # - requirements: requirements.txt -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs test coverage syntax types test check 2 | 3 | 4 | # Check that source code meets quality standards 5 | check: formatcheck syntax types test 6 | 7 | # Format source code automatically 8 | format: 9 | black --line-length 119 --target-version py39 tests torchseq 10 | 11 | formatcheck: 12 | black --check --line-length 119 --target-version py39 tests torchseq 13 | 14 | # Check syntax 15 | syntax: 16 | flake8 ./torchseq --count --select=E9,F63,F7,F82 --show-source --statistics 17 | 18 | types: 19 | mypy ./torchseq --install-types --non-interactive 20 | 21 | # Run tests for the library 22 | test: 23 | WANDB_USERNAME='' pytest --cov=./torchseq ./tests 24 | 25 | # Run tests for the library 26 | testall: 27 | WANDB_USERNAME='' RUN_SLOW=1 pytest --cov=./torchseq ./tests 28 | 29 | # Send coverage report to codecov 30 | coverage: 31 | CODECOV_TOKEN="28535f9f-825a-435e-bb4e-e1de2aa63da3" codecov 32 | rm .coverage 33 | rm coverage.xml 34 | 35 | # Build docs 36 | docs: 37 | sphinx-apidoc -f -o ./docs/_source ./torchseq 38 | (cd ./docs && make html) -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # basic 6 | target: auto 7 | threshold: 5% 8 | paths: 9 | - "torchseq" 10 | patch: 11 | default: 12 | target: 5% 13 | ignore: 14 | - ".aqenv/*" 15 | - "torchseq/args.py" 16 | - "torchseq/main.py" 17 | - "torchseq/metric_hooks/*" 18 | - "torchseq/utils/config_migration.py" -------------------------------------------------------------------------------- /configs/mbart.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mbart_ae_testcase", 3 | "tag": "test_case", 4 | "group": "example", 5 | "task": "seq2seq", 6 | "training": { 7 | "dataset": "json", 8 | "use_preprocessed_data": false, 9 | "log_interval": 100, 10 | "optimizer": { 11 | "type": "adam", 12 | "lr": 0.002, 13 | "beta1": 0.9, 14 | "beta2": 0.98, 15 | "lr_schedule": true, 16 | "lr_warmup_steps": 1000 17 | }, 18 | "batch_size": 10, 19 | "optim_batch_size": 50, 20 | "clip_gradient": 1, 21 | "num_epochs": 200, 22 | "warmup_epochs": 0, 23 | "suppression_loss_weight": 0.0, 24 | "label_smoothing": 0.1, 25 | "early_stopping_lag": 0, 26 | "early_stopping_patience": 10, 27 | "reset_metrics": true, 28 | "token_dropout": 0.1, 29 | "kl_warmup_steps": 0, 30 | "epoch_steps": 0 31 | }, 32 | "json_dataset": { 33 | "path": "semparse/atis", 34 | "filename": "{split}", 35 | "field_map": [ 36 | { 37 | "type": "copy", 38 | "from": "target", 39 | "to": "target" 40 | }, 41 | { 42 | "type": "copy", 43 | "from": "source", 44 | "to": "source" 45 | } 46 | ] 47 | }, 48 | "eval": { 49 | "eval_batch_size": 8, 50 | "sampler": "beam", 51 | "max_out_len": 400, 52 | "metrics": { 53 | "semparse": { 54 | "run_codepred": false 55 | } 56 | }, 57 | "prepend_langcode": true 58 | }, 59 | "beam_search": { 60 | "beam_width": 5, 61 | "beam_expansion": 5, 62 | "length_alpha": 1.0, 63 | "prevent_repetition": false 64 | }, 65 | "prepro": { 66 | "input_vocab_size": 250054, 67 | "output_vocab_size": 250054, 68 | "sent_window": 0, 69 | "tok_window": 400, 70 | "include_lang_codes": true, 71 | "drop_target_lang_codes": true, 72 | "input_tokenizer": "facebook/mbart-large-50-many-to-many-mmt", 73 | "output_tokenizer": "facebook/mbart-large-50-many-to-many-mmt" 74 | }, 75 | "dropout": 0.1, 76 | "input_raw_embedding_dim": 1024, 77 | "output_raw_embedding_dim": 1024, 78 | "encoder": { 79 | "num_heads": 16, 80 | "dim_feedforward": 4096, 81 | "activation": "relu", 82 | "pretrained_encoder": "facebook/mbart-large-50-many-to-many-mmt", 83 | "embedding_dim": 1024, 84 | "freeze_pretrained": true, 85 | "init_embeds_from_tokenizer": false, 86 | "num_layers": 0 87 | }, 88 | "decoder": { 89 | "num_heads": 16, 90 | "dim_feedforward": 4096, 91 | "activation": "relu", 92 | "pretrained_decoder": "facebook/mbart-large-50-many-to-many-mmt", 93 | "embedding_dim": 1024, 94 | "freeze_pretrained": true, 95 | "init_embeds_from_tokenizer": false, 96 | "num_layers": 0 97 | }, 98 | "freeze_embeddings": false, 99 | "freeze_projection": false, 100 | "directional_masks": false, 101 | "bottleneck": { 102 | "embedding_dim": 1024, 103 | "modular": true, 104 | "num_heads": 8, 105 | "modules": [ 106 | { 107 | "range": [ 108 | 0, 109 | 8 110 | ], 111 | "type": "ae", 112 | "pooling": false 113 | } 114 | ] 115 | } 116 | } -------------------------------------------------------------------------------- /configs/paraphrasing_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "paraphrasing_vae", 3 | "tag": "examples", 4 | "task": "para", 5 | "env": { 6 | "cuda": true, 7 | "data_path": "./data", 8 | "gpu_device": 0 9 | 10 | }, 11 | "json_dataset": { 12 | "path": "wikianswers-pp", 13 | "field_map": [ 14 | { 15 | "type": "sample", 16 | "from": "qs", 17 | "to": "s1" 18 | }, 19 | { 20 | "type": "sample", 21 | "from": "qs", 22 | "to": "s2" 23 | } 24 | ] 25 | }, 26 | "training": { 27 | "dataset": "json", 28 | "use_preprocessed_data": false, 29 | "log_interval": 100, 30 | "lr_schedule": true, 31 | "lr": 0.01, 32 | "beta1": 0.9, 33 | "beta2": 0.98, 34 | "batch_size": 16, 35 | "optim_batch_size": 64, 36 | "clip_gradient": 5, 37 | "num_epochs": 120, 38 | "opt": "adam", 39 | "warmup_epochs": 30, 40 | "suppression_loss_weight": 0.0, 41 | "label_smoothing": 0.0, 42 | "early_stopping_lag": 0, 43 | "reset_metrics": true, 44 | "token_dropout": 0.2, 45 | "kl_warmup_steps": 10000 46 | }, 47 | "eval": { 48 | "eval_batch_size": 16, 49 | "sampler": "beam", 50 | "max_out_len": 50 51 | }, 52 | "beam_search": { 53 | "beam_width": 4, 54 | "beam_expansion": 2, 55 | "length_alpha": 1.0 56 | }, 57 | "prepro": { 58 | "vocab_size": 30522, 59 | "sent_window": 0, 60 | "tok_window": 40, 61 | "concat_ctxt_ans": false, 62 | "bio_tagging": true, 63 | "tokenizer": "bert-base-uncased" 64 | }, 65 | 66 | "dropout": 0.1, 67 | 68 | "raw_embedding_dim": 768, 69 | "embedding_dim": 768, 70 | "onehot_bio" : false, 71 | "bio_embedding_dim": 8, 72 | "freeze_embeddings": true, 73 | "freeze_projection": true, 74 | "directional_masks": true, 75 | 76 | "encdec": { 77 | "num_encoder_layers": 5, 78 | "num_decoder_layers": 5, 79 | "num_heads": 8, 80 | "dim_feedforward": 2048, 81 | "activation": "relu", 82 | "bert_encoder": false 83 | }, 84 | "encoder": { 85 | "embedding_dim": 768 86 | }, 87 | "decoder": { 88 | "embedding_dim": 768 89 | }, 90 | "bottleneck": { 91 | "embedding_dim": 768, 92 | "variational": true, 93 | "num_similar_heads": 4 94 | } 95 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_16x8.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam16x8", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "diverse_beam": { 8 | "beam_width": 16, 9 | "beam_expansion": 8, 10 | "length_alpha": 2.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_4x2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam4x2", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 4, 9 | "beam_expansion": 2, 10 | "length_alpha": 2.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_4x2_length1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam4x2L1", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 4, 9 | "beam_expansion": 2, 10 | "length_alpha": 1.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_8x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam8x4", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 8, 9 | "beam_expansion": 4, 10 | "length_alpha": 2.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_8x4_length1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam8x4L1", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 8, 9 | "beam_expansion": 4, 10 | "length_alpha": 1.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/dbs.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_diversebeam16x8_8g", 3 | "eval": { 4 | "sampler": "diverse_beam" 5 | }, 6 | "diverse_beam": { 7 | "beam_width": 16, 8 | "beam_expansion": 8, 9 | "length_alpha": 2.0, 10 | "num_groups": 8, 11 | "penalty_weight": 0.5 12 | } 13 | } -------------------------------------------------------------------------------- /configs/patches/eval_bsz_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "training": { 3 | "batch_size": 2 4 | }, 5 | "eval": { 6 | "eval_batch_size": 2 7 | } 8 | } -------------------------------------------------------------------------------- /configs/patches/eval_bsz_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "training": { 3 | "batch_size": 8 4 | }, 5 | "eval": { 6 | "eval_batch_size": 8 7 | } 8 | } -------------------------------------------------------------------------------- /configs/patches/multi_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_top4", 3 | "eval": { 4 | "topk": 4 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/name_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "vae_%" 3 | } -------------------------------------------------------------------------------- /configs/patches/newsqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_newsqa", 3 | "training": { 4 | "dataset": "newsqa" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nq", 3 | "training": { 4 | "dataset": "naturalquestions" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/nucleus05.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nucleus0.5", 3 | "eval": { 4 | "sampler": "nucleus" 5 | }, 6 | "nucleus_sampling": { 7 | "prevent_repetition": true, 8 | "cutoff": 0.5, 9 | "beam_width": 4, 10 | "length_alpha": 0.01 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/nucleus075.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nucleus0.75", 3 | "eval": { 4 | "sampler": "nucleus" 5 | }, 6 | "nucleus_sampling": { 7 | "prevent_repetition": true, 8 | "cutoff": 0.75, 9 | "beam_width": 4, 10 | "length_alpha": 0.01 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/nucleus09.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nucleus0.9", 3 | "eval": { 4 | "sampler": "nucleus" 5 | }, 6 | "nucleus_sampling": { 7 | "prevent_repetition": true, 8 | "cutoff": 0.9, 9 | "beam_width": 4, 10 | "length_alpha": 0.01 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/rerank_backtranslate.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankbacktrans", 3 | "reranker": { 4 | "strategy": "backtranslate" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/rerank_combo.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankcombo", 3 | "reranker": { 4 | "strategy": "combo" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/rerank_ngram.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankngram", 3 | "reranker": { 4 | "strategy": "ngram" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/rerank_qa.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankqa", 3 | "reranker": { 4 | "strategy": "qa" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_squad", 3 | "training": { 4 | "dataset": "squad" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/tag_ppeval.json: -------------------------------------------------------------------------------- 1 | { 2 | "tag": "ppeval" 3 | } -------------------------------------------------------------------------------- /configs/patches/udeptoq_squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_squadudep", 3 | "training": { 4 | "dataset": "models/squad-udep" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/unfreeze_encdec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_unfreeze_lr1e-6", 3 | "training": { 4 | "batch_size": 2, 5 | "lr": 1e-6 6 | }, 7 | "encdec": { 8 | "freeze_encoder": false, 9 | "freeze_decoder": false 10 | } 11 | } -------------------------------------------------------------------------------- /configs/qg_bart.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "qg_bart", 3 | "tag": "examples", 4 | "task": "aq", 5 | "model": "pretrained_adapter", 6 | "env": { 7 | "cuda": true, 8 | "data_path": "./data", 9 | "gpu_device": 0 10 | 11 | }, 12 | "training": { 13 | "dataset": "squad", 14 | "use_preprocessed_data": false, 15 | "log_interval": 100, 16 | "lr_schedule": false, 17 | "lr": 5e-6, 18 | "beta1": 0.9, 19 | "beta2": 0.98, 20 | "batch_size": 4, 21 | "optim_batch_size": 64, 22 | "clip_gradient": 0.1, 23 | "num_epochs": 30, 24 | "opt": "adam", 25 | "warmup_epochs": 10, 26 | "suppression_loss_weight": 0.01, 27 | "label_smoothing": 0.0, 28 | "early_stopping_lag": 0, 29 | "reset_metrics": true, 30 | "token_dropout": 0.0 31 | }, 32 | "eval": { 33 | "eval_batch_size": 4, 34 | "sampler": "beam", 35 | "prepend_eos": true, 36 | "sample_outputs": true 37 | }, 38 | "beam_search": { 39 | "beam_width": 8, 40 | "beam_expansion": 4, 41 | "length_alpha": 2.0 42 | }, 43 | "prepro": { 44 | "vocab_size": 50265, 45 | "sent_window": 1, 46 | "tok_window": 300, 47 | "concat_ctxt_ans": true, 48 | "roberta_style_encoding": true, 49 | "bio_tagging": true, 50 | "tokenizer": "bart-large" 51 | }, 52 | 53 | "dropout": 0.1, 54 | 55 | "raw_embedding_dim": 1024, 56 | "embedding_dim": 1024, 57 | "onehot_bio" : false, 58 | "bio_embedding_dim": 8, 59 | "freeze_embeddings": true, 60 | "freeze_projection": true, 61 | "directional_masks": true, 62 | 63 | "encoder_outputs": { 64 | "c_raw": true, 65 | "a_raw": false, 66 | "c_enc": true, 67 | "c_enc_pool": false, 68 | "a_enc": false, 69 | "a_enc_pool": false, 70 | "c_enc_anspool": false, 71 | "c_ans_labels": false 72 | }, 73 | "encdec": { 74 | "num_encoder_layers": 3, 75 | "num_decoder_layers": 3, 76 | "num_heads": 8, 77 | "dim_feedforward": 2048, 78 | "activation": "relu", 79 | "bert_encoder": false, 80 | "bert_model": "facebook/bart-large", 81 | "bert_warmup_epochs": 50, 82 | "freeze_encoder": false, 83 | "freeze_decoder": true 84 | }, 85 | "encoder": { 86 | "embedding_dim": 1024 87 | }, 88 | "decoder": { 89 | "embedding_dim": 1024 90 | } 91 | } -------------------------------------------------------------------------------- /configs/qg_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "qg_bert", 3 | "tag": "examples", 4 | "task": "aq", 5 | "env": { 6 | "cuda": true, 7 | "data_path": "./data", 8 | "gpu_device": 0 9 | }, 10 | "training": { 11 | "dataset": "squad", 12 | "use_preprocessed_data": false, 13 | "log_interval": 100, 14 | "lr_schedule": true, 15 | "lr": 0.003, 16 | "beta1": 0.9, 17 | "beta2": 0.98, 18 | "batch_size": 8, 19 | "optim_batch_size": 64, 20 | "clip_gradient": 5, 21 | "num_epochs": 50, 22 | "opt": "adam", 23 | "warmup_epochs": 5, 24 | "suppression_loss_weight": 0.1, 25 | "label_smoothing": 0.0, 26 | "early_stopping_lag": 1, 27 | "token_dropout": 0.2 28 | }, 29 | "eval": { 30 | "eval_batch_size": 6, 31 | "sampler": "beam" 32 | }, 33 | "beam_search": { 34 | "beam_width": 8, 35 | "beam_expansion": 8, 36 | "length_alpha": 2.0 37 | }, 38 | "prepro": { 39 | "vocab_size": 30522, 40 | "sent_window": 0, 41 | "tok_window": 300, 42 | "concat_ctxt_ans": false, 43 | "bio_tagging": true, 44 | "tokenizer": "bert-base-uncased" 45 | }, 46 | "dropout": 0.1, 47 | "raw_embedding_dim": 768, 48 | "embedding_dim": 768, 49 | "onehot_bio": false, 50 | "bio_embedding_dim": 8, 51 | "freeze_embeddings": true, 52 | "freeze_projection": true, 53 | "directional_masks": true, 54 | "encoder_outputs": { 55 | "c_raw": true, 56 | "a_raw": false, 57 | "c_enc": true, 58 | "c_enc_pool": false, 59 | "a_enc": false, 60 | "a_enc_pool": false, 61 | "c_enc_anspool": false, 62 | "c_ans_labels": false 63 | }, 64 | "encdec": { 65 | "num_encoder_layers": 5, 66 | "num_decoder_layers": 5, 67 | "num_heads": 2, 68 | "dim_feedforward": 2048, 69 | "activation": "relu", 70 | "bert_encoder": true, 71 | "bert_model": "bert-base-uncased", 72 | "bert_warmup_epochs": 20, 73 | "bert_finetune": false 74 | }, 75 | "encoder": { 76 | "embedding_dim": 768 77 | }, 78 | "decoder": { 79 | "embedding_dim": 768 80 | } 81 | } -------------------------------------------------------------------------------- /configs/qg_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "qg_transformer", 3 | "tag": "examples", 4 | "task": "aq", 5 | "env": { 6 | "cuda": true, 7 | "data_path": "./data", 8 | "gpu_device": 0 9 | }, 10 | "training": { 11 | "dataset": "squad", 12 | "use_preprocessed_data": false, 13 | "log_interval": 100, 14 | "lr_schedule": true, 15 | "lr": 0.003, 16 | "beta1": 0.9, 17 | "beta2": 0.98, 18 | "batch_size": 16, 19 | "optim_batch_size": 64, 20 | "clip_gradient": 5, 21 | "num_epochs": 50, 22 | "opt": "adam", 23 | "warmup_epochs": 5, 24 | "suppression_loss_weight": 0.1, 25 | "label_smoothing": 0.0, 26 | "early_stopping_lag": 1, 27 | "token_dropout": 0.2 28 | }, 29 | "eval": { 30 | "eval_batch_size": 8, 31 | "sampler": "beam" 32 | }, 33 | "beam_search": { 34 | "beam_width": 8, 35 | "beam_expansion": 4, 36 | "length_alpha": 2.0 37 | }, 38 | "prepro": { 39 | "vocab_size": 30522, 40 | "sent_window": 0, 41 | "tok_window": 300, 42 | "concat_ctxt_ans": false, 43 | "bio_tagging": true, 44 | "tokenizer": "bert-base-uncased" 45 | }, 46 | "dropout": 0.1, 47 | "raw_embedding_dim": 768, 48 | "embedding_dim": 768, 49 | "onehot_bio": false, 50 | "bio_embedding_dim": 8, 51 | "freeze_embeddings": true, 52 | "freeze_projection": true, 53 | "directional_masks": true, 54 | "encoder_outputs": { 55 | "c_raw": true, 56 | "a_raw": false, 57 | "c_enc": true, 58 | "c_enc_pool": false, 59 | "a_enc": false, 60 | "a_enc_pool": false, 61 | "c_enc_anspool": false, 62 | "c_ans_labels": false 63 | }, 64 | "encdec": { 65 | "num_encoder_layers": 5, 66 | "num_decoder_layers": 5, 67 | "num_heads": 2, 68 | "dim_feedforward": 2048, 69 | "activation": "relu", 70 | "bert_encoder": false, 71 | "bert_warmup_epochs": 20, 72 | "bert_finetune": false 73 | }, 74 | "encoder": { 75 | "embedding_dim": 768 76 | }, 77 | "decoder": { 78 | "embedding_dim": 768 79 | } 80 | } -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-cnn-merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-cnn-merges.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-cnn-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-cnn-vocab.json -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-merges.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-vocab.json -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bert-base-cased-vocab.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bert-base-cased-vocab.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bert-base-uncased-vocab.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bert-base-uncased-vocab.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._roberta-base-merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._roberta-base-merges.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._roberta-base-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._roberta-base-vocab.json -------------------------------------------------------------------------------- /data/pretrained-vocabs/deberta-v3-base-spm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/deberta-v3-base-spm.model -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_source/modules.rst: -------------------------------------------------------------------------------- 1 | torchseq 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | torchseq 8 | -------------------------------------------------------------------------------- /docs/_source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | 5 | Installation 6 | ------------ 7 | 8 | From a fresh ``venv``, run: 9 | ``` 10 | pip install -r requirements.txt 11 | pip install -e . 12 | 13 | python3 -m nltk.downloader punkt 14 | python3 ./scripts/download_models.py 15 | ``` 16 | 17 | Done! 18 | 19 | Overview 20 | -------- 21 | 22 | TorchSeq is a framework for training and evaluating Seq2Seq models, built in PyTorch. 23 | 24 | 25 | 26 | 27 | CLI Usage 28 | --------- 29 | 30 | TorchSeq installs a CLI - to load a model and evaluate it on the test set, run ``torchseq --test --load /path/to/model``. 31 | 32 | The CLI options are: 33 | 34 | --train Run training 35 | --validate Run validation (ie, eval on the dev set) 36 | --validate_train Run eval on the training set 37 | --test Run eval on the test set 38 | --silent Disable verbose output 39 | --reload_after_train Use in conjunction with one of the eval commands to reload the best checkpoint once training completes, and evaluate using that 40 | --load_chkpt /path/to/checkpoint.pt Path to checkpoint file 41 | --data_path /path/to/data/ Path to folder containing datasets 42 | --output_path /path/to/output/ Path to dump output 43 | --config,-c /path/to/config.json Path to config file 44 | --patch,-p /path/to/patch.json Path to 'patch' file(s) 45 | --load /path/to/model/ Path to a full model (checkpoint + config) 46 | --cpu Run on CPU 47 | --debug Enable some extra debugging 48 | 49 | 50 | Scripting 51 | --------- 52 | 53 | You can also invoke TorchSeq from within a script, like this: 54 | 55 | ``` 56 | from torchseq.agents.seq2seq_agent import Seq2SeqAgent 57 | from torchseq.datasets.json_loader import JsonDataLoader 58 | 59 | from torchseq.utils.config import Config 60 | 61 | with open(path_to_model + "/config.json") as f: 62 | cfg_dict = json.load(f) 63 | 64 | config = Config(cfg_dict) 65 | 66 | data_loader = JsonDataLoader(config) 67 | 68 | checkpoint_path = path_to_model + "/model/checkpoint.pt" 69 | 70 | instance = Seq2SeqAgent(config=config, run_id=None, output_path="./runs/demo/", silent=True, verbose=False) 71 | 72 | instance.load_checkpoint(checkpoint_path) 73 | instance.model.eval() 74 | 75 | loss, metrics, output, memory = instance.validate(data_loader, save_model=False) 76 | ``` 77 | 78 | In general, a :torchseq.agents.model_agent.ModelAgent: object is the main controller - once you've created one from a :torchseq.utils.config.Config:, you can train it with :torchseq.agents.base.train: -------------------------------------------------------------------------------- /docs/_source/torchseq.agents.rst: -------------------------------------------------------------------------------- 1 | torchseq.agents package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.agents.aq\_agent module 8 | -------------------------------- 9 | 10 | .. automodule:: torchseq.agents.aq_agent 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.agents.base module 16 | --------------------------- 17 | 18 | .. automodule:: torchseq.agents.base 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.agents.exemplar\_agent module 24 | -------------------------------------- 25 | 26 | .. automodule:: torchseq.agents.exemplar_agent 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.agents.lm\_agent module 32 | -------------------------------- 33 | 34 | .. automodule:: torchseq.agents.lm_agent 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.agents.meta\_learning\_agent module 40 | -------------------------------------------- 41 | 42 | .. automodule:: torchseq.agents.meta_learning_agent 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.agents.model\_agent module 48 | ----------------------------------- 49 | 50 | .. automodule:: torchseq.agents.model_agent 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.agents.seq2seq\_agent module 56 | ------------------------------------- 57 | 58 | .. automodule:: torchseq.agents.seq2seq_agent 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: torchseq.agents 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /docs/_source/torchseq.datasets.rst: -------------------------------------------------------------------------------- 1 | torchseq.datasets package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.datasets.builder module 8 | -------------------------------- 9 | 10 | .. automodule:: torchseq.datasets.builder 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.datasets.json\_dataset module 16 | -------------------------------------- 17 | 18 | .. automodule:: torchseq.datasets.json_dataset 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.datasets.json\_loader module 24 | ------------------------------------- 25 | 26 | .. automodule:: torchseq.datasets.json_loader 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.datasets.lm\_dataset module 32 | ------------------------------------ 33 | 34 | .. automodule:: torchseq.datasets.lm_dataset 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.datasets.lm\_loader module 40 | ----------------------------------- 41 | 42 | .. automodule:: torchseq.datasets.lm_loader 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.datasets.loaders module 48 | -------------------------------- 49 | 50 | .. automodule:: torchseq.datasets.loaders 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.datasets.paraphrase\_dataset module 56 | -------------------------------------------- 57 | 58 | .. automodule:: torchseq.datasets.paraphrase_dataset 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchseq.datasets.paraphrase\_loader module 64 | ------------------------------------------- 65 | 66 | .. automodule:: torchseq.datasets.paraphrase_loader 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchseq.datasets.paraphrase\_pair module 72 | ----------------------------------------- 73 | 74 | .. automodule:: torchseq.datasets.paraphrase_pair 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | torchseq.datasets.qa\_dataset module 80 | ------------------------------------ 81 | 82 | .. automodule:: torchseq.datasets.qa_dataset 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | torchseq.datasets.qa\_loader module 88 | ----------------------------------- 89 | 90 | .. automodule:: torchseq.datasets.qa_loader 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | torchseq.datasets.qa\_triple module 96 | ----------------------------------- 97 | 98 | .. automodule:: torchseq.datasets.qa_triple 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | Module contents 104 | --------------- 105 | 106 | .. automodule:: torchseq.datasets 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | -------------------------------------------------------------------------------- /docs/_source/torchseq.metric_hooks.rst: -------------------------------------------------------------------------------- 1 | torchseq.metric\_hooks package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.metric\_hooks.base module 8 | ---------------------------------- 9 | 10 | .. automodule:: torchseq.metric_hooks.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.metric\_hooks.default module 16 | ------------------------------------- 17 | 18 | .. automodule:: torchseq.metric_hooks.default 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.metric\_hooks.hrq\_agg module 24 | -------------------------------------- 25 | 26 | .. automodule:: torchseq.metric_hooks.hrq_agg 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.metric\_hooks.qg\_metric module 32 | ---------------------------------------- 33 | 34 | .. automodule:: torchseq.metric_hooks.qg_metric 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.metric\_hooks.rouge module 40 | ----------------------------------- 41 | 42 | .. automodule:: torchseq.metric_hooks.rouge 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.metric\_hooks.semparse module 48 | -------------------------------------- 49 | 50 | .. automodule:: torchseq.metric_hooks.semparse 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.metric\_hooks.sep\_ae module 56 | ------------------------------------- 57 | 58 | .. automodule:: torchseq.metric_hooks.sep_ae 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchseq.metric\_hooks.textual module 64 | ------------------------------------- 65 | 66 | .. automodule:: torchseq.metric_hooks.textual 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: torchseq.metric_hooks 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /docs/_source/torchseq.models.rerankers.rst: -------------------------------------------------------------------------------- 1 | torchseq.models.rerankers package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.models.rerankers.backtranslate\_reranker module 8 | -------------------------------------------------------- 9 | 10 | .. automodule:: torchseq.models.rerankers.backtranslate_reranker 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.models.rerankers.combo module 16 | -------------------------------------- 17 | 18 | .. automodule:: torchseq.models.rerankers.combo 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.models.rerankers.ngram\_reranker module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: torchseq.models.rerankers.ngram_reranker 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.models.rerankers.qa\_reranker module 32 | --------------------------------------------- 33 | 34 | .. automodule:: torchseq.models.rerankers.qa_reranker 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.models.rerankers.topk module 40 | ------------------------------------- 41 | 42 | .. automodule:: torchseq.models.rerankers.topk 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: torchseq.models.rerankers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/_source/torchseq.models.samplers.rst: -------------------------------------------------------------------------------- 1 | torchseq.models.samplers package 2 | ================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.models.samplers.beam\_search module 8 | -------------------------------------------- 9 | 10 | .. automodule:: torchseq.models.samplers.beam_search 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.models.samplers.diverse\_beam module 16 | --------------------------------------------- 17 | 18 | .. automodule:: torchseq.models.samplers.diverse_beam 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.models.samplers.greedy module 24 | -------------------------------------- 25 | 26 | .. automodule:: torchseq.models.samplers.greedy 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.models.samplers.parallel\_nucleus module 32 | ------------------------------------------------- 33 | 34 | .. automodule:: torchseq.models.samplers.parallel_nucleus 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.models.samplers.teacher\_force module 40 | ---------------------------------------------- 41 | 42 | .. automodule:: torchseq.models.samplers.teacher_force 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: torchseq.models.samplers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/_source/torchseq.pretrained.rst: -------------------------------------------------------------------------------- 1 | torchseq.pretrained package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.pretrained.lexical\_paraphraser\_bert module 8 | ----------------------------------------------------- 9 | 10 | .. automodule:: torchseq.pretrained.lexical_paraphraser_bert 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.pretrained.lm module 16 | ----------------------------- 17 | 18 | .. automodule:: torchseq.pretrained.lm 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.pretrained.qa module 24 | ----------------------------- 25 | 26 | .. automodule:: torchseq.pretrained.qa 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: torchseq.pretrained 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/_source/torchseq.rst: -------------------------------------------------------------------------------- 1 | torchseq package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchseq.agents 11 | torchseq.datasets 12 | torchseq.metric_hooks 13 | torchseq.models 14 | torchseq.pretrained 15 | torchseq.utils 16 | 17 | Submodules 18 | ---------- 19 | 20 | torchseq.args module 21 | -------------------- 22 | 23 | .. automodule:: torchseq.args 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | torchseq.main module 29 | -------------------- 30 | 31 | .. automodule:: torchseq.main 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Module contents 37 | --------------- 38 | 39 | .. automodule:: torchseq 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | -------------------------------------------------------------------------------- /docs/_source/torchseq.utils.rst: -------------------------------------------------------------------------------- 1 | torchseq.utils package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.utils.cache module 8 | --------------------------- 9 | 10 | .. automodule:: torchseq.utils.cache 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.utils.config module 16 | ---------------------------- 17 | 18 | .. automodule:: torchseq.utils.config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.utils.config\_migration module 24 | --------------------------------------- 25 | 26 | .. automodule:: torchseq.utils.config_migration 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.utils.easy\_generate module 32 | ------------------------------------ 33 | 34 | .. automodule:: torchseq.utils.easy_generate 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.utils.fleiss module 40 | ---------------------------- 41 | 42 | .. automodule:: torchseq.utils.fleiss 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.utils.functions module 48 | ------------------------------- 49 | 50 | .. automodule:: torchseq.utils.functions 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.utils.logging module 56 | ----------------------------- 57 | 58 | .. automodule:: torchseq.utils.logging 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchseq.utils.loss\_dropper module 64 | ----------------------------------- 65 | 66 | .. automodule:: torchseq.utils.loss_dropper 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchseq.utils.mckenzie module 72 | ------------------------------ 73 | 74 | .. automodule:: torchseq.utils.mckenzie 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | torchseq.utils.metrics module 80 | ----------------------------- 81 | 82 | .. automodule:: torchseq.utils.metrics 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | torchseq.utils.model\_loader module 88 | ----------------------------------- 89 | 90 | .. automodule:: torchseq.utils.model_loader 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | torchseq.utils.optimizer\_group module 96 | -------------------------------------- 97 | 98 | .. automodule:: torchseq.utils.optimizer_group 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | torchseq.utils.perplexity module 104 | -------------------------------- 105 | 106 | .. automodule:: torchseq.utils.perplexity 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | torchseq.utils.rouge module 112 | --------------------------- 113 | 114 | .. automodule:: torchseq.utils.rouge 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | torchseq.utils.sari module 120 | -------------------------- 121 | 122 | .. automodule:: torchseq.utils.sari 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | torchseq.utils.seed module 128 | -------------------------- 129 | 130 | .. automodule:: torchseq.utils.seed 131 | :members: 132 | :undoc-members: 133 | :show-inheritance: 134 | 135 | torchseq.utils.singleton module 136 | ------------------------------- 137 | 138 | .. automodule:: torchseq.utils.singleton 139 | :members: 140 | :undoc-members: 141 | :show-inheritance: 142 | 143 | torchseq.utils.tokenizer module 144 | ------------------------------- 145 | 146 | .. automodule:: torchseq.utils.tokenizer 147 | :members: 148 | :undoc-members: 149 | :show-inheritance: 150 | 151 | torchseq.utils.tokenizer\_wordlevel module 152 | ------------------------------------------ 153 | 154 | .. automodule:: torchseq.utils.tokenizer_wordlevel 155 | :members: 156 | :undoc-members: 157 | :show-inheritance: 158 | 159 | torchseq.utils.wandb module 160 | --------------------------- 161 | 162 | .. automodule:: torchseq.utils.wandb 163 | :members: 164 | :undoc-members: 165 | :show-inheritance: 166 | 167 | Module contents 168 | --------------- 169 | 170 | .. automodule:: torchseq.utils 171 | :members: 172 | :undoc-members: 173 | :show-inheritance: 174 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../torchseq')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'torchseq' 21 | copyright = '2020, Tom Hosking' 22 | author = 'Tom Hosking' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'recommonmark', 32 | 'sphinx.ext.autodoc', 33 | 'sphinx.ext.napoleon' 34 | ] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | 47 | # The theme to use for HTML and HTML Help pages. See the documentation for 48 | # a list of builtin themes. 49 | # 50 | html_theme = 'sphinx_rtd_theme' 51 | 52 | # Add any paths that contain custom static files (such as style sheets) here, 53 | # relative to this directory. They are copied after the builtin static files, 54 | # so a file named "default.css" will overwrite the builtin "default.css". 55 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. torchseq documentation master file, created by 2 | sphinx-quickstart on Tue Dec 1 11:20:55 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to torchseq's documentation! 7 | ==================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Getting Started: 12 | 13 | _source/quickstart 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | :caption: Contents: 18 | 19 | _source/modules 20 | 21 | 22 | .. Features 23 | .. -------- 24 | 25 | .. - Be awesome 26 | .. - Make things faster 27 | 28 | .. Installation 29 | .. ------------ 30 | 31 | .. Install $project by running: 32 | 33 | .. install project 34 | 35 | .. Contribute 36 | .. ---------- 37 | 38 | .. - Issue Tracker: github.com/$project/$project/issues 39 | .. - Source Code: github.com/$project/$project 40 | 41 | .. Support 42 | .. ------- 43 | 44 | .. If you are having issues, please let us know. 45 | .. We have a mailing list located at: project@google-groups.com 46 | 47 | .. License 48 | .. ------- 49 | 50 | .. The project is licensed under the BSD license. 51 | 52 | 53 | Indices and tables 54 | ================== 55 | 56 | * :ref:`genindex` 57 | * :ref:`modindex` 58 | * :ref:`search` 59 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/Paraphrasing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "from torchseq.agents.para_agent import ParaphraseAgent\n", 11 | "from torchseq.datasets.json_loader import JsonDataLoader\n", 12 | "from torchseq.utils.config import Config\n", 13 | "from torchseq.metric_hooks.textual import TextualMetricHook\n", 14 | "import torch\n", 15 | "\n", 16 | "model_path = '../models/examples/20210222_152157_paraphrasing_vae/'\n", 17 | "# model_path = '../models/examples/20210503_184659_paraphrasing_vqvae/'\n", 18 | "# model_path = '../models/examples/20210225_112226_paraphrasing_ae/'\n", 19 | "\n", 20 | "\n", 21 | "# Load the config\n", 22 | "with open(model_path + 'config.json') as f:\n", 23 | " cfg_dict = json.load(f)\n", 24 | "cfg_dict[\"env\"][\"data_path\"] = \"../data/\"\n", 25 | "\n", 26 | "\n", 27 | "config = Config(cfg_dict)\n", 28 | "\n", 29 | "# Load the model\n", 30 | "instance = ParaphraseAgent(config=config, run_id=None, output_path=\"./runs/examples/paraphrasing_eval\", silent=False, verbose=False, training_mode=False)\n", 31 | "instance.load_checkpoint(model_path + 'model/checkpoint.pt')\n", 32 | "instance.model.eval()\n", 33 | "\n", 34 | "# Create a dataset\n", 35 | "data_loader = JsonDataLoader(config)\n", 36 | "\n", 37 | "# Run inference on the test split\n", 38 | "# test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader.test_loader, metric_hooks=[TextualMetricHook(config, 's1', 's2')])\n", 39 | "\n", 40 | "# Done!\n", 41 | "# print(all_metrics['ibleu'])" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 5, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "Validating after 0 epochs: 100%|██████████| 1/1 [00:00<00:00, 6.07it/s]" 54 | ] 55 | }, 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "['what is the oldest cat in the world?']\n" 61 | ] 62 | }, 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# You can now run your model on your own dataset\n", 73 | "\n", 74 | "examples = [\n", 75 | " {'q': 'Who was the oldest cat in the world?'},\n", 76 | "]\n", 77 | "\n", 78 | "\n", 79 | "cfg_dict[\"json_dataset\"] = {\n", 80 | " \"path\": None,\n", 81 | " \"field_map\": [\n", 82 | " {\n", 83 | " \"type\": \"copy\",\n", 84 | " \"from\": \"q\",\n", 85 | " \"to\": \"s1\"\n", 86 | " },\n", 87 | " {\n", 88 | " \"type\": \"copy\",\n", 89 | " \"from\": \"q\",\n", 90 | " \"to\": \"s2\"\n", 91 | " }\n", 92 | " ]\n", 93 | "}\n", 94 | "\n", 95 | "config = Config(cfg_dict)\n", 96 | "\n", 97 | " \n", 98 | "data_loader_custom = JsonDataLoader(config, test_samples=examples)\n", 99 | "\n", 100 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader_custom.test_loader)\n", 101 | "\n", 102 | "print(pred_output)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "Python 3", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.8.5" 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 4 134 | } 135 | -------------------------------------------------------------------------------- /examples/QG_BART_Eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Validating after 0 epochs: 100%|██████████| 2970/2970 [58:25<00:00, 1.18s/it] \n" 13 | ] 14 | }, 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "21.065812894587264\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "import json\n", 25 | "from torchseq.agents.aq_agent import AQAgent\n", 26 | "from torchseq.datasets.qa_loader import QADataLoader\n", 27 | "from torchseq.utils.config import Config\n", 28 | "from torchseq.metric_hooks.textual import TextualMetricHook\n", 29 | "import torch\n", 30 | "\n", 31 | "model_path = '../models/examples/20210223_191015_qg_bart/'\n", 32 | "\n", 33 | "\n", 34 | "# Load the config\n", 35 | "with open(model_path + 'config.json') as f:\n", 36 | " cfg_dict = json.load(f)\n", 37 | "cfg_dict[\"env\"][\"data_path\"] = \"../data/\"\n", 38 | "\n", 39 | "config = Config(cfg_dict)\n", 40 | "\n", 41 | "# Load the model\n", 42 | "instance = AQAgent(config=config, run_id=None, output_path=\"./runs/examples/qg_bert_eval\", silent=False, verbose=False, training_mode=False)\n", 43 | "instance.load_checkpoint(model_path + 'model/checkpoint.pt')\n", 44 | "instance.model.eval()\n", 45 | "\n", 46 | "# Create a dataset\n", 47 | "data_loader = QADataLoader(config)\n", 48 | "\n", 49 | "# Run inference on the test split\n", 50 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader.test_loader, metric_hooks=[TextualMetricHook(config, 'c', 'q')])\n", 51 | "\n", 52 | "# Done!\n", 53 | "print(all_metrics['bleu'])" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "Validating after 0 epochs: 100%|██████████| 1/1 [00:00<00:00, 2.43it/s]" 66 | ] 67 | }, 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "['Who was the oldest cat?', 'How long did Creme Puff live?']\n" 73 | ] 74 | }, 75 | { 76 | "name": "stderr", 77 | "output_type": "stream", 78 | "text": [ 79 | "\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "# You can now run your model on your own dataset\n", 85 | "\n", 86 | "examples = [\n", 87 | " {'c': 'Creme Puff was the oldest cat.', 'a': 'Creme Puff'},\n", 88 | " {'c': 'Creme Puff lived for 38 years and 3 days', 'a': '38 years and 3 days'},\n", 89 | "]\n", 90 | "\n", 91 | "# The examples need the answer character position, and a placeholder for the question\n", 92 | "examples = [\n", 93 | " {**ex, 'a_pos': ex['c'].index(ex['a']), 'q': ''} for ex in examples\n", 94 | "]\n", 95 | " \n", 96 | "data_loader_custom = QADataLoader(config, test_samples=examples)\n", 97 | "\n", 98 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader_custom.test_loader)\n", 99 | "\n", 100 | "print(pred_output)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "Python 3", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.8.5" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 4 132 | } 133 | -------------------------------------------------------------------------------- /examples/mBART.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "[\"Nous pouvons traduire de l'anglais en français\", 'We can translate from French to English']\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import json\n", 18 | "from torchseq.agents.para_agent import ParaphraseAgent\n", 19 | "from torchseq.datasets.json_loader import JsonDataLoader\n", 20 | "from torchseq.utils.config import Config\n", 21 | "import torch\n", 22 | "\n", 23 | "\n", 24 | "\n", 25 | "with open('../configs/mbart.json') as f:\n", 26 | " cfg_dict = json.load(f)\n", 27 | "cfg_dict[\"env\"][\"data_path\"] = \"../data/\"\n", 28 | "cfg_dict[\"eval\"][\"eval_batch_size\"] = 1\n", 29 | "\n", 30 | "cfg_dict['training'][\"dataset\"] = 'json'\n", 31 | "cfg_dict[\"json_dataset\"] = {\n", 32 | " \"path\": None,\n", 33 | " \"field_map\": [\n", 34 | " {\"type\": \"copy\", \"from\": \"input\", \"to\": \"s2\"},\n", 35 | " {\"type\": \"copy\", \"from\": \"input\", \"to\": \"s1\"},\n", 36 | " ],\n", 37 | "}\n", 38 | "\n", 39 | "config = Config(cfg_dict)\n", 40 | "\n", 41 | "instance = ParaphraseAgent(config=config, run_id=None, output_path=\"./runs/examples/mbart/\", silent=True, verbose=False)\n", 42 | "\n", 43 | "instance.model.eval()\n", 44 | "\n", 45 | "examples = [\n", 46 | " {'input': 'We can translate from English to French', 'src_lang': 'en_XX', 'tgt_lang': 'fr_XX'},\n", 47 | " {'input': 'Nous pouvons traduire de Francais en Anglais', 'src_lang': 'fr_XX', 'tgt_lang': 'en_XX'}\n", 48 | "]\n", 49 | " \n", 50 | "data_loader = JsonDataLoader(config, test_samples=examples)\n", 51 | "\n", 52 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader.test_loader)\n", 53 | "\n", 54 | "print(pred_output)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [] 63 | } 64 | ], 65 | "metadata": { 66 | "kernelspec": { 67 | "display_name": "Python 3", 68 | "language": "python", 69 | "name": "python3" 70 | }, 71 | "language_info": { 72 | "codemirror_mode": { 73 | "name": "ipython", 74 | "version": 3 75 | }, 76 | "file_extension": ".py", 77 | "mimetype": "text/x-python", 78 | "name": "python", 79 | "nbconvert_exporter": "python", 80 | "pygments_lexer": "ipython3", 81 | "version": "3.8.5" 82 | } 83 | }, 84 | "nbformat": 4, 85 | "nbformat_minor": 4 86 | } 87 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_return_any = True 3 | warn_unused_configs = True 4 | ignore_missing_imports = True 5 | exclude = (?x)(torchseq/external/) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard==2.15.0 2 | torch==2.2.2 3 | tqdm>=4.62 4 | scipy>=1.5 5 | lightning==2.1.0 6 | 7 | nltk>=3.6.7 8 | transformers==4.39.2 9 | tokenizers==0.15.2 10 | jsonlines>=2 11 | sacrebleu>=2.0 12 | py-rouge 13 | rouge-score 14 | opentsne 15 | sentencepiece==0.1.95 16 | pydantic==1.10.13 17 | truecase==0.0.14 18 | # summac 19 | 20 | compress_json==1.0.10 21 | 22 | matplotlib 23 | 24 | wandb==0.15.12 25 | protobuf<4 26 | 27 | pytest>=6.2 28 | pytest-cov>=2.12 29 | codecov>=2.1 30 | flake8==6.0.0 31 | black==24.3.0 32 | mypy==1.2.0 33 | 34 | sphinx>=3.3 35 | sphinx_rtd_theme 36 | recommonmark -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir ../data/ 2 | 3 | mkdir ../data/squad 4 | curl https://raw.githubusercontent.com/tomhosking/squad-du-split/master/train-v1.1.json -o ../data/squad/train-v1.1.json -L -C - 5 | curl https://raw.githubusercontent.com/tomhosking/squad-du-split/master/dev-v1.1.json -o ../data/squad/dev-v1.1.json -L -C - 6 | curl https://raw.githubusercontent.com/tomhosking/squad-du-split/master/test-v1.1.json -o ../data/squad/test-v1.1.json -L -C - 7 | 8 | # mkdir ../runs 9 | # mkdir ../runs/slurmlogs 10 | 11 | # curl https://nlp.stanford.edu/data/glove.6B.zip -o ./data/glove.6B.zip -L -C - 12 | # unzip ./data/glove.6B.zip -d ./data/glove.6B/ 13 | # rm ./data/glove.6B.zip 14 | -------------------------------------------------------------------------------- /scripts/download_models.py: -------------------------------------------------------------------------------- 1 | from transformers import BartModel, BertModel, BertTokenizer, RobertaModel, BertForQuestionAnswering, MBartModel, MBartTokenizer 2 | 3 | 4 | mod = BartModel.from_pretrained('facebook/bart-large') 5 | mod = RobertaModel.from_pretrained('roberta-base') 6 | 7 | mod = BertModel.from_pretrained('bert-base-uncased') 8 | mod = BertModel.from_pretrained('bert-base-cased') 9 | 10 | mod = BertTokenizer.from_pretrained('bert-base-uncased') 11 | mod = BertTokenizer.from_pretrained('bert-base-cased') 12 | 13 | mod = BertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") 14 | 15 | 16 | import nltk 17 | nltk.download('punkt', force=True) 18 | nltk.download('wordnet', force=True) 19 | 20 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel 21 | 22 | tokenizer = AutoTokenizer.from_pretrained("tomhosking/deberta-v3-base-debiased-nli") 23 | model = AutoModelForSequenceClassification.from_pretrained("tomhosking/deberta-v3-base-debiased-nli") 24 | 25 | 26 | tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base") 27 | model = AutoModel.from_pretrained("microsoft/deberta-v3-base") 28 | 29 | 30 | mod = MBartTokenizer.from_pretrained('facebook/mbart-large-50', src_lang='en_XX', tgt_lang='en_XX') 31 | mod = MBartTokenizer.from_pretrained('facebook/mbart-large-50-many-to-many-mmt', src_lang='en_XX', tgt_lang='en_XX') 32 | mod = MBartModel.from_pretrained('facebook/mbart-large-50-many-to-many-mmt') 33 | 34 | 35 | from summac.model_summac import SummaCConv 36 | 37 | model_conv = SummaCConv( 38 | models=["vitc"], 39 | bins="percentile", 40 | granularity="sentence", 41 | nli_labels="e", 42 | device="cpu", 43 | start_file="default", 44 | agg="mean", 45 | ) 46 | 47 | for imager in model_conv.imagers: 48 | imager.load_nli() 49 | 50 | model_conv = SummaCConv( 51 | models=["mnli"], 52 | bins="percentile", 53 | granularity="sentence", 54 | nli_labels="e", 55 | device="cpu", 56 | start_file="default", 57 | agg="mean", 58 | ) 59 | 60 | for imager in model_conv.imagers: 61 | imager.load_nli() -------------------------------------------------------------------------------- /scripts/kaggle_prepro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | 5 | from math import floor 6 | import csv 7 | 8 | filtered_pairs = [] 9 | 10 | included = 0 11 | skipped = 0 12 | 13 | np.random.seed(0) 14 | 15 | DEVTEST_FRACTION = 0.1 16 | 17 | qids_train = set() 18 | qids_dev = set() 19 | qids_test = set() 20 | 21 | with open('/mnt/ext/phd/data/kaggle-questionpairs/questions.csv') as f, \ 22 | open('./data/kaggle/paraphrases.train.txt', 'w') as f_train, \ 23 | open('./data/kaggle/paraphrases.dev.txt', 'w') as f_dev, \ 24 | open('./data/kaggle/paraphrases.test.txt', 'w') as f_test: 25 | csv_reader = csv.reader(f) 26 | for ix, cols in enumerate(csv_reader): 27 | if ix == 0: 28 | continue 29 | 30 | # cols = line.split('\t') 31 | 32 | if len(cols[3]) > 500 or len(cols[4]) > 500: 33 | continue 34 | 35 | if cols[5] == '1': 36 | rand = np.random.random() 37 | 38 | qid1 = cols[1] 39 | qid2 = cols[2] 40 | 41 | if qid1 in qids_train or qid2 in qids_train: 42 | print('Forcing to train') 43 | f_train.write('{:}\t{:}\n'.format(cols[3], cols[4])) 44 | 45 | elif qid1 in qids_dev or qid2 in qids_dev: 46 | print('Forcing to dev') 47 | f_dev.write('{:}\t{:}\n'.format(cols[3], cols[4])) 48 | elif qid1 in qids_test or qid2 in qids_test: 49 | print('Forcing to test') 50 | f_test.write('{:}\t{:}\n'.format(cols[3], cols[4])) 51 | else: 52 | # 0.8-0.9 53 | if rand >= (1-DEVTEST_FRACTION*2) and rand < (1-DEVTEST_FRACTION) and cols[5] == '1': 54 | 55 | f_dev.write('{:}\t{:}\n'.format(cols[3], cols[4])) 56 | qids_dev.add(qid1) 57 | qids_dev.add(qid2) 58 | # 0.9+ 59 | elif rand >= (1-DEVTEST_FRACTION) and cols[5] == '1': 60 | f_test.write('{:}\t{:}\n'.format(cols[3], cols[4])) 61 | qids_test.add(qid1) 62 | qids_test.add(qid2) 63 | # Under 0.8 64 | else: 65 | f_train.write('{:}\t{:}\n'.format(cols[3], cols[4])) 66 | qids_train.add(qid1) 67 | qids_train.add(qid2) 68 | included += 1 69 | else: 70 | skipped += 1 71 | 72 | 73 | print(included, skipped) -------------------------------------------------------------------------------- /scripts/paranmt_prepro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import os 5 | 6 | from math import floor 7 | 8 | filtered_pairs = [] 9 | 10 | included = 0 11 | skipped = 0 12 | 13 | np.random.seed(0) 14 | 15 | TRAIN_FRACTION = 0.97 16 | 17 | # with open('./data/paranmt/para-nmt-50m.txt') as f: 18 | # with open('./data/paranmt/paraphrases.train.txt', 'w') as f_train: 19 | # with open('./data/paranmt/paraphrases.dev.txt', 'w') as f_dev: 20 | # for line in f: 21 | # cols = line.strip().split('\t') 22 | 23 | # if line.strip() == '' or len(cols[0]) > 300 or len(cols[1]) > 300 or len(cols[0]) < 30 or len(cols[1]) < 30: 24 | # continue 25 | 26 | # if float(cols[2]) < 0.9 and float(cols[2]) > 0.5: 27 | # if np.random.random() > TRAIN_FRACTION: 28 | # f_dev.write('{:}\t{:}\n'.format(cols[1].strip(), cols[0].strip())) 29 | # else: 30 | # f_train.write('{:}\t{:}\n'.format(cols[1].strip(), cols[0].strip())) 31 | # included += 1 32 | # else: 33 | # skipped += 1 34 | 35 | 36 | # os.makedirs('./data/parabank-qs') 37 | with open('../../data/parabank/parabank-1.0-large-diverse/parabank.50m.tsv') as f: 38 | with open('./data/parabank-qs/paraphrases.train.txt', 'w') as f_train: 39 | with open('./data/parabank-qs/paraphrases.dev.txt', 'w') as f_dev: 40 | for line in f: 41 | cols = line.strip().split('\t') 42 | 43 | if line.strip() == '' or len(cols[0]) > 300 or len(cols[1]) > 300 or len(cols[0]) < 30 or len(cols[1]) < 30: 44 | continue 45 | 46 | if cols[0][-1] != '?' or cols[1][-1] != '?': 47 | continue 48 | 49 | # NOTE: I think parabank is already inverted (mt->orig) but paranmt is (orig->mt) 50 | if np.random.random() > TRAIN_FRACTION: 51 | f_dev.write('{:}\t{:}\n'.format(cols[0].strip(), cols[1].strip())) 52 | else: 53 | f_train.write('{:}\t{:}\n'.format(cols[0].strip(), cols[1].strip())) 54 | included += 1 55 | 56 | 57 | print(included, skipped) -------------------------------------------------------------------------------- /scripts/submit_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MCKENZIE_HOOK=~/mckenzie/scripts/hook.sh 4 | 5 | 6 | 7 | jobName="UNK" 8 | jobTag="custom" 9 | 10 | 11 | POSITIONAL=() 12 | while [[ $# -gt 0 ]] 13 | do 14 | key="$1" 15 | 16 | case $key in 17 | --config) 18 | CONFIG="$2" 19 | POSITIONAL+=("$1") # save it in an array for later 20 | POSITIONAL+=("$2") # save it in an array for later 21 | jobName=$(cat $CONFIG | grep \"name\"\: | sed -E 's/.+\"name\": \"(.*)\"\,/\1/') 22 | jobTag=$(cat $CONFIG | grep \"tag\"\: | sed -E 's/.+\"tag\": \"(.*)\"\,/\1/') 23 | shift # past argument 24 | shift # past value 25 | ;; 26 | *) # unknown option 27 | POSITIONAL+=("$1") # save it in an array for later 28 | shift # past argument 29 | ;; 30 | esac 31 | done 32 | set -- "${POSITIONAL[@]}" # restore positional parameters 33 | 34 | 35 | 36 | res=$(sbatch $@) 37 | 38 | 39 | jobId=`echo $res | sed -E 's/Submitted batch job ([0-9]+)/\1/'` 40 | 41 | partition=`scontrol show job $jobId | grep "Partition=([^\s]+)" -Po | sed s/Partition=//` 42 | 43 | if [ "$jobId" != "$res" ] 44 | then 45 | 46 | ${MCKENZIE_HOOK} -a 1 -i $jobId -p $partition > /dev/null 47 | 48 | echo "Batch job ID $jobId -> $jobTag/$jobName" 49 | 50 | ${MCKENZIE_HOOK} -i $jobId -p $partition -n $jobTag/$jobName > /dev/null 51 | sleep 5 52 | else 53 | echo "Error submitting job!" 54 | echo res 55 | fi -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torchseq", 8 | version="3.1.0", 9 | author="Tom Hosking", 10 | author_email="code@tomho.sk", 11 | description="A Seq2Seq framework for PyTorch", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/tomhosking/torchseq", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | "Intended Audience :: Science/Research", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | ], 23 | python_requires='>=3.6', 24 | entry_points = { 25 | 'console_scripts': ['torchseq=torchseq.main:main', 'torchseq-eval=torchseq.eval.cli:main'], 26 | }, 27 | install_requires = [ 28 | 'tensorboard==2.15.0', 29 | 'torch==2.2.2', 30 | 'tqdm>=4.62', 31 | 'scipy>=1.5', 32 | 'nltk>=3.6.7', 33 | 'transformers==4.39.2', 34 | 'tokenizers==0.15.2', 35 | 'jsonlines>=2', 36 | 'sacrebleu>=2.0', 37 | 'py-rouge', 38 | 'rouge-score', 39 | 'wandb==0.15.12', 40 | 'matplotlib', 41 | 'opentsne', 42 | 'sentencepiece==0.1.95', 43 | 'protobuf<4', 44 | 'pydantic==1.10.13', 45 | 'truecase==0.0.14', 46 | 'lightning==2.1.0', 47 | # 'summac', 48 | 'compress_json==1.0.10', 49 | ], 50 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | 7 | 8 | from . import utils as test_utils 9 | 10 | from torchseq.utils import metrics 11 | from torchseq.utils.rouge import get_jackknife_rouge, get_pairwise_rouge 12 | 13 | from torchseq.utils.fleiss import fleiss 14 | 15 | import numpy as np 16 | 17 | 18 | def test_bleu(): 19 | refs = ["The dog bit the man.", "It was not unexpected.", "The man bit him first."] 20 | sys = ["The dog bit the man.", "It wasn't surprising.", "The man had just bitten him."] 21 | bleu = metrics.bleu_corpus(sys, refs) 22 | assert abs(bleu - 45.0675) < 0.001, "BLEU score for basic examples differs from SacreBLEU reference!" 23 | 24 | 25 | def test_rouge(): 26 | refs = ["The dog bit the man.", "It was not unexpected.", "The man bit him first."] 27 | sys = ["The dog bit the man.", "It wasn't surprising.", "The man had just bitten him."] 28 | rouge = get_pairwise_rouge(sys[0], refs[0]) 29 | assert "rouge2" in rouge 30 | assert "rougeL" in rouge 31 | assert abs(rouge["rouge2"] - 100) < 0.001, "Rouge score for basic examples differs from reference!" 32 | 33 | refs = [ 34 | ["The dog bit the man.", "The man was bitten by the dog"], 35 | ["It was not unexpected.", "it was not surprising"], 36 | ["The man bit him first."], 37 | ] 38 | rouge = get_jackknife_rouge(sys, refs) 39 | assert "rouge2" in rouge 40 | assert "rougeL" in rouge 41 | 42 | # NOTE: This test is intrinsic - the rouge value was obtained using the same rouge implementation, so is only a check for regression! 43 | assert abs(rouge["rouge2"] - 30.741) < 0.001, "Rouge score for jackknife examples differs from reference!" 44 | 45 | 46 | def test_meteor(): 47 | preds = ["It is a guide to action which ensures that the military always obeys the commands of the party"] 48 | refs = ["It is a guide to action that ensures that the military will forever heed Party commands"] 49 | 50 | assert metrics.meteor_corpus(refs, preds) - 0.7398 < 1e-4 51 | 52 | 53 | def test_f1(): 54 | assert metrics.f1("same", "same") == 1.0, "F1 failed for correct example!" 55 | assert metrics.f1("same", "diff") == 0.0, "F1 failed for wrong example!" 56 | assert metrics.f1("tok1 tok2", "tok1 tok3") == 0.5, "F1 failed for overlapping example!" 57 | 58 | 59 | def test_fleiss(): 60 | bad = [[1, 1], [1, 1], [1, 1]] 61 | 62 | good = [[2, 0], [0, 2], [2, 0]] 63 | 64 | random = [[2, 0], [0, 2], [1, 1], [1, 1]] 65 | 66 | assert fleiss(np.array(bad)) == -1.0 67 | assert fleiss(np.array(good)) == 1 68 | assert fleiss(np.array(random)) == 0 69 | 70 | 71 | def test_misc(): 72 | assert metrics.normalize_answer("a Very oLd cat! ") == "very old cat" 73 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | 7 | 8 | from . import utils as test_utils 9 | 10 | from torchseq.utils.config import Config, merge_cfg_dicts 11 | from torchseq.utils.singleton import Singleton 12 | import torchseq.utils.functions as tsfunctions 13 | 14 | 15 | def test_config(): 16 | main_cfg_dict = { 17 | "name": "model_name", 18 | "int": 1, 19 | "float": 0.2, 20 | "str": "hello", 21 | "bool": True, 22 | "nested": {"value": "present"}, 23 | } 24 | 25 | mask_cfg_dict = {"name": "%_mod", "int": 2, "nested": {"value": "overwritten", "newval": "added"}} 26 | 27 | cfg_obj = Config(main_cfg_dict) 28 | 29 | merged_obj = Config(merge_cfg_dicts(main_cfg_dict, mask_cfg_dict)) 30 | 31 | assert cfg_obj.int == 1 32 | assert cfg_obj.float == 0.2 33 | assert cfg_obj.str == "hello" 34 | assert cfg_obj.bool is True 35 | assert cfg_obj.nested.value == "present" 36 | 37 | assert cfg_obj.get_path(["nested", "value"]) == "present" 38 | assert cfg_obj.get_path(["nested", "missing"], "alt_default") == "alt_default" 39 | 40 | assert cfg_obj.get("str", "not hello") == "hello" 41 | assert cfg_obj.get("missing", "default_val") == "default_val" 42 | 43 | assert cfg_obj.get_first(["float", "int"]) == 0.2 44 | 45 | assert merged_obj.int == 2 46 | assert merged_obj.nested.value == "overwritten" 47 | assert merged_obj.nested.newval == "added" 48 | assert merged_obj.name == "model_name_mod" 49 | 50 | 51 | def test_singleton(): 52 | class TestClass(metaclass=Singleton): 53 | def __init__(self, val): 54 | self.val = val 55 | 56 | x = TestClass("x") 57 | y = TestClass("y") 58 | 59 | assert x.val == "x" 60 | assert x.val == y.val 61 | 62 | 63 | def test_cache(): 64 | from torchseq.utils.cache import Cache 65 | 66 | cache = Cache() 67 | 68 | data_str = "testing" 69 | cache.save("string", data_str) 70 | assert cache.load("string") == data_str 71 | 72 | data_tensor = torch.rand(4, 7) 73 | cache.save("tensor", data_tensor) 74 | assert cache.load("tensor").equal(data_tensor) 75 | 76 | assert cache.load("missing") is None 77 | 78 | 79 | def test_functions(): 80 | # topk 81 | input_probs = torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.13, 0.7, 0.17, 0.0]]) 82 | 83 | assert tsfunctions.top_k_top_p_filtering(input_probs, top_k=2).equal( 84 | torch.tensor([[-torch.inf, -torch.inf, 0.3000, 0.4000], [-torch.inf, 0.7000, 0.1700, -torch.inf]]) 85 | ) 86 | 87 | assert tsfunctions.top_k_top_p_filtering(input_probs, top_k=2, filter_value=0).equal( 88 | torch.tensor([[0.0000, 0.0000, 0.3000, 0.4000], [0.0000, 0.7000, 0.1700, 0.0000]]) 89 | ) 90 | 91 | assert tsfunctions.top_k_top_p_filtering(input_probs, top_p=0.5, filter_value=0).equal( 92 | torch.tensor([[0.0000, 0.0000, 0.3000, 0.4000], [0.0000, 0.7000, 0.1700, 0.0000]]) 93 | ) 94 | 95 | x = torch.Tensor([[0.0, 0.0, 0.0, 1.0], [0.13, 0.7, 0.17, 0.0]]) 96 | y = torch.Tensor([[0.0, 0.0, 0.0, 1.0]]) 97 | 98 | assert tsfunctions.cos_sim(x, y).equal(torch.tensor([1.0, 0.0000])) 99 | 100 | test_data = list(range(6)) 101 | 102 | assert list(tsfunctions.batchify(test_data, 4)) == [(0, [0, 1, 2, 3]), (1, [4, 5])] 103 | -------------------------------------------------------------------------------- /tests/test_pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | 7 | 8 | from . import utils as test_utils 9 | 10 | from torchseq.pretrained.qa import PreTrainedQA 11 | from torchseq.pretrained.lm import PretrainedLM 12 | 13 | 14 | @test_utils.slow 15 | def test_qa(): 16 | use_cuda = torch.cuda.is_available() 17 | 18 | instance = PreTrainedQA(device=("cuda" if use_cuda else "cpu")) 19 | 20 | preds = instance.infer_batch( 21 | ["Who was the oldest cat?", "Who was a nice puppet?"], 22 | ["Creme Puff was the oldest cat.", "This is a distraction. " * 50 + "Creme Puff was the oldest cat."], 23 | ) 24 | 25 | assert len(preds) == 2, "Failed to produce the right number of predictions?!" 26 | assert preds[0] == "Creme Puff", "Short QA test failed - answer is wrong" 27 | assert preds[1] == "Creme Puff", "Long QA test failed - answer is wrong" 28 | 29 | 30 | @test_utils.slow 31 | def test_lm(): 32 | instance = PretrainedLM() 33 | 34 | preds = instance.get_log_prob( 35 | [ 36 | "The first thing", 37 | "Creme Puff was the oldest cat.", 38 | " Variational Autoencoders (VAEs) provide a theoretically-backed and popular framework for deep generative models.", 39 | ], 40 | ) 41 | 42 | print(preds) 43 | 44 | assert len(preds) == 3, "Failed to produce the right number of predictions?!" 45 | assert preds[0] < 9, "Simple sentence score is too high" 46 | assert preds[0] > 8, "Simple sentence score is too low" 47 | assert abs(preds[1] - 8.8) < 0.1, "Normal sentence score out of range" 48 | assert abs(preds[2] - 10.8) < 0.1, "Hard sentence score out of range" 49 | -------------------------------------------------------------------------------- /tests/test_tokenisation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import utils as test_utils 4 | from torchseq.utils.tokenizer import Tokenizer 5 | 6 | 7 | def test_bert_uncased_basic(): 8 | TEST_STRING = "This is a test sentence." 9 | 10 | tokenizer = Tokenizer("bert-base-uncased") 11 | 12 | tokenized = tokenizer.tokenise(TEST_STRING) 13 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 14 | 15 | assert [tok["id"] for tok in tokenized] == [ 16 | 101, 17 | 2023, 18 | 2003, 19 | 1037, 20 | 3231, 21 | 6251, 22 | 1012, 23 | 102, 24 | ], "BERT uncased tok ids are wrong for basic example!" 25 | assert decoded == TEST_STRING.lower(), "BERT uncased tokenisation isn't reversible for basic example!" 26 | 27 | 28 | def test_bert_cased_basic(): 29 | TEST_STRING = "This is a test sentence." 30 | 31 | tokenizer = Tokenizer("bert-base-cased") 32 | 33 | tokenized = tokenizer.tokenise(TEST_STRING) 34 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 35 | 36 | assert [tok["id"] for tok in tokenized] == [ 37 | 101, 38 | 1188, 39 | 1110, 40 | 170, 41 | 2774, 42 | 5650, 43 | 119, 44 | 102, 45 | ], "BERT cased tok ids are wrong for basic example!" 46 | assert decoded == TEST_STRING, "BERT cased tokenisation isn't reversible for basic example!" 47 | 48 | 49 | def test_roberta_basic(): 50 | tokenizer = Tokenizer("roberta-base") 51 | 52 | TEST_STRING = "This is a test sentence." 53 | 54 | tokenized = tokenizer.tokenise(TEST_STRING) 55 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 56 | 57 | assert [tok["id"] for tok in tokenized] == [ 58 | 0, 59 | 713, 60 | 16, 61 | 10, 62 | 1296, 63 | 3645, 64 | 4, 65 | 2, 66 | ], "RoBERTa tok ids are wrong for basic example!" 67 | assert decoded == TEST_STRING, "RoBERTa tokenisation isn't reversible for basic example!" 68 | 69 | 70 | def test_auto(): 71 | tokenizer = Tokenizer("xlnet-base-cased") 72 | 73 | TEST_STRING = "This is a test sentence." 74 | 75 | tokenized = tokenizer.tokenise(TEST_STRING) 76 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 77 | 78 | assert [tok["id"] for tok in tokenized] == [ 79 | 1, 80 | 122, 81 | 27, 82 | 24, 83 | 934, 84 | 3833, 85 | 9, 86 | 4, 87 | 3, 88 | 2, 89 | ], "Auto tok ids are wrong for basic example!" 90 | assert decoded == TEST_STRING, "Auto tokenisation isn't reversible for basic example!" 91 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | from distutils.util import strtobool 5 | 6 | 7 | def parse_flag_from_env(key, default=False): 8 | try: 9 | value = os.environ[key] 10 | except KeyError: 11 | # KEY isn't set, default to `default`. 12 | _value = default 13 | else: 14 | # KEY is set, convert it to True or False. 15 | try: 16 | _value = strtobool(value) 17 | except ValueError: 18 | # More values are supported, but let's keep the message simple. 19 | raise ValueError("If set, {} must be yes or no.".format(key)) 20 | return _value 21 | 22 | 23 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) 24 | 25 | 26 | def slow(test_case): 27 | """ 28 | Decorator marking a test as slow. 29 | Slow tests are skipped by default. Set the RUN_SLOW environment variable 30 | to a truthy value to run them. 31 | """ 32 | if not _run_slow_tests: 33 | test_case = unittest.skip("test is slow")(test_case) 34 | return test_case 35 | -------------------------------------------------------------------------------- /torchseq/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.1.0" 2 | 3 | 4 | from .utils.model_loader import model_from_path 5 | from .utils.easy_generate import generate 6 | -------------------------------------------------------------------------------- /torchseq/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/agents/__init__.py -------------------------------------------------------------------------------- /torchseq/agents/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | # import torch._dynamo 6 | 7 | from torchseq.utils.functions import to_device_unless_marked 8 | 9 | # This project was originally based off this template: 10 | # https://github.com/moemen95/Pytorch-Project-Template 11 | 12 | 13 | class BaseAgent: 14 | """ 15 | This base class will contain the base functions to be overloaded by any agent you will implement. 16 | """ 17 | 18 | cuda: bool 19 | use_lightning: bool = False 20 | 21 | def __init__(self, config): 22 | self.config = config 23 | self.logger = logging.getLogger("Agent") 24 | 25 | def set_device(self, use_cuda=True): 26 | # set cuda flag 27 | self.cuda_available = torch.cuda.is_available() 28 | if self.cuda_available and not use_cuda: 29 | self.logger.warning("You have a CUDA device, so you should probably enable CUDA") 30 | 31 | if use_cuda and not self.cuda_available: 32 | self.logger.error("Use CUDA is set to true, but not CUDA devices were found!") 33 | raise Exception("No CUDA devices found") 34 | 35 | self.cuda = self.cuda_available & use_cuda 36 | 37 | if not self.model: 38 | raise Exception("You need to define your model before calling set_device!") 39 | 40 | if self.cuda: 41 | self.device = torch.device("cuda") 42 | 43 | self.logger.info("Model will run on *****GPU-CUDA***** ") 44 | 45 | # self.model.to(self.device) 46 | self.model.apply(to_device_unless_marked(self.device)) 47 | 48 | self.loss.to(self.device) 49 | 50 | # TODO: Enable for pytorch 2.0 51 | # torch._dynamo.config.verbose = False 52 | # torch._dynamo.config.log_level = logging.WARN 53 | # torch._dynamo.reset() 54 | # self.model = torch.compile(self.model, dynamic=True, mode='reduce-overhead') #, backend="inductor",fullgraph=True, mode='reduce-overhead',, dynamic=True 55 | 56 | else: 57 | self.device = torch.device("cpu") 58 | 59 | self.logger.info("Model will run on *****CPU*****") 60 | 61 | def load_checkpoint(self, file_name): 62 | """ 63 | Latest checkpoint loader 64 | :param file_name: name of the checkpoint file 65 | :return: 66 | """ 67 | raise NotImplementedError 68 | 69 | def save_checkpoint(self, file_name="checkpoint.pt", is_best=0): 70 | """ 71 | Checkpoint saver 72 | :param file_name: name of the checkpoint file 73 | :param is_best: boolean flag to indicate whether current checkpoint's metric is the best so far 74 | :return: 75 | """ 76 | raise NotImplementedError 77 | 78 | def train(self, data_loader) -> None: 79 | """ 80 | Main training loop 81 | :return: 82 | """ 83 | raise NotImplementedError 84 | 85 | def train_one_epoch(self): 86 | """ 87 | One epoch of training 88 | :return: 89 | """ 90 | raise NotImplementedError 91 | 92 | def validate(self): 93 | """ 94 | One cycle of model validation 95 | :return: 96 | """ 97 | raise NotImplementedError 98 | 99 | def finalize(self): 100 | """ 101 | Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader 102 | :return: 103 | """ 104 | raise NotImplementedError 105 | -------------------------------------------------------------------------------- /torchseq/agents/lm_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from torchseq.agents.model_agent import ModelAgent 6 | 7 | 8 | from torchseq.models.lm_transformer import TransformerLanguageModel 9 | from torchseq.utils.functions import gaussian_kl 10 | 11 | 12 | class LangModelAgent(ModelAgent): 13 | def __init__( 14 | self, 15 | config, 16 | run_id, 17 | output_path, 18 | data_path, 19 | silent=False, 20 | training_mode=True, 21 | verbose=True, 22 | cache_root=None, 23 | use_cuda=True, 24 | ): 25 | super().__init__(config, run_id, output_path, data_path, silent, training_mode, verbose, cache_root) 26 | 27 | self.tgt_field = "sent" 28 | 29 | self.src_field = "sent" 30 | 31 | # define loss 32 | self.loss = nn.CrossEntropyLoss( 33 | ignore_index=self.output_tokenizer.pad_id, 34 | reduction="none", 35 | label_smoothing=self.config.training.get("label_smoothing", 0.0), 36 | ) 37 | 38 | # define model 39 | self.model = TransformerLanguageModel(self.config, src_field=self.src_field) 40 | 41 | # define optimizer 42 | if training_mode: 43 | self.create_optimizer() 44 | 45 | self.set_device(use_cuda) 46 | 47 | self.create_samplers() 48 | -------------------------------------------------------------------------------- /torchseq/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser( 6 | description="TorchSeq", 7 | ) 8 | 9 | parser.add_argument("-V", "--version", action="store_true", help="Display version") 10 | 11 | # Config stuff 12 | parser.add_argument( 13 | "-c", "--config", type=str, metavar="CONFIG", default="./configs/default.json", help="Path to config file" 14 | ) 15 | parser.add_argument( 16 | "-p", 17 | "--patch", 18 | type=str, 19 | metavar="PATCH", 20 | default=None, 21 | help="Config mask(s) to overwrite main config with", 22 | action="append", 23 | ) 24 | 25 | # Actions 26 | parser.add_argument("--train", action="store_true", help="Run training?") 27 | parser.add_argument("--validate", action="store_true", help="Run eval on dev?") 28 | parser.add_argument("--validate_train", action="store_true", help="Run eval on train?") 29 | parser.add_argument("--test", action="store_true", help="Run eval on test?") 30 | parser.add_argument("--silent", action="store_true", help="Disable logging") 31 | parser.add_argument("--verbose", action="store_true", help="Extra logging") 32 | parser.add_argument( 33 | "--reload_after_train", action="store_true", help="Reload model after training to do a validation run" 34 | ) 35 | parser.add_argument( 36 | "--copy_chkpt", 37 | action="store_true", 38 | help="Save a copy of the checkpoint in current output dir, even if loading from elsewhere", 39 | ) 40 | 41 | # Model loading 42 | parser.add_argument("--load_chkpt", type=str, metavar="CHECKPOINT", default=None, help="Path to checkpoint file") 43 | parser.add_argument("-l", "--load", type=str, metavar="MODEL", default=None, help="Path to model folder") 44 | parser.add_argument("--nocache", action="store_true", help="Disable loading from an old cache") 45 | 46 | # Paths 47 | parser.add_argument("-d", "--data_path", type=str, metavar="DATA", default="./data/", help="Path to data sources") 48 | parser.add_argument( 49 | "-o", "--output_path", type=str, metavar="OUTPUT", default="./runs/", help="Path to output folder" 50 | ) 51 | 52 | # Runtime 53 | parser.add_argument("--cpu", action="store_true", help="Disable CUDA") 54 | parser.add_argument("--debug", action="store_true", help="Enable debug mode") 55 | parser.add_argument("--lightning", action="store_true", help="Enable Lightning") 56 | parser.add_argument("--amp", action="store_true", help="Enable AMP") 57 | 58 | args = parser.parse_args() 59 | 60 | return args 61 | -------------------------------------------------------------------------------- /torchseq/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/datasets/__init__.py -------------------------------------------------------------------------------- /torchseq/datasets/builder.py: -------------------------------------------------------------------------------- 1 | from torchseq.datasets.paraphrase_loader import ParaphraseDataLoader 2 | from torchseq.datasets.qa_loader import QADataLoader 3 | from torchseq.datasets.json_loader import JsonDataLoader 4 | from torchseq.datasets.lm_loader import LangmodellingDataLoader 5 | 6 | 7 | def dataloader_from_config(config, data_path="./data", train_samples=None, dev_samples=None, test_samples=None): 8 | # define data_loader 9 | if config.training.dataset is None: 10 | data_loader = None 11 | elif ( 12 | config.training.dataset 13 | in [ 14 | "paranmt", 15 | "parabank", 16 | "kaggle", 17 | "parabank-qs", 18 | "para-squad", 19 | "models/squad-udep", 20 | "models/squad-constituency", 21 | "models/squad-udep-deptree", 22 | "models/qdmr-squad", 23 | "models/nq_newsqa-udep", 24 | "models/nq_newsqa-udep-deptree", 25 | "models/squad_nq_newsqa-udep", 26 | "models/squad_nq_newsqa-udep-deptree", 27 | "models/naturalquestions-udep", 28 | "models/newsqa-udep", 29 | "models/naturalquestions-udep-deptree", 30 | "models/newsqa-udep-deptree", 31 | ] 32 | or config.training.dataset[:5] == "qdmr-" 33 | or "kaggle-" in config.training.dataset 34 | ): 35 | data_loader = ParaphraseDataLoader( 36 | config=config, 37 | data_path=data_path, 38 | train_samples=train_samples, 39 | dev_samples=dev_samples, 40 | test_samples=test_samples, 41 | ) 42 | elif ( 43 | config.training.dataset 44 | in [ 45 | "squad", 46 | "newsqa", 47 | "msmarco", 48 | "naturalquestions", 49 | "drop", 50 | "nq_newsqa", 51 | "squad_nq_newsqa", 52 | "inquisitive", 53 | ] 54 | or config.training.dataset[:5] == "squad" 55 | or config.training.dataset[:3] == "qa/" 56 | ): 57 | data_loader = QADataLoader( 58 | config=config, 59 | data_path=data_path, 60 | train_samples=train_samples, 61 | dev_samples=dev_samples, 62 | test_samples=test_samples, 63 | ) 64 | elif config.training.dataset in [ 65 | "json", 66 | ]: 67 | data_loader = JsonDataLoader( 68 | config=config, 69 | data_path=data_path, 70 | train_samples=train_samples, 71 | dev_samples=dev_samples, 72 | test_samples=test_samples, 73 | ) 74 | elif config.training.dataset in ["ptb", "wikitext103"]: 75 | data_loader = LangmodellingDataLoader( 76 | config=config, 77 | data_path=data_path, 78 | train_samples=train_samples, 79 | dev_samples=dev_samples, 80 | test_samples=test_samples, 81 | ) 82 | else: 83 | raise Exception("Unrecognised dataset: {:}".format(config.training.dataset)) 84 | 85 | return data_loader 86 | -------------------------------------------------------------------------------- /torchseq/datasets/lm_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import cycle 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import IterableDataset, Dataset 7 | 8 | 9 | from torchseq.utils.tokenizer import Tokenizer 10 | 11 | 12 | # class LangmodellingDataset(IterableDataset): 13 | class LangmodellingDataset(Dataset): 14 | def __init__(self, path, config, dev=False, test=False, repeat=False): 15 | self.config = config 16 | 17 | self.repeat = repeat 18 | 19 | self.path = path 20 | self.variant = "dev" if dev else ("test" if test else "train") 21 | 22 | self.exists = True 23 | 24 | self.length = 0 25 | 26 | if test and not os.path.exists(os.path.join(self.path, "sentences.{:}.txt".format(self.variant))): 27 | self.exists = False 28 | else: 29 | # TODO: Can we get the length without reading the whole file? 30 | with open(os.path.join(self.path, "sentences.{:}.txt".format(self.variant))) as f: 31 | # for line in f: 32 | # self.length += 1 33 | self.samples = [x.strip() for x in f.readlines()] 34 | 35 | def __len__(self): 36 | return len(self.samples) 37 | 38 | # def __iter__(self): 39 | # return self.generator() 40 | 41 | # def generator(self): 42 | # worker_info = torch.utils.data.get_worker_info() 43 | # if not worker_info: 44 | # worker_id = 0 45 | # num_workers = 1 46 | # else: 47 | # worker_id = worker_info.id 48 | # num_workers = worker_info.num_workers 49 | 50 | # with open(os.path.join(self.path, "sentences.{:}.txt".format(self.variant))) as f: 51 | # num_repeats = 0 52 | # while self.repeat or num_repeats < 1: 53 | # num_repeats += 1 54 | # for ix, line in enumerate(f): 55 | # if num_workers > 1 and ix % num_workers != worker_id: 56 | # continue 57 | # x = line.strip("\n") 58 | 59 | # sample = {"sent": x} 60 | # yield self.to_tensor(sample, tok_window=self.config.prepro.tok_window) 61 | 62 | def __getitem__(self, idx): 63 | return LangmodellingDataset.to_tensor({"sent": self.samples[idx]}, tok_window=self.config.prepro.tok_window) 64 | 65 | @staticmethod 66 | def to_tensor(x, tok_window=64): 67 | parsed = LangmodellingInstance(x["sent"]) 68 | 69 | sample = { 70 | "sent": torch.LongTensor(parsed.sent_as_ids()), 71 | "sent_len": torch.LongTensor([len(parsed._doc)]), 72 | "sent_text": x["sent"], 73 | } 74 | 75 | return sample 76 | 77 | @staticmethod 78 | def pad_and_order_sequences(batch): 79 | keys = batch[0].keys() 80 | max_lens = {k: max(len(x[k]) for x in batch) for k in keys} 81 | 82 | for x in batch: 83 | for k in keys: 84 | if k == "a_pos": 85 | x[k] = F.pad(x[k], (0, max_lens[k] - len(x[k])), value=0) 86 | elif k[-5:] != "_text": 87 | x[k] = F.pad(x[k], (0, max_lens[k] - len(x[k])), value=Tokenizer().pad_id) 88 | 89 | tensor_batch = {} 90 | for k in keys: 91 | if k[-5:] != "_text": 92 | tensor_batch[k] = torch.stack([x[k] for x in batch], 0).squeeze(1) 93 | else: 94 | tensor_batch[k] = [x[k] for x in batch] 95 | 96 | return tensor_batch 97 | 98 | 99 | class LangmodellingInstance: 100 | def __init__(self, sent_text, tok_window=256): 101 | self._doc = Tokenizer().tokenise(sent_text) 102 | 103 | if len(self._doc) > tok_window: 104 | self._doc = self._doc[:tok_window] 105 | 106 | def sent_as_ids(self): 107 | id_list = [tok["id"] for tok in self._doc] 108 | return id_list 109 | -------------------------------------------------------------------------------- /torchseq/datasets/lm_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from torchseq.datasets.lm_dataset import LangmodellingDataset 8 | from torchseq.utils.seed import init_worker 9 | from torchseq.utils.tokenizer import Tokenizer 10 | import torchseq.utils.tokenizer as tokenizer 11 | 12 | import logging 13 | 14 | 15 | class LangmodellingDataLoader: 16 | def __init__(self, config, data_path): 17 | """ 18 | :param config: 19 | """ 20 | self.config = config 21 | self.logger = logging.getLogger("DataLoader") 22 | 23 | raise Exception("LM DataLoader is deprecated! Use JsonDataset instead") 24 | 25 | tokenizer.DATA_PATH = data_path 26 | Tokenizer(config.prepro.tokenizer) 27 | 28 | train = LangmodellingDataset( 29 | os.path.join(data_path, self.config.training.dataset), 30 | config=config, 31 | dev=False, 32 | test=False, 33 | repeat=(self.config.training.data.get("epoch_steps", 0) > 0), 34 | ) 35 | valid = LangmodellingDataset( 36 | os.path.join(data_path, self.config.training.dataset), config=config, dev=True, test=False 37 | ) 38 | test = LangmodellingDataset( 39 | os.path.join(data_path, self.config.training.dataset), config=config, dev=False, test=True 40 | ) 41 | 42 | self.len_train_data = len(train) 43 | self.len_valid_data = len(valid) 44 | # self.len_test_data = len(test) 45 | 46 | # TODO: check whether running in silent mode 47 | self.logger.info( 48 | "Loaded {:} training and {:} validation examples from {:}".format( 49 | self.len_train_data, 50 | self.len_valid_data, 51 | os.path.join(data_path, self.config.training.dataset), 52 | ) 53 | ) 54 | 55 | # self.train_iterations = (self.len_train_data + self.config.training.batch_size - 1) // self.config.training.batch_size 56 | # self.valid_iterations = (self.len_valid_data + self.config.training.batch_size - 1) // self.config.training.batch_size 57 | # self.test_iterations = (self.len_test_data + self.config.training.batch_size - 1) // self.config.training.batch_size 58 | 59 | self.train_loader = DataLoader( 60 | train, 61 | batch_size=config.training.batch_size, 62 | shuffle=self.config.training.data.get("shuffle_data", True), 63 | num_workers=0, 64 | collate_fn=LangmodellingDataset.pad_and_order_sequences, 65 | worker_init_fn=init_worker, 66 | ) 67 | 68 | self.valid_loader = DataLoader( 69 | valid, 70 | batch_size=config.eval.eval_batch_size, 71 | shuffle=False, 72 | num_workers=0, 73 | collate_fn=LangmodellingDataset.pad_and_order_sequences, 74 | worker_init_fn=init_worker, 75 | ) 76 | if test.exists: 77 | self.test_loader = DataLoader( 78 | test, 79 | batch_size=config.eval.eval_batch_size, 80 | shuffle=False, 81 | num_workers=0, 82 | collate_fn=LangmodellingDataset.pad_and_order_sequences, 83 | worker_init_fn=init_worker, 84 | ) 85 | -------------------------------------------------------------------------------- /torchseq/datasets/loaders.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | # from nltk.tokenize import TreebankWordTokenizer, sent_tokenize 8 | # import spacy 9 | 10 | 11 | def load_qa_dataset(path, dev=False, test=False, v2=False, filename=None): 12 | expected_version = "v2.0" if v2 else "1.1" 13 | if filename is not None: 14 | pass 15 | elif v2: 16 | filename = "train-v2.0.json" if not dev else "dev-v2.0.json" 17 | elif test and not dev: 18 | filename = "test-v1.1.json" 19 | else: 20 | filename = "train-v1.1.json" if not dev else "dev-v1.1.json" 21 | with open(path + filename) as dataset_file: 22 | dataset_json = json.load(dataset_file) 23 | if "version" in dataset_json and dataset_json["version"] != expected_version and filename is None: 24 | print("Expected SQuAD v-" + expected_version + ", but got dataset with v-" + str(dataset_json["version"])) 25 | dataset = dataset_json["data"] 26 | return dataset 27 | 28 | 29 | def load_squad_triples(path, dev=False, test=False, v2=False, as_dict=False, ans_list=False, filename=None): 30 | raw_data = load_qa_dataset(path, dev=dev, test=test, v2=v2, filename=filename) 31 | triples = [] if not as_dict else {} 32 | for doc in raw_data: 33 | for para in doc["paragraphs"]: 34 | for qa in para["qas"]: 35 | id = qa["id"] 36 | # NOTE: this only takes the first answer per question! ToDo handle this more intelligently 37 | if ans_list: 38 | ans_text = [a["text"] for a in qa["answers"]] 39 | ans_pos = [int(a["answer_start"]) for a in qa["answers"]] 40 | else: 41 | ans_count = defaultdict(int) 42 | for ans in qa["answers"]: 43 | ans_count[(ans["text"], int(ans["answer_start"]))] += 1 44 | 45 | ans_text, ans_pos = sorted(ans_count.items(), reverse=True, key=lambda x: x[1])[0][0] 46 | if v2: 47 | if qa["is_impossible"]: 48 | el = ( 49 | para["context"], 50 | qa["question"], 51 | qa["plausible_answers"][0]["text"] if not dev else "", 52 | int(qa["plausible_answers"][0]["answer_start"]) if not dev else None, 53 | True, 54 | ) 55 | else: 56 | el = ( 57 | para["context"], 58 | qa["question"], 59 | qa["answers"][0]["text"], 60 | int(qa["answers"][0]["answer_start"]), 61 | False, 62 | ) 63 | else: 64 | el = (para["context"], qa["question"], ans_text, ans_pos) 65 | if as_dict: 66 | triples[id] = el 67 | else: 68 | triples.append(el) 69 | return triples 70 | -------------------------------------------------------------------------------- /torchseq/datasets/paraphrase_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from torchseq.datasets.paraphrase_dataset import ParaphraseDataset 8 | from torchseq.utils.seed import init_worker 9 | from torchseq.utils.tokenizer import Tokenizer 10 | import torchseq.utils.tokenizer as tokenizer 11 | import logging 12 | 13 | 14 | class ParaphraseDataLoader: 15 | def __init__(self, config, data_path): 16 | """ 17 | :param config: 18 | """ 19 | self.config = config 20 | self.logger = logging.getLogger("DataLoader") 21 | 22 | raise Exception("Paraphrase DataLoader is deprecated! Use JsonDataset instead") 23 | 24 | tokenizer.DATA_PATH = data_path 25 | Tokenizer(config.prepro.tokenizer) 26 | 27 | train = ParaphraseDataset( 28 | os.path.join(data_path, self.config.training.dataset), 29 | config=config, 30 | dev=False, 31 | test=False, 32 | repeat=(self.config.training.data.get("epoch_steps", 0) > 0), 33 | length_limit=self.config.training.get("truncate_dataset", None), 34 | ) 35 | valid = ParaphraseDataset( 36 | os.path.join(data_path, self.config.training.dataset), 37 | config=config, 38 | dev=True, 39 | test=False, 40 | length_limit=self.config.eval.get("truncate_dataset", None), 41 | ) 42 | test = ParaphraseDataset( 43 | os.path.join(data_path, self.config.training.dataset), 44 | config=config, 45 | dev=False, 46 | test=True, 47 | length_limit=self.config.eval.get("truncate_dataset", None), 48 | ) 49 | 50 | self.len_train_data = len(train) 51 | self.len_valid_data = len(valid) 52 | # self.len_test_data = len(test) 53 | 54 | # TODO: check whether running in silent mode 55 | self.logger.info( 56 | "Loaded {:} training and {:} validation examples from {:}".format( 57 | self.len_train_data, 58 | self.len_valid_data, 59 | os.path.join(data_path, self.config.training.dataset), 60 | ) 61 | ) 62 | 63 | # self.train_iterations = (self.len_train_data + self.config.training.batch_size - 1) // self.config.training.batch_size 64 | # self.valid_iterations = (self.len_valid_data + self.config.training.batch_size - 1) // self.config.training.batch_size 65 | # self.test_iterations = (self.len_test_data + self.config.training.batch_size - 1) // self.config.training.batch_size 66 | 67 | self.train_loader = DataLoader( 68 | train, 69 | batch_size=config.training.batch_size, 70 | # shuffle=self.config.training.data.get("shuffle_data", True), 71 | num_workers=0, 72 | collate_fn=ParaphraseDataset.pad_and_order_sequences, 73 | worker_init_fn=init_worker, 74 | ) 75 | 76 | self.valid_loader = DataLoader( 77 | valid, 78 | batch_size=config.eval.eval_batch_size, 79 | # shuffle=False, 80 | num_workers=0, 81 | collate_fn=ParaphraseDataset.pad_and_order_sequences, 82 | worker_init_fn=init_worker, 83 | ) 84 | if test.exists: 85 | self.test_loader = DataLoader( 86 | test, 87 | batch_size=config.eval.eval_batch_size, 88 | # shuffle=False, 89 | num_workers=0, 90 | collate_fn=ParaphraseDataset.pad_and_order_sequences, 91 | worker_init_fn=init_worker, 92 | ) 93 | -------------------------------------------------------------------------------- /torchseq/datasets/paraphrase_pair.py: -------------------------------------------------------------------------------- 1 | from torchseq.utils.tokenizer import Tokenizer 2 | 3 | 4 | class ParaphrasePair: 5 | def __init__(self, sent1_text, sent2_text, template=None, is_paraphrase=True, tok_window=64): 6 | if "artist appear below the euro symbol" in sent2_text: 7 | print("Found the dodgy pair", sent1_text, sent2_text) 8 | 9 | self._s1_doc = Tokenizer().tokenise(sent1_text) 10 | self._s2_doc = Tokenizer().tokenise(sent2_text) 11 | self._template_doc = Tokenizer().tokenise(sent2_text) if template is not None else None 12 | self.is_paraphrase = is_paraphrase 13 | 14 | if "artist appear below the euro symbol" in sent2_text: 15 | print("Dodgy pair cleared tokenising") 16 | 17 | if len(self._s1_doc) > tok_window: 18 | self._s1_doc = self._s1_doc[:tok_window] 19 | if len(self._s2_doc) > tok_window: 20 | self._s2_doc = self._s2_doc[:tok_window] 21 | 22 | def s1_as_ids(self): 23 | id_list = [tok["id"] for tok in self._s1_doc] 24 | return id_list 25 | 26 | def s2_as_ids(self): 27 | id_list = [tok["id"] for tok in self._s2_doc] 28 | return id_list 29 | 30 | def template_as_ids(self): 31 | id_list = [tok["id"] for tok in self._template_doc] 32 | return id_list 33 | -------------------------------------------------------------------------------- /torchseq/demo/qg_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | # sys.path.insert(0, "./src/") 5 | 6 | from flask import Flask, Response, current_app, redirect, request 7 | 8 | from torchseq.agents.aq_agent import AQAgent 9 | from torchseq.datasets.qa_loader import QADataLoader 10 | from torchseq.utils.config import Config 11 | from torchseq.utils.tokenizer import Tokenizer 12 | 13 | 14 | app = Flask(__name__) 15 | 16 | 17 | @app.route("/") 18 | def index(): 19 | return redirect("/static/qg_demo.htm") 20 | 21 | 22 | @app.route("/api/generate") 23 | def generate(): 24 | context = request.args["context"] 25 | answer = request.args["answer"] 26 | a_pos = context.find(answer) 27 | 28 | query = {"c": context, "a": answer, "a_pos": a_pos, "q": ""} 29 | 30 | # res, scores, _ = app.agent.infer(query, reduce_outputs=False) 31 | 32 | data_loader = QADataLoader(app.agent.config, test_samples=[query]) 33 | loss, metrics, (pred_output, gold_output, gold_input), memory = app.agent.inference(data_loader.test_loader) 34 | 35 | # scores = scores.tolist() 36 | 37 | output = pred_output 38 | 39 | return Response(json.dumps(output, indent=2), mimetype="application/json") 40 | 41 | 42 | @app.route("/api/ping") 43 | def ping(): 44 | return "ack" 45 | 46 | 47 | def init(): 48 | # MODEL_SLUG = "20200220_161434_bert_embeds_para_pbkagsq_ft_squad" 49 | 50 | # MODEL_PATH = f'./runs/augmented/{MODEL_SLUG}/' 51 | MODEL_PATH = "./models/examples/20210222_145021_qg_bert/" 52 | 53 | # Get the config 54 | with open(MODEL_PATH + "config.json") as f: 55 | cfg_dict = json.load(f) 56 | 57 | # Override a few bits 58 | cfg_dict["eval"]["topk"] = 1 59 | # cfg_dict["reranker"] = { 60 | # # 'strategy': 'qa' 61 | # "strategy": None 62 | # } 63 | 64 | config = Config(cfg_dict) 65 | 66 | checkpoint_path = MODEL_PATH + "model/checkpoint.pt" 67 | 68 | # Tokenizer(config.prepro.tokenizer) 69 | 70 | app.agent = AQAgent(config=config, run_id=None, output_path="./runs/parademo/", training_mode=True) 71 | 72 | app.agent.load_checkpoint(checkpoint_path) 73 | app.agent.model.eval() 74 | 75 | 76 | def main(): 77 | init() 78 | with app.app_context(): 79 | app.run(host="0.0.0.0", port=5004, processes=1) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /torchseq/demo/static/para_demo.htm: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | Separator Demo 10 | 11 | 12 | 13 | 20 | 21 | 22 | 23 | 24 |
25 | 26 |

Separator Demo 27 |
28 | Connection Status: 29 |
30 |   31 |
32 |
33 |

34 | 35 |

An interactive demo of the Separator paraphrasing model. Enter a question and an exemplar, then click generate to rephrase the question into that form.

36 | 37 |
38 | 43 | 44 |
45 | 46 | 47 | 48 |
49 |
50 | 51 | 52 | 53 | This should be a valid template for the original question, otherwise the model will definitely produce nonsense :) 54 | 55 |
56 | 57 |
58 |
59 |

Generated Paraphrase:

60 |
61 |
62 |
63 |

Click the "Generate" button to get a paraphrase!

64 | 65 | 66 |
67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 93 | 94 | -------------------------------------------------------------------------------- /torchseq/demo/static/qgen.css: -------------------------------------------------------------------------------- 1 | #connection-indicator-wrapper 2 | { 3 | float: right; 4 | padding: 5px; 5 | /* width: 300px; */ 6 | font-size: 14px;: middle; 7 | } 8 | #connection-indicator 9 | { 10 | display: inline-block; 11 | width: 20px; 12 | height: 20px; 13 | border-radius: 10px; 14 | background-color: yellow; 15 | } 16 | #connection-indicator.good 17 | { 18 | background-color: green; 19 | } 20 | #connection-indicator.bad 21 | { 22 | background-color: red; 23 | } 24 | #response 25 | { 26 | border:1px dotted #000; 27 | padding: 5px; 28 | } -------------------------------------------------------------------------------- /torchseq/demo/static/qq_demo.htm: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | AQ Demo 10 | 11 | 12 | 13 | 20 | 21 | 22 | 23 | 24 |
25 | 26 |

Question Generation Demo 27 |
28 | Connection Status: 29 |
30 |   31 |
32 |
33 |

34 | 35 |

An interactive AQ demo. Enter some text in the box, select a word(s) within the text, and click generate to get a question with that answer.

36 | 37 |
38 |
39 | Load preset document: 40 | 41 | 42 | 43 | 44 |
45 |
46 | 47 | 48 | 49 |
50 |
51 | 52 | 53 | 54 | This is case sensitive, and must exist within the context. If it appears multiple times, the first occurence will be used. 55 | 56 |
57 | 58 |
59 |
60 |

Generated Question:

61 |
62 |
63 |
64 |

Click the "Generate" button to get a question!

65 | 66 | 67 |
68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 94 | 95 | -------------------------------------------------------------------------------- /torchseq/demo/static/separator.js: -------------------------------------------------------------------------------- 1 | function ping() 2 | { 3 | query = "/api/ping" 4 | $.ajax({ 5 | url: query, 6 | cache: false 7 | }) 8 | .done(function( html ) { 9 | if(html=='ack') 10 | { 11 | $( "#connection-indicator" ).attr('class', 'good'); 12 | } 13 | else { 14 | $( "#connection-indicator" ).attr('class', 'bad'); 15 | } 16 | 17 | }) 18 | .fail(function(){ 19 | $( "#connection-indicator" ).attr('class', 'bad'); 20 | }); 21 | setTimeout(ping, 10000); 22 | } 23 | function getQ() 24 | { 25 | query = "/api/generate?sem_input=" +encodeURIComponent($('#sem_input').val())+ "&template="+$('#template').val() 26 | $( "#response-spinner" ).toggleClass('d-none'); 27 | $( "#response" ).toggleClass('d-none'); 28 | $.ajax({ 29 | url: query, 30 | cache: false 31 | }) 32 | .done(function( html ) { 33 | $( "#response" ).html("

"+html+"

"); 34 | 35 | $( "#response-spinner" ).toggleClass('d-none'); 36 | $( "#response" ).toggleClass('d-none'); 37 | }) 38 | .fail(function(){ 39 | $( "#response" ).html("

There was an error generating a question :(

"); 40 | 41 | $( "#response-spinner" ).toggleClass('d-none'); 42 | $( "#response" ).toggleClass('d-none'); 43 | }); 44 | } 45 | function seed(q,t) 46 | { 47 | $('#sem_input').val(q); 48 | $('#template').val(t); 49 | } 50 | function seedWithMoose() 51 | { 52 | q = "What is the weight of an average moose?" 53 | t = "How much is a surgeon's income?" 54 | seed(q,t) 55 | } -------------------------------------------------------------------------------- /torchseq/demo/static/spinner.css: -------------------------------------------------------------------------------- 1 | .lds-ellipsis { 2 | display: inline-block; 3 | position: relative; 4 | width: 64px; 5 | height: 64px; 6 | } 7 | .lds-ellipsis div { 8 | position: absolute; 9 | top: 27px; 10 | width: 11px; 11 | height: 11px; 12 | border-radius: 50%; 13 | background: #000; 14 | animation-timing-function: cubic-bezier(0, 1, 1, 0); 15 | } 16 | .lds-ellipsis div:nth-child(1) { 17 | left: 6px; 18 | animation: lds-ellipsis1 0.6s infinite; 19 | } 20 | .lds-ellipsis div:nth-child(2) { 21 | left: 6px; 22 | animation: lds-ellipsis2 0.6s infinite; 23 | } 24 | .lds-ellipsis div:nth-child(3) { 25 | left: 26px; 26 | animation: lds-ellipsis2 0.6s infinite; 27 | } 28 | .lds-ellipsis div:nth-child(4) { 29 | left: 45px; 30 | animation: lds-ellipsis3 0.6s infinite; 31 | } 32 | @keyframes lds-ellipsis1 { 33 | 0% { 34 | transform: scale(0); 35 | } 36 | 100% { 37 | transform: scale(1); 38 | } 39 | } 40 | @keyframes lds-ellipsis3 { 41 | 0% { 42 | transform: scale(1); 43 | } 44 | 100% { 45 | transform: scale(0); 46 | } 47 | } 48 | @keyframes lds-ellipsis2 { 49 | 0% { 50 | transform: translate(0, 0); 51 | } 52 | 100% { 53 | transform: translate(19px, 0); 54 | } 55 | } -------------------------------------------------------------------------------- /torchseq/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/eval/__init__.py -------------------------------------------------------------------------------- /torchseq/eval/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_eval_args(arg_str=None): 5 | parser = argparse.ArgumentParser( 6 | description="TorchSeq Evaluation runner", 7 | ) 8 | 9 | parser.add_argument("-V", "--version", action="store_true", help="Display version") 10 | 11 | parser.add_argument("--model", type=str, metavar="MODEL", help="Path to model folder", required=True) 12 | 13 | parser.add_argument("--recipe", type=str, metavar="RECIPE", help="Name of recipe to run", required=True) 14 | 15 | parser.add_argument("--test", action="store_true", help="Use test set") 16 | 17 | # Paths 18 | parser.add_argument("--data_path", type=str, metavar="DATA", default="./data/", help="Path to data sources") 19 | parser.add_argument( 20 | "--output_path", type=str, metavar="OUTPUT", default="./evalruns/", help="Path to output folder" 21 | ) 22 | 23 | # Runtime 24 | parser.add_argument("--cpu", action="store_true", help="Disable CUDA") 25 | parser.add_argument("--amp", action="store_true", help="Enable AMP") 26 | 27 | args = parser.parse_args(arg_str) 28 | 29 | return args 30 | -------------------------------------------------------------------------------- /torchseq/eval/cli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import importlib 4 | import torch 5 | import os 6 | 7 | import torchseq 8 | from torchseq.eval.args import parse_eval_args 9 | import torchseq.eval.recipes as recipes 10 | 11 | 12 | def main(): 13 | logging.basicConfig( 14 | level=logging.INFO, format="%(asctime)s\t%(levelname)s\t%(name)s\t%(message)s", datefmt="%H:%M" 15 | ) 16 | logger = logging.getLogger("Eval") 17 | 18 | args = parse_eval_args() 19 | 20 | if args.version: 21 | print(torchseq.__version__) 22 | return 23 | 24 | if args.amp: 25 | logger.info("Using matmul precision = high") 26 | torch.set_float32_matmul_precision("high") 27 | 28 | logger.info("TorchSeq eval runner") 29 | 30 | cfg_patch = {} 31 | 32 | config_path = os.path.join(args.model, "config.json") 33 | 34 | if not os.path.exists(config_path): 35 | raise Exception("No config file found in path {:}".format(args.model)) 36 | 37 | # First load the model 38 | # instance = model_from_path(args.model, config_patch=cfg_patch, use_cuda=(not args.cpu)) 39 | # logger.info("Loaded model from {:}".format(args.model)) 40 | 41 | # Then load the data 42 | # ??? 43 | 44 | # Run the recipe 45 | 46 | recipe_module = importlib.import_module("torchseq.eval.recipes." + args.recipe, None) 47 | if recipe_module is not None: 48 | recipe: recipes.EvalRecipe = recipe_module.Recipe(args.model, args.data_path, args.test, args.cpu, logger) 49 | result = recipe.run() 50 | else: 51 | logger.error("No recipe called {:} found!".format(args.recipe)) 52 | 53 | # Post-process 54 | 55 | # Publish! 56 | 57 | print(json.dumps(result, indent=2)) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /torchseq/eval/recipes/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import EvalRecipe as EvalRecipe 2 | -------------------------------------------------------------------------------- /torchseq/eval/recipes/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | from torchseq.utils.model_loader import config_from_path 4 | from torchseq.utils.config import Config 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class EvalRecipe(ABC): 9 | model_path: str 10 | data_path: str 11 | config: Config 12 | split_str: str 13 | test: bool 14 | cpu: bool 15 | name: str = "" 16 | 17 | def __init__(self, model_path: str, data_path: str, test=False, cpu: bool = False, logger=None): 18 | # self.args = args 19 | self.model_path = model_path 20 | self.data_path = data_path 21 | self.config = config_from_path(model_path) 22 | self.test = test 23 | self.split_str = "test" if test else "dev" 24 | self.cpu = cpu 25 | 26 | self.logger = logger 27 | 28 | self.log(f"Running EvalRecipe: {self.name}") 29 | 30 | def log(self, text): 31 | if self.logger is not None: 32 | self.logger.info(text) 33 | 34 | @abstractmethod 35 | def run(self) -> Dict[str, Any]: 36 | raise Exception("Tried to call run() on a base EvalRecipe object!") 37 | -------------------------------------------------------------------------------- /torchseq/eval/recipes/opagg/extractive_summaries.py: -------------------------------------------------------------------------------- 1 | from torchseq.eval.recipes import EvalRecipe 2 | from torchseq.utils.model_loader import model_from_path 3 | from torchseq.metric_hooks.opsumm_cluster_aug import OpSummClusterAugMetricHook 4 | from torchseq.utils.timer import Timer 5 | 6 | 7 | class Recipe(EvalRecipe): 8 | name: str = "opagg.extractive_summaries" 9 | 10 | def run(self): 11 | result = {} 12 | 13 | instance = model_from_path(self.model_path, use_cuda=(not self.cpu)) 14 | 15 | with Timer(template="\tTime: {:.3f} seconds", show_readout=False) as t: 16 | scores, res = OpSummClusterAugMetricHook.eval_extract_summaries_and_score( 17 | self.config, instance, test=self.test 18 | ) 19 | print("\tExtractive R2 = {:0.2f}".format(scores["extractive"]["rouge2"])) 20 | print("\tSC_ins = {:0.2f}".format(scores["extractive"]["sc_ins"])) 21 | print("\tSC_refs = {:0.2f}".format(scores["extractive"]["sc_refs"])) 22 | clustering_time = t.time 23 | 24 | with Timer(template="\tTime: {:.3f} seconds", show_readout=False) as t: 25 | score = OpSummClusterAugMetricHook.eval_compare_selected_clusters_to_oracle( 26 | self.config, instance, res["evidence"], test=self.test 27 | ) 28 | print( 29 | "\tARI: {:0.3f}, (oracle {:0.1f} vs pred {:0.1f})".format( 30 | score["ari"], score["oracle_mean_size"], score["pred_mean_size"] 31 | ) 32 | ) 33 | ari_time = t.time 34 | 35 | with Timer(template="\tTime: {:.3f} seconds", show_readout=False) as t: 36 | prev_scores = OpSummClusterAugMetricHook.eval_cluster_prevalence( 37 | self.config, instance, res["evidence"], test=self.test 38 | ) 39 | print("\tPrev: ", prev_scores) 40 | prevalence_time = t.time 41 | 42 | return result 43 | -------------------------------------------------------------------------------- /torchseq/metric_hooks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/metric_hooks/__init__.py -------------------------------------------------------------------------------- /torchseq/metric_hooks/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Any 2 | 3 | import torch 4 | from torchseq.utils.config import Config 5 | from torchseq.utils.tokenizer import Tokenizer 6 | from torchseq.agents.base import BaseAgent 7 | 8 | 9 | class MetricHook: 10 | type: str # should be either 'live' or 'slow' - live metrics are calculated every epoch, slow metrics only for evaluation 11 | config: Config 12 | tokenizer: Tokenizer 13 | src_field: Optional[str] 14 | tgt_field: Optional[str] 15 | scores: Dict[str, List[float]] 16 | 17 | def __init__( 18 | self, config: Config, tokenizer: Tokenizer, src_field: Optional[str] = None, tgt_field: Optional[str] = None 19 | ): 20 | self.config = config 21 | self.tokenizer = tokenizer 22 | self.src_field = src_field 23 | self.tgt_field = tgt_field 24 | 25 | def on_begin_epoch(self, use_test: bool = False): 26 | self.scores = {} 27 | 28 | def on_batch( 29 | self, 30 | batch: Dict[str, torch.Tensor], 31 | logits: torch.Tensor, 32 | output: List[str], 33 | memory: Dict[str, torch.Tensor], 34 | use_test: bool = False, 35 | ): 36 | raise NotImplementedError("You need to implement on_batch for your MetricHook!") 37 | 38 | def on_end_epoch(self, agent: BaseAgent, use_test: bool = False) -> Dict[str, float]: 39 | return {k: sum(v) / len(v) for k, v in self.scores.items() if len(v) > 0} 40 | -------------------------------------------------------------------------------- /torchseq/metric_hooks/default.py: -------------------------------------------------------------------------------- 1 | from torchseq.metric_hooks.base import MetricHook 2 | from torchseq.utils.perplexity import get_perplexity 3 | from torchseq.utils.tokenizer import Tokenizer 4 | import torch 5 | 6 | 7 | class DefaultMetricHook(MetricHook): 8 | type = "live" 9 | 10 | # def __init__(self, config, tokenizer, src_field=None, tgt_field=None): 11 | # super().__init__(config, tokenizer, src_field, tgt_field) 12 | 13 | def on_begin_epoch(self, use_test=False): 14 | self.scores = {"ppl": []} 15 | 16 | def on_batch(self, batch, logits, output, memory, use_test=False): 17 | if self.tgt_field is not None and logits is not None: 18 | self.scores["ppl"].extend( 19 | get_perplexity( 20 | logits, 21 | batch[self.tgt_field], 22 | vocab_size=self.config.prepro.get_first(["output_vocab_size", "vocab_size"]), 23 | ignore_index=self.tokenizer.pad_id, 24 | ).tolist() 25 | ) 26 | -------------------------------------------------------------------------------- /torchseq/metric_hooks/prevalence_metric_old.py: -------------------------------------------------------------------------------- 1 | # Modified version of the metric from https://github.com/cdmalon/opinion-prevalence 2 | # Original is a CLI, this provides a more generic OO wrapper 3 | # See https://arxiv.org/abs/2307.14305 4 | 5 | import nltk.tokenize 6 | from summac.model_summac import SummaCZS 7 | from tqdm import tqdm 8 | 9 | 10 | class PrevalenceMetric: 11 | threshold = 0.04 12 | 13 | def __init__(self): 14 | self.model = SummaCZS( 15 | granularity="document", model_name="mnli", bins="percentile", use_con=False, device="cuda" 16 | ) 17 | 18 | def get_prevalence( 19 | self, 20 | reviews, 21 | generated_summaries, 22 | product_names=None, 23 | pbar=False, 24 | ignore_redundancy=False, 25 | summaries_are_sentences=False, 26 | ): 27 | threshold = 0.04 28 | 29 | if product_names is None: 30 | product_names = [""] * len(reviews) 31 | 32 | prevalences = [] 33 | redundancies = [] 34 | trivials = [] 35 | for curr_reviews, summ, product_name in tqdm( 36 | zip(reviews, generated_summaries, product_names), disable=(not pbar), total=len(reviews) 37 | ): 38 | nsent = 0 39 | prevalence = 0 40 | redundancy = 0 41 | trivial_count = 0 42 | 43 | if not summaries_are_sentences: 44 | sents = nltk.tokenize.sent_tokenize(summ) 45 | else: 46 | sents = summ 47 | 48 | for i, generated in enumerate(sents): 49 | nsent = nsent + 1 50 | 51 | implied = 0 52 | tot = 0 53 | 54 | # trivial = "I bought {:}.".format(product_name) 55 | trivial = "I stayed at {:}.".format(product_name) 56 | if self.model.score([trivial], [generated])["scores"][0] > threshold: 57 | # output = output + " " + generated + " (T)" 58 | trivial_count += 1 59 | # print('Triv: ', generated) 60 | continue 61 | 62 | redundant = False 63 | for j in range(i): 64 | if self.model.score([sents[j]], [generated])["scores"][0] > threshold: 65 | # print("Redund: ", sents[j], generated) 66 | redundant = True 67 | redundancy += 1 68 | break 69 | 70 | if redundant and not ignore_redundancy: 71 | continue 72 | 73 | for original in curr_reviews: 74 | tot = tot + 1 75 | score = self.model.score([original], [generated])["scores"][0] 76 | if score > threshold: 77 | implied = implied + 1 78 | # print(implied/tot) 79 | 80 | prevalence = prevalence + (implied / tot) 81 | # output = output + " " + generated + " (" + str(implied) + ")" 82 | 83 | prevalence = prevalence / nsent 84 | 85 | prevalences.append(prevalence) 86 | redundancies.append(redundancy / nsent) 87 | trivials.append(trivial_count / nsent) 88 | 89 | return prevalences, redundancies, trivials 90 | -------------------------------------------------------------------------------- /torchseq/metric_hooks/qg_metric.py: -------------------------------------------------------------------------------- 1 | from torchseq.metric_hooks.base import MetricHook 2 | from torchseq.utils.functions import top_k_top_p_filtering, onehot 3 | from torchseq.utils.tokenizer import Tokenizer 4 | import torch 5 | 6 | 7 | class QGMetricHook(MetricHook): 8 | type = "slow" 9 | 10 | # def __init__(self, config, src_field=None, tgt_field=None): 11 | # super().__init__(config, src_field, tgt_field) 12 | 13 | def on_begin_epoch(self, use_test=False): 14 | self.scores = {"qg_metric": []} 15 | 16 | def on_batch(self, batch, logits, output, memory, use_test=False): 17 | # Calc QG metric 18 | # Calculate metric from "On the Importance of Diversity in Question Generation for QA" 19 | omega = 0.7 20 | if self.config.get("nucleus_sampling", None) is not None: 21 | top_p = self.config.nucleus_sampling.cutoff 22 | else: 23 | top_p = 0.9 24 | 25 | nucleus_prob = torch.softmax(top_k_top_p_filtering(logits, top_p=top_p), dim=-1) 26 | gt_onehot = onehot( 27 | batch[self.tgt_field], 28 | N=self.config.prepro.get_first(["output_vocab_size", "vocab_size"]), 29 | ignore_index=self.tokenizer.pad_id, 30 | ) 31 | accuracy = torch.sum(torch.sum(nucleus_prob * gt_onehot, dim=-1), dim=-1) / ( 32 | batch[self.tgt_field + "_len"] - 1 33 | ) 34 | 35 | diversity = torch.sum(torch.sum(torch.gt(nucleus_prob * gt_onehot, 0) * 1.0, dim=-1), dim=-1) / ( 36 | batch[self.tgt_field + "_len"] - 1 37 | ) 38 | 39 | dev_qg_metric = omega * accuracy + (1 - omega) * diversity 40 | 41 | self.scores["qg_metric"].extend(dev_qg_metric.tolist()) 42 | -------------------------------------------------------------------------------- /torchseq/metric_hooks/rouge.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from torchseq.metric_hooks.base import MetricHook 3 | from torchseq.utils.tokenizer import Tokenizer 4 | from torchseq.utils.metrics import bleu_corpus, meteor_corpus, ibleu_corpus 5 | from torchseq.utils.sari import SARIsent 6 | import torch 7 | import numpy as np 8 | 9 | import rouge 10 | 11 | 12 | class RougeMetricHook(MetricHook): 13 | type = "live" # should be either 'live' or 'slow' - live metrics are calculated every epoch, slow metrics only for evaluation 14 | 15 | # def __init__(self, config, src_field=None, tgt_field=None): 16 | # super().__init__(config, src_field, tgt_field) 17 | 18 | def on_begin_epoch(self, use_test=False): 19 | self.scores = {"rouge": {}} 20 | 21 | self.gold_targets = [] 22 | self.pred_targets = [] 23 | self.inputs = [] 24 | 25 | def on_batch(self, batch, logits, output, memory, use_test=False): 26 | if len(output) > 0: 27 | if self.config.eval.data.get("topk", 1) > 1: 28 | self.pred_targets.extend([x[0] for x in output]) 29 | else: 30 | self.pred_targets.extend(output) 31 | self.gold_targets.extend(batch[self.tgt_field + "_text"]) 32 | self.inputs.extend(batch[self.src_field + "_text"]) 33 | 34 | def on_end_epoch(self, _, use_test=False): 35 | evaluator = rouge.Rouge( 36 | metrics=[ 37 | "rouge-l", 38 | "rouge-n", 39 | ], 40 | max_n=4, 41 | limit_length=True, 42 | length_limit=100, 43 | length_limit_type="words", 44 | apply_avg=True, 45 | apply_best=False, 46 | alpha=0.5, # Default F1_score 47 | weight_factor=1.2, 48 | stemming=True, 49 | ) 50 | 51 | scores = evaluator.get_scores(self.pred_targets, self.gold_targets) 52 | 53 | self.scores["rouge"] = scores 54 | 55 | return self.scores 56 | -------------------------------------------------------------------------------- /torchseq/metric_hooks/textual.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import numpy as np 4 | 5 | from torchseq.metric_hooks.base import MetricHook 6 | from torchseq.utils.tokenizer import Tokenizer 7 | from torchseq.utils.metrics import bleu_corpus, meteor_corpus, ibleu_corpus 8 | from torchseq.utils.sari import SARIsent 9 | import sacrebleu 10 | 11 | 12 | class TextualMetricHook(MetricHook): 13 | type = "live" # should be either 'live' or 'slow' - live metrics are calculated every epoch, slow metrics only for evaluation 14 | 15 | # def __init__(self, config, src_field=None, tgt_field=None): 16 | # super().__init__(config, src_field, tgt_field) 17 | 18 | def on_begin_epoch(self, use_test=False): 19 | self.scores = {"bleu": [], "meteor": [], "em": [], "sari": [], "ibleu": []} 20 | 21 | self.gold_targets = [] 22 | self.pred_targets = [] 23 | self.inputs = [] 24 | 25 | def on_batch(self, batch, logits, output, memory, use_test=False): 26 | if len(output) > 0 and self.tgt_field is not None: 27 | if self.config.eval.data.get("topk", 1) > 1: 28 | self.pred_targets.extend([x[0] for x in output]) 29 | else: 30 | self.pred_targets.extend(output) 31 | if "_refs_text" in batch: 32 | self.gold_targets.extend(batch["_refs_text"]) 33 | else: 34 | self.gold_targets.extend([[x] for x in batch[self.tgt_field + "_text"]]) 35 | self.inputs.extend(batch[self.src_field + "_text"]) 36 | 37 | def on_end_epoch(self, _, use_test=False): 38 | # Flip and pad the references 39 | max_num_refs = max([len(x) for x in self.gold_targets]) 40 | self.gold_targets = [x + [x[0]] * (max_num_refs - len(x)) for x in self.gold_targets] 41 | 42 | # print(len(self.gold_targets), len(self.pred_targets), len(self.inputs)) 43 | 44 | self.scores["bleu"] = sacrebleu.corpus_bleu( 45 | self.pred_targets, list(zip(*self.gold_targets)), lowercase=True 46 | ).score 47 | self.scores["selfbleu"] = sacrebleu.corpus_bleu(self.pred_targets, [self.inputs], lowercase=True).score 48 | 49 | alpha = 0.8 50 | # self.scores["ibleu"] = ibleu_corpus(self.gold_targets, self.pred_targets, self.inputs) 51 | self.scores["ibleu"] = alpha * self.scores["bleu"] - (1 - alpha) * self.scores["selfbleu"] 52 | 53 | self.scores["sari"] = 100 * np.mean( 54 | [ 55 | SARIsent(self.inputs[ix], self.pred_targets[ix], self.gold_targets[ix]) 56 | for ix in range(len(self.pred_targets)) 57 | ] 58 | ) 59 | 60 | self.scores["em"] = np.mean( 61 | [self.pred_targets[ix] in self.gold_targets[ix] for ix in range(len(self.pred_targets))] 62 | ) 63 | 64 | # self.scores["meteor"] = meteor_corpus(self.gold_targets, self.pred_targets) 65 | 66 | return self.scores 67 | -------------------------------------------------------------------------------- /torchseq/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/models/__init__.py -------------------------------------------------------------------------------- /torchseq/models/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Mish(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x): 11 | return x * torch.tanh(F.softplus(x)) 12 | 13 | 14 | class Swish(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, x): 19 | return x * F.sigmoid(x) 20 | -------------------------------------------------------------------------------- /torchseq/models/aq_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torchseq.models.ctxtans_encoder import ContextAnswerEncoder 7 | from torchseq.models.decoder import SequenceDecoder 8 | 9 | 10 | class TransformerAqModel(nn.Module): 11 | def __init__(self, config, input_tokenizer, output_tokenizer, loss=None): 12 | super().__init__() 13 | self.config = config 14 | 15 | self.loss = loss 16 | 17 | self.ctxt_ans_encoder = ContextAnswerEncoder(config, input_tokenizer) 18 | self.seq_decoder = SequenceDecoder(config, output_tokenizer, embeddings=self.ctxt_ans_encoder.embeddings) 19 | 20 | def forward(self, batch, output, memory=None, tgt_field=None): 21 | if memory is None: 22 | memory = dict() 23 | 24 | if "encoding" not in memory: 25 | encoding, memory = self.ctxt_ans_encoder(batch["c"], batch["c_len"], batch["a_pos"], memory) 26 | memory["encoding"] = encoding 27 | 28 | logits, memory = self.seq_decoder(output, memory) 29 | 30 | return logits, memory 31 | -------------------------------------------------------------------------------- /torchseq/models/hyperbolic.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torchseq.models.hyperbolic_utils import RiemannianNormal, WrappedNormal, PoincareBall, GeodesicLayer 6 | from torchseq.utils.logging import Logger 7 | from math import sqrt 8 | 9 | 10 | class HyperbolicBottleneck(nn.Module): 11 | curvature: float 12 | embedding_dim: int 13 | latent_dim: int 14 | kl_weight: float 15 | prior_distribution: str 16 | posterior_distribution: str 17 | 18 | def __init__( 19 | self, 20 | embedding_dim, 21 | latent_dim=None, 22 | curvature=1.0, 23 | kl_weight=1.0, 24 | prior_distribution="wrapped_normal", 25 | posterior_distribution="wrapped_normal", 26 | ): 27 | super().__init__() 28 | 29 | self.curvature = curvature 30 | self.embedding_dim = embedding_dim 31 | self.latent_dim = embedding_dim if latent_dim is None else latent_dim 32 | self.kl_weight = kl_weight 33 | self.prior_distribution = prior_distribution 34 | self.posterior_distribution = posterior_distribution 35 | 36 | self.latent_manifold = PoincareBall(dim=self.latent_dim, c=self.curvature) 37 | 38 | if self.prior_distribution == "riemannian_normal": 39 | self.prior = RiemannianNormal 40 | else: 41 | self.prior = WrappedNormal 42 | 43 | if self.posterior_distribution == "riemannian_normal": 44 | self.posterior = RiemannianNormal 45 | else: 46 | self.posterior = WrappedNormal 47 | 48 | self.mu_proj = nn.Linear(self.embedding_dim, self.latent_dim, bias=True) 49 | self.log_var_proj = nn.Linear(self.embedding_dim, 1, bias=True) 50 | 51 | self._pz_mu = nn.Parameter(torch.zeros(1, self.latent_dim), requires_grad=False) 52 | self._pz_logvar = nn.Parameter(torch.zeros(1, 1), requires_grad=False) 53 | 54 | self.out_proj = GeodesicLayer(self.latent_dim, self.embedding_dim, self.latent_manifold) 55 | 56 | def forward(self, x: torch.Tensor, global_step: int) -> Tuple[torch.Tensor, torch.Tensor]: 57 | x = x.squeeze(1) # Input is bsz x 1 x dim 58 | 59 | mu_preproj, log_var = self.mu_proj(x), torch.log(nn.functional.softplus(self.log_var_proj(x)) + 1e-5) 60 | 61 | mu = self.latent_manifold.expmap0(mu_preproj) 62 | 63 | # If we're in eval mode, make this deterministic 64 | scale = log_var.exp() if self.training else torch.full_like(log_var, 1e-15) 65 | 66 | qz_x = self.posterior(loc=mu, scale=scale, manifold=self.latent_manifold) 67 | z = qz_x.rsample(torch.Size([])) 68 | 69 | pz = self.prior(loc=self._pz_mu, scale=self._pz_logvar.exp(), manifold=self.latent_manifold) 70 | 71 | KLD = torch.nn.functional.relu(qz_x.log_prob(z) - pz.log_prob(z)).sum(-1) / self.latent_dim 72 | 73 | dev_str = "train" if self.training else "dev" 74 | Logger().log_scalar(f"hyperbolic_{dev_str}/kl", KLD.mean(), global_step) 75 | 76 | z = self.out_proj(z) 77 | 78 | return z.unsqueeze(1), KLD * self.kl_weight 79 | -------------------------------------------------------------------------------- /torchseq/models/lr_schedule.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import LambdaLR 2 | from math import exp, tanh 3 | 4 | 5 | # TODO: control diff types of schedule in a more granular way - eg we may want linear warmup -> poly/exp decay 6 | def get_lr(base_lr, step, scheduled=False, warmup=True): 7 | if scheduled: 8 | step = max(step, 1) 9 | warmup_steps = 10000 10 | if warmup: 11 | return base_lr * min(pow(step, -0.5), step * pow(warmup_steps, -1.5)) 12 | else: 13 | return base_lr * min(pow(step, -0.5), warmup_steps * pow(warmup_steps, -1.5)) 14 | else: 15 | return base_lr 16 | 17 | 18 | def get_scheduler( 19 | optimizer, base_lr, scheduled=False, warmup=True, num_warmup_steps=10000, last_epoch=-1, legacy=True 20 | ): 21 | if legacy: 22 | 23 | def lr_lambda(current_step: int): 24 | if scheduled: 25 | step = max(current_step, 1) 26 | if warmup: 27 | return min(pow(step, -0.5), step * pow(num_warmup_steps, -1.5)) 28 | else: 29 | return min(pow(step, -0.5), num_warmup_steps * pow(num_warmup_steps, -1.5)) 30 | else: 31 | return 1.0 32 | 33 | else: 34 | # Replicate the original BERT LR schedule, but such that it peaks at 1.0 35 | def lr_lambda(current_step: int): 36 | if scheduled: 37 | step = max(current_step, 1) 38 | if warmup and step <= num_warmup_steps: 39 | return min(float(step) / float(num_warmup_steps), 1.0) 40 | else: 41 | return pow(step, -0.5) / pow(num_warmup_steps, -0.5) 42 | else: 43 | return 1.0 44 | 45 | return LambdaLR(optimizer, lr_lambda, last_epoch) 46 | 47 | 48 | def get_hyperbolic_schedule(gamma, step): 49 | return 2 / (1 + exp(-float(step) / float(gamma))) - 1 50 | 51 | 52 | def get_tanh_schedule(gamma, step): 53 | return tanh(float(step) / float(gamma)) ** 4 54 | -------------------------------------------------------------------------------- /torchseq/models/multihead_output.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class MultiHeadOutput(nn.Module): 6 | def __init__( 7 | self, 8 | embedding_dim, 9 | vocab_size, 10 | num_heads=1, 11 | num_projections=1, 12 | projection_init=None, 13 | freeze_projection=False, 14 | variational=False, 15 | normed=False, 16 | ): 17 | super(MultiHeadOutput, self).__init__() 18 | 19 | assert embedding_dim % num_heads == 0, "Embedding dim must be divisible by num heads!" 20 | 21 | self.num_heads = num_heads 22 | self.num_projections = num_projections 23 | self.dim_per_head = embedding_dim // num_heads 24 | self.variational = variational 25 | 26 | if num_projections > 1: 27 | self.embeds_to_logits = nn.ModuleList( 28 | [nn.Linear(self.dim_per_head, vocab_size, bias=False) for _ in range(num_projections)] 29 | ).cpu() 30 | else: 31 | self.embeds_to_logits = nn.Linear(self.dim_per_head, vocab_size, bias=False).cpu() 32 | 33 | if variational: 34 | self.embeds_to_logvars = nn.Linear(self.dim_per_head, vocab_size, bias=False).cpu() 35 | raise Exception("Variation projection is not correct! Need to implement KL loss") 36 | 37 | self.head_weight = nn.Linear(embedding_dim, num_heads * num_projections, bias=False) 38 | 39 | if projection_init is not None: 40 | self.embeds_to_logits.weight.data = projection_init 41 | 42 | if num_projections > 1: 43 | for layer in self.embeds_to_logits: 44 | layer.weight.requires_grad = not freeze_projection 45 | else: 46 | self.embeds_to_logits.weight.requires_grad = not freeze_projection 47 | 48 | # if normed: 49 | # self.embeds_to_logits = nn.utils.parametrizations.weight_norm(self.embeds_to_logits) 50 | 51 | def forward(self, embeds): 52 | if self.num_heads > 1 or self.num_projections > 1: 53 | bsz = embeds.shape[0] 54 | # Split embeds into num_heads smaller embeddings 55 | embeds_chunked = embeds.view(bsz, -1, self.num_heads, self.dim_per_head) 56 | 57 | # Project each head 58 | if self.num_projections > 1: 59 | with torch.no_grad(): 60 | for layer in self.embeds_to_logits: 61 | layer.weight.div_(torch.norm(layer.weight, dim=1, keepdim=True)) 62 | logits_split = torch.cat([layer(embeds_chunked) for layer in self.embeds_to_logits], dim=2) 63 | else: 64 | with torch.no_grad(): 65 | self.embeds_to_logits.weight.div_(torch.norm(self.embeds_to_logits.weight, dim=1, keepdim=True)) 66 | logits_split = self.embeds_to_logits(embeds_chunked) 67 | 68 | logits_weights = torch.softmax(self.head_weight(embeds), dim=-1).unsqueeze(-1) 69 | 70 | if self.variational: 71 | mu = logits_split 72 | logvar = self.embeds_to_logvars(embeds_chunked).unsqueeze(1) 73 | 74 | def reparameterize(mu, logvar): 75 | std = torch.exp(0.5 * logvar) 76 | eps = torch.randn_like(std) 77 | return mu + eps * std 78 | 79 | logits_split = reparameterize(mu, logvar) 80 | 81 | # Combine logits from each head 82 | logits = torch.sum(logits_split * logits_weights, dim=2) 83 | 84 | # logits = torch.log(logits) 85 | else: 86 | logits = self.embeds_to_logits(embeds) 87 | 88 | if self.variational: 89 | mu = logits 90 | logvar = self.embeds_to_logvars(embeds) 91 | 92 | def reparameterize(mu, logvar): 93 | std = torch.exp(0.5 * logvar) 94 | eps = torch.randn_like(std) 95 | return mu + eps * std 96 | 97 | logits = reparameterize(mu, logvar) 98 | 99 | if self.variational: 100 | return logits 101 | else: 102 | return logits 103 | -------------------------------------------------------------------------------- /torchseq/models/parallel_wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ParallelModel(nn.Module): 6 | def __init__(self, model, loss, tgt_field): 7 | super(ParallelModel, self).__init__() 8 | 9 | self.model = model 10 | self.loss = loss 11 | 12 | self.tgt_field = tgt_field 13 | 14 | def forward(self, batch, *args, **kwargs): 15 | res = self.model(batch, *args, **kwargs) 16 | 17 | this_loss = self.loss(res[0].permute(0, 2, 1), batch[self.tgt_field]) 18 | 19 | loss = torch.mean(torch.sum(this_loss, dim=1) / batch[self.tgt_field + "_len"].to(this_loss), dim=0) 20 | 21 | ret = (loss, *res) 22 | 23 | return ret 24 | 25 | 26 | def parallelify(model, loss, tgt_field): 27 | return nn.DataParallel(ParallelModel(model, loss, tgt_field)) 28 | -------------------------------------------------------------------------------- /torchseq/models/positional_embeddings.py: -------------------------------------------------------------------------------- 1 | # https://github.com/pytorch/examples/blob/master/word_language_model/model.py 2 | 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | r""" 12 | Inject some information about the relative or absolute position of the tokens 13 | in the sequence. The positional encodings have the same dimension as 14 | the embeddings, so that the two can be summed. Here, we use sine and cosine 15 | functions of different frequencies. 16 | 17 | .. math:: 18 | \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) 19 | \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) 20 | \text{where pos is the word position and i is the embed idx) 21 | 22 | Args: 23 | d_model: the embed dim (required). 24 | dropout: the dropout value (default=0.1). 25 | max_len: the max. length of the incoming sequence (default=5000). 26 | 27 | Examples: 28 | >>> pos_encoder = PositionalEncoding(d_model) 29 | 30 | """ 31 | 32 | def __init__(self, d_model, dropout=0.1, max_len=5000, batch_first=True): 33 | super(PositionalEncoding, self).__init__() 34 | self.dropout = nn.Dropout(p=dropout) 35 | 36 | pe = torch.zeros(max_len, d_model) 37 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 38 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 39 | pe[:, 0::2] = torch.sin(position * div_term) 40 | pe[:, 1::2] = torch.cos(position * div_term) 41 | pe = pe.unsqueeze(0).transpose(0, 1) 42 | # if not batch_first: 43 | # pe = pe.transpose(0, 1) 44 | # self.register_buffer("pe", pe) 45 | self.pe = nn.Parameter(pe, requires_grad=False) 46 | self.batch_first = batch_first 47 | 48 | def forward(self, x): 49 | r""" 50 | Inputs of forward function 51 | 52 | Args: 53 | x: the sequence fed to the positional encoder model (required). 54 | 55 | Shape: 56 | x: [sequence length, batch size, embed dim] 57 | output: [sequence length, batch size, embed dim] 58 | 59 | Examples: 60 | >>> output = pos_encoder(x) 61 | 62 | """ 63 | 64 | if self.batch_first: 65 | x = x + self.pe.transpose(0, 1)[:, : x.size(1), :] 66 | else: 67 | x = x + self.pe[: x.size(0), :] 68 | return self.dropout(x) 69 | -------------------------------------------------------------------------------- /torchseq/models/rerankers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/models/rerankers/__init__.py -------------------------------------------------------------------------------- /torchseq/models/rerankers/backtranslate_reranker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | from torchseq.models.samplers.teacher_force import TeacherForcedSampler 6 | 7 | 8 | def pad_to_match(x1, x2, pad_id): 9 | l1 = x1.shape[1] 10 | l2 = x2.shape[1] 11 | if l1 == l2: 12 | return x1, x2 13 | 14 | pad_required = max(l1 - l2, l2 - l1) 15 | pad_toks = torch.full((x1.shape[0], pad_required), pad_id, dtype=x1.dtype, device=x1.device) 16 | 17 | if l1 > l2: 18 | x2 = torch.cat(x2, pad_toks, dim=1) 19 | elif l2 > l1: 20 | x1 = torch.cat(x1, pad_toks, dim=1) 21 | 22 | return x1, x2 23 | 24 | 25 | class BacktranslateReranker(nn.Module): 26 | def __init__(self, config, pad_id, device, src_field, model): 27 | super(BacktranslateReranker, self).__init__() 28 | self.config = config 29 | self.device = device 30 | self.pad_id = pad_id 31 | 32 | self.src_field = src_field 33 | self.model = model 34 | self.decoder = TeacherForcedSampler(self.config, self.device) 35 | self.loss = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="none") 36 | 37 | def forward(self, candidates, lengths, batch, tgt_field, scores=None, sort=True, top1=True): 38 | # Flatten to a single (large) batch 39 | candidates_flattened = torch.flatten(candidates, 0, 1) 40 | lengths_flattened = torch.flatten(lengths) 41 | 42 | num_candidates = candidates.shape[1] 43 | original_length = batch[self.src_field].shape[1] 44 | tgt_seq_tiled = torch.repeat_interleave(batch[self.src_field], repeats=num_candidates, dim=0) 45 | tgt_seq_lens_tiled = torch.repeat_interleave(batch[self.src_field + "_len"], repeats=num_candidates, dim=0) 46 | 47 | # Pad candidates to match if necessary 48 | if original_length > candidates_flattened.shape[1]: 49 | pad_toks = torch.full( 50 | (candidates_flattened.shape[0], original_length - candidates_flattened.shape[1]), 51 | self.pad_id, 52 | dtype=candidates_flattened.dtype, 53 | device=candidates_flattened.device, 54 | ) 55 | candidates_flattened = torch.cat([candidates_flattened, pad_toks], dim=1) 56 | 57 | # print("orig", batch[self.src_field].shape) 58 | # print("tiled", tgt_seq_tiled.shape) 59 | # print("cands", candidates_flattened.shape) 60 | 61 | batch_backtranslate = { 62 | self.src_field: candidates_flattened, 63 | self.src_field + "_len": lengths_flattened, 64 | } 65 | 66 | # print(batch_backtranslate) 67 | 68 | # Get nll of original source using candidates as input 69 | _, logits, _ = self.decoder(self.model, batch_backtranslate, self.src_field) 70 | 71 | # Truncate logits if the input was longer than tgt 72 | if logits.shape[1] > tgt_seq_tiled.shape[1]: 73 | logits = logits[:, : tgt_seq_tiled.shape[1]] 74 | # If the input was shorter 75 | if tgt_seq_tiled.shape[1] > logits.shape[1]: 76 | print("Shouldnt this have already been padded??") 77 | 78 | this_loss = self.loss(logits.permute(0, 2, 1), tgt_seq_tiled) 79 | nlls = torch.sum(this_loss, dim=1) / (tgt_seq_lens_tiled - 1).to(this_loss) 80 | 81 | # reshape back to beam-wise scores, invert so that higher scores are better (keeping scores +ve) 82 | scores = -1 * nlls.reshape_as(lengths) 83 | 84 | if sort: 85 | scores, sorted_indices = torch.sort(scores, descending=True) 86 | 87 | candidates = torch.gather(candidates, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, candidates.shape[2])) 88 | 89 | if top1: 90 | output = candidates[:, 0, :] 91 | else: 92 | topk = self.config.eval.data.get("topk", None) 93 | if topk is not None: 94 | output = candidates[:, :topk, :] 95 | else: 96 | output = candidates[:, 0, :] 97 | 98 | return output, torch.sum(output != self.pad_id, dim=-1), scores 99 | -------------------------------------------------------------------------------- /torchseq/models/rerankers/combo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | from torchseq.models.rerankers.qa_reranker import QaReranker 6 | from torchseq.models.rerankers.ngram_reranker import NgramReranker 7 | from torchseq.models.rerankers.backtranslate_reranker import BacktranslateReranker 8 | 9 | 10 | class CombinationReranker(nn.Module): 11 | def __init__(self, config, tokenizer, device, src_field, model): 12 | super(CombinationReranker, self).__init__() 13 | self.config = config 14 | self.device = device 15 | self.tokenizer = tokenizer 16 | 17 | self.src_field = src_field 18 | self.model = model 19 | 20 | self.qa_reranker = QaReranker(self.config, tokenizer, self.device) 21 | self.ngram_reranker = NgramReranker(self.config, tokenizer.pad_id, self.device, self.src_field) 22 | self.backtranslate_reranker = BacktranslateReranker( 23 | self.config, tokenizer.pad_id, self.device, self.src_field, self.model 24 | ) 25 | 26 | def forward(self, candidates, lengths, batch, tgt_field, scores=None, sort=True, top1=True): 27 | # store the original seq probs 28 | nll_scores = scores 29 | 30 | # ngram scores are fraction of overlapping toks (lower is better) 31 | _, _, ngram_scores = self.ngram_reranker(candidates, lengths, batch, tgt_field, top1=False, sort=False) 32 | 33 | # backtrans scores are nll of recovering original from candidate (lower is better) 34 | _, _, backtrans_scores = self.backtranslate_reranker( 35 | candidates, lengths, batch, tgt_field, top1=False, sort=False 36 | ) 37 | 38 | # qa scores are F1 score 39 | _, _, qa_scores = self.qa_reranker(candidates, lengths, batch, tgt_field, top1=False, sort=False) 40 | 41 | # QA score should dominate - but if all candidates are unanswerable, then fall back on other scores 42 | scores = (ngram_scores * 1.5 + (backtrans_scores + nll_scores) / 2) * (qa_scores * 0.9 + 0.1) 43 | 44 | if sort: 45 | sorted_scores, sorted_indices = torch.sort(scores, descending=True) 46 | 47 | sorted_seqs = torch.gather(candidates, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, candidates.shape[2])) 48 | 49 | if top1: 50 | output = sorted_seqs[:, 0, :] 51 | else: 52 | topk = self.config.eval.data.get("topk", None) 53 | if topk is not None: 54 | output = sorted_seqs[:, :topk, :] 55 | else: 56 | output = sorted_seqs[:, 0, :] 57 | 58 | return output, torch.sum(output != self.pad_id, dim=-1), sorted_scores 59 | -------------------------------------------------------------------------------- /torchseq/models/rerankers/ngram_reranker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchseq.utils.tokenizer import Tokenizer 5 | from torchseq.utils.functions import onehot 6 | 7 | 8 | class NgramReranker(nn.Module): 9 | def __init__(self, config, pad_id, device, src_field): 10 | super(NgramReranker, self).__init__() 11 | self.config = config 12 | self.device = device 13 | self.pad_id = pad_id 14 | 15 | self.src_field = src_field 16 | 17 | def forward(self, candidates, lengths, batch, tgt_field, scores=None, sort=True, top1=True): 18 | # Get k-hot representations of the ref and candidate sequences 19 | # Also add in the "beam" dimension 20 | refs_k_hot = ( 21 | torch.sum( 22 | onehot( 23 | batch[self.src_field], 24 | N=self.config.prepro.get_first(["output_vocab_size", "vocab_size"]), 25 | ignore_index=self.pad_id, 26 | ), 27 | -2, 28 | ) 29 | .float() 30 | .unsqueeze(1) 31 | ) 32 | 33 | candidates_k_hot = torch.sum( 34 | onehot( 35 | candidates, 36 | N=self.config.prepro.get_first(["output_vocab_size", "vocab_size"]), 37 | ignore_index=self.pad_id, 38 | ), 39 | -2, 40 | ).float() 41 | 42 | # print(self.src_field, batch[self.src_field].shape) 43 | # print(candidates.shape) 44 | # print(refs_k_hot.shape, candidates_k_hot.shape) 45 | 46 | # take dot product to find token overlap between ref and candidates 47 | scores = torch.matmul(refs_k_hot, candidates_k_hot.transpose(-1, -2)) 48 | 49 | # print(scores.shape) 50 | scores = scores.squeeze(1) / (refs_k_hot.norm(dim=-1) * candidates_k_hot.norm(dim=-1)) 51 | # print(scores.shape) 52 | # Convert to fraction of different tokens, so that highest is best 53 | scores = 1 - scores 54 | 55 | if sort: 56 | scores, sorted_indices = torch.sort(scores, descending=True) 57 | 58 | candidates = torch.gather(candidates, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, candidates.shape[2])) 59 | 60 | if top1: 61 | output = candidates[:, 0, :] 62 | else: 63 | topk = self.config.eval.data.get("topk", None) 64 | if topk is not None: 65 | output = candidates[:, :topk, :] 66 | else: 67 | output = candidates[:, 0, :] 68 | 69 | return output, torch.sum(output != self.pad_id, dim=-1), scores 70 | -------------------------------------------------------------------------------- /torchseq/models/rerankers/qa_reranker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchseq.pretrained.qa import PreTrainedQA 5 | from torchseq.utils.metrics import f1 6 | from torchseq.utils.tokenizer import Tokenizer 7 | 8 | 9 | class QaReranker(nn.Module): 10 | def __init__(self, config, tokenizer, device): 11 | super(QaReranker, self).__init__() 12 | self.config = config 13 | self.device = device 14 | self.tokenizer = tokenizer 15 | 16 | self.qa_model = PreTrainedQA(device=self.device) 17 | 18 | def forward(self, candidates, lengths, batch, tgt_field, scores=None, sort=True, top1=True): 19 | # First, stringify 20 | output_strings = [ 21 | [self.tokenizer.decode(candidates.data[i][j][: lengths[i][j]]) for j in range(len(lengths[i]))] 22 | for i in range(len(lengths)) 23 | ] 24 | 25 | qa_scores = [] 26 | for ix, q_batch in enumerate(output_strings): 27 | contexts_cropped = [ 28 | self.tokenizer.decode(batch["c"][ix][: batch["c_len"][ix]]) for _ in range(len(q_batch)) 29 | ] 30 | answers = self.qa_model.infer_batch(question_list=q_batch, context_list=contexts_cropped) 31 | 32 | # this_scores = [(0 if f1(batch["a_text"][ix], ans) > 0.75 else -100) for jx, ans in enumerate(answers)] 33 | this_scores = [f1(batch["a_text"][ix], ans) for jx, ans in enumerate(answers)] 34 | 35 | qa_scores.append(this_scores) 36 | 37 | scores = torch.FloatTensor(qa_scores).to(self.device) 38 | 39 | if sort: 40 | scores, sorted_indices = torch.sort(scores, descending=True) 41 | 42 | candidates = torch.gather(candidates, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, candidates.shape[2])) 43 | 44 | if top1: 45 | output = candidates[:, 0, :] 46 | else: 47 | topk = self.config.eval.data.get("topk", None) 48 | if topk is not None: 49 | output = candidates[:, :topk, :] 50 | else: 51 | output = candidates[:, 0, :] 52 | 53 | return output, torch.sum(output != self.tokenizer.pad_id, dim=-1), scores 54 | -------------------------------------------------------------------------------- /torchseq/models/rerankers/topk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchseq.utils.tokenizer import Tokenizer 5 | 6 | 7 | class TopkReducer(nn.Module): 8 | def __init__(self, config, pad_id, device): 9 | super(TopkReducer, self).__init__() 10 | self.config = config 11 | self.device = device 12 | self.pad_id = pad_id 13 | 14 | def forward(self, candidates, lengths, batch, tgt_field, scores=None, sort=True, top1=True): 15 | # Skip sorting for now - this is unnecessary compute - if a sampling method that does not return sorted output appears this will need to change! 16 | # if sort: 17 | # # Sort with lowest scores first - we want to minimise overlap 18 | # scores, sorted_indices = torch.sort(scores, descending=False) 19 | 20 | # candidates = torch.gather(candidates, 1, sorted_indices.unsqueeze(-1).expand(-1, -1, candidates.shape[2])) 21 | 22 | # Pass-through mode: take the top-1 from a pre-sorted set of candidates (eg beam search) 23 | if top1: 24 | output = candidates[:, 0, :] 25 | 26 | return output, torch.sum(output != self.pad_id, dim=-1), scores 27 | 28 | else: 29 | return candidates, torch.sum(candidates != self.pad_id, dim=-1), scores 30 | -------------------------------------------------------------------------------- /torchseq/models/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/models/samplers/__init__.py -------------------------------------------------------------------------------- /torchseq/models/samplers/greedy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torchseq.utils.tokenizer import Tokenizer, FAIRSEQ_LANGUAGE_CODES 6 | from torchseq.utils.config import Config 7 | 8 | 9 | class GreedySampler(nn.Module): 10 | config: Config 11 | device: Union[str, torch.device] 12 | tokenizer: Tokenizer 13 | 14 | def __init__(self, config: Config, tokenizer: Tokenizer, device: Union[str, torch.device]): 15 | super(GreedySampler, self).__init__() 16 | self.config = config 17 | self.device = device 18 | self.tokenizer = tokenizer 19 | 20 | def forward( 21 | self, model: nn.Module, batch: Dict[str, torch.Tensor], tgt_field: str 22 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 23 | curr_batch_size = batch[[k for k in batch.keys() if k[-5:] != "_text"][0]].size()[0] 24 | 25 | max_output_len = self.config.eval.data.get("max_out_len", 32) 26 | 27 | BART_HACK = self.config.eval.data.get("prepend_eos", False) 28 | MBART_HACK = self.config.eval.data.get("prepend_langcode", False) 29 | 30 | # Create vector of SOS + placeholder for first prediction 31 | output = torch.LongTensor(curr_batch_size, 1).fill_(self.tokenizer.bos_id).to(self.device) 32 | logits = ( 33 | torch.FloatTensor(curr_batch_size, 1, self.config.prepro.get_first(["output_vocab_size", "vocab_size"])) 34 | .fill_(-torch.inf) 35 | .to(self.device) 36 | ) 37 | logits[:, :, self.tokenizer.bos_id] = 1e12 38 | 39 | output_done = torch.BoolTensor(curr_batch_size).fill_(False).to(self.device) 40 | padding = torch.LongTensor(curr_batch_size).fill_(self.tokenizer.pad_id).to(self.device) 41 | 42 | if BART_HACK: 43 | dummy_token = torch.LongTensor(curr_batch_size, 1).fill_(self.tokenizer.eos_id).to(self.device) 44 | output = torch.cat([dummy_token, output], dim=1) 45 | 46 | if MBART_HACK: 47 | lang_token = batch["tgt_lang"].unsqueeze(-1) 48 | eos_token = torch.LongTensor(curr_batch_size, 1).fill_(self.tokenizer.eos_id).to(self.device) 49 | output = torch.cat([eos_token, lang_token], dim=-1) 50 | 51 | seq_ix = 0 52 | memory: Dict[str, torch.Tensor] = {} 53 | while torch.sum(output_done) < curr_batch_size and seq_ix < max_output_len: 54 | new_logits, memory = model(batch, output, memory) 55 | 56 | new_output = torch.argmax(new_logits, -1) 57 | 58 | # Use pad for the output for elements that have completed 59 | new_output[:, -1] = torch.where(output_done, padding, new_output[:, -1]) 60 | 61 | output = torch.cat([output, new_output[:, -1].unsqueeze(-1)], dim=-1) 62 | 63 | logits = torch.cat([logits, new_logits[:, -1:, :]], dim=1) 64 | 65 | output_done = output_done | (output[:, -1] == self.tokenizer.eos_id) 66 | seq_ix += 1 67 | 68 | # print(BART_HACK, MBART_HACK) 69 | # print(batch['c'][0]) 70 | # print(batch['q'][0]) 71 | # print(output[0]) 72 | # print(self.tokenizer.decode(output[0])) 73 | # exit() 74 | 75 | if BART_HACK: 76 | output = output[:, 1:] 77 | # if MBART_HACK: 78 | # output = output[:, 2:] 79 | 80 | return output, logits, torch.sum(output != self.tokenizer.pad_id, dim=-1), memory 81 | -------------------------------------------------------------------------------- /torchseq/models/samplers/teacher_force.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union, Tuple 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torchseq.utils.tokenizer import Tokenizer 6 | from torchseq.utils.config import Config 7 | 8 | 9 | class TeacherForcedSampler(nn.Module): 10 | config: Config 11 | device: Union[str, torch.device] 12 | tokenizer: Tokenizer 13 | 14 | def __init__(self, config: Config, tokenizer: Tokenizer, device: Union[str, torch.device]): 15 | super(TeacherForcedSampler, self).__init__() 16 | self.config = config 17 | self.device = device 18 | self.tokenizer = tokenizer 19 | 20 | def forward( 21 | self, model: nn.Module, batch: Dict[str, torch.Tensor], tgt_field: str 22 | ) -> Tuple[torch.Tensor, torch.Tensor, None, Dict[str, torch.Tensor]]: 23 | curr_batch_size = batch[[k for k in batch.keys() if k[-5:] != "_text"][0]].size()[0] 24 | max_output_len = batch[tgt_field].size()[1] 25 | 26 | BART_HACK = self.config.eval.data.get("prepend_eos", False) 27 | MBART_HACK = self.config.eval.data.get("prepend_langcode", False) 28 | 29 | # Create vector of SOS + placeholder for first prediction 30 | 31 | logits = ( 32 | torch.FloatTensor(curr_batch_size, 1, self.config.prepro.get_first(["output_vocab_size", "vocab_size"])) 33 | .fill_(-torch.inf) 34 | .to(self.device) 35 | ) 36 | 37 | if MBART_HACK: 38 | logits.scatter_(-1, batch["tgt_lang"].unsqueeze(1), float("1e18")) 39 | else: 40 | logits[:, :, self.tokenizer.bos_id] = float("1e18") 41 | 42 | # With a transformer decoder, we can lean on the internal mask to ensure that the model can't see ahead 43 | # ..and then just do a single pass through the whole model using the gold output as input 44 | output = batch[tgt_field][:, : max_output_len - 1].to(self.device) 45 | 46 | if self.config.training.data.get("token_dropout", 0) > 0 and self.training: 47 | rand = torch.rand_like(output, dtype=torch.float) 48 | 49 | masked = torch.full_like(output, self.tokenizer.mask_id) 50 | 51 | output = torch.where( 52 | torch.bitwise_and( 53 | rand < self.config.training.data.get("token_dropout", 0), output != self.tokenizer.pad_id 54 | ), 55 | masked, 56 | output, 57 | ) 58 | 59 | if BART_HACK: 60 | dummy_token = torch.LongTensor(curr_batch_size, 1).fill_(self.tokenizer.eos_id).to(self.device) 61 | output = torch.cat([dummy_token, output], dim=1) 62 | if MBART_HACK: 63 | eos_token = torch.LongTensor(curr_batch_size, 1).fill_(self.tokenizer.eos_id).to(self.device) 64 | 65 | # lang_token = batch["tgt_lang"].unsqueeze(-1) 66 | 67 | output = torch.cat([eos_token, output], dim=1) 68 | # print(output[0]) 69 | # exit() 70 | 71 | memory: Dict[str, torch.Tensor] = {} 72 | pred_logits, memory = model(batch, output, tgt_field=tgt_field, memory=memory) 73 | 74 | # import torch._dynamo as dynamo 75 | 76 | # explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = dynamo.explain( 77 | # model, batch=batch, output=output, tgt_field=tgt_field, memory=memory 78 | # ) 79 | # print(explanation_verbose) 80 | 81 | if BART_HACK or MBART_HACK: 82 | output = output[:, 1:] 83 | 84 | pred_logits = pred_logits[:, 1:, :] 85 | # if MBART_HACK: 86 | # output = output[:, 2:] 87 | 88 | # pred_logits = pred_logits[:, 2:, :] 89 | 90 | logits = torch.cat([logits, pred_logits], dim=1) 91 | 92 | # print(BART_HACK, MBART_HACK) 93 | # print(output) 94 | # print(batch['q']) 95 | # print(torch.argmax(logits, dim=-1)) 96 | # exit() 97 | 98 | return output, logits, None, memory 99 | -------------------------------------------------------------------------------- /torchseq/models/suppression_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchseq.utils.tokenizer import Tokenizer 5 | 6 | # Get a cross-entropy style loss that penalises any token from a given sequence 7 | from torchseq.utils.functions import onehot 8 | 9 | 10 | class SuppressionLoss(nn.Module): 11 | def __init__(self, config): 12 | super(SuppressionLoss, self).__init__() 13 | self.config = config 14 | 15 | def forward(self, logits, penalty_sequence, pad_id): 16 | penalty_onehot = onehot( 17 | penalty_sequence, 18 | N=self.config.prepro.get_first(["output_vocab_size", "vocab_size"]), 19 | ignore_index=pad_id, 20 | ) 21 | 22 | penalty_mask = penalty_onehot.sum(dim=-2, keepdim=True) 23 | penalty_mask = torch.min(penalty_mask, torch.ones_like(penalty_mask)) 24 | probs = nn.functional.softmax(logits, dim=-1) 25 | 26 | loss = penalty_mask * probs 27 | 28 | return loss.sum(dim=-1) 29 | -------------------------------------------------------------------------------- /torchseq/pretrained/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/pretrained/__init__.py -------------------------------------------------------------------------------- /torchseq/pretrained/lexical_paraphraser_bert.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/pretrained/lexical_paraphraser_bert.py -------------------------------------------------------------------------------- /torchseq/pretrained/lm.py: -------------------------------------------------------------------------------- 1 | # from flask import Flask, request, current_app 2 | import json 3 | import logging 4 | import os 5 | 6 | import torch 7 | from transformers import GPT2LMHeadModel, GPT2Model, GPT2Tokenizer 8 | 9 | from tqdm import tqdm 10 | 11 | # TODO: config this 12 | USE_CUDA = True 13 | 14 | 15 | def ceiling_division(n, d): 16 | return -(n // -d) 17 | 18 | 19 | class PretrainedLM: 20 | def __init__(self): 21 | # Load pre-trained model tokenizer (vocabulary) 22 | self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 23 | 24 | # Load pre-trained model (weights) 25 | self.model = GPT2LMHeadModel.from_pretrained("gpt2") 26 | self.model.eval() 27 | 28 | # If you have a GPU, put everything on cuda 29 | if USE_CUDA: 30 | self.model.to(torch.device("cuda")) 31 | 32 | def get_log_prob(self, sentences, silent=False): 33 | if len(sentences) > 32: 34 | log_probs = [] 35 | for b in tqdm(range(ceiling_division(len(sentences), 32)), desc="LM log probs", disable=silent): 36 | start_ix = b * 32 37 | end_ix = min(len(sentences), (b + 1) * 32) 38 | log_probs.extend(self.get_seq_log_prob(self.get_batch(sentences[start_ix:end_ix]))) 39 | else: 40 | log_probs = self.get_seq_log_prob(self.get_batch(sentences)) 41 | 42 | return log_probs 43 | 44 | def get_batch(self, str_in): 45 | tok_unpadded = [self.tokenizer.encode(x, add_prefix_space=True) for x in str_in] 46 | max_len = max([len(x) for x in tok_unpadded]) 47 | tok_batch = [x + [0 for i in range(max_len - len(x))] for x in tok_unpadded] 48 | mask_batch = [[1 for i in range(len(x))] + [0 for i in range(max_len - len(x))] for x in tok_unpadded] 49 | 50 | return tok_batch, mask_batch 51 | 52 | def get_seq_log_prob(self, batch): 53 | tokens, mask = batch 54 | 55 | # Convert inputs to PyTorch tensors 56 | tokens_tensor = torch.tensor(tokens) 57 | mask_tensor = torch.tensor(mask, dtype=torch.float) 58 | 59 | if USE_CUDA: 60 | tokens_tensor = tokens_tensor.to(torch.device("cuda")) 61 | mask_tensor = mask_tensor.to(torch.device("cuda")) 62 | 63 | # Predict all tokens 64 | with torch.no_grad(): 65 | loss, logits = self.model(tokens_tensor, labels=tokens_tensor, attention_mask=mask_tensor)[:2] 66 | all_probs = torch.softmax(logits, -1) 67 | 68 | torch.cuda.empty_cache() 69 | 70 | # print(all_probs.size(), all_probs) 71 | # print(tokens_tensor.unsqueeze(-1).size(), tokens_tensor.unsqueeze(-1)) 72 | 73 | probs = torch.gather(all_probs, 2, tokens_tensor.unsqueeze(-1)).squeeze(-1) 74 | 75 | log_probs = torch.log(probs) 76 | 77 | nll = -1 * torch.sum(log_probs * mask_tensor, -1) / torch.sum(mask_tensor, -1) 78 | return nll.tolist() 79 | -------------------------------------------------------------------------------- /torchseq/pretrained/nli.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline 3 | import torch 4 | from tqdm import tqdm 5 | 6 | 7 | class ListDataset(torch.utils.data.Dataset): 8 | def __init__(self, original_list): 9 | self.original_list = original_list 10 | 11 | def __len__(self): 12 | return len(self.original_list) 13 | 14 | def __getitem__(self, i): 15 | return self.original_list[i] 16 | 17 | 18 | class PretrainedNliModel: 19 | tokenizer: AutoTokenizer 20 | model: AutoModelForSequenceClassification 21 | pipe: pipeline 22 | 23 | def __init__(self, model_id: str = "tomhosking/deberta-v3-base-debiased-nli"): 24 | self.tokenizer = AutoTokenizer.from_pretrained(model_id) 25 | 26 | self.model = AutoModelForSequenceClassification.from_pretrained(model_id).cuda() 27 | 28 | self.ENTAILMENT_LABEL = ( 29 | self.model.config.label2id["ENTAILMENT"] 30 | if "ENTAILMENT" in self.model.config.label2id 31 | else self.model.config.label2id["entailment"] 32 | ) 33 | 34 | self.pipe = pipeline("text-classification", model=self.model, tokenizer=self.tokenizer, device=0) 35 | 36 | def get_scores( 37 | self, 38 | premises: List[str], 39 | hypotheses: List[str], 40 | return_entailment_prob: bool = True, 41 | bsz: int = 64, 42 | progress: bool = False, 43 | ): 44 | dataset = ListDataset([{"text": p, "text_pair": h} for p, h in zip(premises, hypotheses)]) 45 | 46 | outputs = [ 47 | {x["label"]: x["score"] for x in res} 48 | for res in tqdm( 49 | self.pipe(dataset, batch_size=bsz, top_k=None, num_workers=2), 50 | disable=(not progress), 51 | total=len(dataset), 52 | ) 53 | ] 54 | # outputs = [{x["label"]: x["score"] for x in res} for res in outputs] 55 | 56 | if return_entailment_prob: 57 | return [p["ENTAILMENT"] for p in outputs] 58 | else: 59 | return outputs 60 | -------------------------------------------------------------------------------- /torchseq/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/utils/__init__.py -------------------------------------------------------------------------------- /torchseq/utils/ari.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.special import binom 3 | from sklearn.metrics import adjusted_rand_score 4 | 5 | 6 | def C2(x): 7 | return binom(x, 2) 8 | 9 | 10 | def get_cluster_ari(predictions, references): 11 | scores = [] 12 | for curr_references, curr_predictions in zip(references, predictions): 13 | n = len( 14 | set( 15 | [x for cluster in curr_references for x in cluster] 16 | + [x for cluster in curr_predictions for x in cluster] 17 | ) 18 | ) 19 | 20 | nij = np.array([[len(set(X) & set(Y)) for Y in curr_predictions] for X in curr_references]) 21 | ai = nij.sum(axis=0) 22 | bj = nij.sum(axis=1) 23 | 24 | numerator = C2(nij).sum() - (C2(ai).sum() * C2(bj).sum()) / C2(n) 25 | 26 | denominator = 0.5 * (C2(ai).sum() + C2(bj).sum()) - (C2(ai).sum() * C2(bj).sum()) / C2(n) 27 | if numerator > 0: 28 | scores.append(numerator / denominator) 29 | else: 30 | scores.append(0) 31 | return np.mean(scores) * 100 32 | -------------------------------------------------------------------------------- /torchseq/utils/cache.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import numpy as np 5 | 6 | 7 | class Cache: 8 | def __init__(self, output_path=""): 9 | self.path = os.path.join(output_path, "cache") 10 | os.makedirs(self.path, exist_ok=True) 11 | 12 | def load(self, key): 13 | if not os.path.exists(os.path.join(self.path, f"{key}.pt")): 14 | return None 15 | # obj = np.load(os.path.join(self.path, f"{key}.npy")) 16 | # if isinstance(obj, np.ndarry): 17 | # return torch.from_numpy(obj) 18 | # else: 19 | # return obj 20 | return torch.load(os.path.join(self.path, f"{key}.pt")) 21 | 22 | def save(self, key, obj): 23 | # np.save(os.path.join(self.path, f"{key}.npy"), obj) 24 | torch.save(obj, os.path.join(self.path, f"{key}.pt")) 25 | -------------------------------------------------------------------------------- /torchseq/utils/config.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, no_type_check, TYPE_CHECKING 2 | from copy import deepcopy 3 | 4 | # A simple class to convert a (nested) dictionary to an object 5 | 6 | 7 | @no_type_check 8 | class Config: 9 | d: Dict[str, Any] 10 | 11 | def __init__(self, d): 12 | self.data = d 13 | for a, b in d.items(): 14 | if isinstance(b, (list, tuple)): 15 | setattr(self, a, [Config(x) if isinstance(x, dict) else x for x in b]) 16 | else: 17 | setattr(self, a, Config(b) if isinstance(b, dict) else b) 18 | 19 | # Include for typing 20 | if TYPE_CHECKING: 21 | 22 | def __getattr__(self, name: str) -> Any: 23 | pass 24 | 25 | def get(self, key, default=None) -> Any: 26 | return self.data.get(key, default) 27 | 28 | def get_first(self, keys): 29 | for key in keys: 30 | if key in self.data: 31 | if isinstance(self.data[key], dict): 32 | return Config(self.data[key]) 33 | else: 34 | return self.data[key] 35 | else: 36 | raise KeyError 37 | 38 | def get_path(self, path, default=None): 39 | if path[0] in self.data: 40 | if len(path) > 1: 41 | return getattr(self, path[0]).get_path(path[1:], default) 42 | else: 43 | return self.data[path[0]] 44 | else: 45 | return default 46 | 47 | 48 | def merge_cfg_dicts(main_cfg, cfg_mask): 49 | main_cfg = deepcopy(main_cfg) 50 | for k, v in cfg_mask.items(): 51 | if k in main_cfg and isinstance(main_cfg[k], dict) and isinstance(cfg_mask[k], dict): 52 | main_cfg[k] = merge_cfg_dicts(main_cfg[k], cfg_mask[k]) 53 | elif k == "name" and isinstance(cfg_mask[k], str) and "%" in cfg_mask[k]: 54 | main_cfg[k] = cfg_mask[k].replace("%", main_cfg[k]) 55 | else: 56 | main_cfg[k] = cfg_mask[k] 57 | 58 | return main_cfg 59 | -------------------------------------------------------------------------------- /torchseq/utils/easy_generate.py: -------------------------------------------------------------------------------- 1 | from torchseq.datasets.builder import dataloader_from_config 2 | 3 | 4 | def generate(instance, samples=[]): 5 | data_loader = dataloader_from_config(config=instance.config, data_path=instance.data_path, dev_samples=samples) 6 | _, _, (pred_output, _, _), _ = instance.inference(data_loader.valid_loader, desc="Generating") 7 | 8 | return pred_output 9 | -------------------------------------------------------------------------------- /torchseq/utils/fleiss.py: -------------------------------------------------------------------------------- 1 | """ 2 | fleiss.py by Marco Lui, Dec 2010 3 | 4 | Based on 5 | http://en.wikipedia.org/wiki/Fleiss'_kappa 6 | and 7 | Cardillo G. (2007) Fleisses kappa: compute the Fleiss'es kappa for multiple raters. 8 | http://www.mathworks.com/matlabcentral/fileexchange/15426 9 | """ 10 | 11 | import numpy 12 | 13 | # from scipy.special import erfc 14 | 15 | 16 | def fleiss(data): 17 | if not len(data.shape) == 2: 18 | raise (ValueError, "input must be 2-dimensional array") 19 | if not issubclass(data.dtype.type, numpy.integer): 20 | raise (TypeError, "expected integer type") 21 | if not numpy.isfinite(data).all(): 22 | raise (ValueError, "all data must be finite") 23 | 24 | raters = data.sum(axis=1) 25 | if (raters - raters.max()).any(): 26 | raise (ValueError, "inconsistent number of raters") 27 | 28 | num_raters = raters[0] 29 | num_subjects, num_category = data.shape 30 | total_ratings = num_subjects * num_raters 31 | 32 | pj = data.sum(axis=0) / float(total_ratings) 33 | pi = ((data * data).sum(axis=1) - num_raters).astype(float) / (num_raters * (num_raters - 1)) 34 | pbar = pi.sum() / num_subjects 35 | pebar = (pj * pj).sum() 36 | 37 | kappa = (pbar - pebar) / (1 - pebar) 38 | return kappa 39 | 40 | 41 | # def fleiss(data): 42 | # if not len(data.shape) == 2: 43 | # raise ValueError, 'input must be 2-dimensional array' 44 | # if not issubclass(data.dtype.type, numpy.integer): 45 | # raise TypeError, 'expected integer type' 46 | # if not numpy.isfinite(data).all(): 47 | # raise ValueError, 'all data must be finite' 48 | # 49 | # raters = data.sum(axis=1) 50 | # if (raters - raters.max()).any(): 51 | # raise ValueError, 'inconsistent number of raters' 52 | # 53 | # n, num_category = data.shape 54 | # # m=sum(x(1,:)); %raters 55 | # m = data[0].sum(axis=0) 56 | # 57 | # # a=n*m; 58 | # a = n * m 59 | # 60 | # # pj=(sum(x)./(a)); %overall proportion of ratings in category j 61 | # pj = data.sum(axis=0) / float(a) 62 | # 63 | # # b=pj.*(1-pj); 64 | # b = pj * (1-pj) 65 | # 66 | # # c=a*(m-1); 67 | # c = a * (m-1) 68 | # 69 | # # d=sum(b); 70 | # d = numpy.sum(b, axis=0) 71 | # 72 | # # kj=1-(sum((x.*(m-x)))./(c.*b)); %the value of kappa for the j-th category 73 | # kj = 1 - ( (data * (m-data)).sum(axis=0) / (c*b) ) 74 | # 75 | # # sekj=realsqrt(2/c); %kj standar error 76 | # sekj = numpy.sqrt(2/c) 77 | # 78 | # # zkj=kj./sekj; 79 | # zkj = kj / sekj 80 | # 81 | # # pkj=(1-0.5*erfc(-abs(zkj)/realsqrt(2)))*2; 82 | # pkj = (1 - 0.5*erfc(-numpy.abs(zkj) / numpy.sqrt(2))) * 2 83 | # 84 | # # k=sum(b.*kj)/d; %Fleiss'es (overall) kappa 85 | # k = (b*kj).sum(axis=0) / d 86 | # 87 | # # sek=realsqrt(2*(d^2-sum(b.*(1-2.*pj))))/sum(b.*realsqrt(c)); %kappa standard error 88 | # sek = numpy.sqrt( 2*(d*d-(b*(1-2*pj)).sum(axis=0)) ) / (b * numpy.sqrt(c)).sum(axis=0) 89 | # 90 | # # ci=k+([-1 1].*(abs(0.5*erfc(-alpha/2/realsqrt(2)))*sek)); %k confidence interval 91 | # # omitted as we are not working out the ci 92 | # 93 | # # z=k/sek; %normalized kappa 94 | # z = k/sek 95 | # 96 | # # p=(1-0.5*erfc(-abs(z)/realsqrt(2)))*2; 97 | # p = (1 - 0.5*erfc(-numpy.abs(z) / numpy.sqrt(2))) * 2 98 | # return k, p 99 | 100 | if __name__ == "__main__": 101 | data = numpy.array( 102 | [ 103 | [0, 0, 0, 0, 14], 104 | [0, 2, 6, 4, 2], 105 | [0, 0, 3, 5, 6], 106 | [0, 3, 9, 2, 0], 107 | [2, 2, 8, 1, 1], 108 | [7, 7, 0, 0, 0], 109 | [3, 2, 6, 3, 0], 110 | [2, 5, 3, 2, 2], 111 | [6, 5, 2, 1, 0], 112 | [0, 2, 2, 3, 7], 113 | ] 114 | ) 115 | 116 | print(fleiss(data)) 117 | -------------------------------------------------------------------------------- /torchseq/utils/logging.py: -------------------------------------------------------------------------------- 1 | # from torch.utils.tensorboard import SummaryWriter 2 | 3 | from torchseq.utils.singleton import Singleton 4 | 5 | import os 6 | import torch 7 | 8 | 9 | from torchseq.utils.wandb import wandb_log 10 | from wandb import Histogram as wbHistogram 11 | 12 | 13 | class Logger(metaclass=Singleton): 14 | writer = None 15 | step = 0 16 | 17 | def __init__(self, silent=False, log_path=None, interval=10): 18 | self.silent = silent 19 | self.interval = interval 20 | 21 | # if log_path is not None: 22 | # self.writer = SummaryWriter(log_path) 23 | 24 | def log_scalar(self, key, value, iteration): 25 | # if iteration < self.step: 26 | # raise Exception("What's the first thing that decreases step?!") 27 | self.step = iteration 28 | if iteration % self.interval != 0: 29 | return 30 | wandb_log({key: value}, step=iteration) 31 | if self.writer is not None: 32 | self.writer.add_scalar(key, value, iteration) 33 | 34 | def log_histogram(self, key, value, iteration): 35 | if max(value) >= 512: 36 | value = [x for x in value if x < 512] 37 | if len(value) == 0: 38 | return 39 | 40 | wandb_log({key: wbHistogram(value, num_bins=int(max(value)) + 1)}, step=iteration) 41 | if self.writer is not None: 42 | self.writer.add_histogram(key, torch.Tensor(value), iteration) 43 | 44 | def log_text(self, key, value, iteration): 45 | if self.writer is not None: 46 | self.writer.add_text(key, value, iteration) 47 | -------------------------------------------------------------------------------- /torchseq/utils/loss_dropper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | # https://github.com/ddkang/loss_dropper 5 | # @article{kang2020improved, 6 | # title={Improved Natural Language Generation via Loss Truncation}, 7 | # author={Daniel Kang and Tatsunori Hashimoto}, 8 | # journal={ACL}, 9 | # year={2020} 10 | # } 11 | 12 | 13 | class LossDropper(nn.Module): 14 | def __init__(self, dropc=0.4, min_count=10000, recompute=10000, verbose=True): 15 | super().__init__() 16 | self.keepc = 1.0 - dropc 17 | self.count = 0 18 | self.min_count = min_count 19 | 20 | self.recompute = recompute 21 | self.last_computed = 0 22 | self.percentile_val = 100000000.0 23 | self.cur_idx = 0 24 | 25 | self.verbose = verbose 26 | 27 | self.vals = np.zeros(self.recompute, dtype=np.float32) 28 | 29 | def forward(self, loss): 30 | if loss is None: 31 | return loss 32 | 33 | self.last_computed += loss.numel() 34 | self.count += loss.numel() 35 | if self.count < len(self.vals): 36 | self.vals[self.count - loss.numel() : self.count] = loss.detach().cpu().numpy().flatten() 37 | self.cur_idx += loss.numel() 38 | return (loss < np.inf).type(loss.dtype) 39 | else: 40 | for idx, item in enumerate(loss): 41 | self.vals[self.cur_idx] = item 42 | self.cur_idx += 1 43 | if self.cur_idx >= len(self.vals): 44 | self.cur_idx = 0 45 | if self.count < self.min_count: 46 | return (loss < np.inf).type(loss.dtype) 47 | 48 | if self.last_computed > self.recompute: 49 | self.percentile_val = np.percentile(self.vals, self.keepc * 100) 50 | if self.verbose: 51 | print("Using cutoff", self.percentile_val) 52 | self.last_computed = 0 53 | 54 | mask = (loss < self.percentile_val).type(loss.dtype) 55 | return mask 56 | -------------------------------------------------------------------------------- /torchseq/utils/mckenzie.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import requests 4 | 5 | 6 | def update_mckenzie(progress, metric): 7 | logger = logging.getLogger("McKenzie") 8 | if "MCKENZIE_ENDPOINT" in os.environ: 9 | try: 10 | job_id = os.environ["SLURM_JOB_ID"] 11 | partition = os.environ["SLURM_JOB_PARTITION"] 12 | endpoint = os.environ["MCKENZIE_ENDPOINT"] 13 | 14 | requests.post( 15 | "http://" + endpoint + "/hooks/update_job/", 16 | data={"jobid": job_id, "partition": partition, "progress": progress, "metric": metric}, 17 | ) 18 | except Exception as e: 19 | logger.warning("Error updating McKenzie: " + repr(e)) 20 | 21 | 22 | def set_status_mckenzie(status): 23 | logger = logging.getLogger("McKenzie") 24 | if "MCKENZIE_ENDPOINT" in os.environ: 25 | try: 26 | job_id = os.environ["SLURM_JOB_ID"] 27 | partition = os.environ["SLURM_JOB_PARTITION"] 28 | endpoint = os.environ["MCKENZIE_ENDPOINT"] 29 | 30 | requests.post( 31 | "http://" + endpoint + "/hooks/update_job/", 32 | data={"jobid": job_id, "partition": partition, "status": status}, 33 | ) 34 | except Exception as e: 35 | logger.warning("Error updating McKenzie: " + repr(e)) 36 | -------------------------------------------------------------------------------- /torchseq/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | from collections import Counter 4 | 5 | from nltk.tokenize import TreebankWordTokenizer, sent_tokenize 6 | from nltk.translate.meteor_score import single_meteor_score 7 | 8 | # from torchseq.utils.bleu import compute_bleu 9 | from torchseq.utils.sari import SARIsent 10 | 11 | import sacrebleu 12 | 13 | 14 | # def tokenize(text): 15 | # # return text.split(' ') 16 | # sents = sent_tokenize(text) 17 | # tokens = [tok.lower() for sent in sents for tok in TreebankWordTokenizer().tokenize(sent)] 18 | # return tokens 19 | 20 | 21 | # # takes a single untokenised string as input 22 | # def bleu(gold, prediction, order=4): 23 | # return compute_bleu([[tokenize(gold)]], [tokenize(prediction)], smooth=False, max_order=order)[0] 24 | 25 | 26 | # takes a list of untokenized strings as inputs 27 | def bleu_corpus(golds, preds, order=4): 28 | return sacrebleu.corpus_bleu(preds, [golds], lowercase=True).score 29 | # return compute_bleu( 30 | # [[tokenize(gold)] for gold in golds], [tokenize(pred) for pred in preds], smooth=False, max_order=order 31 | # )[0] 32 | 33 | 34 | def ibleu_corpus(golds, preds, inputs, alpha=0.8): 35 | return alpha * bleu_corpus(golds, preds) - (1 - alpha) * bleu_corpus(preds, inputs) 36 | # return sum([alpha*bleu(golds[i], preds[i]) - (1-alpha)*bleu(golds[i], inputs[i]) for i in range(len(golds))])/len(golds) 37 | 38 | 39 | def sari_corpus(golds, preds, inputs): 40 | return sum([SARIsent(i, p, g) for g, p, i in zip(golds, preds, inputs)]) / len(golds) 41 | 42 | 43 | def meteor_corpus(golds, preds): 44 | return sum( 45 | [ 46 | single_meteor_score(TreebankWordTokenizer().tokenize(g), TreebankWordTokenizer().tokenize(p)) 47 | for g, p in zip(golds, preds) 48 | ] 49 | ) / len(golds) 50 | 51 | 52 | def f1(gold, prediction): 53 | prediction_tokens = prediction.lower().split() 54 | ground_truth_tokens = gold.lower().split() 55 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 56 | num_same = sum(common.values()) 57 | if num_same == 0: 58 | return 0 59 | precision = 1.0 * num_same / len(prediction_tokens) 60 | recall = 1.0 * num_same / len(ground_truth_tokens) 61 | f1 = (2 * precision * recall) / (precision + recall) 62 | return f1 63 | 64 | 65 | def normalize_answer(s): 66 | def remove_articles(text): 67 | return re.sub(r"\b(a|an|the)\b", " ", text) 68 | 69 | def white_space_fix(text): 70 | return " ".join(text.split()) 71 | 72 | def remove_punc(text): 73 | exclude = set(string.punctuation) 74 | return "".join(ch for ch in text if ch not in exclude) 75 | 76 | def lower(text): 77 | return text.lower() 78 | 79 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 80 | -------------------------------------------------------------------------------- /torchseq/utils/model_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | from torchseq.agents.aq_agent import AQAgent 3 | from torchseq.agents.seq2seq_agent import Seq2SeqAgent 4 | from torchseq.agents.retrieval_agent import RetrievalAgent 5 | from torchseq.agents.lm_agent import LangModelAgent 6 | from torchseq.utils.config import Config, merge_cfg_dicts 7 | from torchseq.utils.config_migration import check_config 8 | import torch 9 | import logging 10 | 11 | 12 | AGENT_TYPES = { 13 | "aq": AQAgent, 14 | "langmodel": LangModelAgent, 15 | "para": Seq2SeqAgent, 16 | "seq2seq": Seq2SeqAgent, 17 | "autoencoder": Seq2SeqAgent, 18 | "exemplarguided": Seq2SeqAgent, 19 | "retrieval": RetrievalAgent, 20 | } 21 | 22 | logger = logging.getLogger("Loader") 23 | 24 | 25 | def config_from_path(path_to_model, config_patch=None): 26 | with open(path_to_model + "/config.json") as f: 27 | cfg_dict = json.load(f) 28 | 29 | if config_patch is not None: 30 | cfg_dict = merge_cfg_dicts(cfg_dict, config_patch) 31 | 32 | config = Config(cfg_dict) 33 | 34 | return config 35 | 36 | 37 | def model_from_path( 38 | path_to_model, 39 | output_path="./runs/", 40 | data_path="./data/", 41 | config_patch=None, 42 | training_mode=False, 43 | run_id=None, 44 | **kwargs, 45 | ): 46 | torch.cuda.empty_cache() 47 | 48 | config = config_from_path(path_to_model, config_patch) 49 | 50 | if check_config(config.data): 51 | logger.warning("Config is outdated! Run the migration script to update it") 52 | 53 | # run_id = path_to_model.split("/")[-1] if run_id is False else run_id 54 | 55 | checkpoint_path = path_to_model + "/model/checkpoint.pt" 56 | 57 | instance = AGENT_TYPES[config.task]( 58 | config, 59 | run_id, 60 | output_path, 61 | data_path=data_path, 62 | cache_root=path_to_model, 63 | training_mode=training_mode, 64 | **kwargs, 65 | ) 66 | instance.load_checkpoint(checkpoint_path) 67 | if not training_mode: 68 | instance.model.eval() 69 | 70 | return instance 71 | -------------------------------------------------------------------------------- /torchseq/utils/optimizer_group.py: -------------------------------------------------------------------------------- 1 | class OptimizerGroup: 2 | def __init__(self, optimizer_list): 3 | self.optimizers = optimizer_list 4 | 5 | def step(self): 6 | for opt in self.optimizers: 7 | opt.step() 8 | 9 | def load_state_dict(self, state_dict_arr): 10 | if not isinstance(state_dict_arr, list): 11 | state_dict_arr = [state_dict_arr] 12 | for ix, opt in enumerate(self.optimizers): 13 | if state_dict_arr[ix] is not None: 14 | opt.load_state_dict(state_dict_arr[ix]) 15 | 16 | def state_dict(self): 17 | return [opt.state_dict() for opt in self.optimizers] 18 | 19 | def zero_grad(self): 20 | for opt in self.optimizers: 21 | opt.zero_grad() 22 | 23 | def __iter__(self): 24 | for opt in self.optimizers: 25 | yield opt 26 | 27 | def __getitem__(self, ix): 28 | return self.optimizers[ix] 29 | 30 | def __len__(self): 31 | return len(self.optimizers) 32 | 33 | 34 | class SchedulerGroup: 35 | def __init__(self, scheduler_list): 36 | self.schedulers = scheduler_list 37 | 38 | def step(self): 39 | for sched in self.schedulers: 40 | sched.step() 41 | 42 | def load_state_dict(self, state_dict_arr): 43 | if not isinstance(state_dict_arr, list): 44 | state_dict_arr = [state_dict_arr] 45 | for ix, sched in enumerate(self.schedulers): 46 | if state_dict_arr[ix] is not None: 47 | sched.load_state_dict(state_dict_arr[ix]) 48 | 49 | def state_dict(self): 50 | return [sched.state_dict() for sched in self.schedulers] 51 | 52 | def __iter__(self): 53 | for sched in self.schedulers: 54 | yield sched 55 | 56 | def __getitem__(self, ix): 57 | return self.schedulers[ix] 58 | 59 | def __len__(self): 60 | return len(self.schedulers) 61 | -------------------------------------------------------------------------------- /torchseq/utils/perplexity.py: -------------------------------------------------------------------------------- 1 | from torchseq.utils.functions import onehot 2 | import torch 3 | 4 | 5 | def get_perplexity(logits, indices, vocab_size=None, ignore_index=None): 6 | seq_probs = torch.softmax(logits, dim=-1) 7 | seq_oh = onehot(indices, vocab_size, ignore_index) 8 | 9 | seq_entropy = torch.sum(torch.log2(seq_probs + 1e-10) * seq_oh, dim=-1) 10 | 11 | perplexity = torch.pow(2, -torch.sum(seq_entropy, dim=-1) / seq_oh.sum(dim=-1).sum(dim=-1)) 12 | 13 | return perplexity 14 | -------------------------------------------------------------------------------- /torchseq/utils/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import numpy as np 5 | from rouge_score import rouge_scorer, scoring 6 | 7 | # Based on the GEM ROUGE implementation: https://github.com/GEM-benchmark/GEM-metrics/blob/main/gem_metrics/rouge.py 8 | """ROUGE uses Google implementation (https://github.com/google-research/google-research/tree/master/rouge) 9 | but adds own implementation of multi-ref jackknifing. 10 | The Google implementation should be identical to Rouge-155 (except tokenization?), 11 | the jackknifing follows the description of the ROUGE paper. 12 | """ 13 | 14 | from absl import logging 15 | 16 | # Rouge lets us know that it's using the default tokenizer, turn that off: 17 | logging.set_verbosity(logging.WARNING) 18 | 19 | 20 | def get_pairwise_rouge(pred, ref): 21 | rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 22 | rouge = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True) 23 | res = rouge.score(ref, pred) 24 | return {rtype: res[rtype].fmeasure * 100 for rtype in rouge_types} 25 | 26 | 27 | def get_jackknife_rouge(predictions, references, stemming=True): 28 | rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 29 | rouge = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=stemming) 30 | score_list = [] 31 | 32 | for refs, pred in zip( 33 | references, 34 | predictions, 35 | ): 36 | # ROUGE multi-ref jackknifing 37 | if len(refs) > 1: 38 | cur_scores = [rouge.score(ref, pred) for ref in refs] 39 | 40 | # get best score for all leave-one-out sets 41 | best_scores = [] 42 | for leave in range(len(refs)): 43 | cur_scores_leave_one = [cur_scores[s] for s in range(len(refs)) if s != leave] 44 | best_scores.append( 45 | { 46 | rouge_type: max( 47 | [s[rouge_type] for s in cur_scores_leave_one], 48 | key=lambda s: s.fmeasure, 49 | ) 50 | for rouge_type in rouge_types 51 | } 52 | ) 53 | 54 | # average the leave-one-out bests to produce the final score 55 | score = { 56 | rouge_type: scoring.Score( 57 | np.mean([b[rouge_type].precision for b in best_scores]), 58 | np.mean([b[rouge_type].recall for b in best_scores]), 59 | np.mean([b[rouge_type].fmeasure for b in best_scores]), 60 | ) 61 | for rouge_type in rouge_types 62 | } 63 | else: 64 | score = rouge.score(refs[0], pred) 65 | 66 | # convert the named tuples to plain nested dicts 67 | score = { 68 | rouge_type: { 69 | "precision": score[rouge_type].precision, 70 | "recall": score[rouge_type].recall, 71 | "fmeasure": score[rouge_type].fmeasure, 72 | } 73 | for rouge_type in rouge_types 74 | } 75 | score_list.append(score) 76 | 77 | l1_keys = list(score_list[0].keys()) 78 | # l2_keys = score_list[0][l1_keys[0]].keys() 79 | return {key1: round(np.mean([score[key1]["fmeasure"] for score in score_list]) * 100, 5) for key1 in l1_keys} 80 | -------------------------------------------------------------------------------- /torchseq/utils/seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def set_seed(seed=1029): 9 | random.seed(seed) 10 | os.environ["PYTHONHASHSEED"] = str(seed) 11 | os.environ["__TORCHAQSEED"] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def init_worker(worker_id): 21 | set_seed(int(os.environ["__TORCHAQSEED"])) 22 | -------------------------------------------------------------------------------- /torchseq/utils/singleton.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class Singleton(type): 5 | _instances: Dict[str, type] = {} 6 | 7 | def __call__(cls, *args, **kwargs): 8 | if cls not in cls._instances: 9 | cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) 10 | return cls._instances[cls] 11 | -------------------------------------------------------------------------------- /torchseq/utils/timer.py: -------------------------------------------------------------------------------- 1 | from time import perf_counter 2 | 3 | 4 | class Timer: 5 | def __init__(self, template="Time: {:.3f} seconds", show_readout=True) -> None: 6 | self.show_readout = show_readout 7 | self.template = template 8 | 9 | def __enter__(self): 10 | self.start = perf_counter() 11 | return self 12 | 13 | def __exit__(self, type, value, traceback): 14 | self.time = perf_counter() - self.start 15 | self.readout = self.template.format(self.time) 16 | if self.show_readout: 17 | print(self.readout) 18 | -------------------------------------------------------------------------------- /torchseq/utils/tokenizer_wordlevel.py: -------------------------------------------------------------------------------- 1 | from tokenizers.implementations.base_tokenizer import BaseTokenizer 2 | from typing import Optional, List, Union 3 | 4 | from tokenizers import Tokenizer, Encoding, AddedToken 5 | from tokenizers.models import WordLevel 6 | from tokenizers.normalizers import unicode_normalizer_from_str, Lowercase, Sequence 7 | import tokenizers 8 | 9 | 10 | class WordLevelTokenizer(BaseTokenizer): 11 | """WordLevelTokenizer 12 | Represents a simple word level tokenization 13 | """ 14 | 15 | def __init__( 16 | self, 17 | vocab_file: Optional[str] = None, 18 | unk_token: Union[str, AddedToken] = "", 19 | bos_token: Union[str, AddedToken] = "", 20 | eos_token: Union[str, AddedToken] = "", 21 | pad_token: Union[str, AddedToken] = "", 22 | mask_token: Union[str, AddedToken] = "", 23 | lowercase: bool = False, 24 | unicode_normalizer: Optional[str] = None, 25 | ): 26 | if vocab_file is not None: 27 | tokenizer = Tokenizer(WordLevel(vocab_file)) 28 | else: 29 | tokenizer = Tokenizer(WordLevel()) 30 | 31 | # Let the tokenizer know about special tokens if they are part of the vocab 32 | if tokenizer.token_to_id(str(unk_token)) is not None: 33 | tokenizer.add_special_tokens([str(unk_token)]) 34 | if tokenizer.token_to_id(str(bos_token)) is not None: 35 | tokenizer.add_special_tokens([str(bos_token)]) 36 | if tokenizer.token_to_id(str(eos_token)) is not None: 37 | tokenizer.add_special_tokens([str(eos_token)]) 38 | if tokenizer.token_to_id(str(pad_token)) is not None: 39 | tokenizer.add_special_tokens([str(pad_token)]) 40 | if tokenizer.token_to_id(str(mask_token)) is not None: 41 | tokenizer.add_special_tokens([str(mask_token)]) 42 | 43 | # Check for Unicode normalization first (before everything else) 44 | normalizers = [] 45 | 46 | if unicode_normalizer: 47 | normalizers += [unicode_normalizer_from_str(unicode_normalizer)] 48 | 49 | if lowercase: 50 | normalizers += [Lowercase()] 51 | 52 | # Create the normalizer structure 53 | if len(normalizers) > 0: 54 | if len(normalizers) > 1: 55 | tokenizer.normalizer = Sequence(normalizers) 56 | else: 57 | tokenizer.normalizer = normalizers[0] 58 | 59 | tokenizer.pre_tokenizer = tokenizers.pre_tokenizers.WhitespaceSplit() 60 | 61 | if vocab_file is not None: 62 | bos_token_id = tokenizer.token_to_id(str(bos_token)) 63 | if bos_token_id is None: 64 | raise TypeError("bos_token not found in the vocabulary") 65 | eos_token_id = tokenizer.token_to_id(str(eos_token)) 66 | if eos_token_id is None: 67 | raise TypeError("eos_token not found in the vocabulary") 68 | 69 | # tokenizer.post_processor = tokenizers.processors.BertProcessing( 70 | # (str(bos_token), bos_token_id), (str(eos_token), eos_token_id) 71 | # ) 72 | 73 | parameters = { 74 | "model": "WordLevel", 75 | "unk_token": unk_token, 76 | "bos_token": bos_token, 77 | "eos_token": eos_token, 78 | "pad_token": pad_token, 79 | "mask_token": mask_token, 80 | "lowercase": lowercase, 81 | "unicode_normalizer": unicode_normalizer, 82 | } 83 | 84 | super().__init__(tokenizer, parameters) 85 | -------------------------------------------------------------------------------- /torchseq/utils/wandb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import wandb 4 | 5 | 6 | def wandb_init(config, run_id=None, path=None): 7 | if "WANDB_API_KEY" in os.environ and "WANDB_USERNAME" in os.environ and os.environ.get("WANDB_USERNAME", "") != "": 8 | # W&B hierarchy is: project > group > job_type > name > id 9 | wandb.init( 10 | project=config.tag, 11 | group=config.get("group", None), 12 | job_type=config.get("job_type", None), 13 | config=config.data, 14 | name=config.get("name", None), 15 | id=run_id, 16 | dir=path, 17 | ) 18 | elif os.environ.get("WANDB_MODE", None) == "disabled": 19 | wandb.init( 20 | project=config.tag, 21 | group=config.get("group", None), 22 | job_type=config.get("job_type", None), 23 | config=config.data, 24 | name=config.get("name", None), 25 | id=run_id, 26 | dir=path, 27 | mode="disabled", 28 | ) 29 | 30 | 31 | def wandb_log(data, step=None): 32 | if ( 33 | "WANDB_API_KEY" in os.environ and "WANDB_USERNAME" in os.environ and os.environ.get("WANDB_USERNAME", "") != "" 34 | ) or os.environ.get("WANDB_MODE", None) == "disabled": 35 | if step >= wandb.run.step: 36 | wandb.log(data, step) 37 | 38 | 39 | def wandb_summary(data): 40 | if ( 41 | "WANDB_API_KEY" in os.environ and "WANDB_USERNAME" in os.environ and os.environ.get("WANDB_USERNAME", "") != "" 42 | ) or os.environ.get("WANDB_MODE", None) == "disabled": 43 | 44 | def stringify_keys(data): 45 | if isinstance(data, dict): 46 | return {str(k): stringify_keys(v) for k, v in data.items()} 47 | else: 48 | return data 49 | 50 | for k, v in stringify_keys(data).items(): 51 | wandb.run.summary[k] = v 52 | --------------------------------------------------------------------------------