├── .flake8 ├── .github └── workflows │ └── torchseq.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── Makefile ├── README.md ├── codecov.yml ├── configs ├── mbart.json ├── paraphrasing_vae.json ├── patches │ ├── beamsearch_16x8.json │ ├── beamsearch_4x2.json │ ├── beamsearch_4x2_length1.json │ ├── beamsearch_8x4.json │ ├── beamsearch_8x4_length1.json │ ├── dbs.json │ ├── eval_bsz_2.json │ ├── eval_bsz_8.json │ ├── multi_sample.json │ ├── name_vae.json │ ├── newsqa.json │ ├── nq.json │ ├── nucleus05.json │ ├── nucleus075.json │ ├── nucleus09.json │ ├── rerank_backtranslate.json │ ├── rerank_combo.json │ ├── rerank_ngram.json │ ├── rerank_qa.json │ ├── squad.json │ ├── tag_ppeval.json │ ├── udeptoq_squad.json │ └── unfreeze_encdec.json ├── qg_bart.json ├── qg_bert.json └── qg_transformer.json ├── data └── pretrained-vocabs │ ├── ._bart-large-cnn-merges.txt │ ├── ._bart-large-cnn-vocab.json │ ├── ._bart-large-merges.txt │ ├── ._bart-large-vocab.json │ ├── ._bert-base-cased-vocab.txt │ ├── ._bert-base-uncased-vocab.txt │ ├── ._roberta-base-merges.txt │ ├── ._roberta-base-vocab.json │ ├── bart-large-cnn-merges.txt │ ├── bart-large-cnn-vocab.json │ ├── bart-large-merges.txt │ ├── bart-large-vocab.json │ ├── bert-base-cased-vocab.txt │ ├── bert-base-uncased-vocab.txt │ ├── deberta-v3-base-spm.model │ ├── roberta-base-merges.txt │ └── roberta-base-vocab.json ├── docs ├── Makefile ├── _source │ ├── modules.rst │ ├── quickstart.rst │ ├── torchseq.agents.rst │ ├── torchseq.datasets.rst │ ├── torchseq.metric_hooks.rst │ ├── torchseq.models.rerankers.rst │ ├── torchseq.models.rst │ ├── torchseq.models.samplers.rst │ ├── torchseq.pretrained.rst │ ├── torchseq.rst │ └── torchseq.utils.rst ├── conf.py ├── index.rst └── make.bat ├── examples ├── Paraphrasing.ipynb ├── QG_BART_Eval.ipynb └── mBART.ipynb ├── mypy.ini ├── requirements.txt ├── scripts ├── download_data.sh ├── download_models.py ├── generate_3way_wikianswers.py ├── kaggle_prepro.py ├── paranmt_prepro.py ├── submit_job.sh └── train_vq_code_predictor.py ├── setup.py ├── tests ├── __init__.py ├── test_eval.py ├── test_misc.py ├── test_pretrained.py ├── test_regression.py ├── test_tokenisation.py └── utils.py └── torchseq ├── __init__.py ├── agents ├── __init__.py ├── aq_agent.py ├── base.py ├── lm_agent.py ├── model_agent.py ├── retrieval_agent.py └── seq2seq_agent.py ├── args.py ├── datasets ├── __init__.py ├── builder.py ├── json_dataset.py ├── json_loader.py ├── lm_dataset.py ├── lm_loader.py ├── loaders.py ├── paraphrase_dataset.py ├── paraphrase_loader.py ├── paraphrase_pair.py ├── qa_dataset.py ├── qa_loader.py └── qa_triple.py ├── demo ├── para_app.py ├── qg_app.py └── static │ ├── para_demo.htm │ ├── qgen.css │ ├── qgen.js │ ├── qq_demo.htm │ ├── separator.js │ └── spinner.css ├── eval ├── __init__.py ├── args.py ├── cli.py └── recipes │ ├── __init__.py │ ├── base.py │ └── opagg │ ├── extractive_summaries.py │ ├── hiro_post.py │ └── hiro_pre.py ├── main.py ├── metric_hooks ├── __init__.py ├── base.py ├── default.py ├── hrq_agg.py ├── opsumm_cluster_aug.py ├── prevalence_metric.py ├── prevalence_metric_old.py ├── qg_metric.py ├── rouge.py ├── semparse.py ├── sep_ae.py └── textual.py ├── models ├── __init__.py ├── activations.py ├── aq_transformer.py ├── bottleneck.py ├── bottleneck_autoencoder.py ├── contrastive_hrq_loss.py ├── contrastive_loss.py ├── contrastive_triplet_loss.py ├── ctxtans_encoder.py ├── decoder.py ├── encoder.py ├── exemplar_guided_autoencoder.py ├── hrq_vae.py ├── hyperbolic.py ├── hyperbolic_utils.py ├── lm_transformer.py ├── lr_schedule.py ├── modular_bottleneck.py ├── multihead_output.py ├── parallel_wrapper.py ├── pooling.py ├── positional_embeddings.py ├── pythae_vq.py ├── ranger.py ├── rerankers │ ├── __init__.py │ ├── backtranslate_reranker.py │ ├── combo.py │ ├── ngram_reranker.py │ ├── qa_reranker.py │ └── topk.py ├── retrieval_model.py ├── samplers │ ├── __init__.py │ ├── beam_search.py │ ├── diverse_beam.py │ ├── greedy.py │ ├── parallel_nucleus.py │ └── teacher_force.py ├── suppression_loss.py ├── transformer.py ├── vmf.py ├── vq_code_predictor.py ├── vq_vae.py └── vq_vae_legacy.py ├── pretrained ├── __init__.py ├── lexical_paraphraser_bert.py ├── lm.py ├── nli.py └── qa.py └── utils ├── __init__.py ├── ari.py ├── bart_score.py ├── cache.py ├── config.py ├── config_migration.py ├── easy_generate.py ├── fleiss.py ├── functions.py ├── logging.py ├── loss_dropper.py ├── mckenzie.py ├── metrics.py ├── model_loader.py ├── optimizer_group.py ├── perplexity.py ├── rouge.py ├── sari.py ├── seed.py ├── singleton.py ├── timer.py ├── tokenizer.py ├── tokenizer_wordlevel.py └── wandb.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203,E266,E501,W503,F403,F401,C901,E402 3 | max-line-length = 127 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 -------------------------------------------------------------------------------- /.github/workflows/torchseq.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: TorchSeq 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: 3.9 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -r requirements.txt 27 | python -m nltk.downloader wordnet 28 | python -m nltk.downloader omw-1.4 29 | pip install . 30 | - name: Lint with flake8 31 | run: | 32 | pip install flake8 black 33 | # stop the build if there are Python syntax errors or undefined names 34 | # flake8 ./torchseq --count --select=E9,F63,F7,F82 --show-source --statistics 35 | make syntax 36 | make check 37 | - name: Test with pytest 38 | run: | 39 | pip install pytest pytest-cov codecov 40 | make test 41 | - name: Check types 42 | run: | 43 | pip install mypy 44 | mypy ./torchseq --install-types --non-interactive 45 | # - name: Upload coverage 46 | # run: | 47 | # make coverage 48 | 49 | 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | .mypy_cache 4 | 5 | /runs 6 | /runs_bak 7 | /models 8 | /data 9 | /data_bak 10 | /configs 11 | /scripts 12 | /baselines 13 | /plots 14 | /wandb 15 | /cache 16 | 17 | .ipynb_checkpoints/ 18 | *.DS_Store 19 | *.ipynb 20 | 21 | *.sh 22 | 23 | /*env/ 24 | /torchseq/external 25 | .vscode 26 | /notebooks 27 | 28 | .coverage 29 | coverage.xml 30 | 31 | *.egg-info/ 32 | 33 | # Sphinx documentation 34 | docs/_build/ 35 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | # You can also specify other tool versions: 14 | # nodejs: "19" 15 | # rust: "1.64" 16 | # golang: "1.19" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | # python: 28 | # install: 29 | # - requirements: requirements.txt -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: docs test coverage syntax types test check 2 | 3 | 4 | # Check that source code meets quality standards 5 | check: formatcheck syntax types test 6 | 7 | # Format source code automatically 8 | format: 9 | black --line-length 119 --target-version py39 tests torchseq 10 | 11 | formatcheck: 12 | black --check --line-length 119 --target-version py39 tests torchseq 13 | 14 | # Check syntax 15 | syntax: 16 | flake8 ./torchseq --count --select=E9,F63,F7,F82 --show-source --statistics 17 | 18 | types: 19 | mypy ./torchseq --install-types --non-interactive 20 | 21 | # Run tests for the library 22 | test: 23 | WANDB_USERNAME='' pytest --cov=./torchseq ./tests 24 | 25 | # Run tests for the library 26 | testall: 27 | WANDB_USERNAME='' RUN_SLOW=1 pytest --cov=./torchseq ./tests 28 | 29 | # Send coverage report to codecov 30 | coverage: 31 | CODECOV_TOKEN="28535f9f-825a-435e-bb4e-e1de2aa63da3" codecov 32 | rm .coverage 33 | rm coverage.xml 34 | 35 | # Build docs 36 | docs: 37 | sphinx-apidoc -f -o ./docs/_source ./torchseq 38 | (cd ./docs && make html) -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # basic 6 | target: auto 7 | threshold: 5% 8 | paths: 9 | - "torchseq" 10 | patch: 11 | default: 12 | target: 5% 13 | ignore: 14 | - ".aqenv/*" 15 | - "torchseq/args.py" 16 | - "torchseq/main.py" 17 | - "torchseq/metric_hooks/*" 18 | - "torchseq/utils/config_migration.py" -------------------------------------------------------------------------------- /configs/mbart.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mbart_ae_testcase", 3 | "tag": "test_case", 4 | "group": "example", 5 | "task": "seq2seq", 6 | "training": { 7 | "dataset": "json", 8 | "use_preprocessed_data": false, 9 | "log_interval": 100, 10 | "optimizer": { 11 | "type": "adam", 12 | "lr": 0.002, 13 | "beta1": 0.9, 14 | "beta2": 0.98, 15 | "lr_schedule": true, 16 | "lr_warmup_steps": 1000 17 | }, 18 | "batch_size": 10, 19 | "optim_batch_size": 50, 20 | "clip_gradient": 1, 21 | "num_epochs": 200, 22 | "warmup_epochs": 0, 23 | "suppression_loss_weight": 0.0, 24 | "label_smoothing": 0.1, 25 | "early_stopping_lag": 0, 26 | "early_stopping_patience": 10, 27 | "reset_metrics": true, 28 | "token_dropout": 0.1, 29 | "kl_warmup_steps": 0, 30 | "epoch_steps": 0 31 | }, 32 | "json_dataset": { 33 | "path": "semparse/atis", 34 | "filename": "{split}", 35 | "field_map": [ 36 | { 37 | "type": "copy", 38 | "from": "target", 39 | "to": "target" 40 | }, 41 | { 42 | "type": "copy", 43 | "from": "source", 44 | "to": "source" 45 | } 46 | ] 47 | }, 48 | "eval": { 49 | "eval_batch_size": 8, 50 | "sampler": "beam", 51 | "max_out_len": 400, 52 | "metrics": { 53 | "semparse": { 54 | "run_codepred": false 55 | } 56 | }, 57 | "prepend_langcode": true 58 | }, 59 | "beam_search": { 60 | "beam_width": 5, 61 | "beam_expansion": 5, 62 | "length_alpha": 1.0, 63 | "prevent_repetition": false 64 | }, 65 | "prepro": { 66 | "input_vocab_size": 250054, 67 | "output_vocab_size": 250054, 68 | "sent_window": 0, 69 | "tok_window": 400, 70 | "include_lang_codes": true, 71 | "drop_target_lang_codes": true, 72 | "input_tokenizer": "facebook/mbart-large-50-many-to-many-mmt", 73 | "output_tokenizer": "facebook/mbart-large-50-many-to-many-mmt" 74 | }, 75 | "dropout": 0.1, 76 | "input_raw_embedding_dim": 1024, 77 | "output_raw_embedding_dim": 1024, 78 | "encoder": { 79 | "num_heads": 16, 80 | "dim_feedforward": 4096, 81 | "activation": "relu", 82 | "pretrained_encoder": "facebook/mbart-large-50-many-to-many-mmt", 83 | "embedding_dim": 1024, 84 | "freeze_pretrained": true, 85 | "init_embeds_from_tokenizer": false, 86 | "num_layers": 0 87 | }, 88 | "decoder": { 89 | "num_heads": 16, 90 | "dim_feedforward": 4096, 91 | "activation": "relu", 92 | "pretrained_decoder": "facebook/mbart-large-50-many-to-many-mmt", 93 | "embedding_dim": 1024, 94 | "freeze_pretrained": true, 95 | "init_embeds_from_tokenizer": false, 96 | "num_layers": 0 97 | }, 98 | "freeze_embeddings": false, 99 | "freeze_projection": false, 100 | "directional_masks": false, 101 | "bottleneck": { 102 | "embedding_dim": 1024, 103 | "modular": true, 104 | "num_heads": 8, 105 | "modules": [ 106 | { 107 | "range": [ 108 | 0, 109 | 8 110 | ], 111 | "type": "ae", 112 | "pooling": false 113 | } 114 | ] 115 | } 116 | } -------------------------------------------------------------------------------- /configs/paraphrasing_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "paraphrasing_vae", 3 | "tag": "examples", 4 | "task": "para", 5 | "env": { 6 | "cuda": true, 7 | "data_path": "./data", 8 | "gpu_device": 0 9 | 10 | }, 11 | "json_dataset": { 12 | "path": "wikianswers-pp", 13 | "field_map": [ 14 | { 15 | "type": "sample", 16 | "from": "qs", 17 | "to": "s1" 18 | }, 19 | { 20 | "type": "sample", 21 | "from": "qs", 22 | "to": "s2" 23 | } 24 | ] 25 | }, 26 | "training": { 27 | "dataset": "json", 28 | "use_preprocessed_data": false, 29 | "log_interval": 100, 30 | "lr_schedule": true, 31 | "lr": 0.01, 32 | "beta1": 0.9, 33 | "beta2": 0.98, 34 | "batch_size": 16, 35 | "optim_batch_size": 64, 36 | "clip_gradient": 5, 37 | "num_epochs": 120, 38 | "opt": "adam", 39 | "warmup_epochs": 30, 40 | "suppression_loss_weight": 0.0, 41 | "label_smoothing": 0.0, 42 | "early_stopping_lag": 0, 43 | "reset_metrics": true, 44 | "token_dropout": 0.2, 45 | "kl_warmup_steps": 10000 46 | }, 47 | "eval": { 48 | "eval_batch_size": 16, 49 | "sampler": "beam", 50 | "max_out_len": 50 51 | }, 52 | "beam_search": { 53 | "beam_width": 4, 54 | "beam_expansion": 2, 55 | "length_alpha": 1.0 56 | }, 57 | "prepro": { 58 | "vocab_size": 30522, 59 | "sent_window": 0, 60 | "tok_window": 40, 61 | "concat_ctxt_ans": false, 62 | "bio_tagging": true, 63 | "tokenizer": "bert-base-uncased" 64 | }, 65 | 66 | "dropout": 0.1, 67 | 68 | "raw_embedding_dim": 768, 69 | "embedding_dim": 768, 70 | "onehot_bio" : false, 71 | "bio_embedding_dim": 8, 72 | "freeze_embeddings": true, 73 | "freeze_projection": true, 74 | "directional_masks": true, 75 | 76 | "encdec": { 77 | "num_encoder_layers": 5, 78 | "num_decoder_layers": 5, 79 | "num_heads": 8, 80 | "dim_feedforward": 2048, 81 | "activation": "relu", 82 | "bert_encoder": false 83 | }, 84 | "encoder": { 85 | "embedding_dim": 768 86 | }, 87 | "decoder": { 88 | "embedding_dim": 768 89 | }, 90 | "bottleneck": { 91 | "embedding_dim": 768, 92 | "variational": true, 93 | "num_similar_heads": 4 94 | } 95 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_16x8.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam16x8", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "diverse_beam": { 8 | "beam_width": 16, 9 | "beam_expansion": 8, 10 | "length_alpha": 2.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_4x2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam4x2", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 4, 9 | "beam_expansion": 2, 10 | "length_alpha": 2.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_4x2_length1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam4x2L1", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 4, 9 | "beam_expansion": 2, 10 | "length_alpha": 1.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_8x4.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam8x4", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 8, 9 | "beam_expansion": 4, 10 | "length_alpha": 2.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/beamsearch_8x4_length1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_beam8x4L1", 3 | "eval": { 4 | "sampler": "beam", 5 | "sample_outputs": true 6 | }, 7 | "beam_search": { 8 | "beam_width": 8, 9 | "beam_expansion": 4, 10 | "length_alpha": 1.0 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/dbs.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_diversebeam16x8_8g", 3 | "eval": { 4 | "sampler": "diverse_beam" 5 | }, 6 | "diverse_beam": { 7 | "beam_width": 16, 8 | "beam_expansion": 8, 9 | "length_alpha": 2.0, 10 | "num_groups": 8, 11 | "penalty_weight": 0.5 12 | } 13 | } -------------------------------------------------------------------------------- /configs/patches/eval_bsz_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "training": { 3 | "batch_size": 2 4 | }, 5 | "eval": { 6 | "eval_batch_size": 2 7 | } 8 | } -------------------------------------------------------------------------------- /configs/patches/eval_bsz_8.json: -------------------------------------------------------------------------------- 1 | { 2 | "training": { 3 | "batch_size": 8 4 | }, 5 | "eval": { 6 | "eval_batch_size": 8 7 | } 8 | } -------------------------------------------------------------------------------- /configs/patches/multi_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_top4", 3 | "eval": { 4 | "topk": 4 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/name_vae.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "vae_%" 3 | } -------------------------------------------------------------------------------- /configs/patches/newsqa.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_newsqa", 3 | "training": { 4 | "dataset": "newsqa" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/nq.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nq", 3 | "training": { 4 | "dataset": "naturalquestions" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/nucleus05.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nucleus0.5", 3 | "eval": { 4 | "sampler": "nucleus" 5 | }, 6 | "nucleus_sampling": { 7 | "prevent_repetition": true, 8 | "cutoff": 0.5, 9 | "beam_width": 4, 10 | "length_alpha": 0.01 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/nucleus075.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nucleus0.75", 3 | "eval": { 4 | "sampler": "nucleus" 5 | }, 6 | "nucleus_sampling": { 7 | "prevent_repetition": true, 8 | "cutoff": 0.75, 9 | "beam_width": 4, 10 | "length_alpha": 0.01 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/nucleus09.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_nucleus0.9", 3 | "eval": { 4 | "sampler": "nucleus" 5 | }, 6 | "nucleus_sampling": { 7 | "prevent_repetition": true, 8 | "cutoff": 0.9, 9 | "beam_width": 4, 10 | "length_alpha": 0.01 11 | } 12 | } -------------------------------------------------------------------------------- /configs/patches/rerank_backtranslate.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankbacktrans", 3 | "reranker": { 4 | "strategy": "backtranslate" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/rerank_combo.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankcombo", 3 | "reranker": { 4 | "strategy": "combo" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/rerank_ngram.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankngram", 3 | "reranker": { 4 | "strategy": "ngram" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/rerank_qa.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_rerankqa", 3 | "reranker": { 4 | "strategy": "qa" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_squad", 3 | "training": { 4 | "dataset": "squad" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/tag_ppeval.json: -------------------------------------------------------------------------------- 1 | { 2 | "tag": "ppeval" 3 | } -------------------------------------------------------------------------------- /configs/patches/udeptoq_squad.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_squadudep", 3 | "training": { 4 | "dataset": "models/squad-udep" 5 | } 6 | } -------------------------------------------------------------------------------- /configs/patches/unfreeze_encdec.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "%_unfreeze_lr1e-6", 3 | "training": { 4 | "batch_size": 2, 5 | "lr": 1e-6 6 | }, 7 | "encdec": { 8 | "freeze_encoder": false, 9 | "freeze_decoder": false 10 | } 11 | } -------------------------------------------------------------------------------- /configs/qg_bart.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "qg_bart", 3 | "tag": "examples", 4 | "task": "aq", 5 | "model": "pretrained_adapter", 6 | "env": { 7 | "cuda": true, 8 | "data_path": "./data", 9 | "gpu_device": 0 10 | 11 | }, 12 | "training": { 13 | "dataset": "squad", 14 | "use_preprocessed_data": false, 15 | "log_interval": 100, 16 | "lr_schedule": false, 17 | "lr": 5e-6, 18 | "beta1": 0.9, 19 | "beta2": 0.98, 20 | "batch_size": 4, 21 | "optim_batch_size": 64, 22 | "clip_gradient": 0.1, 23 | "num_epochs": 30, 24 | "opt": "adam", 25 | "warmup_epochs": 10, 26 | "suppression_loss_weight": 0.01, 27 | "label_smoothing": 0.0, 28 | "early_stopping_lag": 0, 29 | "reset_metrics": true, 30 | "token_dropout": 0.0 31 | }, 32 | "eval": { 33 | "eval_batch_size": 4, 34 | "sampler": "beam", 35 | "prepend_eos": true, 36 | "sample_outputs": true 37 | }, 38 | "beam_search": { 39 | "beam_width": 8, 40 | "beam_expansion": 4, 41 | "length_alpha": 2.0 42 | }, 43 | "prepro": { 44 | "vocab_size": 50265, 45 | "sent_window": 1, 46 | "tok_window": 300, 47 | "concat_ctxt_ans": true, 48 | "roberta_style_encoding": true, 49 | "bio_tagging": true, 50 | "tokenizer": "bart-large" 51 | }, 52 | 53 | "dropout": 0.1, 54 | 55 | "raw_embedding_dim": 1024, 56 | "embedding_dim": 1024, 57 | "onehot_bio" : false, 58 | "bio_embedding_dim": 8, 59 | "freeze_embeddings": true, 60 | "freeze_projection": true, 61 | "directional_masks": true, 62 | 63 | "encoder_outputs": { 64 | "c_raw": true, 65 | "a_raw": false, 66 | "c_enc": true, 67 | "c_enc_pool": false, 68 | "a_enc": false, 69 | "a_enc_pool": false, 70 | "c_enc_anspool": false, 71 | "c_ans_labels": false 72 | }, 73 | "encdec": { 74 | "num_encoder_layers": 3, 75 | "num_decoder_layers": 3, 76 | "num_heads": 8, 77 | "dim_feedforward": 2048, 78 | "activation": "relu", 79 | "bert_encoder": false, 80 | "bert_model": "facebook/bart-large", 81 | "bert_warmup_epochs": 50, 82 | "freeze_encoder": false, 83 | "freeze_decoder": true 84 | }, 85 | "encoder": { 86 | "embedding_dim": 1024 87 | }, 88 | "decoder": { 89 | "embedding_dim": 1024 90 | } 91 | } -------------------------------------------------------------------------------- /configs/qg_bert.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "qg_bert", 3 | "tag": "examples", 4 | "task": "aq", 5 | "env": { 6 | "cuda": true, 7 | "data_path": "./data", 8 | "gpu_device": 0 9 | }, 10 | "training": { 11 | "dataset": "squad", 12 | "use_preprocessed_data": false, 13 | "log_interval": 100, 14 | "lr_schedule": true, 15 | "lr": 0.003, 16 | "beta1": 0.9, 17 | "beta2": 0.98, 18 | "batch_size": 8, 19 | "optim_batch_size": 64, 20 | "clip_gradient": 5, 21 | "num_epochs": 50, 22 | "opt": "adam", 23 | "warmup_epochs": 5, 24 | "suppression_loss_weight": 0.1, 25 | "label_smoothing": 0.0, 26 | "early_stopping_lag": 1, 27 | "token_dropout": 0.2 28 | }, 29 | "eval": { 30 | "eval_batch_size": 6, 31 | "sampler": "beam" 32 | }, 33 | "beam_search": { 34 | "beam_width": 8, 35 | "beam_expansion": 8, 36 | "length_alpha": 2.0 37 | }, 38 | "prepro": { 39 | "vocab_size": 30522, 40 | "sent_window": 0, 41 | "tok_window": 300, 42 | "concat_ctxt_ans": false, 43 | "bio_tagging": true, 44 | "tokenizer": "bert-base-uncased" 45 | }, 46 | "dropout": 0.1, 47 | "raw_embedding_dim": 768, 48 | "embedding_dim": 768, 49 | "onehot_bio": false, 50 | "bio_embedding_dim": 8, 51 | "freeze_embeddings": true, 52 | "freeze_projection": true, 53 | "directional_masks": true, 54 | "encoder_outputs": { 55 | "c_raw": true, 56 | "a_raw": false, 57 | "c_enc": true, 58 | "c_enc_pool": false, 59 | "a_enc": false, 60 | "a_enc_pool": false, 61 | "c_enc_anspool": false, 62 | "c_ans_labels": false 63 | }, 64 | "encdec": { 65 | "num_encoder_layers": 5, 66 | "num_decoder_layers": 5, 67 | "num_heads": 2, 68 | "dim_feedforward": 2048, 69 | "activation": "relu", 70 | "bert_encoder": true, 71 | "bert_model": "bert-base-uncased", 72 | "bert_warmup_epochs": 20, 73 | "bert_finetune": false 74 | }, 75 | "encoder": { 76 | "embedding_dim": 768 77 | }, 78 | "decoder": { 79 | "embedding_dim": 768 80 | } 81 | } -------------------------------------------------------------------------------- /configs/qg_transformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "qg_transformer", 3 | "tag": "examples", 4 | "task": "aq", 5 | "env": { 6 | "cuda": true, 7 | "data_path": "./data", 8 | "gpu_device": 0 9 | }, 10 | "training": { 11 | "dataset": "squad", 12 | "use_preprocessed_data": false, 13 | "log_interval": 100, 14 | "lr_schedule": true, 15 | "lr": 0.003, 16 | "beta1": 0.9, 17 | "beta2": 0.98, 18 | "batch_size": 16, 19 | "optim_batch_size": 64, 20 | "clip_gradient": 5, 21 | "num_epochs": 50, 22 | "opt": "adam", 23 | "warmup_epochs": 5, 24 | "suppression_loss_weight": 0.1, 25 | "label_smoothing": 0.0, 26 | "early_stopping_lag": 1, 27 | "token_dropout": 0.2 28 | }, 29 | "eval": { 30 | "eval_batch_size": 8, 31 | "sampler": "beam" 32 | }, 33 | "beam_search": { 34 | "beam_width": 8, 35 | "beam_expansion": 4, 36 | "length_alpha": 2.0 37 | }, 38 | "prepro": { 39 | "vocab_size": 30522, 40 | "sent_window": 0, 41 | "tok_window": 300, 42 | "concat_ctxt_ans": false, 43 | "bio_tagging": true, 44 | "tokenizer": "bert-base-uncased" 45 | }, 46 | "dropout": 0.1, 47 | "raw_embedding_dim": 768, 48 | "embedding_dim": 768, 49 | "onehot_bio": false, 50 | "bio_embedding_dim": 8, 51 | "freeze_embeddings": true, 52 | "freeze_projection": true, 53 | "directional_masks": true, 54 | "encoder_outputs": { 55 | "c_raw": true, 56 | "a_raw": false, 57 | "c_enc": true, 58 | "c_enc_pool": false, 59 | "a_enc": false, 60 | "a_enc_pool": false, 61 | "c_enc_anspool": false, 62 | "c_ans_labels": false 63 | }, 64 | "encdec": { 65 | "num_encoder_layers": 5, 66 | "num_decoder_layers": 5, 67 | "num_heads": 2, 68 | "dim_feedforward": 2048, 69 | "activation": "relu", 70 | "bert_encoder": false, 71 | "bert_warmup_epochs": 20, 72 | "bert_finetune": false 73 | }, 74 | "encoder": { 75 | "embedding_dim": 768 76 | }, 77 | "decoder": { 78 | "embedding_dim": 768 79 | } 80 | } -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-cnn-merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-cnn-merges.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-cnn-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-cnn-vocab.json -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-merges.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bart-large-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bart-large-vocab.json -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bert-base-cased-vocab.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bert-base-cased-vocab.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._bert-base-uncased-vocab.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._bert-base-uncased-vocab.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._roberta-base-merges.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._roberta-base-merges.txt -------------------------------------------------------------------------------- /data/pretrained-vocabs/._roberta-base-vocab.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/._roberta-base-vocab.json -------------------------------------------------------------------------------- /data/pretrained-vocabs/deberta-v3-base-spm.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/data/pretrained-vocabs/deberta-v3-base-spm.model -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_source/modules.rst: -------------------------------------------------------------------------------- 1 | torchseq 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | torchseq 8 | -------------------------------------------------------------------------------- /docs/_source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | 5 | Installation 6 | ------------ 7 | 8 | From a fresh ``venv``, run: 9 | ``` 10 | pip install -r requirements.txt 11 | pip install -e . 12 | 13 | python3 -m nltk.downloader punkt 14 | python3 ./scripts/download_models.py 15 | ``` 16 | 17 | Done! 18 | 19 | Overview 20 | -------- 21 | 22 | TorchSeq is a framework for training and evaluating Seq2Seq models, built in PyTorch. 23 | 24 | 25 | 26 | 27 | CLI Usage 28 | --------- 29 | 30 | TorchSeq installs a CLI - to load a model and evaluate it on the test set, run ``torchseq --test --load /path/to/model``. 31 | 32 | The CLI options are: 33 | 34 | --train Run training 35 | --validate Run validation (ie, eval on the dev set) 36 | --validate_train Run eval on the training set 37 | --test Run eval on the test set 38 | --silent Disable verbose output 39 | --reload_after_train Use in conjunction with one of the eval commands to reload the best checkpoint once training completes, and evaluate using that 40 | --load_chkpt /path/to/checkpoint.pt Path to checkpoint file 41 | --data_path /path/to/data/ Path to folder containing datasets 42 | --output_path /path/to/output/ Path to dump output 43 | --config,-c /path/to/config.json Path to config file 44 | --patch,-p /path/to/patch.json Path to 'patch' file(s) 45 | --load /path/to/model/ Path to a full model (checkpoint + config) 46 | --cpu Run on CPU 47 | --debug Enable some extra debugging 48 | 49 | 50 | Scripting 51 | --------- 52 | 53 | You can also invoke TorchSeq from within a script, like this: 54 | 55 | ``` 56 | from torchseq.agents.seq2seq_agent import Seq2SeqAgent 57 | from torchseq.datasets.json_loader import JsonDataLoader 58 | 59 | from torchseq.utils.config import Config 60 | 61 | with open(path_to_model + "/config.json") as f: 62 | cfg_dict = json.load(f) 63 | 64 | config = Config(cfg_dict) 65 | 66 | data_loader = JsonDataLoader(config) 67 | 68 | checkpoint_path = path_to_model + "/model/checkpoint.pt" 69 | 70 | instance = Seq2SeqAgent(config=config, run_id=None, output_path="./runs/demo/", silent=True, verbose=False) 71 | 72 | instance.load_checkpoint(checkpoint_path) 73 | instance.model.eval() 74 | 75 | loss, metrics, output, memory = instance.validate(data_loader, save_model=False) 76 | ``` 77 | 78 | In general, a :torchseq.agents.model_agent.ModelAgent: object is the main controller - once you've created one from a :torchseq.utils.config.Config:, you can train it with :torchseq.agents.base.train: -------------------------------------------------------------------------------- /docs/_source/torchseq.agents.rst: -------------------------------------------------------------------------------- 1 | torchseq.agents package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.agents.aq\_agent module 8 | -------------------------------- 9 | 10 | .. automodule:: torchseq.agents.aq_agent 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.agents.base module 16 | --------------------------- 17 | 18 | .. automodule:: torchseq.agents.base 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.agents.exemplar\_agent module 24 | -------------------------------------- 25 | 26 | .. automodule:: torchseq.agents.exemplar_agent 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.agents.lm\_agent module 32 | -------------------------------- 33 | 34 | .. automodule:: torchseq.agents.lm_agent 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.agents.meta\_learning\_agent module 40 | -------------------------------------------- 41 | 42 | .. automodule:: torchseq.agents.meta_learning_agent 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.agents.model\_agent module 48 | ----------------------------------- 49 | 50 | .. automodule:: torchseq.agents.model_agent 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.agents.seq2seq\_agent module 56 | ------------------------------------- 57 | 58 | .. automodule:: torchseq.agents.seq2seq_agent 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | Module contents 64 | --------------- 65 | 66 | .. automodule:: torchseq.agents 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | -------------------------------------------------------------------------------- /docs/_source/torchseq.datasets.rst: -------------------------------------------------------------------------------- 1 | torchseq.datasets package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.datasets.builder module 8 | -------------------------------- 9 | 10 | .. automodule:: torchseq.datasets.builder 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.datasets.json\_dataset module 16 | -------------------------------------- 17 | 18 | .. automodule:: torchseq.datasets.json_dataset 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.datasets.json\_loader module 24 | ------------------------------------- 25 | 26 | .. automodule:: torchseq.datasets.json_loader 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.datasets.lm\_dataset module 32 | ------------------------------------ 33 | 34 | .. automodule:: torchseq.datasets.lm_dataset 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.datasets.lm\_loader module 40 | ----------------------------------- 41 | 42 | .. automodule:: torchseq.datasets.lm_loader 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.datasets.loaders module 48 | -------------------------------- 49 | 50 | .. automodule:: torchseq.datasets.loaders 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.datasets.paraphrase\_dataset module 56 | -------------------------------------------- 57 | 58 | .. automodule:: torchseq.datasets.paraphrase_dataset 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchseq.datasets.paraphrase\_loader module 64 | ------------------------------------------- 65 | 66 | .. automodule:: torchseq.datasets.paraphrase_loader 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchseq.datasets.paraphrase\_pair module 72 | ----------------------------------------- 73 | 74 | .. automodule:: torchseq.datasets.paraphrase_pair 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | torchseq.datasets.qa\_dataset module 80 | ------------------------------------ 81 | 82 | .. automodule:: torchseq.datasets.qa_dataset 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | torchseq.datasets.qa\_loader module 88 | ----------------------------------- 89 | 90 | .. automodule:: torchseq.datasets.qa_loader 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | torchseq.datasets.qa\_triple module 96 | ----------------------------------- 97 | 98 | .. automodule:: torchseq.datasets.qa_triple 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | Module contents 104 | --------------- 105 | 106 | .. automodule:: torchseq.datasets 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | -------------------------------------------------------------------------------- /docs/_source/torchseq.metric_hooks.rst: -------------------------------------------------------------------------------- 1 | torchseq.metric\_hooks package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.metric\_hooks.base module 8 | ---------------------------------- 9 | 10 | .. automodule:: torchseq.metric_hooks.base 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.metric\_hooks.default module 16 | ------------------------------------- 17 | 18 | .. automodule:: torchseq.metric_hooks.default 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.metric\_hooks.hrq\_agg module 24 | -------------------------------------- 25 | 26 | .. automodule:: torchseq.metric_hooks.hrq_agg 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.metric\_hooks.qg\_metric module 32 | ---------------------------------------- 33 | 34 | .. automodule:: torchseq.metric_hooks.qg_metric 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.metric\_hooks.rouge module 40 | ----------------------------------- 41 | 42 | .. automodule:: torchseq.metric_hooks.rouge 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.metric\_hooks.semparse module 48 | -------------------------------------- 49 | 50 | .. automodule:: torchseq.metric_hooks.semparse 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.metric\_hooks.sep\_ae module 56 | ------------------------------------- 57 | 58 | .. automodule:: torchseq.metric_hooks.sep_ae 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchseq.metric\_hooks.textual module 64 | ------------------------------------- 65 | 66 | .. automodule:: torchseq.metric_hooks.textual 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | Module contents 72 | --------------- 73 | 74 | .. automodule:: torchseq.metric_hooks 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | -------------------------------------------------------------------------------- /docs/_source/torchseq.models.rerankers.rst: -------------------------------------------------------------------------------- 1 | torchseq.models.rerankers package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.models.rerankers.backtranslate\_reranker module 8 | -------------------------------------------------------- 9 | 10 | .. automodule:: torchseq.models.rerankers.backtranslate_reranker 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.models.rerankers.combo module 16 | -------------------------------------- 17 | 18 | .. automodule:: torchseq.models.rerankers.combo 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.models.rerankers.ngram\_reranker module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: torchseq.models.rerankers.ngram_reranker 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.models.rerankers.qa\_reranker module 32 | --------------------------------------------- 33 | 34 | .. automodule:: torchseq.models.rerankers.qa_reranker 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.models.rerankers.topk module 40 | ------------------------------------- 41 | 42 | .. automodule:: torchseq.models.rerankers.topk 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: torchseq.models.rerankers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/_source/torchseq.models.samplers.rst: -------------------------------------------------------------------------------- 1 | torchseq.models.samplers package 2 | ================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.models.samplers.beam\_search module 8 | -------------------------------------------- 9 | 10 | .. automodule:: torchseq.models.samplers.beam_search 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.models.samplers.diverse\_beam module 16 | --------------------------------------------- 17 | 18 | .. automodule:: torchseq.models.samplers.diverse_beam 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.models.samplers.greedy module 24 | -------------------------------------- 25 | 26 | .. automodule:: torchseq.models.samplers.greedy 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.models.samplers.parallel\_nucleus module 32 | ------------------------------------------------- 33 | 34 | .. automodule:: torchseq.models.samplers.parallel_nucleus 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.models.samplers.teacher\_force module 40 | ---------------------------------------------- 41 | 42 | .. automodule:: torchseq.models.samplers.teacher_force 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: torchseq.models.samplers 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/_source/torchseq.pretrained.rst: -------------------------------------------------------------------------------- 1 | torchseq.pretrained package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.pretrained.lexical\_paraphraser\_bert module 8 | ----------------------------------------------------- 9 | 10 | .. automodule:: torchseq.pretrained.lexical_paraphraser_bert 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.pretrained.lm module 16 | ----------------------------- 17 | 18 | .. automodule:: torchseq.pretrained.lm 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.pretrained.qa module 24 | ----------------------------- 25 | 26 | .. automodule:: torchseq.pretrained.qa 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: torchseq.pretrained 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/_source/torchseq.rst: -------------------------------------------------------------------------------- 1 | torchseq package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | torchseq.agents 11 | torchseq.datasets 12 | torchseq.metric_hooks 13 | torchseq.models 14 | torchseq.pretrained 15 | torchseq.utils 16 | 17 | Submodules 18 | ---------- 19 | 20 | torchseq.args module 21 | -------------------- 22 | 23 | .. automodule:: torchseq.args 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | torchseq.main module 29 | -------------------- 30 | 31 | .. automodule:: torchseq.main 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | Module contents 37 | --------------- 38 | 39 | .. automodule:: torchseq 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | -------------------------------------------------------------------------------- /docs/_source/torchseq.utils.rst: -------------------------------------------------------------------------------- 1 | torchseq.utils package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchseq.utils.cache module 8 | --------------------------- 9 | 10 | .. automodule:: torchseq.utils.cache 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchseq.utils.config module 16 | ---------------------------- 17 | 18 | .. automodule:: torchseq.utils.config 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchseq.utils.config\_migration module 24 | --------------------------------------- 25 | 26 | .. automodule:: torchseq.utils.config_migration 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | torchseq.utils.easy\_generate module 32 | ------------------------------------ 33 | 34 | .. automodule:: torchseq.utils.easy_generate 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | torchseq.utils.fleiss module 40 | ---------------------------- 41 | 42 | .. automodule:: torchseq.utils.fleiss 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | torchseq.utils.functions module 48 | ------------------------------- 49 | 50 | .. automodule:: torchseq.utils.functions 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | torchseq.utils.logging module 56 | ----------------------------- 57 | 58 | .. automodule:: torchseq.utils.logging 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | torchseq.utils.loss\_dropper module 64 | ----------------------------------- 65 | 66 | .. automodule:: torchseq.utils.loss_dropper 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | torchseq.utils.mckenzie module 72 | ------------------------------ 73 | 74 | .. automodule:: torchseq.utils.mckenzie 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | torchseq.utils.metrics module 80 | ----------------------------- 81 | 82 | .. automodule:: torchseq.utils.metrics 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | 87 | torchseq.utils.model\_loader module 88 | ----------------------------------- 89 | 90 | .. automodule:: torchseq.utils.model_loader 91 | :members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | torchseq.utils.optimizer\_group module 96 | -------------------------------------- 97 | 98 | .. automodule:: torchseq.utils.optimizer_group 99 | :members: 100 | :undoc-members: 101 | :show-inheritance: 102 | 103 | torchseq.utils.perplexity module 104 | -------------------------------- 105 | 106 | .. automodule:: torchseq.utils.perplexity 107 | :members: 108 | :undoc-members: 109 | :show-inheritance: 110 | 111 | torchseq.utils.rouge module 112 | --------------------------- 113 | 114 | .. automodule:: torchseq.utils.rouge 115 | :members: 116 | :undoc-members: 117 | :show-inheritance: 118 | 119 | torchseq.utils.sari module 120 | -------------------------- 121 | 122 | .. automodule:: torchseq.utils.sari 123 | :members: 124 | :undoc-members: 125 | :show-inheritance: 126 | 127 | torchseq.utils.seed module 128 | -------------------------- 129 | 130 | .. automodule:: torchseq.utils.seed 131 | :members: 132 | :undoc-members: 133 | :show-inheritance: 134 | 135 | torchseq.utils.singleton module 136 | ------------------------------- 137 | 138 | .. automodule:: torchseq.utils.singleton 139 | :members: 140 | :undoc-members: 141 | :show-inheritance: 142 | 143 | torchseq.utils.tokenizer module 144 | ------------------------------- 145 | 146 | .. automodule:: torchseq.utils.tokenizer 147 | :members: 148 | :undoc-members: 149 | :show-inheritance: 150 | 151 | torchseq.utils.tokenizer\_wordlevel module 152 | ------------------------------------------ 153 | 154 | .. automodule:: torchseq.utils.tokenizer_wordlevel 155 | :members: 156 | :undoc-members: 157 | :show-inheritance: 158 | 159 | torchseq.utils.wandb module 160 | --------------------------- 161 | 162 | .. automodule:: torchseq.utils.wandb 163 | :members: 164 | :undoc-members: 165 | :show-inheritance: 166 | 167 | Module contents 168 | --------------- 169 | 170 | .. automodule:: torchseq.utils 171 | :members: 172 | :undoc-members: 173 | :show-inheritance: 174 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../torchseq')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'torchseq' 21 | copyright = '2020, Tom Hosking' 22 | author = 'Tom Hosking' 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | 'recommonmark', 32 | 'sphinx.ext.autodoc', 33 | 'sphinx.ext.napoleon' 34 | ] 35 | 36 | # Add any paths that contain templates here, relative to this directory. 37 | templates_path = ['_templates'] 38 | 39 | # List of patterns, relative to source directory, that match files and 40 | # directories to ignore when looking for source files. 41 | # This pattern also affects html_static_path and html_extra_path. 42 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 43 | 44 | 45 | # -- Options for HTML output ------------------------------------------------- 46 | 47 | # The theme to use for HTML and HTML Help pages. See the documentation for 48 | # a list of builtin themes. 49 | # 50 | html_theme = 'sphinx_rtd_theme' 51 | 52 | # Add any paths that contain custom static files (such as style sheets) here, 53 | # relative to this directory. They are copied after the builtin static files, 54 | # so a file named "default.css" will overwrite the builtin "default.css". 55 | html_static_path = ['_static'] -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. torchseq documentation master file, created by 2 | sphinx-quickstart on Tue Dec 1 11:20:55 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to torchseq's documentation! 7 | ==================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Getting Started: 12 | 13 | _source/quickstart 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | :caption: Contents: 18 | 19 | _source/modules 20 | 21 | 22 | .. Features 23 | .. -------- 24 | 25 | .. - Be awesome 26 | .. - Make things faster 27 | 28 | .. Installation 29 | .. ------------ 30 | 31 | .. Install $project by running: 32 | 33 | .. install project 34 | 35 | .. Contribute 36 | .. ---------- 37 | 38 | .. - Issue Tracker: github.com/$project/$project/issues 39 | .. - Source Code: github.com/$project/$project 40 | 41 | .. Support 42 | .. ------- 43 | 44 | .. If you are having issues, please let us know. 45 | .. We have a mailing list located at: project@google-groups.com 46 | 47 | .. License 48 | .. ------- 49 | 50 | .. The project is licensed under the BSD license. 51 | 52 | 53 | Indices and tables 54 | ================== 55 | 56 | * :ref:`genindex` 57 | * :ref:`modindex` 58 | * :ref:`search` 59 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /examples/Paraphrasing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import json\n", 10 | "from torchseq.agents.para_agent import ParaphraseAgent\n", 11 | "from torchseq.datasets.json_loader import JsonDataLoader\n", 12 | "from torchseq.utils.config import Config\n", 13 | "from torchseq.metric_hooks.textual import TextualMetricHook\n", 14 | "import torch\n", 15 | "\n", 16 | "model_path = '../models/examples/20210222_152157_paraphrasing_vae/'\n", 17 | "# model_path = '../models/examples/20210503_184659_paraphrasing_vqvae/'\n", 18 | "# model_path = '../models/examples/20210225_112226_paraphrasing_ae/'\n", 19 | "\n", 20 | "\n", 21 | "# Load the config\n", 22 | "with open(model_path + 'config.json') as f:\n", 23 | " cfg_dict = json.load(f)\n", 24 | "cfg_dict[\"env\"][\"data_path\"] = \"../data/\"\n", 25 | "\n", 26 | "\n", 27 | "config = Config(cfg_dict)\n", 28 | "\n", 29 | "# Load the model\n", 30 | "instance = ParaphraseAgent(config=config, run_id=None, output_path=\"./runs/examples/paraphrasing_eval\", silent=False, verbose=False, training_mode=False)\n", 31 | "instance.load_checkpoint(model_path + 'model/checkpoint.pt')\n", 32 | "instance.model.eval()\n", 33 | "\n", 34 | "# Create a dataset\n", 35 | "data_loader = JsonDataLoader(config)\n", 36 | "\n", 37 | "# Run inference on the test split\n", 38 | "# test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader.test_loader, metric_hooks=[TextualMetricHook(config, 's1', 's2')])\n", 39 | "\n", 40 | "# Done!\n", 41 | "# print(all_metrics['ibleu'])" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 5, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "Validating after 0 epochs: 100%|██████████| 1/1 [00:00<00:00, 6.07it/s]" 54 | ] 55 | }, 56 | { 57 | "name": "stdout", 58 | "output_type": "stream", 59 | "text": [ 60 | "['what is the oldest cat in the world?']\n" 61 | ] 62 | }, 63 | { 64 | "name": "stderr", 65 | "output_type": "stream", 66 | "text": [ 67 | "\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "# You can now run your model on your own dataset\n", 73 | "\n", 74 | "examples = [\n", 75 | " {'q': 'Who was the oldest cat in the world?'},\n", 76 | "]\n", 77 | "\n", 78 | "\n", 79 | "cfg_dict[\"json_dataset\"] = {\n", 80 | " \"path\": None,\n", 81 | " \"field_map\": [\n", 82 | " {\n", 83 | " \"type\": \"copy\",\n", 84 | " \"from\": \"q\",\n", 85 | " \"to\": \"s1\"\n", 86 | " },\n", 87 | " {\n", 88 | " \"type\": \"copy\",\n", 89 | " \"from\": \"q\",\n", 90 | " \"to\": \"s2\"\n", 91 | " }\n", 92 | " ]\n", 93 | "}\n", 94 | "\n", 95 | "config = Config(cfg_dict)\n", 96 | "\n", 97 | " \n", 98 | "data_loader_custom = JsonDataLoader(config, test_samples=examples)\n", 99 | "\n", 100 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader_custom.test_loader)\n", 101 | "\n", 102 | "print(pred_output)" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "Python 3", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.8.5" 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 4 134 | } 135 | -------------------------------------------------------------------------------- /examples/QG_BART_Eval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Validating after 0 epochs: 100%|██████████| 2970/2970 [58:25<00:00, 1.18s/it] \n" 13 | ] 14 | }, 15 | { 16 | "name": "stdout", 17 | "output_type": "stream", 18 | "text": [ 19 | "21.065812894587264\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "import json\n", 25 | "from torchseq.agents.aq_agent import AQAgent\n", 26 | "from torchseq.datasets.qa_loader import QADataLoader\n", 27 | "from torchseq.utils.config import Config\n", 28 | "from torchseq.metric_hooks.textual import TextualMetricHook\n", 29 | "import torch\n", 30 | "\n", 31 | "model_path = '../models/examples/20210223_191015_qg_bart/'\n", 32 | "\n", 33 | "\n", 34 | "# Load the config\n", 35 | "with open(model_path + 'config.json') as f:\n", 36 | " cfg_dict = json.load(f)\n", 37 | "cfg_dict[\"env\"][\"data_path\"] = \"../data/\"\n", 38 | "\n", 39 | "config = Config(cfg_dict)\n", 40 | "\n", 41 | "# Load the model\n", 42 | "instance = AQAgent(config=config, run_id=None, output_path=\"./runs/examples/qg_bert_eval\", silent=False, verbose=False, training_mode=False)\n", 43 | "instance.load_checkpoint(model_path + 'model/checkpoint.pt')\n", 44 | "instance.model.eval()\n", 45 | "\n", 46 | "# Create a dataset\n", 47 | "data_loader = QADataLoader(config)\n", 48 | "\n", 49 | "# Run inference on the test split\n", 50 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader.test_loader, metric_hooks=[TextualMetricHook(config, 'c', 'q')])\n", 51 | "\n", 52 | "# Done!\n", 53 | "print(all_metrics['bleu'])" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stderr", 63 | "output_type": "stream", 64 | "text": [ 65 | "Validating after 0 epochs: 100%|██████████| 1/1 [00:00<00:00, 2.43it/s]" 66 | ] 67 | }, 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "['Who was the oldest cat?', 'How long did Creme Puff live?']\n" 73 | ] 74 | }, 75 | { 76 | "name": "stderr", 77 | "output_type": "stream", 78 | "text": [ 79 | "\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "# You can now run your model on your own dataset\n", 85 | "\n", 86 | "examples = [\n", 87 | " {'c': 'Creme Puff was the oldest cat.', 'a': 'Creme Puff'},\n", 88 | " {'c': 'Creme Puff lived for 38 years and 3 days', 'a': '38 years and 3 days'},\n", 89 | "]\n", 90 | "\n", 91 | "# The examples need the answer character position, and a placeholder for the question\n", 92 | "examples = [\n", 93 | " {**ex, 'a_pos': ex['c'].index(ex['a']), 'q': ''} for ex in examples\n", 94 | "]\n", 95 | " \n", 96 | "data_loader_custom = QADataLoader(config, test_samples=examples)\n", 97 | "\n", 98 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader_custom.test_loader)\n", 99 | "\n", 100 | "print(pred_output)" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "Python 3", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.8.5" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 4 132 | } 133 | -------------------------------------------------------------------------------- /examples/mBART.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "[\"Nous pouvons traduire de l'anglais en français\", 'We can translate from French to English']\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import json\n", 18 | "from torchseq.agents.para_agent import ParaphraseAgent\n", 19 | "from torchseq.datasets.json_loader import JsonDataLoader\n", 20 | "from torchseq.utils.config import Config\n", 21 | "import torch\n", 22 | "\n", 23 | "\n", 24 | "\n", 25 | "with open('../configs/mbart.json') as f:\n", 26 | " cfg_dict = json.load(f)\n", 27 | "cfg_dict[\"env\"][\"data_path\"] = \"../data/\"\n", 28 | "cfg_dict[\"eval\"][\"eval_batch_size\"] = 1\n", 29 | "\n", 30 | "cfg_dict['training'][\"dataset\"] = 'json'\n", 31 | "cfg_dict[\"json_dataset\"] = {\n", 32 | " \"path\": None,\n", 33 | " \"field_map\": [\n", 34 | " {\"type\": \"copy\", \"from\": \"input\", \"to\": \"s2\"},\n", 35 | " {\"type\": \"copy\", \"from\": \"input\", \"to\": \"s1\"},\n", 36 | " ],\n", 37 | "}\n", 38 | "\n", 39 | "config = Config(cfg_dict)\n", 40 | "\n", 41 | "instance = ParaphraseAgent(config=config, run_id=None, output_path=\"./runs/examples/mbart/\", silent=True, verbose=False)\n", 42 | "\n", 43 | "instance.model.eval()\n", 44 | "\n", 45 | "examples = [\n", 46 | " {'input': 'We can translate from English to French', 'src_lang': 'en_XX', 'tgt_lang': 'fr_XX'},\n", 47 | " {'input': 'Nous pouvons traduire de Francais en Anglais', 'src_lang': 'fr_XX', 'tgt_lang': 'en_XX'}\n", 48 | "]\n", 49 | " \n", 50 | "data_loader = JsonDataLoader(config, test_samples=examples)\n", 51 | "\n", 52 | "test_loss, all_metrics, (pred_output, gold_output, gold_input), memory_values_to_return = instance.inference(data_loader.test_loader)\n", 53 | "\n", 54 | "print(pred_output)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [] 63 | } 64 | ], 65 | "metadata": { 66 | "kernelspec": { 67 | "display_name": "Python 3", 68 | "language": "python", 69 | "name": "python3" 70 | }, 71 | "language_info": { 72 | "codemirror_mode": { 73 | "name": "ipython", 74 | "version": 3 75 | }, 76 | "file_extension": ".py", 77 | "mimetype": "text/x-python", 78 | "name": "python", 79 | "nbconvert_exporter": "python", 80 | "pygments_lexer": "ipython3", 81 | "version": "3.8.5" 82 | } 83 | }, 84 | "nbformat": 4, 85 | "nbformat_minor": 4 86 | } 87 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_return_any = True 3 | warn_unused_configs = True 4 | ignore_missing_imports = True 5 | exclude = (?x)(torchseq/external/) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard==2.15.0 2 | torch==2.2.2 3 | tqdm>=4.62 4 | scipy>=1.5 5 | lightning==2.1.0 6 | 7 | nltk>=3.6.7 8 | transformers==4.39.2 9 | tokenizers==0.15.2 10 | jsonlines>=2 11 | sacrebleu>=2.0 12 | py-rouge 13 | rouge-score 14 | opentsne 15 | sentencepiece==0.1.95 16 | pydantic==1.10.13 17 | truecase==0.0.14 18 | # summac 19 | 20 | compress_json==1.0.10 21 | 22 | matplotlib 23 | 24 | wandb==0.15.12 25 | protobuf<4 26 | 27 | pytest>=6.2 28 | pytest-cov>=2.12 29 | codecov>=2.1 30 | flake8==6.0.0 31 | black==24.3.0 32 | mypy==1.2.0 33 | 34 | sphinx>=3.3 35 | sphinx_rtd_theme 36 | recommonmark -------------------------------------------------------------------------------- /scripts/download_data.sh: -------------------------------------------------------------------------------- 1 | mkdir ../data/ 2 | 3 | mkdir ../data/squad 4 | curl https://raw.githubusercontent.com/tomhosking/squad-du-split/master/train-v1.1.json -o ../data/squad/train-v1.1.json -L -C - 5 | curl https://raw.githubusercontent.com/tomhosking/squad-du-split/master/dev-v1.1.json -o ../data/squad/dev-v1.1.json -L -C - 6 | curl https://raw.githubusercontent.com/tomhosking/squad-du-split/master/test-v1.1.json -o ../data/squad/test-v1.1.json -L -C - 7 | 8 | # mkdir ../runs 9 | # mkdir ../runs/slurmlogs 10 | 11 | # curl https://nlp.stanford.edu/data/glove.6B.zip -o ./data/glove.6B.zip -L -C - 12 | # unzip ./data/glove.6B.zip -d ./data/glove.6B/ 13 | # rm ./data/glove.6B.zip 14 | -------------------------------------------------------------------------------- /scripts/download_models.py: -------------------------------------------------------------------------------- 1 | from transformers import BartModel, BertModel, BertTokenizer, RobertaModel, BertForQuestionAnswering, MBartModel, MBartTokenizer 2 | 3 | 4 | mod = BartModel.from_pretrained('facebook/bart-large') 5 | mod = RobertaModel.from_pretrained('roberta-base') 6 | 7 | mod = BertModel.from_pretrained('bert-base-uncased') 8 | mod = BertModel.from_pretrained('bert-base-cased') 9 | 10 | mod = BertTokenizer.from_pretrained('bert-base-uncased') 11 | mod = BertTokenizer.from_pretrained('bert-base-cased') 12 | 13 | mod = BertForQuestionAnswering.from_pretrained("bert-large-uncased-whole-word-masking-finetuned-squad") 14 | 15 | 16 | import nltk 17 | nltk.download('punkt', force=True) 18 | nltk.download('wordnet', force=True) 19 | 20 | from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel 21 | 22 | tokenizer = AutoTokenizer.from_pretrained("tomhosking/deberta-v3-base-debiased-nli") 23 | model = AutoModelForSequenceClassification.from_pretrained("tomhosking/deberta-v3-base-debiased-nli") 24 | 25 | 26 | tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base") 27 | model = AutoModel.from_pretrained("microsoft/deberta-v3-base") 28 | 29 | 30 | mod = MBartTokenizer.from_pretrained('facebook/mbart-large-50', src_lang='en_XX', tgt_lang='en_XX') 31 | mod = MBartTokenizer.from_pretrained('facebook/mbart-large-50-many-to-many-mmt', src_lang='en_XX', tgt_lang='en_XX') 32 | mod = MBartModel.from_pretrained('facebook/mbart-large-50-many-to-many-mmt') 33 | 34 | 35 | from summac.model_summac import SummaCConv 36 | 37 | model_conv = SummaCConv( 38 | models=["vitc"], 39 | bins="percentile", 40 | granularity="sentence", 41 | nli_labels="e", 42 | device="cpu", 43 | start_file="default", 44 | agg="mean", 45 | ) 46 | 47 | for imager in model_conv.imagers: 48 | imager.load_nli() 49 | 50 | model_conv = SummaCConv( 51 | models=["mnli"], 52 | bins="percentile", 53 | granularity="sentence", 54 | nli_labels="e", 55 | device="cpu", 56 | start_file="default", 57 | agg="mean", 58 | ) 59 | 60 | for imager in model_conv.imagers: 61 | imager.load_nli() -------------------------------------------------------------------------------- /scripts/kaggle_prepro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | 5 | from math import floor 6 | import csv 7 | 8 | filtered_pairs = [] 9 | 10 | included = 0 11 | skipped = 0 12 | 13 | np.random.seed(0) 14 | 15 | DEVTEST_FRACTION = 0.1 16 | 17 | qids_train = set() 18 | qids_dev = set() 19 | qids_test = set() 20 | 21 | with open('/mnt/ext/phd/data/kaggle-questionpairs/questions.csv') as f, \ 22 | open('./data/kaggle/paraphrases.train.txt', 'w') as f_train, \ 23 | open('./data/kaggle/paraphrases.dev.txt', 'w') as f_dev, \ 24 | open('./data/kaggle/paraphrases.test.txt', 'w') as f_test: 25 | csv_reader = csv.reader(f) 26 | for ix, cols in enumerate(csv_reader): 27 | if ix == 0: 28 | continue 29 | 30 | # cols = line.split('\t') 31 | 32 | if len(cols[3]) > 500 or len(cols[4]) > 500: 33 | continue 34 | 35 | if cols[5] == '1': 36 | rand = np.random.random() 37 | 38 | qid1 = cols[1] 39 | qid2 = cols[2] 40 | 41 | if qid1 in qids_train or qid2 in qids_train: 42 | print('Forcing to train') 43 | f_train.write('{:}\t{:}\n'.format(cols[3], cols[4])) 44 | 45 | elif qid1 in qids_dev or qid2 in qids_dev: 46 | print('Forcing to dev') 47 | f_dev.write('{:}\t{:}\n'.format(cols[3], cols[4])) 48 | elif qid1 in qids_test or qid2 in qids_test: 49 | print('Forcing to test') 50 | f_test.write('{:}\t{:}\n'.format(cols[3], cols[4])) 51 | else: 52 | # 0.8-0.9 53 | if rand >= (1-DEVTEST_FRACTION*2) and rand < (1-DEVTEST_FRACTION) and cols[5] == '1': 54 | 55 | f_dev.write('{:}\t{:}\n'.format(cols[3], cols[4])) 56 | qids_dev.add(qid1) 57 | qids_dev.add(qid2) 58 | # 0.9+ 59 | elif rand >= (1-DEVTEST_FRACTION) and cols[5] == '1': 60 | f_test.write('{:}\t{:}\n'.format(cols[3], cols[4])) 61 | qids_test.add(qid1) 62 | qids_test.add(qid2) 63 | # Under 0.8 64 | else: 65 | f_train.write('{:}\t{:}\n'.format(cols[3], cols[4])) 66 | qids_train.add(qid1) 67 | qids_train.add(qid2) 68 | included += 1 69 | else: 70 | skipped += 1 71 | 72 | 73 | print(included, skipped) -------------------------------------------------------------------------------- /scripts/paranmt_prepro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import os 5 | 6 | from math import floor 7 | 8 | filtered_pairs = [] 9 | 10 | included = 0 11 | skipped = 0 12 | 13 | np.random.seed(0) 14 | 15 | TRAIN_FRACTION = 0.97 16 | 17 | # with open('./data/paranmt/para-nmt-50m.txt') as f: 18 | # with open('./data/paranmt/paraphrases.train.txt', 'w') as f_train: 19 | # with open('./data/paranmt/paraphrases.dev.txt', 'w') as f_dev: 20 | # for line in f: 21 | # cols = line.strip().split('\t') 22 | 23 | # if line.strip() == '' or len(cols[0]) > 300 or len(cols[1]) > 300 or len(cols[0]) < 30 or len(cols[1]) < 30: 24 | # continue 25 | 26 | # if float(cols[2]) < 0.9 and float(cols[2]) > 0.5: 27 | # if np.random.random() > TRAIN_FRACTION: 28 | # f_dev.write('{:}\t{:}\n'.format(cols[1].strip(), cols[0].strip())) 29 | # else: 30 | # f_train.write('{:}\t{:}\n'.format(cols[1].strip(), cols[0].strip())) 31 | # included += 1 32 | # else: 33 | # skipped += 1 34 | 35 | 36 | # os.makedirs('./data/parabank-qs') 37 | with open('../../data/parabank/parabank-1.0-large-diverse/parabank.50m.tsv') as f: 38 | with open('./data/parabank-qs/paraphrases.train.txt', 'w') as f_train: 39 | with open('./data/parabank-qs/paraphrases.dev.txt', 'w') as f_dev: 40 | for line in f: 41 | cols = line.strip().split('\t') 42 | 43 | if line.strip() == '' or len(cols[0]) > 300 or len(cols[1]) > 300 or len(cols[0]) < 30 or len(cols[1]) < 30: 44 | continue 45 | 46 | if cols[0][-1] != '?' or cols[1][-1] != '?': 47 | continue 48 | 49 | # NOTE: I think parabank is already inverted (mt->orig) but paranmt is (orig->mt) 50 | if np.random.random() > TRAIN_FRACTION: 51 | f_dev.write('{:}\t{:}\n'.format(cols[0].strip(), cols[1].strip())) 52 | else: 53 | f_train.write('{:}\t{:}\n'.format(cols[0].strip(), cols[1].strip())) 54 | included += 1 55 | 56 | 57 | print(included, skipped) -------------------------------------------------------------------------------- /scripts/submit_job.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MCKENZIE_HOOK=~/mckenzie/scripts/hook.sh 4 | 5 | 6 | 7 | jobName="UNK" 8 | jobTag="custom" 9 | 10 | 11 | POSITIONAL=() 12 | while [[ $# -gt 0 ]] 13 | do 14 | key="$1" 15 | 16 | case $key in 17 | --config) 18 | CONFIG="$2" 19 | POSITIONAL+=("$1") # save it in an array for later 20 | POSITIONAL+=("$2") # save it in an array for later 21 | jobName=$(cat $CONFIG | grep \"name\"\: | sed -E 's/.+\"name\": \"(.*)\"\,/\1/') 22 | jobTag=$(cat $CONFIG | grep \"tag\"\: | sed -E 's/.+\"tag\": \"(.*)\"\,/\1/') 23 | shift # past argument 24 | shift # past value 25 | ;; 26 | *) # unknown option 27 | POSITIONAL+=("$1") # save it in an array for later 28 | shift # past argument 29 | ;; 30 | esac 31 | done 32 | set -- "${POSITIONAL[@]}" # restore positional parameters 33 | 34 | 35 | 36 | res=$(sbatch $@) 37 | 38 | 39 | jobId=`echo $res | sed -E 's/Submitted batch job ([0-9]+)/\1/'` 40 | 41 | partition=`scontrol show job $jobId | grep "Partition=([^\s]+)" -Po | sed s/Partition=//` 42 | 43 | if [ "$jobId" != "$res" ] 44 | then 45 | 46 | ${MCKENZIE_HOOK} -a 1 -i $jobId -p $partition > /dev/null 47 | 48 | echo "Batch job ID $jobId -> $jobTag/$jobName" 49 | 50 | ${MCKENZIE_HOOK} -i $jobId -p $partition -n $jobTag/$jobName > /dev/null 51 | sleep 5 52 | else 53 | echo "Error submitting job!" 54 | echo res 55 | fi -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torchseq", 8 | version="3.1.0", 9 | author="Tom Hosking", 10 | author_email="code@tomho.sk", 11 | description="A Seq2Seq framework for PyTorch", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/tomhosking/torchseq", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Operating System :: OS Independent", 20 | "Intended Audience :: Science/Research", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | ], 23 | python_requires='>=3.6', 24 | entry_points = { 25 | 'console_scripts': ['torchseq=torchseq.main:main', 'torchseq-eval=torchseq.eval.cli:main'], 26 | }, 27 | install_requires = [ 28 | 'tensorboard==2.15.0', 29 | 'torch==2.2.2', 30 | 'tqdm>=4.62', 31 | 'scipy>=1.5', 32 | 'nltk>=3.6.7', 33 | 'transformers==4.39.2', 34 | 'tokenizers==0.15.2', 35 | 'jsonlines>=2', 36 | 'sacrebleu>=2.0', 37 | 'py-rouge', 38 | 'rouge-score', 39 | 'wandb==0.15.12', 40 | 'matplotlib', 41 | 'opentsne', 42 | 'sentencepiece==0.1.95', 43 | 'protobuf<4', 44 | 'pydantic==1.10.13', 45 | 'truecase==0.0.14', 46 | 'lightning==2.1.0', 47 | # 'summac', 48 | 'compress_json==1.0.10', 49 | ], 50 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | 7 | 8 | from . import utils as test_utils 9 | 10 | from torchseq.utils import metrics 11 | from torchseq.utils.rouge import get_jackknife_rouge, get_pairwise_rouge 12 | 13 | from torchseq.utils.fleiss import fleiss 14 | 15 | import numpy as np 16 | 17 | 18 | def test_bleu(): 19 | refs = ["The dog bit the man.", "It was not unexpected.", "The man bit him first."] 20 | sys = ["The dog bit the man.", "It wasn't surprising.", "The man had just bitten him."] 21 | bleu = metrics.bleu_corpus(sys, refs) 22 | assert abs(bleu - 45.0675) < 0.001, "BLEU score for basic examples differs from SacreBLEU reference!" 23 | 24 | 25 | def test_rouge(): 26 | refs = ["The dog bit the man.", "It was not unexpected.", "The man bit him first."] 27 | sys = ["The dog bit the man.", "It wasn't surprising.", "The man had just bitten him."] 28 | rouge = get_pairwise_rouge(sys[0], refs[0]) 29 | assert "rouge2" in rouge 30 | assert "rougeL" in rouge 31 | assert abs(rouge["rouge2"] - 100) < 0.001, "Rouge score for basic examples differs from reference!" 32 | 33 | refs = [ 34 | ["The dog bit the man.", "The man was bitten by the dog"], 35 | ["It was not unexpected.", "it was not surprising"], 36 | ["The man bit him first."], 37 | ] 38 | rouge = get_jackknife_rouge(sys, refs) 39 | assert "rouge2" in rouge 40 | assert "rougeL" in rouge 41 | 42 | # NOTE: This test is intrinsic - the rouge value was obtained using the same rouge implementation, so is only a check for regression! 43 | assert abs(rouge["rouge2"] - 30.741) < 0.001, "Rouge score for jackknife examples differs from reference!" 44 | 45 | 46 | def test_meteor(): 47 | preds = ["It is a guide to action which ensures that the military always obeys the commands of the party"] 48 | refs = ["It is a guide to action that ensures that the military will forever heed Party commands"] 49 | 50 | assert metrics.meteor_corpus(refs, preds) - 0.7398 < 1e-4 51 | 52 | 53 | def test_f1(): 54 | assert metrics.f1("same", "same") == 1.0, "F1 failed for correct example!" 55 | assert metrics.f1("same", "diff") == 0.0, "F1 failed for wrong example!" 56 | assert metrics.f1("tok1 tok2", "tok1 tok3") == 0.5, "F1 failed for overlapping example!" 57 | 58 | 59 | def test_fleiss(): 60 | bad = [[1, 1], [1, 1], [1, 1]] 61 | 62 | good = [[2, 0], [0, 2], [2, 0]] 63 | 64 | random = [[2, 0], [0, 2], [1, 1], [1, 1]] 65 | 66 | assert fleiss(np.array(bad)) == -1.0 67 | assert fleiss(np.array(good)) == 1 68 | assert fleiss(np.array(random)) == 0 69 | 70 | 71 | def test_misc(): 72 | assert metrics.normalize_answer("a Very oLd cat! ") == "very old cat" 73 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | 7 | 8 | from . import utils as test_utils 9 | 10 | from torchseq.utils.config import Config, merge_cfg_dicts 11 | from torchseq.utils.singleton import Singleton 12 | import torchseq.utils.functions as tsfunctions 13 | 14 | 15 | def test_config(): 16 | main_cfg_dict = { 17 | "name": "model_name", 18 | "int": 1, 19 | "float": 0.2, 20 | "str": "hello", 21 | "bool": True, 22 | "nested": {"value": "present"}, 23 | } 24 | 25 | mask_cfg_dict = {"name": "%_mod", "int": 2, "nested": {"value": "overwritten", "newval": "added"}} 26 | 27 | cfg_obj = Config(main_cfg_dict) 28 | 29 | merged_obj = Config(merge_cfg_dicts(main_cfg_dict, mask_cfg_dict)) 30 | 31 | assert cfg_obj.int == 1 32 | assert cfg_obj.float == 0.2 33 | assert cfg_obj.str == "hello" 34 | assert cfg_obj.bool is True 35 | assert cfg_obj.nested.value == "present" 36 | 37 | assert cfg_obj.get_path(["nested", "value"]) == "present" 38 | assert cfg_obj.get_path(["nested", "missing"], "alt_default") == "alt_default" 39 | 40 | assert cfg_obj.get("str", "not hello") == "hello" 41 | assert cfg_obj.get("missing", "default_val") == "default_val" 42 | 43 | assert cfg_obj.get_first(["float", "int"]) == 0.2 44 | 45 | assert merged_obj.int == 2 46 | assert merged_obj.nested.value == "overwritten" 47 | assert merged_obj.nested.newval == "added" 48 | assert merged_obj.name == "model_name_mod" 49 | 50 | 51 | def test_singleton(): 52 | class TestClass(metaclass=Singleton): 53 | def __init__(self, val): 54 | self.val = val 55 | 56 | x = TestClass("x") 57 | y = TestClass("y") 58 | 59 | assert x.val == "x" 60 | assert x.val == y.val 61 | 62 | 63 | def test_cache(): 64 | from torchseq.utils.cache import Cache 65 | 66 | cache = Cache() 67 | 68 | data_str = "testing" 69 | cache.save("string", data_str) 70 | assert cache.load("string") == data_str 71 | 72 | data_tensor = torch.rand(4, 7) 73 | cache.save("tensor", data_tensor) 74 | assert cache.load("tensor").equal(data_tensor) 75 | 76 | assert cache.load("missing") is None 77 | 78 | 79 | def test_functions(): 80 | # topk 81 | input_probs = torch.Tensor([[0.1, 0.2, 0.3, 0.4], [0.13, 0.7, 0.17, 0.0]]) 82 | 83 | assert tsfunctions.top_k_top_p_filtering(input_probs, top_k=2).equal( 84 | torch.tensor([[-torch.inf, -torch.inf, 0.3000, 0.4000], [-torch.inf, 0.7000, 0.1700, -torch.inf]]) 85 | ) 86 | 87 | assert tsfunctions.top_k_top_p_filtering(input_probs, top_k=2, filter_value=0).equal( 88 | torch.tensor([[0.0000, 0.0000, 0.3000, 0.4000], [0.0000, 0.7000, 0.1700, 0.0000]]) 89 | ) 90 | 91 | assert tsfunctions.top_k_top_p_filtering(input_probs, top_p=0.5, filter_value=0).equal( 92 | torch.tensor([[0.0000, 0.0000, 0.3000, 0.4000], [0.0000, 0.7000, 0.1700, 0.0000]]) 93 | ) 94 | 95 | x = torch.Tensor([[0.0, 0.0, 0.0, 1.0], [0.13, 0.7, 0.17, 0.0]]) 96 | y = torch.Tensor([[0.0, 0.0, 0.0, 1.0]]) 97 | 98 | assert tsfunctions.cos_sim(x, y).equal(torch.tensor([1.0, 0.0000])) 99 | 100 | test_data = list(range(6)) 101 | 102 | assert list(tsfunctions.batchify(test_data, 4)) == [(0, [0, 1, 2, 3]), (1, [4, 5])] 103 | -------------------------------------------------------------------------------- /tests/test_pretrained.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | 7 | 8 | from . import utils as test_utils 9 | 10 | from torchseq.pretrained.qa import PreTrainedQA 11 | from torchseq.pretrained.lm import PretrainedLM 12 | 13 | 14 | @test_utils.slow 15 | def test_qa(): 16 | use_cuda = torch.cuda.is_available() 17 | 18 | instance = PreTrainedQA(device=("cuda" if use_cuda else "cpu")) 19 | 20 | preds = instance.infer_batch( 21 | ["Who was the oldest cat?", "Who was a nice puppet?"], 22 | ["Creme Puff was the oldest cat.", "This is a distraction. " * 50 + "Creme Puff was the oldest cat."], 23 | ) 24 | 25 | assert len(preds) == 2, "Failed to produce the right number of predictions?!" 26 | assert preds[0] == "Creme Puff", "Short QA test failed - answer is wrong" 27 | assert preds[1] == "Creme Puff", "Long QA test failed - answer is wrong" 28 | 29 | 30 | @test_utils.slow 31 | def test_lm(): 32 | instance = PretrainedLM() 33 | 34 | preds = instance.get_log_prob( 35 | [ 36 | "The first thing", 37 | "Creme Puff was the oldest cat.", 38 | " Variational Autoencoders (VAEs) provide a theoretically-backed and popular framework for deep generative models.", 39 | ], 40 | ) 41 | 42 | print(preds) 43 | 44 | assert len(preds) == 3, "Failed to produce the right number of predictions?!" 45 | assert preds[0] < 9, "Simple sentence score is too high" 46 | assert preds[0] > 8, "Simple sentence score is too low" 47 | assert abs(preds[1] - 8.8) < 0.1, "Normal sentence score out of range" 48 | assert abs(preds[2] - 10.8) < 0.1, "Hard sentence score out of range" 49 | -------------------------------------------------------------------------------- /tests/test_tokenisation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import utils as test_utils 4 | from torchseq.utils.tokenizer import Tokenizer 5 | 6 | 7 | def test_bert_uncased_basic(): 8 | TEST_STRING = "This is a test sentence." 9 | 10 | tokenizer = Tokenizer("bert-base-uncased") 11 | 12 | tokenized = tokenizer.tokenise(TEST_STRING) 13 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 14 | 15 | assert [tok["id"] for tok in tokenized] == [ 16 | 101, 17 | 2023, 18 | 2003, 19 | 1037, 20 | 3231, 21 | 6251, 22 | 1012, 23 | 102, 24 | ], "BERT uncased tok ids are wrong for basic example!" 25 | assert decoded == TEST_STRING.lower(), "BERT uncased tokenisation isn't reversible for basic example!" 26 | 27 | 28 | def test_bert_cased_basic(): 29 | TEST_STRING = "This is a test sentence." 30 | 31 | tokenizer = Tokenizer("bert-base-cased") 32 | 33 | tokenized = tokenizer.tokenise(TEST_STRING) 34 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 35 | 36 | assert [tok["id"] for tok in tokenized] == [ 37 | 101, 38 | 1188, 39 | 1110, 40 | 170, 41 | 2774, 42 | 5650, 43 | 119, 44 | 102, 45 | ], "BERT cased tok ids are wrong for basic example!" 46 | assert decoded == TEST_STRING, "BERT cased tokenisation isn't reversible for basic example!" 47 | 48 | 49 | def test_roberta_basic(): 50 | tokenizer = Tokenizer("roberta-base") 51 | 52 | TEST_STRING = "This is a test sentence." 53 | 54 | tokenized = tokenizer.tokenise(TEST_STRING) 55 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 56 | 57 | assert [tok["id"] for tok in tokenized] == [ 58 | 0, 59 | 713, 60 | 16, 61 | 10, 62 | 1296, 63 | 3645, 64 | 4, 65 | 2, 66 | ], "RoBERTa tok ids are wrong for basic example!" 67 | assert decoded == TEST_STRING, "RoBERTa tokenisation isn't reversible for basic example!" 68 | 69 | 70 | def test_auto(): 71 | tokenizer = Tokenizer("xlnet-base-cased") 72 | 73 | TEST_STRING = "This is a test sentence." 74 | 75 | tokenized = tokenizer.tokenise(TEST_STRING) 76 | decoded = tokenizer.decode(torch.LongTensor([tok["id"] for tok in tokenized])) 77 | 78 | assert [tok["id"] for tok in tokenized] == [ 79 | 1, 80 | 122, 81 | 27, 82 | 24, 83 | 934, 84 | 3833, 85 | 9, 86 | 4, 87 | 3, 88 | 2, 89 | ], "Auto tok ids are wrong for basic example!" 90 | assert decoded == TEST_STRING, "Auto tokenisation isn't reversible for basic example!" 91 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import unittest 4 | from distutils.util import strtobool 5 | 6 | 7 | def parse_flag_from_env(key, default=False): 8 | try: 9 | value = os.environ[key] 10 | except KeyError: 11 | # KEY isn't set, default to `default`. 12 | _value = default 13 | else: 14 | # KEY is set, convert it to True or False. 15 | try: 16 | _value = strtobool(value) 17 | except ValueError: 18 | # More values are supported, but let's keep the message simple. 19 | raise ValueError("If set, {} must be yes or no.".format(key)) 20 | return _value 21 | 22 | 23 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) 24 | 25 | 26 | def slow(test_case): 27 | """ 28 | Decorator marking a test as slow. 29 | Slow tests are skipped by default. Set the RUN_SLOW environment variable 30 | to a truthy value to run them. 31 | """ 32 | if not _run_slow_tests: 33 | test_case = unittest.skip("test is slow")(test_case) 34 | return test_case 35 | -------------------------------------------------------------------------------- /torchseq/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "3.1.0" 2 | 3 | 4 | from .utils.model_loader import model_from_path 5 | from .utils.easy_generate import generate 6 | -------------------------------------------------------------------------------- /torchseq/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/agents/__init__.py -------------------------------------------------------------------------------- /torchseq/agents/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | # import torch._dynamo 6 | 7 | from torchseq.utils.functions import to_device_unless_marked 8 | 9 | # This project was originally based off this template: 10 | # https://github.com/moemen95/Pytorch-Project-Template 11 | 12 | 13 | class BaseAgent: 14 | """ 15 | This base class will contain the base functions to be overloaded by any agent you will implement. 16 | """ 17 | 18 | cuda: bool 19 | use_lightning: bool = False 20 | 21 | def __init__(self, config): 22 | self.config = config 23 | self.logger = logging.getLogger("Agent") 24 | 25 | def set_device(self, use_cuda=True): 26 | # set cuda flag 27 | self.cuda_available = torch.cuda.is_available() 28 | if self.cuda_available and not use_cuda: 29 | self.logger.warning("You have a CUDA device, so you should probably enable CUDA") 30 | 31 | if use_cuda and not self.cuda_available: 32 | self.logger.error("Use CUDA is set to true, but not CUDA devices were found!") 33 | raise Exception("No CUDA devices found") 34 | 35 | self.cuda = self.cuda_available & use_cuda 36 | 37 | if not self.model: 38 | raise Exception("You need to define your model before calling set_device!") 39 | 40 | if self.cuda: 41 | self.device = torch.device("cuda") 42 | 43 | self.logger.info("Model will run on *****GPU-CUDA***** ") 44 | 45 | # self.model.to(self.device) 46 | self.model.apply(to_device_unless_marked(self.device)) 47 | 48 | self.loss.to(self.device) 49 | 50 | # TODO: Enable for pytorch 2.0 51 | # torch._dynamo.config.verbose = False 52 | # torch._dynamo.config.log_level = logging.WARN 53 | # torch._dynamo.reset() 54 | # self.model = torch.compile(self.model, dynamic=True, mode='reduce-overhead') #, backend="inductor",fullgraph=True, mode='reduce-overhead',, dynamic=True 55 | 56 | else: 57 | self.device = torch.device("cpu") 58 | 59 | self.logger.info("Model will run on *****CPU*****") 60 | 61 | def load_checkpoint(self, file_name): 62 | """ 63 | Latest checkpoint loader 64 | :param file_name: name of the checkpoint file 65 | :return: 66 | """ 67 | raise NotImplementedError 68 | 69 | def save_checkpoint(self, file_name="checkpoint.pt", is_best=0): 70 | """ 71 | Checkpoint saver 72 | :param file_name: name of the checkpoint file 73 | :param is_best: boolean flag to indicate whether current checkpoint's metric is the best so far 74 | :return: 75 | """ 76 | raise NotImplementedError 77 | 78 | def train(self, data_loader) -> None: 79 | """ 80 | Main training loop 81 | :return: 82 | """ 83 | raise NotImplementedError 84 | 85 | def train_one_epoch(self): 86 | """ 87 | One epoch of training 88 | :return: 89 | """ 90 | raise NotImplementedError 91 | 92 | def validate(self): 93 | """ 94 | One cycle of model validation 95 | :return: 96 | """ 97 | raise NotImplementedError 98 | 99 | def finalize(self): 100 | """ 101 | Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader 102 | :return: 103 | """ 104 | raise NotImplementedError 105 | -------------------------------------------------------------------------------- /torchseq/agents/lm_agent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from torchseq.agents.model_agent import ModelAgent 6 | 7 | 8 | from torchseq.models.lm_transformer import TransformerLanguageModel 9 | from torchseq.utils.functions import gaussian_kl 10 | 11 | 12 | class LangModelAgent(ModelAgent): 13 | def __init__( 14 | self, 15 | config, 16 | run_id, 17 | output_path, 18 | data_path, 19 | silent=False, 20 | training_mode=True, 21 | verbose=True, 22 | cache_root=None, 23 | use_cuda=True, 24 | ): 25 | super().__init__(config, run_id, output_path, data_path, silent, training_mode, verbose, cache_root) 26 | 27 | self.tgt_field = "sent" 28 | 29 | self.src_field = "sent" 30 | 31 | # define loss 32 | self.loss = nn.CrossEntropyLoss( 33 | ignore_index=self.output_tokenizer.pad_id, 34 | reduction="none", 35 | label_smoothing=self.config.training.get("label_smoothing", 0.0), 36 | ) 37 | 38 | # define model 39 | self.model = TransformerLanguageModel(self.config, src_field=self.src_field) 40 | 41 | # define optimizer 42 | if training_mode: 43 | self.create_optimizer() 44 | 45 | self.set_device(use_cuda) 46 | 47 | self.create_samplers() 48 | -------------------------------------------------------------------------------- /torchseq/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def parse_args(): 5 | parser = argparse.ArgumentParser( 6 | description="TorchSeq", 7 | ) 8 | 9 | parser.add_argument("-V", "--version", action="store_true", help="Display version") 10 | 11 | # Config stuff 12 | parser.add_argument( 13 | "-c", "--config", type=str, metavar="CONFIG", default="./configs/default.json", help="Path to config file" 14 | ) 15 | parser.add_argument( 16 | "-p", 17 | "--patch", 18 | type=str, 19 | metavar="PATCH", 20 | default=None, 21 | help="Config mask(s) to overwrite main config with", 22 | action="append", 23 | ) 24 | 25 | # Actions 26 | parser.add_argument("--train", action="store_true", help="Run training?") 27 | parser.add_argument("--validate", action="store_true", help="Run eval on dev?") 28 | parser.add_argument("--validate_train", action="store_true", help="Run eval on train?") 29 | parser.add_argument("--test", action="store_true", help="Run eval on test?") 30 | parser.add_argument("--silent", action="store_true", help="Disable logging") 31 | parser.add_argument("--verbose", action="store_true", help="Extra logging") 32 | parser.add_argument( 33 | "--reload_after_train", action="store_true", help="Reload model after training to do a validation run" 34 | ) 35 | parser.add_argument( 36 | "--copy_chkpt", 37 | action="store_true", 38 | help="Save a copy of the checkpoint in current output dir, even if loading from elsewhere", 39 | ) 40 | 41 | # Model loading 42 | parser.add_argument("--load_chkpt", type=str, metavar="CHECKPOINT", default=None, help="Path to checkpoint file") 43 | parser.add_argument("-l", "--load", type=str, metavar="MODEL", default=None, help="Path to model folder") 44 | parser.add_argument("--nocache", action="store_true", help="Disable loading from an old cache") 45 | 46 | # Paths 47 | parser.add_argument("-d", "--data_path", type=str, metavar="DATA", default="./data/", help="Path to data sources") 48 | parser.add_argument( 49 | "-o", "--output_path", type=str, metavar="OUTPUT", default="./runs/", help="Path to output folder" 50 | ) 51 | 52 | # Runtime 53 | parser.add_argument("--cpu", action="store_true", help="Disable CUDA") 54 | parser.add_argument("--debug", action="store_true", help="Enable debug mode") 55 | parser.add_argument("--lightning", action="store_true", help="Enable Lightning") 56 | parser.add_argument("--amp", action="store_true", help="Enable AMP") 57 | 58 | args = parser.parse_args() 59 | 60 | return args 61 | -------------------------------------------------------------------------------- /torchseq/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomhosking/torchseq/0576fe765c2eba2df6fa00009c5b302672417cb7/torchseq/datasets/__init__.py -------------------------------------------------------------------------------- /torchseq/datasets/builder.py: -------------------------------------------------------------------------------- 1 | from torchseq.datasets.paraphrase_loader import ParaphraseDataLoader 2 | from torchseq.datasets.qa_loader import QADataLoader 3 | from torchseq.datasets.json_loader import JsonDataLoader 4 | from torchseq.datasets.lm_loader import LangmodellingDataLoader 5 | 6 | 7 | def dataloader_from_config(config, data_path="./data", train_samples=None, dev_samples=None, test_samples=None): 8 | # define data_loader 9 | if config.training.dataset is None: 10 | data_loader = None 11 | elif ( 12 | config.training.dataset 13 | in [ 14 | "paranmt", 15 | "parabank", 16 | "kaggle", 17 | "parabank-qs", 18 | "para-squad", 19 | "models/squad-udep", 20 | "models/squad-constituency", 21 | "models/squad-udep-deptree", 22 | "models/qdmr-squad", 23 | "models/nq_newsqa-udep", 24 | "models/nq_newsqa-udep-deptree", 25 | "models/squad_nq_newsqa-udep", 26 | "models/squad_nq_newsqa-udep-deptree", 27 | "models/naturalquestions-udep", 28 | "models/newsqa-udep", 29 | "models/naturalquestions-udep-deptree", 30 | "models/newsqa-udep-deptree", 31 | ] 32 | or config.training.dataset[:5] == "qdmr-" 33 | or "kaggle-" in config.training.dataset 34 | ): 35 | data_loader = ParaphraseDataLoader( 36 | config=config, 37 | data_path=data_path, 38 | train_samples=train_samples, 39 | dev_samples=dev_samples, 40 | test_samples=test_samples, 41 | ) 42 | elif ( 43 | config.training.dataset 44 | in [ 45 | "squad", 46 | "newsqa", 47 | "msmarco", 48 | "naturalquestions", 49 | "drop", 50 | "nq_newsqa", 51 | "squad_nq_newsqa", 52 | "inquisitive", 53 | ] 54 | or config.training.dataset[:5] == "squad" 55 | or config.training.dataset[:3] == "qa/" 56 | ): 57 | data_loader = QADataLoader( 58 | config=config, 59 | data_path=data_path, 60 | train_samples=train_samples, 61 | dev_samples=dev_samples, 62 | test_samples=test_samples, 63 | ) 64 | elif config.training.dataset in [ 65 | "json", 66 | ]: 67 | data_loader = JsonDataLoader( 68 | config=config, 69 | data_path=data_path, 70 | train_samples=train_samples, 71 | dev_samples=dev_samples, 72 | test_samples=test_samples, 73 | ) 74 | elif config.training.dataset in ["ptb", "wikitext103"]: 75 | data_loader = LangmodellingDataLoader( 76 | config=config, 77 | data_path=data_path, 78 | train_samples=train_samples, 79 | dev_samples=dev_samples, 80 | test_samples=test_samples, 81 | ) 82 | else: 83 | raise Exception("Unrecognised dataset: {:}".format(config.training.dataset)) 84 | 85 | return data_loader 86 | -------------------------------------------------------------------------------- /torchseq/datasets/lm_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from itertools import cycle 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import IterableDataset, Dataset 7 | 8 | 9 | from torchseq.utils.tokenizer import Tokenizer 10 | 11 | 12 | # class LangmodellingDataset(IterableDataset): 13 | class LangmodellingDataset(Dataset): 14 | def __init__(self, path, config, dev=False, test=False, repeat=False): 15 | self.config = config 16 | 17 | self.repeat = repeat 18 | 19 | self.path = path 20 | self.variant = "dev" if dev else ("test" if test else "train") 21 | 22 | self.exists = True 23 | 24 | self.length = 0 25 | 26 | if test and not os.path.exists(os.path.join(self.path, "sentences.{:}.txt".format(self.variant))): 27 | self.exists = False 28 | else: 29 | # TODO: Can we get the length without reading the whole file? 30 | with open(os.path.join(self.path, "sentences.{:}.txt".format(self.variant))) as f: 31 | # for line in f: 32 | # self.length += 1 33 | self.samples = [x.strip() for x in f.readlines()] 34 | 35 | def __len__(self): 36 | return len(self.samples) 37 | 38 | # def __iter__(self): 39 | # return self.generator() 40 | 41 | # def generator(self): 42 | # worker_info = torch.utils.data.get_worker_info() 43 | # if not worker_info: 44 | # worker_id = 0 45 | # num_workers = 1 46 | # else: 47 | # worker_id = worker_info.id 48 | # num_workers = worker_info.num_workers 49 | 50 | # with open(os.path.join(self.path, "sentences.{:}.txt".format(self.variant))) as f: 51 | # num_repeats = 0 52 | # while self.repeat or num_repeats < 1: 53 | # num_repeats += 1 54 | # for ix, line in enumerate(f): 55 | # if num_workers > 1 and ix % num_workers != worker_id: 56 | # continue 57 | # x = line.strip("\n") 58 | 59 | # sample = {"sent": x} 60 | # yield self.to_tensor(sample, tok_window=self.config.prepro.tok_window) 61 | 62 | def __getitem__(self, idx): 63 | return LangmodellingDataset.to_tensor({"sent": self.samples[idx]}, tok_window=self.config.prepro.tok_window) 64 | 65 | @staticmethod 66 | def to_tensor(x, tok_window=64): 67 | parsed = LangmodellingInstance(x["sent"]) 68 | 69 | sample = { 70 | "sent": torch.LongTensor(parsed.sent_as_ids()), 71 | "sent_len": torch.LongTensor([len(parsed._doc)]), 72 | "sent_text": x["sent"], 73 | } 74 | 75 | return sample 76 | 77 | @staticmethod 78 | def pad_and_order_sequences(batch): 79 | keys = batch[0].keys() 80 | max_lens = {k: max(len(x[k]) for x in batch) for k in keys} 81 | 82 | for x in batch: 83 | for k in keys: 84 | if k == "a_pos": 85 | x[k] = F.pad(x[k], (0, max_lens[k] - len(x[k])), value=0) 86 | elif k[-5:] != "_text": 87 | x[k] = F.pad(x[k], (0, max_lens[k] - len(x[k])), value=Tokenizer().pad_id) 88 | 89 | tensor_batch = {} 90 | for k in keys: 91 | if k[-5:] != "_text": 92 | tensor_batch[k] = torch.stack([x[k] for x in batch], 0).squeeze(1) 93 | else: 94 | tensor_batch[k] = [x[k] for x in batch] 95 | 96 | return tensor_batch 97 | 98 | 99 | class LangmodellingInstance: 100 | def __init__(self, sent_text, tok_window=256): 101 | self._doc = Tokenizer().tokenise(sent_text) 102 | 103 | if len(self._doc) > tok_window: 104 | self._doc = self._doc[:tok_window] 105 | 106 | def sent_as_ids(self): 107 | id_list = [tok["id"] for tok in self._doc] 108 | return id_list 109 | -------------------------------------------------------------------------------- /torchseq/datasets/lm_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from torchseq.datasets.lm_dataset import LangmodellingDataset 8 | from torchseq.utils.seed import init_worker 9 | from torchseq.utils.tokenizer import Tokenizer 10 | import torchseq.utils.tokenizer as tokenizer 11 | 12 | import logging 13 | 14 | 15 | class LangmodellingDataLoader: 16 | def __init__(self, config, data_path): 17 | """ 18 | :param config: 19 | """ 20 | self.config = config 21 | self.logger = logging.getLogger("DataLoader") 22 | 23 | raise Exception("LM DataLoader is deprecated! Use JsonDataset instead") 24 | 25 | tokenizer.DATA_PATH = data_path 26 | Tokenizer(config.prepro.tokenizer) 27 | 28 | train = LangmodellingDataset( 29 | os.path.join(data_path, self.config.training.dataset), 30 | config=config, 31 | dev=False, 32 | test=False, 33 | repeat=(self.config.training.data.get("epoch_steps", 0) > 0), 34 | ) 35 | valid = LangmodellingDataset( 36 | os.path.join(data_path, self.config.training.dataset), config=config, dev=True, test=False 37 | ) 38 | test = LangmodellingDataset( 39 | os.path.join(data_path, self.config.training.dataset), config=config, dev=False, test=True 40 | ) 41 | 42 | self.len_train_data = len(train) 43 | self.len_valid_data = len(valid) 44 | # self.len_test_data = len(test) 45 | 46 | # TODO: check whether running in silent mode 47 | self.logger.info( 48 | "Loaded {:} training and {:} validation examples from {:}".format( 49 | self.len_train_data, 50 | self.len_valid_data, 51 | os.path.join(data_path, self.config.training.dataset), 52 | ) 53 | ) 54 | 55 | # self.train_iterations = (self.len_train_data + self.config.training.batch_size - 1) // self.config.training.batch_size 56 | # self.valid_iterations = (self.len_valid_data + self.config.training.batch_size - 1) // self.config.training.batch_size 57 | # self.test_iterations = (self.len_test_data + self.config.training.batch_size - 1) // self.config.training.batch_size 58 | 59 | self.train_loader = DataLoader( 60 | train, 61 | batch_size=config.training.batch_size, 62 | shuffle=self.config.training.data.get("shuffle_data", True), 63 | num_workers=0, 64 | collate_fn=LangmodellingDataset.pad_and_order_sequences, 65 | worker_init_fn=init_worker, 66 | ) 67 | 68 | self.valid_loader = DataLoader( 69 | valid, 70 | batch_size=config.eval.eval_batch_size, 71 | shuffle=False, 72 | num_workers=0, 73 | collate_fn=LangmodellingDataset.pad_and_order_sequences, 74 | worker_init_fn=init_worker, 75 | ) 76 | if test.exists: 77 | self.test_loader = DataLoader( 78 | test, 79 | batch_size=config.eval.eval_batch_size, 80 | shuffle=False, 81 | num_workers=0, 82 | collate_fn=LangmodellingDataset.pad_and_order_sequences, 83 | worker_init_fn=init_worker, 84 | ) 85 | -------------------------------------------------------------------------------- /torchseq/datasets/loaders.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | # from nltk.tokenize import TreebankWordTokenizer, sent_tokenize 8 | # import spacy 9 | 10 | 11 | def load_qa_dataset(path, dev=False, test=False, v2=False, filename=None): 12 | expected_version = "v2.0" if v2 else "1.1" 13 | if filename is not None: 14 | pass 15 | elif v2: 16 | filename = "train-v2.0.json" if not dev else "dev-v2.0.json" 17 | elif test and not dev: 18 | filename = "test-v1.1.json" 19 | else: 20 | filename = "train-v1.1.json" if not dev else "dev-v1.1.json" 21 | with open(path + filename) as dataset_file: 22 | dataset_json = json.load(dataset_file) 23 | if "version" in dataset_json and dataset_json["version"] != expected_version and filename is None: 24 | print("Expected SQuAD v-" + expected_version + ", but got dataset with v-" + str(dataset_json["version"])) 25 | dataset = dataset_json["data"] 26 | return dataset 27 | 28 | 29 | def load_squad_triples(path, dev=False, test=False, v2=False, as_dict=False, ans_list=False, filename=None): 30 | raw_data = load_qa_dataset(path, dev=dev, test=test, v2=v2, filename=filename) 31 | triples = [] if not as_dict else {} 32 | for doc in raw_data: 33 | for para in doc["paragraphs"]: 34 | for qa in para["qas"]: 35 | id = qa["id"] 36 | # NOTE: this only takes the first answer per question! ToDo handle this more intelligently 37 | if ans_list: 38 | ans_text = [a["text"] for a in qa["answers"]] 39 | ans_pos = [int(a["answer_start"]) for a in qa["answers"]] 40 | else: 41 | ans_count = defaultdict(int) 42 | for ans in qa["answers"]: 43 | ans_count[(ans["text"], int(ans["answer_start"]))] += 1 44 | 45 | ans_text, ans_pos = sorted(ans_count.items(), reverse=True, key=lambda x: x[1])[0][0] 46 | if v2: 47 | if qa["is_impossible"]: 48 | el = ( 49 | para["context"], 50 | qa["question"], 51 | qa["plausible_answers"][0]["text"] if not dev else "", 52 | int(qa["plausible_answers"][0]["answer_start"]) if not dev else None, 53 | True, 54 | ) 55 | else: 56 | el = ( 57 | para["context"], 58 | qa["question"], 59 | qa["answers"][0]["text"], 60 | int(qa["answers"][0]["answer_start"]), 61 | False, 62 | ) 63 | else: 64 | el = (para["context"], qa["question"], ans_text, ans_pos) 65 | if as_dict: 66 | triples[id] = el 67 | else: 68 | triples.append(el) 69 | return triples 70 | -------------------------------------------------------------------------------- /torchseq/datasets/paraphrase_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | 7 | from torchseq.datasets.paraphrase_dataset import ParaphraseDataset 8 | from torchseq.utils.seed import init_worker 9 | from torchseq.utils.tokenizer import Tokenizer 10 | import torchseq.utils.tokenizer as tokenizer 11 | import logging 12 | 13 | 14 | class ParaphraseDataLoader: 15 | def __init__(self, config, data_path): 16 | """ 17 | :param config: 18 | """ 19 | self.config = config 20 | self.logger = logging.getLogger("DataLoader") 21 | 22 | raise Exception("Paraphrase DataLoader is deprecated! Use JsonDataset instead") 23 | 24 | tokenizer.DATA_PATH = data_path 25 | Tokenizer(config.prepro.tokenizer) 26 | 27 | train = ParaphraseDataset( 28 | os.path.join(data_path, self.config.training.dataset), 29 | config=config, 30 | dev=False, 31 | test=False, 32 | repeat=(self.config.training.data.get("epoch_steps", 0) > 0), 33 | length_limit=self.config.training.get("truncate_dataset", None), 34 | ) 35 | valid = ParaphraseDataset( 36 | os.path.join(data_path, self.config.training.dataset), 37 | config=config, 38 | dev=True, 39 | test=False, 40 | length_limit=self.config.eval.get("truncate_dataset", None), 41 | ) 42 | test = ParaphraseDataset( 43 | os.path.join(data_path, self.config.training.dataset), 44 | config=config, 45 | dev=False, 46 | test=True, 47 | length_limit=self.config.eval.get("truncate_dataset", None), 48 | ) 49 | 50 | self.len_train_data = len(train) 51 | self.len_valid_data = len(valid) 52 | # self.len_test_data = len(test) 53 | 54 | # TODO: check whether running in silent mode 55 | self.logger.info( 56 | "Loaded {:} training and {:} validation examples from {:}".format( 57 | self.len_train_data, 58 | self.len_valid_data, 59 | os.path.join(data_path, self.config.training.dataset), 60 | ) 61 | ) 62 | 63 | # self.train_iterations = (self.len_train_data + self.config.training.batch_size - 1) // self.config.training.batch_size 64 | # self.valid_iterations = (self.len_valid_data + self.config.training.batch_size - 1) // self.config.training.batch_size 65 | # self.test_iterations = (self.len_test_data + self.config.training.batch_size - 1) // self.config.training.batch_size 66 | 67 | self.train_loader = DataLoader( 68 | train, 69 | batch_size=config.training.batch_size, 70 | # shuffle=self.config.training.data.get("shuffle_data", True), 71 | num_workers=0, 72 | collate_fn=ParaphraseDataset.pad_and_order_sequences, 73 | worker_init_fn=init_worker, 74 | ) 75 | 76 | self.valid_loader = DataLoader( 77 | valid, 78 | batch_size=config.eval.eval_batch_size, 79 | # shuffle=False, 80 | num_workers=0, 81 | collate_fn=ParaphraseDataset.pad_and_order_sequences, 82 | worker_init_fn=init_worker, 83 | ) 84 | if test.exists: 85 | self.test_loader = DataLoader( 86 | test, 87 | batch_size=config.eval.eval_batch_size, 88 | # shuffle=False, 89 | num_workers=0, 90 | collate_fn=ParaphraseDataset.pad_and_order_sequences, 91 | worker_init_fn=init_worker, 92 | ) 93 | -------------------------------------------------------------------------------- /torchseq/datasets/paraphrase_pair.py: -------------------------------------------------------------------------------- 1 | from torchseq.utils.tokenizer import Tokenizer 2 | 3 | 4 | class ParaphrasePair: 5 | def __init__(self, sent1_text, sent2_text, template=None, is_paraphrase=True, tok_window=64): 6 | if "artist appear below the euro symbol" in sent2_text: 7 | print("Found the dodgy pair", sent1_text, sent2_text) 8 | 9 | self._s1_doc = Tokenizer().tokenise(sent1_text) 10 | self._s2_doc = Tokenizer().tokenise(sent2_text) 11 | self._template_doc = Tokenizer().tokenise(sent2_text) if template is not None else None 12 | self.is_paraphrase = is_paraphrase 13 | 14 | if "artist appear below the euro symbol" in sent2_text: 15 | print("Dodgy pair cleared tokenising") 16 | 17 | if len(self._s1_doc) > tok_window: 18 | self._s1_doc = self._s1_doc[:tok_window] 19 | if len(self._s2_doc) > tok_window: 20 | self._s2_doc = self._s2_doc[:tok_window] 21 | 22 | def s1_as_ids(self): 23 | id_list = [tok["id"] for tok in self._s1_doc] 24 | return id_list 25 | 26 | def s2_as_ids(self): 27 | id_list = [tok["id"] for tok in self._s2_doc] 28 | return id_list 29 | 30 | def template_as_ids(self): 31 | id_list = [tok["id"] for tok in self._template_doc] 32 | return id_list 33 | -------------------------------------------------------------------------------- /torchseq/demo/qg_app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | # sys.path.insert(0, "./src/") 5 | 6 | from flask import Flask, Response, current_app, redirect, request 7 | 8 | from torchseq.agents.aq_agent import AQAgent 9 | from torchseq.datasets.qa_loader import QADataLoader 10 | from torchseq.utils.config import Config 11 | from torchseq.utils.tokenizer import Tokenizer 12 | 13 | 14 | app = Flask(__name__) 15 | 16 | 17 | @app.route("/") 18 | def index(): 19 | return redirect("/static/qg_demo.htm") 20 | 21 | 22 | @app.route("/api/generate") 23 | def generate(): 24 | context = request.args["context"] 25 | answer = request.args["answer"] 26 | a_pos = context.find(answer) 27 | 28 | query = {"c": context, "a": answer, "a_pos": a_pos, "q": ""} 29 | 30 | # res, scores, _ = app.agent.infer(query, reduce_outputs=False) 31 | 32 | data_loader = QADataLoader(app.agent.config, test_samples=[query]) 33 | loss, metrics, (pred_output, gold_output, gold_input), memory = app.agent.inference(data_loader.test_loader) 34 | 35 | # scores = scores.tolist() 36 | 37 | output = pred_output 38 | 39 | return Response(json.dumps(output, indent=2), mimetype="application/json") 40 | 41 | 42 | @app.route("/api/ping") 43 | def ping(): 44 | return "ack" 45 | 46 | 47 | def init(): 48 | # MODEL_SLUG = "20200220_161434_bert_embeds_para_pbkagsq_ft_squad" 49 | 50 | # MODEL_PATH = f'./runs/augmented/{MODEL_SLUG}/' 51 | MODEL_PATH = "./models/examples/20210222_145021_qg_bert/" 52 | 53 | # Get the config 54 | with open(MODEL_PATH + "config.json") as f: 55 | cfg_dict = json.load(f) 56 | 57 | # Override a few bits 58 | cfg_dict["eval"]["topk"] = 1 59 | # cfg_dict["reranker"] = { 60 | # # 'strategy': 'qa' 61 | # "strategy": None 62 | # } 63 | 64 | config = Config(cfg_dict) 65 | 66 | checkpoint_path = MODEL_PATH + "model/checkpoint.pt" 67 | 68 | # Tokenizer(config.prepro.tokenizer) 69 | 70 | app.agent = AQAgent(config=config, run_id=None, output_path="./runs/parademo/", training_mode=True) 71 | 72 | app.agent.load_checkpoint(checkpoint_path) 73 | app.agent.model.eval() 74 | 75 | 76 | def main(): 77 | init() 78 | with app.app_context(): 79 | app.run(host="0.0.0.0", port=5004, processes=1) 80 | 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /torchseq/demo/static/para_demo.htm: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 | 7 | 8 | 9 |An interactive demo of the Separator paraphrasing model. Enter a question and an exemplar, then click generate to rephrase the question into that form.
36 | 37 |