├── .circleci └── unittest │ ├── linux │ └── scripts │ │ ├── environment.yml │ │ ├── install.sh │ │ ├── post_process.sh │ │ ├── run_test.sh │ │ └── setup_env.sh │ └── windows │ └── scripts │ ├── environment.yml │ ├── install.sh │ ├── install_conda.bat │ ├── post_process.sh │ ├── run_test.sh │ └── setup_env.sh ├── .clang-format ├── .flake8 ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.md │ ├── documentation.md │ ├── feature-request.md │ └── questions-help-support.md ├── pytorch-probot.yml ├── scripts │ └── validate_binaries.sh └── workflows │ ├── bandit.yml │ ├── build-docs.yml │ ├── build-wheels-linux.yml │ ├── build-wheels-m1.yml │ ├── build-wheels-windows.yml │ ├── codeql.yml │ ├── integration-test.yml │ ├── lint.yml │ ├── test-linux-cpu.yml │ ├── test-linux-gpu.yml │ ├── test-macos-cpu.yml │ ├── test-windows-cpu.yml │ ├── validate-binaries.yml │ └── validate-nightly-binaries.yml ├── .gitignore ├── .gitmodules ├── .pre-commit-config.yaml ├── .prettierignore ├── .prettierrc.yaml ├── .python3 ├── CMakeLists.txt ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── CONTRIBUTING_DATASETS.md ├── LICENSE ├── README.rst ├── benchmark ├── benchmark_basic_english_normalize.py ├── benchmark_bert_tokenizer.py ├── benchmark_experimental_vectors.py ├── benchmark_pytext_vocab.py ├── benchmark_roberta_model.py ├── benchmark_roberta_pipeline.py ├── benchmark_sentencepiece.py ├── benchmark_torcharrow_ops.py ├── benchmark_vocab.py ├── data_construction.py ├── mha_block.py └── utils.py ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── _static │ ├── css │ │ └── pytorch_theme.css │ └── img │ │ ├── pytorch-logo-dark.png │ │ ├── pytorch-logo-dark.svg │ │ ├── pytorch-logo-flame.png │ │ ├── pytorch-logo-flame.svg │ │ └── torchtext_logo.png │ ├── _templates │ └── layout.html │ ├── conf.py │ ├── data_functional.rst │ ├── data_metrics.rst │ ├── data_utils.rst │ ├── datasets.rst │ ├── experimental_models_utils.rst │ ├── experimental_transforms.rst │ ├── experimental_vectors.rst │ ├── experimental_vocab.rst │ ├── functional.rst │ ├── index.rst │ ├── logo.rst │ ├── models.rst │ ├── nn_modules.rst │ ├── transforms.rst │ ├── utils.rst │ └── vocab.rst ├── examples ├── data_pipeline │ ├── roberta_dataframe.py │ └── roberta_datapipe.py ├── libtorchtext │ ├── .gitignore │ ├── CMakeLists.txt │ ├── README.md │ ├── build.sh │ └── tokenizer │ │ ├── CMakeLists.txt │ │ ├── README.md │ │ ├── create_tokenizer.py │ │ └── main.cpp ├── text_classification │ ├── README.md │ ├── model.py │ ├── predict.py │ ├── run_script.sh │ └── train.py ├── torcharrow │ ├── README.md │ └── roberta_sst2_training_with_torcharrow.py ├── tutorials │ ├── README.rst │ ├── sst2_classification_non_distributed.py │ └── t5_demo.py └── vocab │ ├── fairseq_vocab.py │ └── test.csv ├── notebooks ├── hf_vs_tt_t5.ipynb ├── hf_with_torchtext_gen.ipynb └── torchscriptable_t5_with_torchtext.ipynb ├── packaging ├── build_conda.sh ├── build_wheel.sh ├── cut_release.sh ├── pkg_helpers.bash ├── torchtext │ ├── bld.bat │ ├── build.sh │ └── meta.yaml ├── vc_env_helper.bat └── vs2019 │ ├── activate.bat │ ├── conda_build_config.yaml │ ├── install_activate.bat │ ├── install_runtime.bat │ └── meta.yaml ├── pyproject.toml ├── pytest.ini ├── readthedocs.yml ├── requirements.txt ├── run-clang-format.py ├── setup.py ├── test ├── integration_tests │ ├── __init__.py │ ├── conftest.py │ ├── test_generate.py │ ├── test_roberta_models.py │ └── test_t5_models.py ├── smoke_tests │ └── smoke_tests.py └── torchtext_unittest │ ├── __init__.py │ ├── asset │ ├── SST2 │ │ └── SST-2.zip │ ├── bert_base_cased_vocab.txt │ ├── bert_base_uncased_vocab.txt │ ├── clip_encoder.json │ ├── clip_vocab.bpe │ ├── glove.6B.zip │ ├── glove.840B.300d.zip │ ├── gpt2_bpe_encoder.json │ ├── gpt2_bpe_vocab.bpe │ ├── label_names.txt │ ├── openai-gpt-merges.txt │ ├── openai-gpt-vocab.json │ ├── raw_datasets.jsonl │ ├── roberta.base.output.pt │ ├── roberta.distilled.output.pt │ ├── roberta.large.output.pt │ ├── spm_example.model │ ├── t5.base.encoder.output.pt │ ├── t5.base.generation.output.pt │ ├── t5.base.model.output.pt │ ├── t5.flan.base.encoder.output.pt │ ├── t5.flan.base.generation.output.pt │ ├── t5.flan.base.model.output.pt │ ├── t5.large.encoder.output.pt │ ├── t5.large.generation.output.pt │ ├── t5.large.model.output.pt │ ├── t5.small.encoder.output.pt │ ├── t5.small.generation.output.pt │ ├── t5.small.model.output.pt │ ├── t5_tokenizer_base.model │ ├── text_normalization_ag_news_ref_results.test │ ├── text_normalization_ag_news_test.csv │ ├── vectors_test.csv │ ├── vocab_raw_text_test.txt │ ├── vocab_test.txt │ ├── vocab_test2.txt │ ├── wiki.en.vec │ ├── xlmr.base.output.pt │ └── xlmr.large.output.pt │ ├── common │ ├── __init__.py │ ├── assets.py │ ├── case_utils.py │ ├── parameterized_utils.py │ └── torchtext_test_case.py │ ├── csrc │ ├── __init__.py │ └── test_gpt2_bpe_tokenizer.py │ ├── data │ ├── __init__.py │ ├── test_dataset_utils.py │ ├── test_functional.py │ ├── test_jit.py │ ├── test_metrics.py │ ├── test_modules.py │ └── test_utils.py │ ├── datasets │ ├── __init__.py │ ├── common.py │ ├── test_agnews.py │ ├── test_amazonreviews.py │ ├── test_cc100.py │ ├── test_cnndm.py │ ├── test_cola.py │ ├── test_conll2000chunking.py │ ├── test_dbpedia.py │ ├── test_enwik9.py │ ├── test_imdb.py │ ├── test_iwslt2016.py │ ├── test_iwslt2017.py │ ├── test_mnli.py │ ├── test_mrpc.py │ ├── test_multi30k.py │ ├── test_penntreebank.py │ ├── test_qnli.py │ ├── test_qqp.py │ ├── test_rte.py │ ├── test_sogounews.py │ ├── test_squads.py │ ├── test_sst2.py │ ├── test_stsb.py │ ├── test_udpos.py │ ├── test_wikitexts.py │ ├── test_wnli.py │ ├── test_yahooanswers.py │ └── test_yelpreviews.py │ ├── models │ ├── __init__.py │ ├── gpu_tests │ │ └── models_gpu_test.py │ ├── models_cpu_test.py │ ├── roberta_models_test_impl.py │ ├── t5_models_test_impl.py │ ├── t5_test_transforms.py │ └── test_transformers.py │ ├── prototype │ ├── __init__.py │ ├── test_functional.py │ ├── test_transforms.py │ ├── test_vectors.py │ └── test_with_asset.py │ ├── test_build.py │ ├── test_functional.py │ ├── test_transforms.py │ ├── test_utils.py │ └── test_vocab.py ├── third_party └── CMakeLists.txt ├── tools ├── __init__.py ├── conda │ └── torchtext │ │ └── meta.yaml └── setup_helpers │ ├── __init__.py │ └── extension.py ├── torchtext ├── __init__.py ├── _download_hooks.py ├── _extension.py ├── _internal │ ├── __init__.py │ └── module_utils.py ├── csrc │ ├── CMakeLists.txt │ ├── bert_tokenizer.cpp │ ├── bert_tokenizer.h │ ├── clip_tokenizer.cpp │ ├── clip_tokenizer.h │ ├── common.cpp │ ├── common.h │ ├── export.h │ ├── gpt2_bpe_tokenizer.cpp │ ├── gpt2_bpe_tokenizer.h │ ├── regex.cpp │ ├── regex.h │ ├── regex_tokenizer.cpp │ ├── regex_tokenizer.h │ ├── register_pybindings.cpp │ ├── register_torchbindings.cpp │ ├── sentencepiece.cpp │ ├── sentencepiece.h │ ├── vectors.cpp │ ├── vectors.h │ ├── vocab.cpp │ ├── vocab.h │ ├── vocab_factory.cpp │ └── vocab_factory.h ├── data │ ├── __init__.py │ ├── datasets_utils.py │ ├── functional.py │ ├── metrics.py │ └── utils.py ├── datasets │ ├── __init__.py │ ├── ag_news.py │ ├── amazonreviewfull.py │ ├── amazonreviewpolarity.py │ ├── cc100.py │ ├── cnndm.py │ ├── cola.py │ ├── conll2000chunking.py │ ├── dbpedia.py │ ├── enwik9.py │ ├── imdb.py │ ├── iwslt2016.py │ ├── iwslt2017.py │ ├── mnli.py │ ├── mrpc.py │ ├── multi30k.py │ ├── penntreebank.py │ ├── qnli.py │ ├── qqp.py │ ├── rte.py │ ├── sogounews.py │ ├── squad1.py │ ├── squad2.py │ ├── sst2.py │ ├── stsb.py │ ├── udpos.py │ ├── wikitext103.py │ ├── wikitext2.py │ ├── wnli.py │ ├── yahooanswers.py │ ├── yelpreviewfull.py │ └── yelpreviewpolarity.py ├── experimental │ ├── __init__.py │ ├── transforms.py │ ├── vectors.py │ └── vocab_factory.py ├── functional.py ├── lib │ └── .gitignore ├── models │ ├── __init__.py │ ├── roberta │ │ ├── __init__.py │ │ ├── bundler.py │ │ ├── model.py │ │ └── modules.py │ └── t5 │ │ ├── __init__.py │ │ ├── bundler.py │ │ ├── model.py │ │ ├── modules.py │ │ └── t5_transform.py ├── nn │ ├── __init__.py │ └── modules │ │ ├── __init__.py │ │ └── multiheadattention.py ├── prototype │ ├── __init__.py │ ├── asset │ │ ├── get_checksum.sh │ │ └── get_checksums_fast_text.py │ ├── generate.py │ ├── transforms.py │ ├── vectors.py │ └── vocab_factory.py ├── transforms.py ├── utils.py └── vocab │ ├── __init__.py │ ├── vectors.py │ ├── vocab.py │ └── vocab_factory.py └── version.txt /.circleci/unittest/linux/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | dependencies: 4 | - codecov 5 | - pip 6 | - pip: 7 | - dataclasses 8 | - nltk 9 | - requests 10 | - revtok 11 | - pytest 12 | - pytest-cov 13 | - pytest-pythonpath 14 | - sacremoses 15 | - spacy 16 | - tqdm 17 | - expecttest 18 | - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 19 | - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 20 | -------------------------------------------------------------------------------- /.circleci/unittest/linux/scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | unset PYTORCH_VERSION 4 | # For unittest, nightly PyTorch is used as the following section, 5 | # so no need to set PYTORCH_VERSION. 6 | # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. 7 | 8 | set -e 9 | 10 | case "$(uname -s)" in 11 | Darwin*) os=MacOSX;; 12 | *) os=Linux 13 | esac 14 | 15 | eval "$(./conda/bin/conda shell.bash hook)" 16 | conda activate ./env 17 | 18 | printf "* Installing PyTorch\n" 19 | ( 20 | if [ "${os}" == MacOSX ] ; then 21 | # TODO: this can be removed as soon as linking issue could be resolved 22 | # see https://github.com/pytorch/pytorch/issues/62424 from details 23 | MKL_CONSTRAINT='mkl==2021.2.0' 24 | else 25 | MKL_CONSTRAINT='' 26 | fi 27 | set -x 28 | conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} ${MKL_CONSTRAINT} pytorch cpuonly 29 | ) 30 | 31 | 32 | printf "* Installing torchtext\n" 33 | python setup.py develop 34 | 35 | printf "* Installing parameterized\n" 36 | pip install parameterized 37 | -------------------------------------------------------------------------------- /.circleci/unittest/linux/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | codecov 9 | -------------------------------------------------------------------------------- /.circleci/unittest/linux/scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/bin/conda shell.bash hook)" 6 | conda activate ./env 7 | 8 | python -m torch.utils.collect_env 9 | cd test 10 | pytest --cov=torchtext --junitxml=test-results/junit.xml -v --durations 20 torchtext_unittest 11 | -------------------------------------------------------------------------------- /.circleci/unittest/linux/scripts/setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is for setting up environment in which unit test is ran. 4 | # To speed up the CI time, the resulting environment is cached. 5 | # 6 | # Do not install PyTorch and torchtext here, otherwise they also get cached. 7 | 8 | set -e 9 | 10 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 11 | root_dir="$(git rev-parse --show-toplevel)" 12 | conda_dir="${root_dir}/conda" 13 | env_dir="${root_dir}/env" 14 | 15 | cd "${root_dir}" 16 | 17 | case "$(uname -s)" in 18 | Darwin*) os=MacOSX;; 19 | *) os=Linux 20 | esac 21 | 22 | # 1. Install conda at ./conda 23 | if [ ! -d "${conda_dir}" ]; then 24 | printf "* Installing conda\n" 25 | curl --silent -L -o miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" 26 | bash ./miniconda.sh -b -f -p "${conda_dir}" 27 | fi 28 | eval "$(${conda_dir}/bin/conda shell.bash hook)" 29 | 30 | # 2. Create test environment at ./env 31 | if [ ! -d "${env_dir}" ]; then 32 | printf "* Creating a test environment\n" 33 | conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" 34 | fi 35 | conda activate "${env_dir}" 36 | 37 | 38 | # 3. Install minimal build tools 39 | pip --quiet install cmake>=3.18.0 ninja 40 | 41 | # 4. Install Conda dependencies 42 | printf "* Installing dependencies (except PyTorch)\n" 43 | conda env update --file "${this_dir}/environment.yml" --prune 44 | 45 | # 5. Download 46 | printf "* Downloading SpaCy English models\n" 47 | python -m spacy download en_core_web_sm 48 | printf "* Downloading SpaCy German models\n" 49 | python -m spacy download de_core_news_sm 50 | -------------------------------------------------------------------------------- /.circleci/unittest/windows/scripts/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | dependencies: 4 | - codecov 5 | - pip 6 | - setuptools == 58.0.4 7 | - spacy 8 | - pip: 9 | - dataclasses 10 | - nltk 11 | - requests 12 | - revtok 13 | - pytest 14 | - pytest-cov 15 | - pytest-pythonpath 16 | - sacremoses 17 | - tqdm 18 | - certifi 19 | - expecttest 20 | - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 21 | - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 22 | -------------------------------------------------------------------------------- /.circleci/unittest/windows/scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | unset PYTORCH_VERSION 4 | # For unittest, nightly PyTorch is used as the following section, 5 | # so no need to set PYTORCH_VERSION. 6 | # In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. 7 | 8 | set -e 9 | 10 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 11 | root_dir="$(git rev-parse --show-toplevel)" 12 | 13 | cd "${root_dir}" 14 | 15 | eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" 16 | conda activate ./env 17 | 18 | printf "* Installing PyTorch\n" 19 | conda install -y -c "pytorch-${UPLOAD_CHANNEL}" ${CONDA_CHANNEL_FLAGS} pytorch cpuonly 20 | 21 | printf "* Installing pywin32_postinstall script\n" 22 | curl --output pywin32_postinstall.py https://raw.githubusercontent.com/mhammond/pywin32/main/pywin32_postinstall.py 23 | python pywin32_postinstall.py -install 24 | 25 | printf "* Installing torchtext\n" 26 | "$root_dir/packaging/vc_env_helper.bat" python setup.py develop 27 | 28 | printf "* Installing parameterized\n" 29 | pip install parameterized 30 | -------------------------------------------------------------------------------- /.circleci/unittest/windows/scripts/install_conda.bat: -------------------------------------------------------------------------------- 1 | start /wait "" "%miniconda_exe%" /S /InstallationType=JustMe /RegisterPython=0 /AddToPath=0 /D=%tmp_conda% 2 | -------------------------------------------------------------------------------- /.circleci/unittest/windows/scripts/post_process.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" 6 | conda activate ./env 7 | 8 | codecov 9 | -------------------------------------------------------------------------------- /.circleci/unittest/windows/scripts/run_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e 4 | 5 | eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')" 6 | conda activate ./env 7 | 8 | python -m torch.utils.collect_env 9 | cd test 10 | pytest --cov=torchtext --junitxml=test-results/junit.xml -v --durations 20 torchtext_unittest 11 | -------------------------------------------------------------------------------- /.circleci/unittest/windows/scripts/setup_env.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script is for setting up environment in which unit test is ran. 4 | # To speed up the CI time, the resulting environment is cached. 5 | # 6 | # Do not install PyTorch and torchtext here, otherwise they also get cached. 7 | 8 | set -e 9 | 10 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 11 | root_dir="$(git rev-parse --show-toplevel)" 12 | conda_dir="${root_dir}/conda" 13 | env_dir="${root_dir}/env" 14 | 15 | cd "${root_dir}" 16 | 17 | # 1. Install conda at ./conda 18 | if [ ! -d "${conda_dir}" ]; then 19 | printf "* Installing conda\n" 20 | export tmp_conda="$(echo $conda_dir | tr '/' '\\')" 21 | export miniconda_exe="$(echo $root_dir | tr '/' '\\')\\miniconda.exe" 22 | curl --output miniconda.exe https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe -O 23 | "$this_dir/install_conda.bat" 24 | unset tmp_conda 25 | unset miniconda_exe 26 | fi 27 | eval "$(${conda_dir}/Scripts/conda.exe 'shell.bash' 'hook')" 28 | 29 | # 2. Create test environment at ./env 30 | if [ ! -d "${env_dir}" ]; then 31 | printf "* Creating a test environment\n" 32 | conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" 33 | fi 34 | conda activate "${env_dir}" 35 | 36 | # 3. Install minimal build tools 37 | pip --quiet install cmake>=3.18.0 ninja 38 | 39 | # 4. Install Conda dependencies 40 | printf "* Installing dependencies (except PyTorch)\n" 41 | conda env update --file "${this_dir}/environment.yml" --prune 42 | 43 | # 5. Download 44 | printf "* Downloading SpaCy English models\n" 45 | python -m spacy download en_core_web_sm 46 | printf "* Downloading SpaCy German models\n" 47 | python -m spacy download de_core_news_sm 48 | -------------------------------------------------------------------------------- /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | AccessModifierOffset: -1 3 | AlignAfterOpenBracket: AlwaysBreak 4 | AlignConsecutiveAssignments: false 5 | AlignConsecutiveDeclarations: false 6 | AlignEscapedNewlinesLeft: true 7 | AlignOperands: false 8 | AlignTrailingComments: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: false 11 | AllowShortCaseLabelsOnASingleLine: false 12 | AllowShortFunctionsOnASingleLine: Empty 13 | AllowShortIfStatementsOnASingleLine: false 14 | AllowShortLoopsOnASingleLine: false 15 | AlwaysBreakAfterReturnType: None 16 | AlwaysBreakBeforeMultilineStrings: true 17 | AlwaysBreakTemplateDeclarations: true 18 | BinPackArguments: false 19 | BinPackParameters: false 20 | BraceWrapping: 21 | AfterClass: false 22 | AfterControlStatement: false 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterObjCDeclaration: false 27 | AfterStruct: false 28 | AfterUnion: false 29 | BeforeCatch: false 30 | BeforeElse: false 31 | IndentBraces: false 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeBraces: Attach 34 | BreakBeforeTernaryOperators: true 35 | BreakConstructorInitializersBeforeComma: false 36 | BreakAfterJavaFieldAnnotations: false 37 | BreakStringLiterals: false 38 | ColumnLimit: 80 39 | CommentPragmas: "^ IWYU pragma:" 40 | CompactNamespaces: false 41 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 42 | ConstructorInitializerIndentWidth: 4 43 | ContinuationIndentWidth: 4 44 | Cpp11BracedListStyle: true 45 | DerivePointerAlignment: false 46 | DisableFormat: false 47 | ForEachMacros: [FOR_EACH_RANGE, FOR_EACH] 48 | IncludeCategories: 49 | - Regex: '^<.*\.h(pp)?>' 50 | Priority: 1 51 | - Regex: "^<.*" 52 | Priority: 2 53 | - Regex: ".*" 54 | Priority: 3 55 | IndentCaseLabels: true 56 | IndentWidth: 2 57 | IndentWrappedFunctionNames: false 58 | KeepEmptyLinesAtTheStartOfBlocks: false 59 | MacroBlockBegin: "" 60 | MacroBlockEnd: "" 61 | MaxEmptyLinesToKeep: 1 62 | NamespaceIndentation: None 63 | ObjCBlockIndentWidth: 2 64 | ObjCSpaceAfterProperty: false 65 | ObjCSpaceBeforeProtocolList: false 66 | PenaltyBreakBeforeFirstCallParameter: 1 67 | PenaltyBreakComment: 300 68 | PenaltyBreakFirstLessLess: 120 69 | PenaltyBreakString: 1000 70 | PenaltyExcessCharacter: 1000000 71 | PenaltyReturnTypeOnItsOwnLine: 2000000 72 | PointerAlignment: Left 73 | ReflowComments: true 74 | SortIncludes: true 75 | SpaceAfterCStyleCast: false 76 | SpaceBeforeAssignmentOperators: true 77 | SpaceBeforeParens: ControlStatements 78 | SpaceInEmptyParentheses: false 79 | SpacesBeforeTrailingComments: 1 80 | SpacesInAngles: false 81 | SpacesInContainerLiterals: true 82 | SpacesInCStyleCastParentheses: false 83 | SpacesInParentheses: false 84 | SpacesInSquareBrackets: false 85 | Standard: Cpp11 86 | TabWidth: 8 87 | UseTab: Never 88 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | E401,E402,E501,E722,W503,W504,F821,B006,B007,B008,B009, 4 | # https://github.com/PyCQA/pycodestyle/issues/373 5 | E203 6 | select = 7 | B,C,E,F,P,T4,W,B9, 8 | # Missing argument descriptions in the docstring 9 | D417, 10 | # TorchFix 11 | TOR0,TOR1,TOR2 12 | max-line-length = 120 13 | exclude = docs/source,third_party 14 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # To exclude autogenerated files from code reviews 2 | .circleci/config.yml linguist-generated=true 3 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F41B Bug Report" 3 | about: Submit a bug report to help us improve TorchText 4 | --- 5 | 6 | ## 🐛 Bug 7 | 8 | **Describe the bug** A clear and concise description of what the bug is. 9 | 10 | **To Reproduce** Steps to reproduce the behavior: 11 | 12 | 1. Go to '...' 13 | 2. Click on '....' 14 | 3. Scroll down to '....' 15 | 4. See error 16 | 17 | **Expected behavior** A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** If applicable, add screenshots to help explain your problem. 20 | 21 | **Environment** 22 | 23 | Please copy and paste the output from our 24 | [environment collection script](https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py) (or 25 | fill out the checklist below manually). 26 | 27 | You can get the script and run it with: 28 | 29 | ``` 30 | wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py 31 | # For security purposes, please check the contents of collect_env.py before running it. 32 | python collect_env.py 33 | python -c "import torchtext; print(\"torchtext version is \", torchtext.__version__)" 34 | ``` 35 | 36 | - PyTorch Version (e.g., 1.0): 37 | - OS (e.g., Linux): 38 | - How you installed PyTorch (`conda`, `pip`, source): 39 | - Build command you used (if compiling from source): 40 | - Python version: 41 | - CUDA/cuDNN version: 42 | - GPU models and configuration: 43 | - Any other relevant information: 44 | 45 | **Additional context** Add any other context about the problem here. 46 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F4DA Documentation" 3 | about: Report an issue related to TorchText 4 | --- 5 | 6 | ## 📚 Documentation 7 | 8 | **Description** 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "\U0001F680Feature Request" 3 | about: Submit a proposal/request for a new TorchText feature 4 | --- 5 | 6 | ## 🚀 Feature 7 | 8 | 9 | 10 | **Motivation** 11 | 12 | 13 | 14 | **Pitch** 15 | 16 | 17 | 18 | **Alternatives** 19 | 20 | 21 | 22 | **Additional context** 23 | 24 | 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "❓Questions/Help/Support" 3 | about: Do you need support? We have resources. 4 | --- 5 | 6 | ## ❓ Questions and Help 7 | 8 | **Description** 9 | 10 | 11 | -------------------------------------------------------------------------------- /.github/pytorch-probot.yml: -------------------------------------------------------------------------------- 1 | tracking_issue: 876 2 | -------------------------------------------------------------------------------- /.github/scripts/validate_binaries.sh: -------------------------------------------------------------------------------- 1 | 2 | if [[ ${MATRIX_PACKAGE_TYPE} = "conda" ]]; then 3 | conda install -y torchtext -c ${PYTORCH_CONDA_CHANNEL} 4 | else 5 | pip install ${PYTORCH_PIP_PREFIX} torchtext --index-url ${PYTORCH_PIP_DOWNLOAD_URL} 6 | fi 7 | 8 | python ./test/smoke_tests/smoke_tests.py 9 | -------------------------------------------------------------------------------- /.github/workflows/bandit.yml: -------------------------------------------------------------------------------- 1 | # GitHub Actions Bandit Workflow 2 | 3 | name: Bandit 4 | 5 | on: 6 | pull_request: 7 | branches: [main] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | # Task will fail if any high-severity issues are found 19 | # Ignoring submodules 20 | - name: Run Bandit Security Analysis 21 | run: | 22 | python -m pip install bandit 23 | python -m bandit -r . -x ./third_party -lll 24 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-linux.yml: -------------------------------------------------------------------------------- 1 | name: Build Linux Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: linux 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | with-cuda: disable 29 | with-rocm: disable 30 | build: 31 | needs: generate-matrix 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | include: 36 | - repository: pytorch/text 37 | pre-script: "" 38 | post-script: "" 39 | smoke-test-script: test/smoke_tests/smoke_tests.py 40 | package-name: torchtext 41 | name: ${{ matrix.repository }} 42 | uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main 43 | with: 44 | repository: ${{ matrix.repository }} 45 | ref: "" 46 | test-infra-repository: pytorch/test-infra 47 | test-infra-ref: main 48 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 49 | pre-script: ${{ matrix.pre-script }} 50 | post-script: ${{ matrix.post-script }} 51 | package-name: ${{ matrix.package-name }} 52 | smoke-test-script: ${{ matrix.smoke-test-script }} 53 | trigger-event: ${{ github.event_name }} 54 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-m1.yml: -------------------------------------------------------------------------------- 1 | name: Build M1 Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: macos-arm64 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | build: 29 | needs: generate-matrix 30 | strategy: 31 | fail-fast: false 32 | matrix: 33 | include: 34 | - repository: pytorch/text 35 | pre-script: "" 36 | post-script: "" 37 | package-name: torchtext 38 | smoke-test-script: test/smoke_tests/smoke_tests.py 39 | name: ${{ matrix.repository }} 40 | uses: pytorch/test-infra/.github/workflows/build_wheels_macos.yml@main 41 | with: 42 | repository: ${{ matrix.repository }} 43 | ref: "" 44 | test-infra-repository: pytorch/test-infra 45 | test-infra-ref: main 46 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 47 | pre-script: ${{ matrix.pre-script }} 48 | post-script: ${{ matrix.post-script }} 49 | package-name: ${{ matrix.package-name }} 50 | smoke-test-script: ${{ matrix.smoke-test-script }} 51 | runner-type: macos-m1-stable 52 | trigger-event: ${{ github.event_name }} 53 | -------------------------------------------------------------------------------- /.github/workflows/build-wheels-windows.yml: -------------------------------------------------------------------------------- 1 | name: Build Windows Wheels 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | tags: 11 | # NOTE: Binary build pipelines should only get triggered on release candidate builds 12 | # Release candidate tags look like: v1.11.0-rc1 13 | - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ 14 | workflow_dispatch: 15 | 16 | permissions: 17 | id-token: write 18 | contents: read 19 | 20 | jobs: 21 | generate-matrix: 22 | uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main 23 | with: 24 | package-type: wheel 25 | os: windows 26 | test-infra-repository: pytorch/test-infra 27 | test-infra-ref: main 28 | with-cuda: disable 29 | build: 30 | needs: generate-matrix 31 | strategy: 32 | fail-fast: false 33 | matrix: 34 | include: 35 | - repository: pytorch/text 36 | pre-script: "" 37 | env-script: packaging/vc_env_helper.bat 38 | post-script: "" 39 | smoke-test-script: test/smoke_tests/smoke_tests.py 40 | package-name: torchtext 41 | name: ${{ matrix.repository }} 42 | uses: pytorch/test-infra/.github/workflows/build_wheels_windows.yml@main 43 | with: 44 | repository: ${{ matrix.repository }} 45 | ref: "" 46 | test-infra-repository: pytorch/test-infra 47 | test-infra-ref: main 48 | build-matrix: ${{ needs.generate-matrix.outputs.matrix }} 49 | pre-script: ${{ matrix.pre-script }} 50 | env-script: ${{ matrix.env-script }} 51 | post-script: ${{ matrix.post-script }} 52 | package-name: ${{ matrix.package-name }} 53 | smoke-test-script: ${{ matrix.smoke-test-script }} 54 | trigger-event: ${{ github.event_name }} 55 | -------------------------------------------------------------------------------- /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # GitHub Actions CodeQL Workflow 2 | 3 | name: CodeQL 4 | 5 | on: 6 | pull_request: 7 | branches: [main] 8 | 9 | workflow_dispatch: 10 | 11 | jobs: 12 | build: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | 18 | - name: Initialize CodeQL 19 | uses: github/codeql-action/init@v1 20 | with: 21 | languages: python, cpp 22 | 23 | - name: Install Ninja 24 | run: | 25 | sudo apt-get update -y 26 | sudo apt-get install -y ninja-build 27 | 28 | - name: Update submodules 29 | run: git submodule update --init --recursive 30 | 31 | - name: Install Torch 32 | run: | 33 | python -m pip install cmake 34 | sudo ln -s /usr/bin/ninja /usr/bin/ninja-build 35 | 36 | - name: Build TorchText 37 | run: | 38 | python -m pip install setuptools==65.7.0 39 | python setup.py develop --user 40 | 41 | # If any code scanning alerts are found, they will be under Security -> CodeQL 42 | # Link: https://github.com/pytorch/text/security/code-scanning 43 | - name: Perform CodeQL Analysis 44 | uses: github/codeql-action/analyze@v1 45 | -------------------------------------------------------------------------------- /.github/workflows/integration-test.yml: -------------------------------------------------------------------------------- 1 | name: Integration Test 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | 7 | workflow_dispatch: 8 | 9 | jobs: 10 | tests: 11 | strategy: 12 | matrix: 13 | python_version: ["3.8"] 14 | fail-fast: false 15 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 16 | with: 17 | runner: linux.12xlarge 18 | repository: pytorch/text 19 | script: | 20 | # Mark Build Directory Safe 21 | git config --global --add safe.directory /__w/text/text 22 | # Set up Environment Variables 23 | export PYTHON_VERSION="${{ matrix.python_version }}" 24 | export VERSION="cpu" 25 | export CUDATOOLKIT="cpuonly" 26 | # Set CHANNEL 27 | if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then 28 | export CHANNEL=test 29 | else 30 | export CHANNEL=nightly 31 | fi 32 | # Create Conda Env 33 | conda create -yp ci_env python="${PYTHON_VERSION}" 34 | conda activate /work/ci_env 35 | python3 -m pip --quiet install cmake>=3.18.0 ninja 36 | conda env update --file ".circleci/unittest/linux/scripts/environment.yml" --prune 37 | # TorchText-specific Setup 38 | printf "* Downloading SpaCy English models\n" 39 | python -m spacy download en_core_web_sm 40 | printf "* Downloading SpaCy German models\n" 41 | python -m spacy download de_core_news_sm 42 | # Install PyTorch, Torchvision 43 | set -ex 44 | conda install \ 45 | --yes \ 46 | -c "pytorch-${CHANNEL}" \ 47 | -c nvidia "pytorch-${CHANNEL}"::pytorch[build="*${VERSION}*"] \ 48 | "${CUDATOOLKIT}" 49 | python3 setup.py develop 50 | # Install integration test dependencies 51 | python3 -m pip --quiet install parameterized 52 | python3 -m pip --quiet install requests 53 | python3 -m pip --quiet install sentencepiece 54 | python3 -m pip --quiet install tqdm 55 | python3 -m pip --quiet install expecttest 56 | # Run Tests 57 | python3 -m torch.utils.collect_env 58 | cd test 59 | pytest integration_tests -v --use-tmp-hub-dir 60 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | jobs: 13 | python-source-and-configs: 14 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 15 | with: 16 | repository: pytorch/text 17 | script: | 18 | set -euo pipefail 19 | 20 | echo '::group::Setup environment' 21 | CONDA_PATH=$(which conda) 22 | eval "$(${CONDA_PATH} shell.bash hook)" 23 | conda create --name ci --quiet --yes python=3.8 pip 24 | conda activate ci 25 | echo '::endgroup::' 26 | 27 | echo '::group::Install lint tools' 28 | pip install --progress-bar=off pre-commit 29 | echo '::endgroup::' 30 | 31 | echo '::group::Lint Python source and configs' 32 | set +e 33 | echo $LD_LIBRARY_PATH 34 | export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" 35 | pre-commit run --all-files 36 | 37 | if [ $? -ne 0 ]; then 38 | git --no-pager diff 39 | exit 1 40 | fi 41 | echo '::endgroup::' 42 | 43 | c-source: 44 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 45 | with: 46 | repository: pytorch/text 47 | script: | 48 | set -euo pipefail 49 | 50 | echo '::group::Setup environment' 51 | CONDA_PATH=$(which conda) 52 | eval "$(${CONDA_PATH} shell.bash hook)" 53 | conda create --name ci --quiet --yes -c conda-forge python=3.8 ncurses=5 libgcc 54 | conda activate ci 55 | export LD_LIBRARY_PATH="${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}" 56 | echo '::endgroup::' 57 | 58 | echo '::group::Install lint tools' 59 | curl https://oss-clang-format.s3.us-east-2.amazonaws.com/linux64/clang-format-linux64 -o ./clang-format 60 | chmod +x ./clang-format 61 | echo '::endgroup::' 62 | 63 | echo '::group::Lint C source' 64 | set +e 65 | python run-clang-format.py \ 66 | --recursive \ 67 | --clang-format-executable=./clang-format \ 68 | torchtext/csrc 69 | 70 | if [ $? -ne 0 ]; then 71 | git --no-pager diff 72 | exit 1 73 | fi 74 | echo '::endgroup::' 75 | -------------------------------------------------------------------------------- /.github/workflows/test-linux-cpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests on Linux CPU 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | jobs: 16 | tests: 17 | strategy: 18 | matrix: 19 | python_version: ["3.8", "3.9", "3.10"] 20 | fail-fast: false 21 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 22 | with: 23 | runner: linux.12xlarge 24 | repository: pytorch/text 25 | script: | 26 | # Mark Build Directory Safe 27 | git config --global --add safe.directory /__w/text/text 28 | 29 | # Set up Environment Variables 30 | export PYTHON_VERSION="${{ matrix.python_version }}" 31 | export VERSION="cpu" 32 | export CUDATOOLKIT="cpuonly" 33 | 34 | # Set CHANNEL 35 | if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then 36 | export CHANNEL=test 37 | else 38 | export CHANNEL=nightly 39 | fi 40 | 41 | # Create Conda Env 42 | conda create -yp ci_env python="${PYTHON_VERSION}" 43 | conda activate /work/ci_env 44 | python3 -m pip --quiet install cmake>=3.18.0 ninja 45 | conda env update --file ".circleci/unittest/linux/scripts/environment.yml" --prune 46 | 47 | # TorchText-specific Setup 48 | printf "* Downloading SpaCy English models\n" 49 | python -m spacy download en_core_web_sm 50 | printf "* Downloading SpaCy German models\n" 51 | python -m spacy download de_core_news_sm 52 | 53 | # Install PyTorch, Torchvision 54 | set -ex 55 | conda install \ 56 | --yes \ 57 | -c "pytorch-${CHANNEL}" \ 58 | -c nvidia "pytorch-${CHANNEL}"::pytorch[build="*${VERSION}*"] \ 59 | "${CUDATOOLKIT}" 60 | python3 setup.py develop 61 | python3 -m pip install parameterized 62 | 63 | # Run Tests 64 | python3 -m torch.utils.collect_env 65 | cd test 66 | python3 -m pytest --cov=torchtext --junitxml=test-results/junit.xml -v --durations 20 torchtext_unittest 67 | -------------------------------------------------------------------------------- /.github/workflows/test-linux-gpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests on Linux GPU 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | jobs: 16 | tests: 17 | strategy: 18 | matrix: 19 | python_version: ["3.8"] 20 | cuda_arch_version: ["11.7"] 21 | fail-fast: false 22 | uses: pytorch/test-infra/.github/workflows/linux_job.yml@main 23 | with: 24 | runner: linux.g5.4xlarge.nvidia.gpu 25 | repository: pytorch/text 26 | gpu-arch-type: cuda 27 | gpu-arch-version: ${{ matrix.cuda_arch_version }} 28 | timeout: 120 29 | script: | 30 | # Mark Build Directory Safe 31 | git config --global --add safe.directory /__w/text/text 32 | 33 | # Set up Environment Variables 34 | export PYTHON_VERSION="${{ matrix.python_version }}" 35 | export VERSION="${{ matrix.cuda_arch_version }}" 36 | export CUDATOOLKIT="pytorch-cuda=${VERSION}" 37 | 38 | # Set CHANNEL 39 | if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then 40 | export CHANNEL=test 41 | else 42 | export CHANNEL=nightly 43 | fi 44 | 45 | # Create Conda Env 46 | conda create --quiet -yp ci_env python="${PYTHON_VERSION}" 47 | conda activate /work/ci_env 48 | python3 -m pip --quiet install cmake>=3.18.0 ninja 49 | conda env update --file ".circleci/unittest/linux/scripts/environment.yml" --prune 50 | 51 | # TorchText-specific Setup 52 | printf "* Downloading SpaCy English models\n" 53 | python -m spacy download en_core_web_sm 54 | printf "* Downloading SpaCy German models\n" 55 | python -m spacy download de_core_news_sm 56 | 57 | # Install PyTorch 58 | set -ex 59 | conda install \ 60 | --yes \ 61 | --quiet \ 62 | -c "pytorch-${CHANNEL}" \ 63 | -c nvidia "pytorch-${CHANNEL}"::pytorch[build="*${VERSION}*"] \ 64 | "${CUDATOOLKIT}" 65 | python3 setup.py develop 66 | python3 -m pip install parameterized --quiet 67 | 68 | # Run Tests 69 | python3 -m torch.utils.collect_env 70 | cd test 71 | python3 -m pytest --junitxml=test-results/junit.xml -v --durations 20 -m gpu_test torchtext_unittest 72 | -------------------------------------------------------------------------------- /.github/workflows/test-macos-cpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests on Macos CPU 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | jobs: 16 | tests: 17 | strategy: 18 | matrix: 19 | python_version: ["3.8", "3.9", "3.10"] 20 | fail-fast: false 21 | uses: pytorch/test-infra/.github/workflows/macos_job.yml@main 22 | with: 23 | runner: macos-12 24 | repository: pytorch/text 25 | timeout: 60 26 | script: | 27 | # Mark Build Directory Safe 28 | git config --global --add safe.directory /__w/text/text 29 | 30 | # Set up Environment Variables 31 | export PYTHON_VERSION="${{ matrix.python_version }}" 32 | export VERSION="cpu" 33 | export CUDATOOLKIT="cpuonly" 34 | 35 | # Set CHANNEL 36 | if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then 37 | export CHANNEL=test 38 | else 39 | export CHANNEL=nightly 40 | fi 41 | 42 | # TODO: this can be removed as soon as linking issue could be resolved 43 | # see https://github.com/pytorch/pytorch/issues/62424 from details 44 | export MKL_CONSTRAINT='mkl==2021.2.0' 45 | 46 | # Create Conda Env 47 | conda create -yp ./ci_env python="${PYTHON_VERSION}" 48 | conda activate ./ci_env 49 | python3 -m pip --quiet install cmake>=3.18.0 ninja 50 | conda env update --file ".circleci/unittest/linux/scripts/environment.yml" --prune 51 | 52 | # TorchText-specific Setup 53 | printf "* Downloading SpaCy English models\n" 54 | python -m spacy download en_core_web_sm 55 | printf "* Downloading SpaCy German models\n" 56 | python -m spacy download de_core_news_sm 57 | 58 | # Install PyTorch, Torchvision 59 | set -ex 60 | conda install \ 61 | --yes \ 62 | -c "pytorch-${CHANNEL}" \ 63 | -c nvidia \ 64 | "${MKL_CONSTRAINT}" \ 65 | pytorch \ 66 | "${CUDATOOLKIT}" 67 | python3 setup.py develop 68 | python3 -m pip install parameterized 69 | 70 | # Run Tests 71 | python3 -m torch.utils.collect_env 72 | cd test 73 | python3 -m pytest --cov=torchtext --junitxml=test-results/junit.xml -v --durations 20 torchtext_unittest 74 | -------------------------------------------------------------------------------- /.github/workflows/test-windows-cpu.yml: -------------------------------------------------------------------------------- 1 | name: Unit-tests on Windows CPU 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - nightly 8 | - main 9 | - release/* 10 | workflow_dispatch: 11 | 12 | env: 13 | CHANNEL: "nightly" 14 | 15 | jobs: 16 | tests: 17 | strategy: 18 | matrix: 19 | python_version: ["3.8", "3.9", "3.10"] 20 | fail-fast: false 21 | uses: pytorch/test-infra/.github/workflows/windows_job.yml@main 22 | with: 23 | runner: windows.4xlarge 24 | repository: pytorch/text 25 | script: | 26 | set -euxo pipefail 27 | 28 | # Mark Build Directory Safe 29 | git config --global --add safe.directory /__w/text/text 30 | 31 | # Set up Environment Variables 32 | export PYTHON_VERSION="${{ matrix.python_version }}" 33 | export VERSION="cpu" 34 | 35 | # Set CHANNEL 36 | if [[ (${GITHUB_EVENT_NAME} = 'pull_request' && (${GITHUB_BASE_REF} = 'release'*)) || (${GITHUB_REF} = 'refs/heads/release'*) ]]; then 37 | export CHANNEL=test 38 | else 39 | export CHANNEL=nightly 40 | fi 41 | 42 | # Create Conda Env 43 | conda create -y --name ci_env python="${PYTHON_VERSION}" 44 | conda activate ci_env 45 | python -m pip --quiet install cmake>=3.18.0 ninja 46 | conda env update --file ".circleci/unittest/windows/scripts/environment.yml" --prune 47 | 48 | # TorchText-specific Setup 49 | printf "* Downloading SpaCy English models\n" 50 | python -m spacy download en_core_web_sm 51 | printf "* Downloading SpaCy German models\n" 52 | python -m spacy download de_core_news_sm 53 | 54 | # Install PyTorch, Torchvision 55 | conda install \ 56 | --yes \ 57 | -c "pytorch-${CHANNEL}" \ 58 | pytorch \ 59 | cpuonly 60 | 61 | printf "* Installing pywin32_postinstall script\n" 62 | curl --output pywin32_postinstall.py https://raw.githubusercontent.com/mhammond/pywin32/main/pywin32_postinstall.py 63 | python pywin32_postinstall.py -install 64 | 65 | "packaging/vc_env_helper.bat" python setup.py develop 66 | python -m pip install parameterized 67 | 68 | # Run Tests 69 | python -m torch.utils.collect_env 70 | cd test 71 | python -m pytest --cov=torchtext --junitxml=test-results/junit.xml -v --durations 20 torchtext_unittest 72 | -------------------------------------------------------------------------------- /.github/workflows/validate-binaries.yml: -------------------------------------------------------------------------------- 1 | name: Validate binaries 2 | 3 | on: 4 | workflow_call: 5 | inputs: 6 | channel: 7 | description: "Channel to use (nightly, test, release, all)" 8 | required: false 9 | type: string 10 | default: release 11 | os: 12 | description: "Operating system to generate for (linux, windows, macos, macos-arm64)" 13 | required: true 14 | type: string 15 | ref: 16 | description: "Reference to checkout, defaults to empty" 17 | default: "" 18 | required: false 19 | type: string 20 | workflow_dispatch: 21 | inputs: 22 | channel: 23 | description: "Channel to use (nightly, test, release, all)" 24 | required: true 25 | type: choice 26 | options: 27 | - release 28 | - nightly 29 | - test 30 | - all 31 | os: 32 | description: "Operating system to generate for (linux, windows, macos)" 33 | required: true 34 | type: choice 35 | default: all 36 | options: 37 | - windows 38 | - linux 39 | - macos 40 | - all 41 | ref: 42 | description: "Reference to checkout, defaults to empty" 43 | default: "" 44 | required: false 45 | type: string 46 | pytorch_version: 47 | description: "PyTorch version to validate (ie. 2.0, 2.2.2, etc.) - optional" 48 | default: "" 49 | required: false 50 | type: string 51 | jobs: 52 | validate-binaries: 53 | uses: pytorch/test-infra/.github/workflows/validate-domain-library.yml@main 54 | with: 55 | package_type: "conda,wheel" 56 | version: ${{ inputs.version }} 57 | os: ${{ inputs.os }} 58 | channel: ${{ inputs.channel }} 59 | repository: "pytorch/text" 60 | smoke_test: "source ./.github/scripts/validate_binaries.sh" 61 | install_torch: true 62 | -------------------------------------------------------------------------------- /.github/workflows/validate-nightly-binaries.yml: -------------------------------------------------------------------------------- 1 | # Scheduled validation of the nightly binaries 2 | name: cron 3 | 4 | on: 5 | schedule: 6 | # At 5:30 pm UTC (7:30 am PDT) 7 | - cron: "30 17 * * *" 8 | # Have the ability to trigger this job manually through the API 9 | workflow_dispatch: 10 | push: 11 | branches: 12 | - main 13 | paths: 14 | - .github/workflows/validate-nightly-binaries.yml 15 | - .github/workflows/validate-binaries.yml 16 | - .github/scripts/validate_binaries.sh 17 | pull_request: 18 | paths: 19 | - .github/workflows/validate-nightly-binaries.yml 20 | - .github/workflows/validate-binaries.yml 21 | - .github/scripts/validate_binaries.sh 22 | jobs: 23 | nightly: 24 | uses: ./.github/workflows/validate-binaries.yml 25 | with: 26 | channel: nightly 27 | os: all 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | *.zip 3 | *~ 4 | .vector_cache 5 | .idea/ 6 | 7 | # Documentation 8 | docs/build 9 | 10 | # Download folder 11 | .data 12 | 13 | # Created by https://www.gitignore.io/api/python 14 | 15 | ### Python ### 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | .hypothesis/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | docs/src/ 81 | docs/source/tutorials 82 | docs/source/gen_modules 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # computed checksum files 88 | torchtext/experimental/asset/.checksums/ 89 | 90 | # Jupyter Notebook 91 | .ipynb_checkpoints 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule.* 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # VSCode project settings 119 | .vscode 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | 127 | # End of https://www.gitignore.io/api/python 128 | 129 | # vim 130 | *.swp 131 | *.swo 132 | 133 | torchtext/version.py 134 | 135 | # Thirdparty directories 136 | third_party/*/ 137 | 138 | # Mac OS .DS_Store files 139 | .DS_Store 140 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/sentencepiece"] 2 | path = third_party/sentencepiece 3 | url = https://github.com/google/sentencepiece 4 | ignore = dirty 5 | [submodule "third_party/re2"] 6 | path = third_party/re2 7 | url = https://github.com/google/re2 8 | ignore = dirty 9 | [submodule "third_party/double-conversion"] 10 | path = third_party/double-conversion 11 | url = https://github.com/google/double-conversion 12 | ignore = dirty 13 | [submodule "third_party/utf8proc"] 14 | path = third_party/utf8proc 15 | url = https://github.com/JuliaStrings/utf8proc 16 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | node: 16.14.2 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.0.1 7 | hooks: 8 | - id: trailing-whitespace 9 | - id: mixed-line-ending 10 | args: 11 | - --fix=lf 12 | - id: end-of-file-fixer 13 | 14 | - repo: https://github.com/pre-commit/mirrors-prettier 15 | rev: v2.5.1 16 | hooks: 17 | - id: prettier 18 | types_or: 19 | - markdown 20 | - toml 21 | - yaml 22 | 23 | - repo: https://github.com/omnilib/ufmt 24 | rev: v1.3.1 25 | hooks: 26 | - id: ufmt 27 | additional_dependencies: 28 | - black == 21.4b2 29 | - usort == 0.6.4 30 | 31 | - repo: https://github.com/pycqa/flake8 32 | rev: 4.0.1 33 | hooks: 34 | - id: flake8 35 | additional_dependencies: 36 | - flake8-docstrings == 1.6.0 37 | - torchfix == 0.0.2 38 | args: 39 | - --config=.flake8 40 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | packaging/* 2 | .circleci/config.yml 3 | -------------------------------------------------------------------------------- /.prettierrc.yaml: -------------------------------------------------------------------------------- 1 | proseWrap: always 2 | printWidth: 120 3 | -------------------------------------------------------------------------------- /.python3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/.python3 -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | 3 | # Most of the configurations are taken from PyTorch 4 | # https://github.com/pytorch/pytorch/blob/0c9fb4aff0d60eaadb04e4d5d099fb1e1d5701a9/CMakeLists.txt 5 | 6 | # Use compiler ID "AppleClang" instead of "Clang" for XCode. 7 | # Not setting this sometimes makes XCode C compiler gets detected as "Clang", 8 | # even when the C++ one is detected as "AppleClang". 9 | cmake_policy(SET CMP0010 NEW) 10 | cmake_policy(SET CMP0025 NEW) 11 | 12 | # Suppress warning flags in default MSVC configuration. It's not 13 | # mandatory that we do this (and we don't if cmake is old), but it's 14 | # nice when it's possible, and it's possible on our Windows configs. 15 | if(NOT CMAKE_VERSION VERSION_LESS 3.15.0) 16 | cmake_policy(SET CMP0092 NEW) 17 | endif() 18 | 19 | project(torchtext) 20 | 21 | 22 | # check and set CMAKE_CXX_STANDARD 23 | string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard) 24 | if(env_cxx_standard GREATER -1) 25 | message( 26 | WARNING "C++ standard version definition detected in environment variable." 27 | "PyTorch requires -std=c++17. Please remove -std=c++ settings in your environment.") 28 | endif() 29 | 30 | set(CMAKE_CXX_STANDARD 17) 31 | set(CMAKE_C_STANDARD 11) 32 | 33 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 34 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 35 | 36 | # Apple specific 37 | if(APPLE) 38 | # Get clang version on macOS 39 | execute_process( COMMAND ${CMAKE_CXX_COMPILER} --version OUTPUT_VARIABLE clang_full_version_string ) 40 | string(REGEX REPLACE "Apple LLVM version ([0-9]+\\.[0-9]+).*" "\\1" CLANG_VERSION_STRING ${clang_full_version_string}) 41 | message( STATUS "CLANG_VERSION_STRING: " ${CLANG_VERSION_STRING} ) 42 | 43 | # RPATH stuff 44 | set(CMAKE_MACOSX_RPATH ON) 45 | 46 | set(CMAKE_SHARED_LIBRARY_SUFFIX ".so") 47 | endif() 48 | 49 | # Options 50 | option(BUILD_TORCHTEXT_PYTHON_EXTENSION "Build Python extension" OFF) 51 | 52 | set(CMAKE_MODULE_PATH "${CMAKE_MODULE_PATH};${CMAKE_CURRENT_SOURCE_DIR}/cmake") 53 | set(TORCH_INSTALL_PREFIX "${CMAKE_PREFIX_PATH}/../.." CACHE STRING "Install path for torch") 54 | set(TORCH_COMPILED_WITH_CXX_ABI "-D_GLIBCXX_USE_CXX11_ABI=0" CACHE STRING "Compile torchtext with cxx11_abi") 55 | 56 | find_library(TORCH_C10_LIBRARY c10 PATHS "${TORCH_INSTALL_PREFIX}/lib") 57 | find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib") 58 | find_library(TORCH_CPU_LIBRARY torch_cpu PATHS "${TORCH_INSTALL_PREFIX}/lib") 59 | 60 | if(MSVC) 61 | set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") 62 | endif() 63 | 64 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_COMPILED_WITH_CXX_ABI} -Wall ${TORCH_CXX_FLAGS}") 65 | 66 | add_subdirectory(third_party) 67 | add_subdirectory(torchtext/csrc) 68 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) James Bradbury and Soumith Chintala 2016, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /benchmark/benchmark_basic_english_normalize.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torchtext.data.utils import get_tokenizer 5 | from torchtext.datasets import AG_NEWS 6 | from torchtext.prototype.transforms import basic_english_normalize 7 | 8 | 9 | def benchmark_basic_english_normalize(): 10 | def _run_benchmark_lookup(train, tokenizer): 11 | t0 = time.monotonic() 12 | for (_, text) in train: 13 | tokenizer(text) 14 | print("Tokenization time:", time.monotonic() - t0) 15 | 16 | existing_basic_english_tokenizer = get_tokenizer("basic_english") 17 | experimental_basic_english_normalize = basic_english_normalize() 18 | experimental_jit_basic_english_normalize = torch.jit.script(experimental_basic_english_normalize) 19 | 20 | # existing eager lookup 21 | train = AG_NEWS(split="train") 22 | print("BasicEnglishNormalize - Eager Mode") 23 | _run_benchmark_lookup(train, existing_basic_english_tokenizer) 24 | 25 | # experimental eager lookup 26 | train = AG_NEWS(split="train") 27 | print("BasicEnglishNormalize Experimental - Eager Mode") 28 | _run_benchmark_lookup(train, experimental_basic_english_normalize) 29 | 30 | # experimental jit lookup 31 | train = AG_NEWS(split="train") 32 | print("BasicEnglishNormalize Experimental - Jit Mode") 33 | _run_benchmark_lookup(train, experimental_jit_basic_english_normalize) 34 | 35 | 36 | if __name__ == "__main__": 37 | benchmark_basic_english_normalize() 38 | -------------------------------------------------------------------------------- /benchmark/benchmark_bert_tokenizer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from benchmark.utils import Timer 4 | from tokenizers import Tokenizer as hf_tokenizer_lib 5 | from torchtext.datasets import EnWik9 6 | from torchtext.transforms import BERTTokenizer as tt_bert_tokenizer 7 | from transformers import BertTokenizer as hf_bert_tokenizer_slow 8 | 9 | 10 | VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt" 11 | 12 | 13 | def benchmark_bert_tokenizer(args): 14 | tt_tokenizer = tt_bert_tokenizer(VOCAB_FILE, return_tokens=True) 15 | hf_tokenizer_slow = hf_bert_tokenizer_slow.from_pretrained("bert-base-uncased") 16 | hf_tokenizer_fast = hf_tokenizer_lib.from_pretrained("bert-base-uncased") 17 | dp = EnWik9().header(args.num_samples).batch(args.batch_size) 18 | samples = list(dp) 19 | 20 | with Timer("Running TorchText BERT Tokenizer on non-batched input"): 21 | for batch in samples: 22 | for s in batch: 23 | tt_tokenizer(s) 24 | 25 | with Timer("Running HF BERT Tokenizer (slow) on non-batched input"): 26 | for batch in samples: 27 | for s in batch: 28 | hf_tokenizer_slow.tokenize(s) 29 | 30 | with Timer("Running HF BERT Tokenizer (fast) on non-batched input"): 31 | for batch in samples: 32 | for s in batch: 33 | hf_tokenizer_fast.encode(s) 34 | 35 | with Timer("Running TorchText BERT Tokenizer on batched input"): 36 | for batch in samples: 37 | tt_tokenizer(batch) 38 | 39 | with Timer("Running HF BERT Tokenizer (fast) on batched input"): 40 | for batch in samples: 41 | hf_tokenizer_fast.encode_batch(batch) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = ArgumentParser() 46 | parser.add_argument("--num-samples", default=10000, type=int) 47 | parser.add_argument("--batch-size", default=100, type=int) 48 | 49 | benchmark_bert_tokenizer(parser.parse_args()) 50 | -------------------------------------------------------------------------------- /benchmark/benchmark_experimental_vectors.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torchtext.prototype.datasets import AG_NEWS 5 | from torchtext.prototype.vectors import FastText as FastTextExperimental 6 | from torchtext.vocab import FastText 7 | 8 | 9 | def benchmark_experimental_vectors(): 10 | def _run_benchmark_lookup(tokens, vector): 11 | t0 = time.monotonic() 12 | for token in tokens: 13 | vector[token] 14 | print("Lookup time:", time.monotonic() - t0) 15 | 16 | train = AG_NEWS(split="train") 17 | vocab = train.get_vocab() 18 | tokens = [] 19 | for (label, text) in train: 20 | for id in text.tolist(): 21 | tokens.append(vocab.itos[id]) 22 | 23 | # existing FastText construction 24 | print("FastText Existing Construction") 25 | t0 = time.monotonic() 26 | fast_text = FastText() 27 | print("Construction time:", time.monotonic() - t0) 28 | 29 | # experimental FastText construction 30 | print("FastText Experimental Construction") 31 | t0 = time.monotonic() 32 | fast_text_experimental = FastTextExperimental(validate_file=False) 33 | print("Construction time:", time.monotonic() - t0) 34 | 35 | # existing FastText eager lookup 36 | print("FastText Existing - Eager Mode") 37 | _run_benchmark_lookup(tokens, fast_text) 38 | 39 | # experimental FastText eager lookup 40 | print("FastText Experimental - Eager Mode") 41 | _run_benchmark_lookup(tokens, fast_text_experimental) 42 | 43 | # experimental FastText jit lookup 44 | print("FastText Experimental - Jit Mode") 45 | jit_fast_text_experimental = torch.jit.script(fast_text_experimental) 46 | _run_benchmark_lookup(tokens, jit_fast_text_experimental) 47 | 48 | 49 | if __name__ == "__main__": 50 | benchmark_experimental_vectors() 51 | -------------------------------------------------------------------------------- /benchmark/benchmark_roberta_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from benchmark.utils import Timer 5 | from torchtext.functional import to_tensor 6 | from torchtext.models import XLMR_BASE_ENCODER, XLMR_LARGE_ENCODER, ROBERTA_BASE_ENCODER, ROBERTA_LARGE_ENCODER 7 | 8 | ENCODERS = { 9 | "xlmr_base": XLMR_BASE_ENCODER, 10 | "xlmr_large": XLMR_LARGE_ENCODER, 11 | "roberta_base": ROBERTA_BASE_ENCODER, 12 | "roberta_large": ROBERTA_LARGE_ENCODER, 13 | } 14 | 15 | 16 | def basic_model_input(encoder): 17 | transform = encoder.transform() 18 | input_batch = ["Hello world", "How are you!"] 19 | return to_tensor(transform(input_batch), padding_value=1) 20 | 21 | 22 | def _train(model, model_input): 23 | model_out = model(model_input) 24 | model_out.backward(torch.ones_like(model_out)) 25 | model.zero_grad() 26 | 27 | 28 | def run(args): 29 | encoder_name = args.encoder 30 | num_passes = args.num_passes 31 | warmup_passes = args.num_passes 32 | model_input = args.model_input 33 | 34 | encoder = ENCODERS.get(encoder_name, None) 35 | if not encoder: 36 | raise NotImplementedError("Given encoder [{}] is not available".format(encoder_name)) 37 | 38 | model = encoder.get_model() 39 | if model_input == "basic": 40 | model_input = basic_model_input(encoder) 41 | else: 42 | raise NotImplementedError("Given model input [{}] is not available".format(model_input)) 43 | 44 | model.eval() 45 | for _ in range(warmup_passes): 46 | model(model_input) 47 | 48 | with Timer("Executing model forward"): 49 | with torch.no_grad(): 50 | for _ in range(num_passes): 51 | model(model_input) 52 | 53 | model.train() 54 | with Timer("Executing model forward/backward"): 55 | for _ in range(num_passes): 56 | _train(model, model_input) 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = ArgumentParser() 61 | parser.add_argument("--encoder", default="xlmr_base", type=str) 62 | parser.add_argument("--num-passes", default=50, type=int) 63 | parser.add_argument("--warmup-passes", default=10, type=int) 64 | parser.add_argument("--model-input", default="basic", type=str) 65 | run(parser.parse_args()) 66 | -------------------------------------------------------------------------------- /benchmark/benchmark_sentencepiece.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | 4 | from torchtext.data.functional import load_sp_model as load_torchbind_sp_model 5 | from torchtext.datasets import DATASETS 6 | from torchtext.prototype.transforms import load_sp_model as load_pybind_sp_model 7 | from torchtext.utils import download_from_url 8 | 9 | 10 | def benchmark_sentencepiece(args): 11 | def _run_benchmark(train, spm_processor): 12 | t0 = time.monotonic() 13 | for (_, text) in train: 14 | spm_processor(text) 15 | print("Sentencepiece processor time:", time.monotonic() - t0) 16 | 17 | # Download a pretrained sentencepiece model 18 | sp_model_path = download_from_url( 19 | "https://pytorch.s3.amazonaws.com/models/text/pretrained_spm/text_unigram_15000.model" 20 | ) 21 | 22 | # existing sentencepiece model with torchbind 23 | train = DATASETS[args.dataset](split="train") 24 | sp_model = load_torchbind_sp_model(sp_model_path) 25 | print("SentencePiece EncodeAsIds - torchbind") 26 | _run_benchmark(train, sp_model.EncodeAsIds) 27 | 28 | # experimental sentencepiece model with pybind 29 | train = DATASETS[args.dataset](split="train") 30 | sp_model = load_pybind_sp_model(sp_model_path) 31 | print("SentencePiece EncodeAsIds - pybind") 32 | _run_benchmark(train, sp_model.EncodeAsIds) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser(description="SentencePiece benchmark") 37 | parser.add_argument("--dataset", type=str, default="AG_NEWS", help="Dataset for performance benchmark") 38 | args = parser.parse_args() 39 | benchmark_sentencepiece(args) 40 | 41 | # Running with AG_NEWS 42 | # SentencePiece EncodeAsIds - torchbind 43 | # Sentencepiece processor time: 11.536989663727582 44 | # SentencePiece EncodeAsIds - pybind 45 | # Sentencepiece processor time: 11.38821320142597 46 | 47 | # Running with YelpReviewFull 48 | # SentencePiece EncodeAsIds - torchbind 49 | # Sentencepiece processor time: 224.23954573180526 50 | # SentencePiece EncodeAsIds - pybind 51 | # Sentencepiece processor time: 217.134037473239 52 | -------------------------------------------------------------------------------- /benchmark/data_construction.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from torchtext.prototype import datasets 4 | 5 | 6 | def benchmark_construction(name, Dataset): 7 | t0 = time.perf_counter() 8 | print(name, end="") 9 | (d,) = Dataset(data_select=("train",)) 10 | print(" construction time {0:.2f}s".format(time.perf_counter() - t0)) 11 | del d 12 | 13 | 14 | def benchmark_raw_construction(name, Dataset): 15 | print(name, end="") 16 | if name in "WMTNewsCrawl": 17 | d = Dataset(data_select=("train",)) 18 | else: 19 | d = Dataset() 20 | del d 21 | 22 | 23 | if __name__ == "__main__": 24 | for name, Dataset in datasets.DATASETS.items(): 25 | benchmark_construction(name, Dataset) 26 | -------------------------------------------------------------------------------- /benchmark/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class Timer: 5 | """Basic utility class to calculate execution time. It can also be used a context manager.""" 6 | 7 | def __init__(self, text=""): 8 | self._text = text 9 | self._start = None 10 | 11 | def start(self): 12 | if self._start is not None: 13 | raise Exception("Timer is already running. Call .stop() to stop it") 14 | 15 | self._start = time.perf_counter() 16 | 17 | def stop(self): 18 | if self._start is None: 19 | raise Exception("Timer is not running. Call .start() to start the timer.") 20 | 21 | elapsed = time.perf_counter() - self._start 22 | 23 | print("{} ... Total running time: {}".format(self._text, elapsed)) 24 | 25 | def __enter__(self): 26 | self.start() 27 | return self 28 | 29 | def __exit__(self, *exc_info): 30 | self.stop() 31 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = -W # turn warnings into errors 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = torchtext 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | docset: html 16 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/text/ --force $(BUILDDIR)/html/ 17 | 18 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. 19 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png 20 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png 21 | 22 | .PHONY: help Makefile 23 | 24 | # Catch-all target: route all unknown targets to Sphinx using the new 25 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 26 | %: Makefile 27 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 28 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=torchtext 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Jinja2<3.1.0 2 | sphinx==5.1.1 3 | -e git+https://github.com/pytorch/pytorch_sphinx_theme.git@cece053#egg=pytorch_sphinx_theme 4 | sphinx_gallery==0.11.1 5 | matplotlib 6 | regex 7 | -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/docs/source/_static/img/pytorch-logo-dark.png -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 10 | 13 | 14 | 16 | 17 | 18 | 20 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-flame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/docs/source/_static/img/pytorch-logo-flame.png -------------------------------------------------------------------------------- /docs/source/_static/img/pytorch-logo-flame.svg: -------------------------------------------------------------------------------- 1 | 2 | image/svg+xml 34 | -------------------------------------------------------------------------------- /docs/source/_static/img/torchtext_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/docs/source/_static/img/torchtext_logo.png -------------------------------------------------------------------------------- /docs/source/data_functional.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.data.functional 5 | =========================== 6 | 7 | .. automodule:: torchtext.data.functional 8 | .. currentmodule:: torchtext.data.functional 9 | 10 | :hidden:`generate_sp_model` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autofunction:: generate_sp_model 14 | 15 | :hidden:`load_sp_model` 16 | ~~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | .. autofunction:: load_sp_model 19 | 20 | :hidden:`sentencepiece_numericalizer` 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | 23 | .. autofunction:: sentencepiece_numericalizer 24 | 25 | :hidden:`sentencepiece_tokenizer` 26 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 27 | 28 | .. autofunction:: sentencepiece_tokenizer 29 | 30 | :hidden:`custom_replace` 31 | ~~~~~~~~~~~~~~~~~~~~~~~~ 32 | 33 | .. autofunction:: custom_replace 34 | 35 | :hidden:`simple_space_split` 36 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 37 | 38 | .. autofunction:: simple_space_split 39 | 40 | :hidden:`numericalize_tokens_from_iterator` 41 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 42 | 43 | .. autofunction:: numericalize_tokens_from_iterator 44 | 45 | 46 | :hidden:`filter_wikipedia_xml` 47 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 48 | 49 | .. autofunction:: filter_wikipedia_xml 50 | 51 | :hidden:`to_map_style_dataset` 52 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 53 | 54 | .. autofunction:: to_map_style_dataset 55 | -------------------------------------------------------------------------------- /docs/source/data_metrics.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.data.metrics 5 | =========================== 6 | 7 | .. automodule:: torchtext.data.metrics 8 | .. currentmodule:: torchtext.data.metrics 9 | 10 | :hidden:`bleu_score` 11 | ~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autofunction:: bleu_score 14 | -------------------------------------------------------------------------------- /docs/source/data_utils.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.data.utils 5 | =========================== 6 | 7 | .. automodule:: torchtext.data.utils 8 | .. currentmodule:: torchtext.data.utils 9 | 10 | :hidden:`get_tokenizer` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autofunction:: get_tokenizer 14 | 15 | :hidden:`ngrams_iterator` 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | .. autofunction:: ngrams_iterator 19 | -------------------------------------------------------------------------------- /docs/source/experimental_models_utils.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.experimental.models.utils 5 | =================================== 6 | 7 | .. automodule:: torchtext.experimental.models.utils 8 | .. currentmodule:: torchtext.experimental.models.utils 9 | 10 | :hidden:`count_model_param` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autofunction:: count_model_param 14 | -------------------------------------------------------------------------------- /docs/source/experimental_transforms.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.experimental.transforms 5 | ================================= 6 | 7 | .. automodule:: torchtext.experimental.transforms 8 | .. currentmodule:: torchtext.experimental.transforms 9 | 10 | :hidden:`BasicEnglishNormalize` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autoclass:: BasicEnglishNormalize 14 | :members: 15 | :special-members: __init__ 16 | 17 | :hidden:`RegexTokenizer` 18 | ~~~~~~~~~~~~~~~~~~~~~~~~ 19 | 20 | .. autoclass:: RegexTokenizer 21 | :members: 22 | :special-members: __init__ 23 | 24 | :hidden:`TextSequentialTransforms` 25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 26 | 27 | .. autoclass:: TextSequentialTransforms 28 | :members: 29 | :special-members: __init__ 30 | 31 | :hidden:`load_sp_model` 32 | ~~~~~~~~~~~~~~~~~~~~~~~ 33 | 34 | .. autofunction:: load_sp_model 35 | 36 | :hidden:`sentencepiece_tokenizer` 37 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 38 | 39 | .. autofunction:: sentencepiece_tokenizer 40 | 41 | :hidden:`SentencePieceTokenizer` 42 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 43 | 44 | .. autoclass:: SentencePieceTokenizer 45 | :members: 46 | :special-members: __init__ 47 | 48 | :hidden:`sentencepiece_processor` 49 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 50 | 51 | .. autofunction:: sentencepiece_processor 52 | 53 | :hidden:`SentencePieceProcessor` 54 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 55 | 56 | .. autoclass:: SentencePieceProcessor 57 | :members: 58 | :special-members: __init__ 59 | 60 | :hidden:`VocabTransform` 61 | ~~~~~~~~~~~~~~~~~~~~~~~~ 62 | 63 | .. autoclass:: VocabTransform 64 | :members: 65 | :special-members: __init__ 66 | 67 | :hidden:`VectorTransform` 68 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 69 | 70 | .. autoclass:: VectorTransform 71 | :members: 72 | :special-members: __init__ 73 | -------------------------------------------------------------------------------- /docs/source/experimental_vectors.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.experimental.vectors 5 | ============================== 6 | 7 | .. automodule:: torchtext.experimental.vectors 8 | .. currentmodule:: torchtext.experimental.vectors 9 | 10 | :hidden:`Vector` 11 | ~~~~~~~~~~~~~~~~ 12 | 13 | .. autoclass:: Vectors 14 | :members: 15 | :special-members: 16 | 17 | :hidden:`build_vectors` 18 | ~~~~~~~~~~~~~~~~~~~~~~~ 19 | 20 | .. autofunction:: build_vectors 21 | 22 | :hidden:`load_vectors_from_file_path` 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. autofunction:: load_vectors_from_file_path 26 | 27 | :hidden:`FastText` 28 | ~~~~~~~~~~~~~~~~~~ 29 | 30 | .. autofunction:: FastText 31 | 32 | :hidden:`GloVe` 33 | ~~~~~~~~~~~~~~~ 34 | 35 | .. autofunction:: GloVe 36 | -------------------------------------------------------------------------------- /docs/source/experimental_vocab.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.experimental.vocab_factory 5 | ==================================== 6 | 7 | .. automodule:: torchtext.experimental.vocab_factory 8 | .. currentmodule:: torchtext.experimental.vocab_factory 9 | 10 | :hidden:`load_vocab_from_file` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autofunction:: load_vocab_from_file 14 | 15 | :hidden:`build_vocab_from_text_file` 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | .. autofunction:: build_vocab_from_text_file 19 | -------------------------------------------------------------------------------- /docs/source/functional.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.functional 5 | =========================== 6 | 7 | .. automodule:: torchtext.functional 8 | .. currentmodule:: torchtext.functional 9 | 10 | to_tensor 11 | --------- 12 | 13 | .. autofunction:: to_tensor 14 | 15 | 16 | truncate 17 | -------- 18 | 19 | .. autofunction:: truncate 20 | 21 | 22 | add_token 23 | --------- 24 | 25 | .. autofunction:: add_token 26 | 27 | str_to_int 28 | ---------- 29 | 30 | .. autofunction:: str_to_int 31 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | torchtext 2 | ========= 3 | .. image:: _static/img/torchtext_logo.png 4 | 5 | .. warning:: 6 | 7 | TorchText development is stopped and the ``0.18`` release (April 2024) will 8 | be the last stable release of the library. 9 | 10 | This library is part of the `PyTorch 11 | `_ project. PyTorch is an open source 12 | machine learning framework. 13 | 14 | Features described in this documentation are classified by release status: 15 | 16 | *Stable:* These features will be maintained long-term and there should generally 17 | be no major performance limitations or gaps in documentation. 18 | We also expect to maintain backwards compatibility (although 19 | breaking changes can happen and notice will be given one release ahead 20 | of time). 21 | 22 | *Beta:* Features are tagged as Beta because the API may change based on 23 | user feedback, because the performance needs to improve, or because 24 | coverage across operators is not yet complete. For Beta features, we are 25 | committing to seeing the feature through to the Stable classification. 26 | We are not, however, committing to backwards compatibility. 27 | 28 | *Prototype:* These features are typically not available as part of 29 | binary distributions like PyPI or Conda, except sometimes behind run-time 30 | flags, and are at an early stage for feedback and testing. 31 | 32 | 33 | The :mod:`torchtext` package consists of data processing utilities and 34 | popular datasets for natural language. 35 | 36 | .. toctree:: 37 | :maxdepth: 1 38 | :caption: Torchtext Documentation 39 | :hidden: 40 | 41 | Index 42 | logo 43 | 44 | .. toctree:: 45 | :maxdepth: 2 46 | :caption: Package Reference 47 | 48 | nn_modules 49 | data_functional 50 | data_metrics 51 | data_utils 52 | datasets 53 | torchtext.vocab 54 | torchtext.utils 55 | transforms 56 | functional 57 | models 58 | 59 | Getting Started 60 | --------------- 61 | 62 | .. toctree:: 63 | :maxdepth: 1 64 | :caption: Getting Started 65 | 66 | tutorials/sst2_classification_non_distributed 67 | tutorials/t5_demo 68 | 69 | 70 | .. automodule:: torchtext 71 | :members: 72 | 73 | .. toctree:: 74 | :maxdepth: 1 75 | :caption: PyTorch Libraries 76 | 77 | PyTorch 78 | torchaudio 79 | torchtext 80 | torchvision 81 | TorchElastic 82 | TorchServe 83 | PyTorch on XLA Devices 84 | -------------------------------------------------------------------------------- /docs/source/logo.rst: -------------------------------------------------------------------------------- 1 | TorchText Logo 2 | =============== 3 | 4 | If you make your project using TorchText and you want to mention TorchText, you can use the TorchText logo. There are couple of variations. You can download them from `here `__. 5 | 6 | Please follow `the guideline `__ for the proper usage. 7 | 8 | .. warning:: 9 | 10 | Please do not alter the logo. The guideline lists examples of improper usages as well, so please check them out before using the logos. 11 | 12 | Icon 13 | ---- 14 | 15 | .. image:: https://download.pytorch.org/torchtext/logo/v1/TorchText_Symbol_fullColor_RGB.png 16 | :width: 400 17 | 18 | Horizontal 19 | ---------- 20 | 21 | .. image:: https://download.pytorch.org/torchtext/logo/v1/TorchText_Horiz_fullColor_RGB.png 22 | :width: 400 23 | 24 | | 25 | 26 | .. image:: https://download.pytorch.org/torchtext/logo/v1/TorchText_Horiz_black_RGB.png 27 | :width: 400 28 | 29 | | 30 | 31 | .. raw:: html 32 | 33 |
34 | 35 |
36 | 37 | Vertical 38 | -------- 39 | 40 | .. image:: https://download.pytorch.org/torchtext/logo/v1/TorchText_Vertical_fullColor_RGB.png 41 | :width: 400 42 | 43 | | 44 | 45 | .. image:: https://download.pytorch.org/torchtext/logo/v1/TorchText_Vertical_black_RGB.png 46 | :width: 400 47 | 48 | | 49 | 50 | .. raw:: html 51 | 52 |
53 | 54 |
55 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.models 5 | =========================== 6 | 7 | .. automodule:: torchtext.models 8 | .. currentmodule:: torchtext.models 9 | 10 | RobertaBundle 11 | ------------- 12 | 13 | .. autoclass:: RobertaBundle 14 | :members: transform 15 | 16 | .. automethod:: get_model 17 | 18 | XLMR_BASE_ENCODER 19 | ----------------- 20 | 21 | .. container:: py attribute 22 | 23 | .. autodata:: XLMR_BASE_ENCODER 24 | :no-value: 25 | 26 | 27 | XLMR_LARGE_ENCODER 28 | ------------------ 29 | 30 | .. container:: py attribute 31 | 32 | .. autodata:: XLMR_LARGE_ENCODER 33 | :no-value: 34 | 35 | ROBERTA_BASE_ENCODER 36 | -------------------- 37 | 38 | .. container:: py attribute 39 | 40 | .. autodata:: ROBERTA_BASE_ENCODER 41 | :no-value: 42 | 43 | 44 | ROBERTA_LARGE_ENCODER 45 | --------------------- 46 | 47 | .. container:: py attribute 48 | 49 | .. autodata:: ROBERTA_LARGE_ENCODER 50 | :no-value: 51 | -------------------------------------------------------------------------------- /docs/source/nn_modules.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.nn 5 | ======================================= 6 | 7 | .. automodule:: torchtext.nn 8 | .. currentmodule:: torchtext.nn 9 | 10 | :hidden:`MultiheadAttentionContainer` 11 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autoclass:: MultiheadAttentionContainer 14 | :members: 15 | :special-members: __init__ 16 | 17 | :hidden:`InProjContainer` 18 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 19 | 20 | .. autoclass:: InProjContainer 21 | :members: 22 | :special-members: __init__ 23 | 24 | :hidden:`ScaledDotProduct` 25 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 26 | 27 | .. autoclass:: ScaledDotProduct 28 | :members: 29 | :special-members: __init__ 30 | -------------------------------------------------------------------------------- /docs/source/transforms.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.transforms 5 | =========================== 6 | 7 | .. automodule:: torchtext.transforms 8 | .. currentmodule:: torchtext.transforms 9 | 10 | Transforms are common text transforms. They can be chained together using :class:`torch.nn.Sequential` or using :class:`torchtext.transforms.Sequential` to support torch-scriptability. 11 | 12 | SentencePieceTokenizer 13 | ---------------------- 14 | 15 | .. autoclass:: SentencePieceTokenizer 16 | 17 | .. automethod:: forward 18 | 19 | GPT2BPETokenizer 20 | ---------------- 21 | 22 | .. autoclass:: GPT2BPETokenizer 23 | 24 | .. automethod:: forward 25 | 26 | CLIPTokenizer 27 | ------------- 28 | 29 | .. autoclass:: CLIPTokenizer 30 | 31 | .. automethod:: forward 32 | 33 | RegexTokenizer 34 | -------------- 35 | 36 | .. autoclass:: RegexTokenizer 37 | 38 | .. automethod:: forward 39 | 40 | BERTTokenizer 41 | ------------- 42 | 43 | .. autoclass:: BERTTokenizer 44 | 45 | .. automethod:: forward 46 | 47 | VocabTransform 48 | -------------- 49 | 50 | .. autoclass:: VocabTransform 51 | 52 | .. automethod:: forward 53 | 54 | ToTensor 55 | -------- 56 | 57 | .. autoclass:: ToTensor 58 | 59 | .. automethod:: forward 60 | 61 | LabelToIndex 62 | ------------ 63 | 64 | .. autoclass:: LabelToIndex 65 | 66 | .. automethod:: forward 67 | 68 | Truncate 69 | -------- 70 | 71 | .. autoclass:: Truncate 72 | 73 | .. automethod:: forward 74 | 75 | AddToken 76 | -------- 77 | 78 | .. autoclass:: AddToken 79 | 80 | .. automethod:: forward 81 | 82 | Sequential 83 | ---------- 84 | 85 | .. autoclass:: Sequential 86 | 87 | .. automethod:: forward 88 | 89 | PadTransform 90 | ------------ 91 | 92 | .. autoclass:: PadTransform 93 | 94 | .. automethod:: forward 95 | 96 | StrToIntTransform 97 | ----------------- 98 | 99 | .. autoclass:: StrToIntTransform 100 | 101 | .. automethod:: forward 102 | 103 | CharBPETokenizer 104 | ---------------- 105 | 106 | .. autoclass:: CharBPETokenizer 107 | 108 | .. automethod:: forward 109 | -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.utils 5 | =========================== 6 | 7 | .. automodule:: torchtext.utils 8 | .. currentmodule:: torchtext.utils 9 | 10 | :hidden:`reporthook` 11 | ~~~~~~~~~~~~~~~~~~~~ 12 | 13 | .. autofunction:: reporthook 14 | 15 | :hidden:`download_from_url` 16 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 17 | 18 | .. autofunction:: download_from_url 19 | 20 | :hidden:`extract_archive` 21 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 22 | 23 | .. autofunction:: extract_archive 24 | -------------------------------------------------------------------------------- /docs/source/vocab.rst: -------------------------------------------------------------------------------- 1 | .. role:: hidden 2 | :class: hidden-section 3 | 4 | torchtext.vocab 5 | =========================== 6 | 7 | .. automodule:: torchtext.vocab 8 | .. currentmodule:: torchtext.vocab 9 | 10 | :hidden:`Vocab` 11 | ~~~~~~~~~~~~~~~ 12 | 13 | .. autoclass:: Vocab 14 | :members: 15 | :special-members: 16 | 17 | :hidden:`vocab` 18 | ~~~~~~~~~~~~~~~ 19 | 20 | .. autofunction:: vocab 21 | 22 | :hidden:`build_vocab_from_iterator` 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | .. autofunction:: build_vocab_from_iterator 26 | 27 | 28 | :hidden:`Vectors` 29 | ~~~~~~~~~~~~~~~~~ 30 | 31 | .. autoclass:: Vectors 32 | :members: 33 | :special-members: __init__ 34 | 35 | Pretrained Word Embeddings 36 | -------------------------- 37 | 38 | :hidden:`GloVe` 39 | ~~~~~~~~~~~~~~~ 40 | 41 | .. autoclass:: GloVe 42 | :members: 43 | 44 | :hidden:`FastText` 45 | ~~~~~~~~~~~~~~~~~~ 46 | 47 | .. autoclass:: FastText 48 | :members: 49 | 50 | :hidden:`CharNGram` 51 | ~~~~~~~~~~~~~~~~~~~ 52 | 53 | .. autoclass:: CharNGram 54 | :members: 55 | -------------------------------------------------------------------------------- /examples/data_pipeline/roberta_datapipe.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from functools import partial 3 | from typing import Dict, Any 4 | 5 | import torchtext.functional as F 6 | import torchtext.transforms as T 7 | from torch.hub import load_state_dict_from_url 8 | from torch.nn import Module 9 | from torch.utils.data import DataLoader 10 | from torchtext.datasets import SST2 11 | 12 | 13 | class RobertaTransformDataPipe(Module): 14 | def __init__(self) -> None: 15 | super().__init__() 16 | # Instantiate various transforms 17 | 18 | # Tokenizer to split input text into tokens 19 | encoder_json_path = "https://download.pytorch.org/models/text/gpt2_bpe_encoder.json" 20 | vocab_bpe_path = "https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe" 21 | self.tokenizer = T.GPT2BPETokenizer(encoder_json_path, vocab_bpe_path) 22 | 23 | # vocabulary converting tokens to IDs 24 | vocab_path = "https://download.pytorch.org/models/text/roberta.vocab.pt" 25 | self.vocab = T.VocabTransform(load_state_dict_from_url(vocab_path)) 26 | 27 | # Add BOS token to the beginning of sentence 28 | self.add_bos = T.AddToken(token=0, begin=True) 29 | 30 | # Add EOS token to the end of sentence 31 | self.add_eos = T.AddToken(token=2, begin=False) 32 | 33 | def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: 34 | tokens = self.tokenizer(input["text"]) 35 | tokens = F.truncate(tokens, max_seq_len=254) 36 | tokens = self.vocab(tokens) 37 | tokens = self.add_bos(tokens) 38 | tokens = self.add_eos(tokens) 39 | input["tokens"] = tokens 40 | return input 41 | 42 | 43 | def main(args): 44 | # Instantiate transform 45 | transform = RobertaTransformDataPipe() 46 | 47 | # Create SST2 datapipe and apply pre-processing 48 | batch_size = args.batch_size 49 | train_dp = SST2(split="train") 50 | train_dp = train_dp.batch(batch_size).rows2columnar(["text", "label"]) 51 | 52 | # Apply text pre-processing 53 | train_dp = train_dp.map(transform) 54 | 55 | # convert to Tensor 56 | train_dp = train_dp.map(partial(F.to_tensor, padding_value=1), input_col="tokens") 57 | train_dp = train_dp.map(F.to_tensor, input_col="label") 58 | 59 | # create DataLoader 60 | dl = DataLoader(train_dp, batch_size=None) 61 | 62 | train_steps = args.train_steps 63 | for i, batch in enumerate(dl): 64 | if i == train_steps: 65 | break 66 | 67 | # model_input = batch["tokens"] 68 | # target = batch["label"] 69 | ... 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = ArgumentParser() 74 | parser.add_argument("--batch-size", default=4, type=int) 75 | parser.add_argument("--train-steps", default=-1, type=int) 76 | main(parser.parse_args()) 77 | -------------------------------------------------------------------------------- /examples/libtorchtext/.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | **/*.pt 3 | **/*.bpe 4 | **/*.json 5 | -------------------------------------------------------------------------------- /examples/libtorchtext/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18 FATAL_ERROR) 2 | project(libtorchtext_cpp_example) 3 | 4 | SET(BUILD_TORCHTEXT_PYTHON_EXTENSION OFF CACHE BOOL "Build Python binding") 5 | 6 | find_package(Torch REQUIRED) 7 | message("libtorchtext CMakeLists: ${TORCH_CXX_FLAGS}") 8 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") 9 | 10 | add_subdirectory(../.. libtorchtext) 11 | add_subdirectory(tokenizer) 12 | -------------------------------------------------------------------------------- /examples/libtorchtext/README.md: -------------------------------------------------------------------------------- 1 | # Libtorchtext Examples 2 | 3 | - [Tokenizer](./tokenizer) 4 | 5 | ## Build 6 | 7 | The example applications in this directory depend on `libtorch` and `libtorchtext`. If you have a working `PyTorch`, you 8 | already have `libtorch`. Please refer to 9 | [this tutorial](https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html) for the use of `libtorch` and 10 | TorchScript. 11 | 12 | `libtorchtext` is the library of torchtext's C++ components without Python components. It is currently not distributed, 13 | and it will be built alongside with the applications. 14 | 15 | To build `libtorchtext` and the example applications you can run the following command. 16 | 17 | ```bash 18 | chmod +x build.sh # give script execute permission 19 | ./build.sh 20 | ``` 21 | 22 | For the usages of each application, refer to the corresponding application directory. 23 | -------------------------------------------------------------------------------- /examples/libtorchtext/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -eux 4 | 5 | this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | build_dir="${this_dir}/build" 7 | 8 | mkdir -p "${build_dir}" 9 | cd "${build_dir}" 10 | 11 | git submodule update 12 | cmake \ 13 | -DCMAKE_PREFIX_PATH="$(python -c 'import torch;print(torch.utils.cmake_prefix_path)')" \ 14 | -DRE2_BUILD_TESTING:BOOL=OFF \ 15 | -DBUILD_TESTING:BOOL=OFF \ 16 | -DSPM_ENABLE_SHARED=OFF \ 17 | .. 18 | cmake --build . 19 | -------------------------------------------------------------------------------- /examples/libtorchtext/tokenizer/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(tokenize main.cpp) 2 | target_link_libraries(tokenize "${TORCH_LIBRARIES}" "${TORCHTEXT_LIBRARY}") 3 | set_property(TARGET tokenize PROPERTY CXX_STANDARD 14) 4 | -------------------------------------------------------------------------------- /examples/libtorchtext/tokenizer/README.md: -------------------------------------------------------------------------------- 1 | # Tokenizer 2 | 3 | This example demonstrates how you can use torchtext's `GPT2BPETokenizer` in a C++ environment. 4 | 5 | ## Steps 6 | 7 | ### 1. Download necessary artifacts 8 | 9 | First we download `gpt2_bpe_vocab.bpe` and `gpt2_bpe_encoder.json` artifacts, both of which are needed to construct the 10 | `GPT2BPETokenizer` object. 11 | 12 | ```bash 13 | curl -O https://download.pytorch.org/models/text/gpt2_bpe_vocab.bpe 14 | curl -O https://download.pytorch.org/models/text/gpt2_bpe_encoder.json 15 | ``` 16 | 17 | ### 2. Create tokenizer TorchScript file 18 | 19 | Next we create our tokenizer object, and save it as a TorchScript object. We also print out the output of the tokenizer 20 | on a sample sentence and verify that the output is the same before and after saving and re-loading the tokenizer. In the 21 | next steps we will load and execute the tokenizer in our C++ application. The C++ code is found in 22 | [`main.cpp`](./main.cpp). 23 | 24 | ```bash 25 | tokenizer_file="tokenizer.pt" 26 | python create_tokenizer.py --tokenizer-file "${tokenizer_file}" 27 | ``` 28 | 29 | ### 3. Build the application 30 | 31 | Please refer to [the top level README.md](../README.md) 32 | 33 | ### 4. Run the application 34 | 35 | Now we run the C++ application `tokenizer`, with the TorchScript object we created in Step 2. The tokenizer is run with 36 | the following sentence as input and we verify that the output is the same as that of Step 2. 37 | 38 | In [the top level directory](../) 39 | 40 | ```bash 41 | ./build/tokenizer/tokenize "tokenizer/${tokenizer_file}" 42 | ``` 43 | -------------------------------------------------------------------------------- /examples/libtorchtext/tokenizer/create_tokenizer.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from torchtext import transforms 5 | 6 | 7 | def main(args): 8 | tokenizer_file = args.tokenizer_file 9 | sentence = "The green grasshopper jumped over the fence" 10 | 11 | # create tokenizer object 12 | encoder_json = "gpt2_bpe_encoder.json" 13 | bpe_vocab = "gpt2_bpe_vocab.bpe" 14 | tokenizer = transforms.GPT2BPETokenizer(encoder_json_path=encoder_json, vocab_bpe_path=bpe_vocab) 15 | 16 | # script and save tokenizer 17 | tokenizer = torch.jit.script(tokenizer) 18 | print(tokenizer(sentence)) 19 | torch.jit.save(tokenizer, tokenizer_file) 20 | 21 | # load saved tokenizer and verify outputs match 22 | t = torch.jit.load(tokenizer_file) 23 | print(t(sentence)) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = ArgumentParser() 28 | parser.add_argument("--tokenizer-file", default="tokenizer.pt", type=str) 29 | main(parser.parse_args()) 30 | -------------------------------------------------------------------------------- /examples/libtorchtext/tokenizer/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | int main(int argc, const char* argv[]) { 9 | std::cout << "Loading model...\n"; 10 | 11 | torch::jit::script::Module module; 12 | try { 13 | module = torch::jit::load(argv[1]); 14 | } catch (const c10::Error& e) { 15 | return -1; 16 | } 17 | 18 | torch::NoGradGuard no_grad; // ensures that autograd is off 19 | torch::jit::IValue tokens_ivalue = module.forward(std::vector( 20 | 1, "The green grasshopper jumped over the fence")); 21 | std::cout << "Result: " << tokens_ivalue << std::endl; 22 | 23 | return 0; 24 | } 25 | -------------------------------------------------------------------------------- /examples/text_classification/README.md: -------------------------------------------------------------------------------- 1 | # This is an example to train a text classification model 2 | 3 | In the basic case, users can train the sentiment model in model.py with AG_NEWS dataset in torchtext.datasets. 4 | 5 | To try the example, run the following script: 6 | 7 | ```bash 8 | ./run_script.sh 9 | ``` 10 | 11 | In addition, one can also use sentencepiece tokenizer as shown below. A text classification model is developed and 12 | applied to reproduce the YelpReviewFull results from fastText. 13 | 14 | To try the example, simply run the following commands: 15 | 16 | ```bash 17 | python train.py YelpReviewFull --device cuda --use-sp-tokenizer True --num-epochs 10 --embed-dim 64 18 | ``` 19 | -------------------------------------------------------------------------------- /examples/text_classification/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | r""" 4 | The model is composed of the embeddingbag layer and the linear layer. 5 | 6 | nn.EmbeddingBag computes the mean of 'bags' of embeddings. The text 7 | entries here have different lengths. nn.EmbeddingBag requires no 8 | padding because the lengths of sentences are saved in offsets. 9 | Therefore, this method is much faster than the original one 10 | with TorchText Iterator and Batch. 11 | 12 | Additionally, since it accumulates the average across the embeddings on the fly, 13 | nn.EmbeddingBag can enhance the performance and memory efficiency 14 | to process a sequence of tensors. 15 | 16 | """ 17 | 18 | 19 | class TextClassificationModel(nn.Module): 20 | def __init__(self, vocab_size, embed_dim, num_class): 21 | super(TextClassificationModel, self).__init__() 22 | self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) 23 | self.fc = nn.Linear(embed_dim, num_class) 24 | self.init_weights() 25 | 26 | def init_weights(self): 27 | initrange = 0.5 28 | self.embedding.weight.data.uniform_(-initrange, initrange) 29 | self.fc.weight.data.uniform_(-initrange, initrange) 30 | self.fc.bias.data.zero_() 31 | 32 | def forward(self, text, offsets): 33 | embedded = self.embedding(text, offsets) 34 | return self.fc(embedded) 35 | -------------------------------------------------------------------------------- /examples/text_classification/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | 4 | import torch 5 | from torchtext.data.utils import get_tokenizer, ngrams_iterator 6 | from torchtext.prototype.transforms import load_sp_model, PRETRAINED_SP_MODEL, SentencePieceTokenizer 7 | from torchtext.utils import download_from_url 8 | 9 | 10 | def predict(text, model, dictionary, tokenizer, ngrams): 11 | r""" 12 | The predict() function here is used to test the model on a sample text. 13 | The input text is numericalized with the vocab and then sent to 14 | the model for inference. 15 | 16 | Args: 17 | text: a sample text string 18 | model: the trained model 19 | dictionary: a vocab object for the information of string-to-index 20 | tokenizer: tokenizer object to split text into tokens 21 | ngrams: the number of ngrams. 22 | """ 23 | with torch.no_grad(): 24 | text = torch.tensor(dictionary(list(ngrams_iterator(tokenizer(text), ngrams)))) 25 | output = model(text, torch.tensor([0])) 26 | return output.argmax(1).item() + 1 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description="Predict text from stdin given model and dictionary") 31 | parser.add_argument("model", help="the path for model") 32 | parser.add_argument("dictionary", help="the path for dictionary") 33 | parser.add_argument("--ngrams", type=int, default=2, help="ngrams (default=2)") 34 | parser.add_argument( 35 | "--use-sp-tokenizer", type=bool, default=False, help="use sentencepiece tokenizer (default=False)" 36 | ) 37 | args = parser.parse_args() 38 | 39 | model = torch.load(args.model) 40 | dictionary = torch.load(args.dictionary) 41 | if args.use_sp_tokenizer: 42 | sp_model_path = download_from_url(PRETRAINED_SP_MODEL["text_unigram_15000"]) 43 | sp_model = load_sp_model(sp_model_path) 44 | tokenizer = SentencePieceTokenizer(sp_model) 45 | else: 46 | tokenizer = get_tokenizer("basic_english") 47 | for line in sys.stdin: 48 | print(predict(line, model, dictionary, tokenizer, args.ngrams)) 49 | -------------------------------------------------------------------------------- /examples/text_classification/run_script.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d ".data" ]; then 2 | mkdir .data 3 | fi 4 | 5 | python train.py AG_NEWS --device cpu --save-model-path model.i --dictionary vocab.i 6 | cut -f 2- -d "," .data/AG_NEWS/test.csv | python predict.py model.i vocab.i > predict_script.o 7 | 8 | # To train using pre-trained sentencepiece tokenizer 9 | # python train.py AG_NEWS --device cpu --save-model-path model.i --dictionary vocab.i --use-sp-tokenizer True 10 | 11 | # To run spm with YelpReviewFull 12 | # python train.py YelpReviewFull --device cuda --save-model-path model.i --dictionary vocab.i --use-sp-tokenizer True 13 | -------------------------------------------------------------------------------- /examples/torcharrow/README.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | This example shows end-2-end training for SST-2 binary classification using the RoBERTa model and TorchArrow based text 4 | pre-processing. The main motivation for this example is to demonstrate the authoring of a text processing pipeline on 5 | top of TorchArrow DataFrame. 6 | 7 | ## Installation and Usage 8 | 9 | The example depends on TorchArrow and TorchData. 10 | 11 | #### TorchArrow 12 | 13 | Install it from source following instructions at https://github.com/pytorch/torcharrow#from-source. Note that some of 14 | the natively integrated text operators (`bpe_tokenize` for tokenization, `lookup_indices` for vocabulary look-up) used 15 | in this example depend on the torch library. By default, TorchArrow doesn’t take dependency on the torch library. Hence 16 | make sure to use flag `USE_TORCH=1` during TorchArrow installation (this is also the reason why we cannot depend on 17 | nightly releases) 18 | 19 | ``` 20 | USE_TORCH=1 python setup.py install 21 | ``` 22 | 23 | #### TorchData 24 | 25 | To install TorchData follow instructions at https://github.com/pytorch/data#installation 26 | 27 | #### Usage 28 | 29 | To run example from command line run following command: 30 | 31 | ```bash 32 | python roberta_sst2_training_with_torcharrow.py \ 33 | --batch-size 16 \ 34 | --num-epochs 1 \ 35 | --learning-rate 1e-5 36 | ``` 37 | -------------------------------------------------------------------------------- /examples/tutorials/README.rst: -------------------------------------------------------------------------------- 1 | Tutorials 2 | ========= 3 | -------------------------------------------------------------------------------- /examples/vocab/fairseq_vocab.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Dict, List, Optional 3 | 4 | from fairseq.data.dictionary import Dictionary 5 | from torchtext.vocab import Vocab 6 | 7 | 8 | def build_fairseq_vocab( 9 | vocab_file: str, 10 | dictionary_class: Dictionary = Dictionary, 11 | special_token_replacements: Dict[str, str] = None, 12 | unk_token: str = "", 13 | max_vocab: int = -1, 14 | min_count: int = -1, 15 | tokens_to_add: Optional[List[str]] = None, 16 | ): 17 | """Function builds a torchtext Vocab for models pre-trained using Fairseq 18 | modules. 19 | 20 | The dictionary class can take any Fairseq Dictionary class and is 21 | used to load the vocab file. 22 | 23 | """ 24 | if not special_token_replacements: 25 | special_token_replacements = { 26 | "": "__PAD__", 27 | "": "__BEGIN_OF_SENTENCE__", 28 | "": "__END_OF_SENTENCE__", 29 | "": "__UNKNOWN__", 30 | "": "__MASK__", 31 | } 32 | unk_replacement = ( 33 | special_token_replacements[unk_token] if unk_token in special_token_replacements else unk_token 34 | ) 35 | special_tokens_to_remove = [special_pair[0] for special_pair in special_token_replacements] 36 | special_tokens_to_add = tuple( 37 | special_pair[1] for special_pair in special_token_replacements if special_pair[0] != unk_token 38 | ) 39 | 40 | with open(vocab_file) as f: 41 | dictionary = dictionary_class.load(f) 42 | # finalize will sort the dict based on frequency so only do this if 43 | # a min_count or max_vocab size is specified 44 | if min_count > 0 or max_vocab > 0: 45 | dictionary.finalize(threshold=min_count, nwords=max_vocab, padding_factor=1) 46 | if tokens_to_add: 47 | for token in tokens_to_add: 48 | dictionary.add_symbol(token) 49 | 50 | dictionary_items = list(zip(dictionary.symbols, dictionary.count)) 51 | 52 | ordered_dict = OrderedDict() 53 | # add special tokens to beginning of ordered_dict 54 | for s in special_tokens_to_add: 55 | ordered_dict[s] = 1 56 | 57 | # add all other tokens from dictionary_items 58 | for token, freq in dictionary_items: 59 | ordered_dict[token] = freq 60 | 61 | # remove special_tokens_to_remove from dict 62 | for s in special_tokens_to_remove: 63 | if s in ordered_dict: 64 | del ordered_dict[s] 65 | 66 | return Vocab(dictionary_items, unk_token=unk_replacement) 67 | -------------------------------------------------------------------------------- /examples/vocab/test.csv: -------------------------------------------------------------------------------- 1 | "2","3 US boxers punched out of Games","Athens -- Vanes Martirosyan became the second American to bow out of the Olympic boxing tournament Thursday when he was defeated 20-11 by Lorenzo Aragon of Cuba in their welterweight bout at 152 pounds. " 2 | "3","Before-the Bell; Rouse Co. Shares Jump"," <A HREF=""http://www.investor.reuters.com/FullQuote.aspx?ticker=RSE.N target=/stocks/quickinfo/fullquote"">RSE.N</A> jumped before the bell after General Growth Properties Inc. <A HREF=""http://www.investor.reuters.com/FullQuote.aspx?ticker=GGP.N target=/stocks/quickinfo/fullquote"">GGP.N</A>, the No. 2 U.S. shopping mall owner, on Friday said it would buy Rouse for \$7.2 billion." 3 | "3","Services make big gains in Japan","Tertiary index comes in at almost double expectations, drives up yen and helps Nikkei overcome oil. LONDON (Reuters) - The yen hit a four-week high against the dollar Friday as stronger-than-expected Japanese service sector data raised optimism about the ..." 4 | "3","Google shares bounce up 18 in trading debut","In the stock #39;s first day of trading, investors bought, sold and flipped shares at a furious pace, with the price ending just above \$100 - 18 percent higher than where it started. It was, in other words, everything the company #39;s founders, Sergy Brin and ..." 5 | "3","Stocks Lower as Oil Prices Steam Higher","With the much-ballyhooed initial public offering of Google behind them and oil chugging to a new record high, investors took a step back today." 6 | "2","Don #39;t expect Tiger to relinquish his top ranking without a fight","They #39;re calling Ohio a quot;battleground state, quot; one of the two or three places likely to decide November #39;s presidential election. On local TV, the Bush and Kerry ads air so frequently that it #39;s easy to forget it #39;s Bob Costas who actually runs the country. " 7 | -------------------------------------------------------------------------------- /notebooks/hf_vs_tt_t5.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "### Ensuring the TorchText T5 implementation matches other OSS implementations\n", 8 | "\n", 9 | "> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 29, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "from transformers import T5Model\n", 19 | "from torchtext.prototype.models import T5_BASE\n", 20 | "\n", 21 | "import torch" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 30, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "input_sentence = [\"translate to Spanish: My name is Joe\"]\n", 31 | "output_sentence = [\"Me llamo Joe\"]\n", 32 | "\n", 33 | "transform = T5_BASE.transform()\n", 34 | "tt_t5_model = T5_BASE.get_model()\n", 35 | "\n", 36 | "hf_t5_model = T5Model.from_pretrained(\"t5-base\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 31, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "tokenized_sentence = transform(input_sentence)\n", 46 | "tokenized_output = transform(output_sentence)\n", 47 | "\n", 48 | "tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n", 49 | "hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n", 50 | "\n", 51 | "assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n", 52 | "assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))" 53 | ] 54 | } 55 | ], 56 | "metadata": { 57 | "kernelspec": { 58 | "display_name": "Python 3.9.13 ('torchtext39')", 59 | "language": "python", 60 | "name": "python3" 61 | }, 62 | "language_info": { 63 | "codemirror_mode": { 64 | "name": "ipython", 65 | "version": 3 66 | }, 67 | "file_extension": ".py", 68 | "mimetype": "text/x-python", 69 | "name": "python", 70 | "nbconvert_exporter": "python", 71 | "pygments_lexer": "ipython3", 72 | "version": "3.9.13" 73 | }, 74 | "orig_nbformat": 4, 75 | "vscode": { 76 | "interpreter": { 77 | "hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" 78 | } 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 2 83 | } 84 | -------------------------------------------------------------------------------- /packaging/build_conda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | . "$script_dir/pkg_helpers.bash" 6 | 7 | export BUILD_TYPE="conda" 8 | export NO_CUDA_PACKAGE=1 9 | setup_env 10 | export SOURCE_ROOT_DIR="$PWD" 11 | setup_conda_pytorch_constraint 12 | setup_visual_studio_constraint 13 | 14 | conda build $CONDA_CHANNEL_FLAGS --no-anaconda-upload --python "$PYTHON_VERSION" packaging/torchtext 15 | -------------------------------------------------------------------------------- /packaging/build_wheel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 5 | . "$script_dir/pkg_helpers.bash" 6 | 7 | export BUILD_TYPE="wheel" 8 | export NO_CUDA_PACKAGE=1 9 | setup_env 10 | setup_wheel_python 11 | pip_install numpy future cmake>=3.18.0 ninja 12 | setup_pip_pytorch_version 13 | git submodule update --init --recursive 14 | python setup.py clean 15 | if [[ "$OSTYPE" == "msys" ]]; then 16 | "$script_dir/vc_env_helper.bat" python setup.py bdist_wheel 17 | else 18 | python setup.py bdist_wheel 19 | fi 20 | -------------------------------------------------------------------------------- /packaging/cut_release.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # Usage (run from root of project): 4 | # TEST_INFRA_BRANCH=release/2.1 RELEASE_BRANCH=release/2.1 RELEASE_VERSION=2.1.0 packaging/cut_release.sh 5 | # 6 | # TEST_INFRA_BRANCH: The release branch of test-infra that houses all reusable 7 | # workflows 8 | # 9 | # RELEASE_BRANCH: The name of the release branch for this repo 10 | # 11 | # RELEASE_VERSION: Version of this current release 12 | 13 | set -eou pipefail 14 | 15 | # Create and Check out to Release Branch 16 | git checkout -b "${RELEASE_BRANCH}" 17 | 18 | # Change all GitHub Actions to reference the test-infra release branch 19 | # as opposed to main. 20 | for i in .github/workflows/*.yml; do 21 | if [[ "$OSTYPE" == "darwin"* ]]; then 22 | sed -i '' -e s#@main#@"${TEST_INFRA_BRANCH}"# $i; 23 | sed -i '' -e s#test-infra-ref:[[:space:]]main#"test-infra-ref: ${TEST_INFRA_BRANCH}"# $i; 24 | else 25 | sed -i -e s#@main#@"${TEST_INFRA_BRANCH}"# $i; 26 | sed -i -e s#test-infra-ref:[[:space:]]main#"test-infra-ref: ${TEST_INFRA_BRANCH}"# $i; 27 | fi 28 | done 29 | 30 | # Update the Release Version in version.txt 31 | echo "${RELEASE_VERSION}" >version.txt 32 | 33 | # Optional 34 | git add .github/workflows version.txt 35 | git commit -m "[RELEASE-ONLY CHANGES] Branch Cut for Release ${RELEASE_VERSION}" 36 | git push origin "${RELEASE_BRANCH}" 37 | -------------------------------------------------------------------------------- /packaging/torchtext/bld.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | python setup.py install --single-version-externally-managed --record=record.txt 4 | if errorlevel 1 exit /b 1 5 | -------------------------------------------------------------------------------- /packaging/torchtext/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -ex 3 | 4 | python setup.py install --single-version-externally-managed --record=record.txt 5 | -------------------------------------------------------------------------------- /packaging/torchtext/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: torchtext 3 | version: "{{ environ.get('BUILD_VERSION') }}" 4 | 5 | source: 6 | path: "{{ environ.get('SOURCE_ROOT_DIR') }}" 7 | 8 | requirements: 9 | build: 10 | - {{ compiler('c') }} # [win] 11 | - {{ compiler('cxx') }} # [win] 12 | - cmake 13 | 14 | host: 15 | - python 16 | - setuptools 17 | - cpuonly 18 | {{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT', 'pytorch') }} 19 | {{ environ.get('CONDA_EXTRA_BUILD_CONSTRAINT', '') }} 20 | 21 | 22 | run: 23 | - python 24 | - requests 25 | - tqdm 26 | {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} 27 | 28 | build: 29 | string: py{{py}} 30 | script_env: 31 | - BUILD_VERSION 32 | - MACOSX_DEPLOYMENT_TARGET 33 | 34 | test: 35 | imports: 36 | - torchtext 37 | - torchtext.datasets 38 | - torchtext.data 39 | - torchtext.prototype 40 | 41 | source_files: 42 | - test 43 | 44 | requires: 45 | - pytest 46 | - cpuonly 47 | 48 | about: 49 | home: https://github.com/pytorch/text 50 | license: BSD 51 | license_file: LICENSE 52 | summary: 'Data loaders and abstractions for text and NLP' 53 | -------------------------------------------------------------------------------- /packaging/vc_env_helper.bat: -------------------------------------------------------------------------------- 1 | @echo on 2 | 3 | set VC_VERSION_LOWER=16 4 | set VC_VERSION_UPPER=17 5 | 6 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [%VC_VERSION_LOWER%^,%VC_VERSION_UPPER%^) -property installationPath`) do ( 7 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 8 | set "VS15INSTALLDIR=%%i" 9 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 10 | goto vswhere 11 | ) 12 | ) 13 | 14 | :vswhere 15 | if "%VSDEVCMD_ARGS%" == "" ( 16 | call "%VS15VCVARSALL%" x64 || exit /b 1 17 | ) else ( 18 | call "%VS15VCVARSALL%" x64 %VSDEVCMD_ARGS% || exit /b 1 19 | ) 20 | 21 | @echo on 22 | 23 | set DISTUTILS_USE_SDK=1 24 | 25 | set args=%1 26 | shift 27 | :start 28 | if [%1] == [] goto done 29 | set args=%args% %1 30 | shift 31 | goto start 32 | 33 | :done 34 | if "%args%" == "" ( 35 | echo Usage: vc_env_helper.bat [command] [args] 36 | echo e.g. vc_env_helper.bat cl /c test.cpp 37 | ) 38 | 39 | %args% || exit /b 1 40 | -------------------------------------------------------------------------------- /packaging/vs2019/activate.bat: -------------------------------------------------------------------------------- 1 | :: Set env vars that tell distutils to use the compiler that we put on path 2 | SET DISTUTILS_USE_SDK=1 3 | SET MSSdk=1 4 | 5 | SET "VS_VERSION=16.0" 6 | SET "VS_MAJOR=16" 7 | SET "VS_YEAR=2019" 8 | 9 | set "MSYS2_ARG_CONV_EXCL=/AI;/AL;/OUT;/out" 10 | set "MSYS2_ENV_CONV_EXCL=CL" 11 | 12 | :: For Python 3.5+, ensure that we link with the dynamic runtime. See 13 | :: http://stevedower.id.au/blog/building-for-python-3-5-part-two/ for more info 14 | set "PY_VCRUNTIME_REDIST=%PREFIX%\\bin\\vcruntime140.dll" 15 | 16 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [16^,17^) -property installationPath`) do ( 17 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 18 | set "VSINSTALLDIR=%%i\" 19 | goto :vswhere 20 | ) 21 | ) 22 | 23 | :vswhere 24 | 25 | :: Shorten PATH to avoid the `input line too long` error. 26 | SET MyPath=%PATH% 27 | 28 | setlocal EnableDelayedExpansion 29 | 30 | SET TempPath="%MyPath:;=";"%" 31 | SET var= 32 | FOR %%a IN (%TempPath%) DO ( 33 | IF EXIST %%~sa ( 34 | SET "var=!var!;%%~sa" 35 | ) 36 | ) 37 | 38 | set "TempPath=!var:~1!" 39 | endlocal & set "PATH=%TempPath%" 40 | 41 | :: Shorten current directory too 42 | FOR %%A IN (.) DO CD "%%~sA" 43 | 44 | :: other things added by install_activate.bat at package build time 45 | -------------------------------------------------------------------------------- /packaging/vs2019/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | blas_impl: 2 | - mkl # [x86_64] 3 | c_compiler: 4 | - vs2019 # [win] 5 | cxx_compiler: 6 | - vs2019 # [win] 7 | python: 8 | - 3.5 9 | - 3.6 10 | # This differs from target_platform in that it determines what subdir the compiler 11 | # will target, not what subdir the compiler package will be itself. 12 | # For example, we need a win-64 vs2008_win-32 package, so that we compile win-32 13 | # code on win-64 miniconda. 14 | cross_compiler_target_platform: 15 | - win-64 # [win] 16 | target_platform: 17 | - win-64 # [win] 18 | vc: 19 | - 14 20 | zip_keys: 21 | - # [win] 22 | - vc # [win] 23 | - c_compiler # [win] 24 | - cxx_compiler # [win] 25 | -------------------------------------------------------------------------------- /packaging/vs2019/install_activate.bat: -------------------------------------------------------------------------------- 1 | set YEAR=2019 2 | set VER=16 3 | 4 | mkdir "%PREFIX%\etc\conda\activate.d" 5 | COPY "%RECIPE_DIR%\activate.bat" "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 6 | 7 | IF "%cross_compiler_target_platform%" == "win-64" ( 8 | set "target_platform=amd64" 9 | echo SET "CMAKE_GENERATOR=Visual Studio %VER% %YEAR% Win64" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 10 | echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 11 | IF "%VSDEVCMD_ARGS%" == "" ( 12 | echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x64 >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 13 | echo popd >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 14 | echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 15 | echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x86_amd64 >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 16 | ) ELSE ( 17 | echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x64 %VSDEVCMD_ARGS% >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 18 | echo popd >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 19 | echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 20 | echo CALL "VC\Auxiliary\Build\vcvarsall.bat" x86_amd64 %VSDEVCMD_ARGS% >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 21 | ) 22 | echo popd >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 23 | ) else ( 24 | set "target_platform=x86" 25 | echo SET "CMAKE_GENERATOR=Visual Studio %VER% %YEAR%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 26 | echo pushd "%%VSINSTALLDIR%%" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 27 | echo CALL "VC\Auxiliary\Build\vcvars32.bat" >> "%PREFIX%\etc\conda\activate.d\vs%YEAR%_compiler_vars.bat" 28 | echo popd 29 | ) 30 | -------------------------------------------------------------------------------- /packaging/vs2019/install_runtime.bat: -------------------------------------------------------------------------------- 1 | set VC_PATH=x86 2 | if "%ARCH%"=="64" ( 3 | set VC_PATH=x64 4 | ) 5 | 6 | set MSC_VER=2019 7 | 8 | rem :: This should always be present for VC installed with VS. Not sure about VC installed with Visual C++ Build Tools 2015 9 | rem FOR /F "usebackq tokens=3*" %%A IN (`REG QUERY "HKEY_LOCAL_MACHINE\Software\Microsoft\DevDiv\VC\Servicing\14.0\IDE.x64" /v UpdateVersion`) DO ( 10 | rem set SP=%%A 11 | rem ) 12 | 13 | rem if not "%SP%" == "%PKG_VERSION%" ( 14 | rem echo "Version detected from registry: %SP%" 15 | rem echo "does not match version of package being built (%PKG_VERSION%)" 16 | rem echo "Do you have current updates for VS 2015 installed?" 17 | rem exit 1 18 | rem ) 19 | 20 | 21 | REM ========== REQUIRES Win 10 SDK be installed, or files otherwise copied to location below! 22 | robocopy "C:\Program Files (x86)\Windows Kits\10\Redist\ucrt\DLLs\%VC_PATH%" "%LIBRARY_BIN%" *.dll /E 23 | robocopy "C:\Program Files (x86)\Windows Kits\10\Redist\ucrt\DLLs\%VC_PATH%" "%PREFIX%" *.dll /E 24 | if %ERRORLEVEL% GEQ 8 exit 1 25 | 26 | REM ========== This one comes from visual studio 2019 27 | set "VC_VER=142" 28 | 29 | for /f "usebackq tokens=*" %%i in (`"%ProgramFiles(x86)%\Microsoft Visual Studio\Installer\vswhere.exe" -legacy -products * -version [16^,17^) -property installationPath`) do ( 30 | if exist "%%i" if exist "%%i\VC\Auxiliary\Build\vcvarsall.bat" ( 31 | set "VS15VCVARSALL=%%i\VC\Auxiliary\Build\vcvarsall.bat" 32 | goto :eof 33 | ) 34 | ) 35 | 36 | @setlocal 37 | call "%VS15VARSALL%" x64 38 | 39 | set "REDIST_ROOT=%VCToolsRedistDir%%VC_PATH%" 40 | 41 | robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.CRT" "%LIBRARY_BIN%" *.dll /E 42 | if %ERRORLEVEL% LSS 8 exit 0 43 | robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.CRT" "%PREFIX%" *.dll /E 44 | if %ERRORLEVEL% LSS 8 exit 0 45 | robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.OpenMP" "%LIBRARY_BIN%" *.dll /E 46 | if %ERRORLEVEL% LSS 8 exit 0 47 | robocopy "%REDIST_ROOT%\Microsoft.VC%VC_VER%.OpenMP" "%PREFIX%" *.dll /E 48 | if %ERRORLEVEL% LSS 8 exit 0 49 | @endlocal 50 | -------------------------------------------------------------------------------- /packaging/vs2019/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set vcver="14.2" %} 2 | {% set vcfeature="14" %} 3 | {% set vsyear="2019" %} 4 | {% set fullver="15.4.27004.2010" %} 5 | 6 | package: 7 | name: vs{{ vsyear }} 8 | version: {{ fullver }} 9 | 10 | build: 11 | skip: True [not win] 12 | script_env: 13 | - VSDEVCMD_ARGS # [win] 14 | 15 | outputs: 16 | - name: vs{{ vsyear }}_{{ cross_compiler_target_platform }} 17 | script: install_activate.bat 18 | track_features: 19 | # VS 2019 is binary-compatible with VS 2017/vc 14.1 and 2015/vc14. Tools are "v142". 20 | strong: 21 | - vc{{ vcfeature }} 22 | about: 23 | summary: Activation and version verification of MSVC {{ vcver }} (VS {{ vsyear }}) compiler 24 | license: BSD 3-clause 25 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.usort] 2 | 3 | first_party_detection = false 4 | 5 | [tool.black] 6 | 7 | line-length = 120 8 | target-version = ["py37"] 9 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | addopts = --ignore-glob=test/torchtext_unittest/datasets/* 3 | testpaths = test/ 4 | python_paths = ./ 5 | markers = 6 | gpu_test: marks cuda tests 7 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | build: 2 | image: latest 3 | 4 | python: 5 | version: 3.5 6 | setup_py_install: true 7 | 8 | # Don't build any extra formats 9 | formats: [] 10 | 11 | requirements_file: docs/requirements.txt 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Progress bars on iterators 2 | tqdm 3 | 4 | # Downloading data and other files 5 | requests 6 | 7 | # Optional NLP tools 8 | nltk 9 | spacy 10 | sacremoses 11 | git+https://github.com/jekbradbury/revtok.git 12 | 13 | # Documentation 14 | Sphinx 15 | 16 | # Required for tests only: 17 | 18 | # Run unit tests 19 | pytest 20 | expecttest 21 | parameterized 22 | 23 | # Lets pytest find our code by automatically modifying PYTHONPATH 24 | pytest-pythonpath 25 | 26 | # Coverage statistics 27 | pytest-cov 28 | codecov 29 | 30 | # To parse untrusted XML data 31 | defusedxml 32 | -------------------------------------------------------------------------------- /test/integration_tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/integration_tests/__init__.py -------------------------------------------------------------------------------- /test/integration_tests/conftest.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | 3 | import pytest 4 | import torch 5 | 6 | 7 | def pytest_addoption(parser): 8 | parser.addoption( 9 | "--use-tmp-hub-dir", 10 | action="store_true", 11 | help=( 12 | "When provided, tests will use temporary directory as Torch Hub directory. " 13 | "Downloaded models will be deleted after each test." 14 | ), 15 | ) 16 | 17 | 18 | @pytest.fixture(autouse=True, scope="class") 19 | def temp_hub_dir(tmp_path_factory, pytestconfig): 20 | if not pytestconfig.getoption("use_tmp_hub_dir"): 21 | yield 22 | else: 23 | tmp_dir = tmp_path_factory.mktemp("hub", numbered=True).resolve() 24 | org_dir = torch.hub.get_dir() 25 | torch.hub.set_dir(tmp_dir) 26 | yield 27 | torch.hub.set_dir(org_dir) 28 | shutil.rmtree(tmp_dir, ignore_errors=True) 29 | -------------------------------------------------------------------------------- /test/integration_tests/test_roberta_models.py: -------------------------------------------------------------------------------- 1 | import pytest # noqa: F401 2 | import torch 3 | from parameterized import parameterized, parameterized_class 4 | from torchtext.models import ( 5 | ROBERTA_BASE_ENCODER, 6 | ROBERTA_LARGE_ENCODER, 7 | ROBERTA_DISTILLED_ENCODER, 8 | XLMR_BASE_ENCODER, 9 | XLMR_LARGE_ENCODER, 10 | ) 11 | from torchtext_unittest.common.assets import get_asset_path 12 | from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase 13 | 14 | BUNDLERS = { 15 | "xlmr_base": XLMR_BASE_ENCODER, 16 | "xlmr_large": XLMR_LARGE_ENCODER, 17 | "roberta_base": ROBERTA_BASE_ENCODER, 18 | "roberta_large": ROBERTA_LARGE_ENCODER, 19 | "roberta_distilled": ROBERTA_DISTILLED_ENCODER, 20 | } 21 | 22 | 23 | @parameterized_class( 24 | ("model_name",), 25 | [ 26 | ("xlmr_base",), 27 | ("xlmr_large",), 28 | ("roberta_base",), 29 | ("roberta_large",), 30 | ("roberta_distilled",), 31 | ], 32 | ) 33 | class TestRobertaEncoders(TorchtextTestCase): 34 | def _roberta_encoders(self, is_jit, encoder, expected_asset_name, test_text): 35 | """Verify pre-trained XLM-R and Roberta models in torchtext produce 36 | the same output as the reference implementation within fairseq 37 | """ 38 | expected_asset_path = get_asset_path(expected_asset_name) 39 | 40 | transform = encoder.transform() 41 | model = encoder.get_model() 42 | model = model.eval() 43 | 44 | if is_jit: 45 | transform = torch.jit.script(transform) 46 | model = torch.jit.script(model) 47 | 48 | model_input = torch.tensor(transform([test_text])) 49 | actual = model(model_input) 50 | expected = torch.load(expected_asset_path) 51 | torch.testing.assert_close(actual, expected) 52 | 53 | @parameterized.expand(["jit", "not_jit"]) 54 | def test_models(self, name): 55 | configuration, type = self.model_name.split("_") 56 | 57 | expected_asset_name = f"{configuration}.{type}.output.pt" 58 | is_jit = name == "jit" 59 | if configuration == "xlmr": 60 | test_text = "XLMR base Model Comparison" 61 | else: 62 | test_text = "Roberta base Model Comparison" 63 | 64 | self._roberta_encoders( 65 | is_jit=is_jit, 66 | encoder=BUNDLERS[configuration + "_" + type], 67 | expected_asset_name=expected_asset_name, 68 | test_text=test_text, 69 | ) 70 | -------------------------------------------------------------------------------- /test/smoke_tests/smoke_tests.py: -------------------------------------------------------------------------------- 1 | """Run smoke tests""" 2 | 3 | import torchtext 4 | 5 | 6 | print("torchtext version is ", torchtext.__version__) 7 | -------------------------------------------------------------------------------- /test/torchtext_unittest/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/__init__.py -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/SST2/SST-2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/SST2/SST-2.zip -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/glove.6B.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/glove.6B.zip -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/glove.840B.300d.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/glove.840B.300d.zip -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/label_names.txt: -------------------------------------------------------------------------------- 1 | test 2 | label 3 | indices 4 | -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/roberta.base.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/roberta.base.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/roberta.distilled.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/roberta.distilled.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/roberta.large.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/roberta.large.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/spm_example.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/spm_example.model -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.base.encoder.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.base.encoder.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.base.generation.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.base.generation.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.base.model.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.base.model.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.flan.base.encoder.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.flan.base.encoder.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.flan.base.generation.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.flan.base.generation.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.flan.base.model.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.flan.base.model.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.large.encoder.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.large.encoder.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.large.generation.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.large.generation.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.large.model.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.large.model.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.small.encoder.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.small.encoder.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.small.generation.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.small.generation.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5.small.model.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5.small.model.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/t5_tokenizer_base.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/t5_tokenizer_base.model -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/vectors_test.csv: -------------------------------------------------------------------------------- 1 | a,1 0 0 2 | b,0 1 0 3 | -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/vocab_raw_text_test.txt: -------------------------------------------------------------------------------- 1 | Fears for T N pension after talks Unions 2 | representing workers at Turner Newall say they are 'disappointed' 3 | after talks with stricken parent firm Federal Mogul. 4 | -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/vocab_test.txt: -------------------------------------------------------------------------------- 1 | b 2 | a 3 | c 4 | a 5 | b 6 | a 7 | c 8 | -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/vocab_test2.txt: -------------------------------------------------------------------------------- 1 | 2 | 3 | . 4 | the 5 | , 6 | to 7 | a 8 | of 9 | in 10 | and 11 | s 12 | on 13 | for 14 | #39 15 | ( 16 | ) 17 | - 18 | ' 19 | that 20 | with 21 | as 22 | at 23 | is 24 | its 25 | new 26 | by 27 | it 28 | said 29 | reuters 30 | -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/xlmr.base.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/xlmr.base.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/asset/xlmr.large.output.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/asset/xlmr.large.output.pt -------------------------------------------------------------------------------- /test/torchtext_unittest/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/common/__init__.py -------------------------------------------------------------------------------- /test/torchtext_unittest/common/assets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import shutil 4 | from pathlib import Path 5 | 6 | _ASSET_DIR = (Path(__file__).parent.parent / "asset").resolve() 7 | 8 | 9 | def get_asset_path(*path_components): 10 | """Get the path to the file under `test/assets` directory.""" 11 | return str(_ASSET_DIR.joinpath(*path_components)) 12 | 13 | 14 | def conditional_remove(f): 15 | for path in glob.glob(f): 16 | if os.path.isfile(path): 17 | os.remove(path) 18 | elif os.path.isdir(path): 19 | shutil.rmtree(path) 20 | -------------------------------------------------------------------------------- /test/torchtext_unittest/common/parameterized_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import product 3 | 4 | from parameterized import param, parameterized 5 | 6 | from .assets import get_asset_path 7 | 8 | 9 | def load_params(*paths): 10 | with open(get_asset_path(*paths), "r") as file: 11 | return [param(json.loads(line)) for line in file] 12 | 13 | 14 | def _name_func(func, _, params): 15 | strs = [] 16 | for arg in params.args: 17 | if isinstance(arg, tuple): 18 | strs.append("_".join(str(a) for a in arg)) 19 | else: 20 | strs.append(str(arg)) 21 | # sanitize the test name 22 | name = parameterized.to_safe_name("_".join(strs)) 23 | return f"{func.__name__}_{name}" 24 | 25 | 26 | def nested_params(*params_set): 27 | """Generate the cartesian product of the given list of parameters. 28 | Args: 29 | params_set (list of parameters): Parameters. When using ``parameterized.param`` class, 30 | all the parameters have to be specified with the class, only using kwargs. 31 | """ 32 | flatten = [p for params in params_set for p in params] 33 | 34 | # Parameters to be nested are given as list of plain objects 35 | if all(not isinstance(p, param) for p in flatten): 36 | args = list(product(*params_set)) 37 | return parameterized.expand(args, name_func=_name_func) 38 | 39 | # Parameters to be nested are given as list of `parameterized.param` 40 | if not all(isinstance(p, param) for p in flatten): 41 | raise TypeError("When using ``parameterized.param``, all the parameters have to be of the ``param`` type.") 42 | if any(p.args for p in flatten): 43 | raise ValueError( 44 | "When using ``parameterized.param``, all the parameters have to be provided as keyword argument." 45 | ) 46 | args = [param()] 47 | for params in params_set: 48 | args = [param(**x.kwargs, **y.kwargs) for x in args for y in params] 49 | return parameterized.expand(args) 50 | -------------------------------------------------------------------------------- /test/torchtext_unittest/csrc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/csrc/__init__.py -------------------------------------------------------------------------------- /test/torchtext_unittest/csrc/test_gpt2_bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | import regex as re 2 | import torch 3 | import torchtext # noqa: F401 4 | 5 | from ..common.torchtext_test_case import TorchtextTestCase 6 | 7 | 8 | class TestGPT2BPETokenizer(TorchtextTestCase): 9 | def test_gpt2_bpe_pre_tokenizer(self) -> None: 10 | # Regex pattern for GPT-2 BPE which includes the negative lookahead 11 | # Reference: https://github.com/pytorch/fairseq/blob/main/fairseq/data/encoders/gpt2_bpe_utils.py#L69 12 | gpt2_bpe_pattern = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 13 | test_cases = [ 14 | # test spaces 15 | "Lorem ipsum dolor sit amet.", 16 | "Lorem ipsum dolor sit amet.", 17 | "Lorem ipsum dolor sit amet. ", 18 | "Lorem ipsum dolor sit amet ", 19 | "Lorem\x0d\x0dipsum dolor sit amet\r\r", 20 | "Lorem ipsum\x20dolor sit amet", 21 | "Lorem ipsum\x20\x20\x20dolor sit amet", 22 | "Lorem ipsum\x20\x20 dolor sit amet", 23 | # test tabs 24 | "Lorem ipsum dolor sit \t\t\t amet.", 25 | "Lorem ipsum dolor sit \t\t\t\tamet.", 26 | "Lorem ipsum dolor sit \x09\x09amet.", 27 | "Lorem ipsum dolor sit \x09\x09 amet.", 28 | "Lorem ipsum dolor sit \x09\x09 amet. ", 29 | "Lorem ipsum dolor sit \t \tamet.", 30 | "Lorem ipsum dolor sit amet \t", 31 | "Lorem ipsum\tdolor sit amet", 32 | # test carriage returns 33 | "Lorem ipsum\r\r dolor sit amet", 34 | "Lorem ipsum\r\r dolor sit amet\r\r", 35 | "Lorem ipsum \x0d\x0ddolor sit amet.", 36 | "Lorem ipsum\x0ddolor sit amet.", 37 | "Lorem ipsum\x0d\x0d dolor sit amet.", 38 | "Lorem ipsum\x0d\x0d dolor sit amet.\x0d", 39 | # test form feeds 40 | "Lorem ipsum\f\fdolor sit amet\f", 41 | "Lorem ipsum\f\f dolor sit amet\f ", 42 | "Lorem ipsum\x0c\x0c dolor sit amet", 43 | "Lorem \x0c\x0c\x0c\x0cipsum dolor sit amet", 44 | # test vertical tabs 45 | "Lorem ipsum dolor sit\vamet.", 46 | "Lorem ipsum dolor sit\v\vamet.", 47 | "Lorem ipsum dolor sit\v\v amet.", 48 | "Lorem ipsum dolor sit\v\v amet. \v", 49 | "Lorem ipsum dolor sit\x0b\x0b amet. \v ", 50 | "Lorem ipsum dolor sit\x0bamet.", 51 | "Lorem ipsum dolor sit\x0b\x0bamet.", 52 | "Lorem ipsum dolor sit\x0b\x0b amet.", 53 | ] 54 | for t in test_cases: 55 | self.assertEqual(re.findall(gpt2_bpe_pattern, t), torch.ops.torchtext.gpt2_bpe_pre_tokenizer(t)) 56 | -------------------------------------------------------------------------------- /test/torchtext_unittest/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/data/__init__.py -------------------------------------------------------------------------------- /test/torchtext_unittest/data/test_dataset_utils.py: -------------------------------------------------------------------------------- 1 | from parameterized import parameterized 2 | from torch.utils.data.datapipes.iter import IterableWrapper 3 | from torchtext.data.datasets_utils import _ParseIOBData 4 | 5 | from ..common.torchtext_test_case import TorchtextTestCase 6 | 7 | 8 | class TestDatasetUtils(TorchtextTestCase): 9 | @parameterized.expand( 10 | [ 11 | [lambda it: list(_ParseIOBData(IterableWrapper(it), sep=" "))], 12 | [lambda it: list(IterableWrapper(it).read_iob(sep=" "))], 13 | ] 14 | ) 15 | def test_iob_datapipe(self, pipe_fn): 16 | iob = ["Alex I-PER", "is O", "going O", "to O", "Los I-LOC", "Angeles I-LOC", "in O", "California I-LOC"] 17 | iterable = [("ignored.txt", e) for e in iob] 18 | iob_dp = pipe_fn(iterable) 19 | # There's only one example in this dataset 20 | self.assertEqual(len(iob_dp), 1) 21 | # The length of the list of surface forms is the number of lines in the example 22 | self.assertEqual(len(iob_dp[0][0]), len(iob)) 23 | # The length of the list labels is the number of lines in the example 24 | self.assertEqual(len(iob_dp[0][1]), len(iob)) 25 | iob = [ 26 | "Alex I-PER", 27 | "is O", 28 | "going O", 29 | "to O", 30 | "Los I-LOC", 31 | "Angeles I-LOC", 32 | "in O", 33 | "California I-LOC", 34 | "", 35 | "Alex I-PER", 36 | "is O", 37 | "going O", 38 | "to O", 39 | "Los I-LOC", 40 | "Angeles I-LOC", 41 | "in O", 42 | "California I-LOC", 43 | ] 44 | iterable = [("ignored.txt", e) for e in iob] 45 | iob_dp = pipe_fn(iterable) 46 | # There are two examples in this dataset 47 | self.assertEqual(len(iob_dp), 2) 48 | # The length of the first list of surface forms is the length of everything before the empty line. 49 | # The length of the first labels is the length of everything before the empty line. 50 | self.assertEqual(len(iob_dp[0][0]), iob.index("")) 51 | self.assertEqual(len(iob_dp[0][1]), iob.index("")) 52 | # The length of the second list of surface forms is the length of everything after the empty line. 53 | # The length of the second labels is the length of everything after the empty line. 54 | self.assertEqual(len(iob_dp[1][0]), len(iob) - iob.index("") - 1) 55 | self.assertEqual(len(iob_dp[1][1]), len(iob) - iob.index("") - 1) 56 | -------------------------------------------------------------------------------- /test/torchtext_unittest/data/test_jit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.testing import assert_close 3 | from torchtext.nn import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct 4 | 5 | from ..common.torchtext_test_case import TorchtextTestCase 6 | 7 | 8 | class TestJIT(TorchtextTestCase): 9 | def test_torchscript_multiheadattention(self) -> None: 10 | embed_dim, nhead, tgt_len, src_len, bsz = 10, 5, 6, 10, 64 11 | # Build torchtext MultiheadAttention models 12 | in_proj_container = InProjContainer( 13 | torch.nn.Linear(embed_dim, embed_dim, bias=False), 14 | torch.nn.Linear(embed_dim, embed_dim, bias=False), 15 | torch.nn.Linear(embed_dim, embed_dim, bias=False), 16 | ) 17 | 18 | MHA = MultiheadAttentionContainer( 19 | nhead, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim, bias=False) 20 | ) 21 | query = torch.rand((tgt_len, bsz, embed_dim)) 22 | key = value = torch.rand((src_len, bsz, embed_dim)) 23 | attn_mask = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool) 24 | attn_mask = torch.stack([attn_mask] * (bsz * nhead)) 25 | mha_output, attn_weights = MHA(query, key, value, attn_mask=attn_mask) 26 | 27 | ts_MHA = torch.jit.script(MHA) 28 | ts_mha_output, ts_attn_weights = ts_MHA(query, key, value, attn_mask=attn_mask) 29 | assert_close(mha_output, ts_mha_output) 30 | assert_close(attn_weights, ts_attn_weights) 31 | -------------------------------------------------------------------------------- /test/torchtext_unittest/data/test_utils.py: -------------------------------------------------------------------------------- 1 | from torchtext.data import get_tokenizer 2 | 3 | from ..common.torchtext_test_case import TorchtextTestCase 4 | 5 | 6 | class TestUtils(TorchtextTestCase): 7 | TEST_STR = "A string, particularly one with slightly complex punctuation." 8 | 9 | def test_get_tokenizer_split(self) -> None: 10 | # Test the default case with str.split 11 | assert get_tokenizer(str.split) == str.split 12 | assert get_tokenizer(str.split)(self.TEST_STR) == str.split(self.TEST_STR) 13 | 14 | def test_get_tokenizer_toktokt(self) -> None: 15 | # Test Toktok option. Test strings taken from NLTK doctests. 16 | # Note that internally, MosesTokenizer converts to unicode if applicable 17 | toktok_tokenizer = get_tokenizer("toktok") 18 | assert toktok_tokenizer(self.TEST_STR) == [ 19 | "A", 20 | "string", 21 | ",", 22 | "particularly", 23 | "one", 24 | "with", 25 | "slightly", 26 | "complex", 27 | "punctuation", 28 | ".", 29 | ] 30 | 31 | # Test that errors are raised for invalid input arguments. 32 | with self.assertRaises(ValueError): 33 | get_tokenizer(1) 34 | with self.assertRaises(ValueError): 35 | get_tokenizer("some other string") 36 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/datasets/__init__.py -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/common.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from parameterized import parameterized 4 | from torch.utils.data.graph import traverse_dps 5 | from torch.utils.data.graph_settings import get_all_graph_pipes 6 | from torchdata.dataloader2.linter import _check_shuffle_before_sharding 7 | from torchdata.datapipes.iter import Shuffler, ShardingFilter 8 | from torchtext.datasets import DATASETS 9 | 10 | from ..common.torchtext_test_case import TorchtextTestCase 11 | 12 | 13 | class TestDatasetPickling(TorchtextTestCase): 14 | @parameterized.expand([(f,) for f in DATASETS.values()]) 15 | def test_pickling(self, dataset_fn): 16 | dp = dataset_fn() 17 | if type(dp) == tuple: 18 | dp = list(dp) 19 | else: 20 | dp = [dp] 21 | 22 | for dp_split in dp: 23 | pickle.loads(pickle.dumps(dp_split)) 24 | 25 | 26 | class TestShuffleShardDatasetWrapper(TorchtextTestCase): 27 | # Note that for order i.e shuffle before sharding, TorchData will provide linter warning 28 | # Modify this test when linter warning is available 29 | @parameterized.expand([(f,) for f in DATASETS.values()]) 30 | def test_shuffle_shard_wrapper(self, dataset_fn): 31 | dp = dataset_fn() 32 | if type(dp) == tuple: 33 | dp = list(dp) 34 | else: 35 | dp = [dp] 36 | 37 | for dp_split in dp: 38 | _check_shuffle_before_sharding(dp_split) 39 | 40 | dp_graph = get_all_graph_pipes(traverse_dps(dp_split)) 41 | for annotation_dp_type in [Shuffler, ShardingFilter]: 42 | if not any(isinstance(dp, annotation_dp_type) for dp in dp_graph): 43 | raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") 44 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/test_agnews.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from unittest.mock import patch 4 | 5 | from parameterized import parameterized 6 | from torchtext.datasets.ag_news import AG_NEWS 7 | 8 | from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode 9 | from ..common.torchtext_test_case import TorchtextTestCase 10 | 11 | 12 | def _get_mock_dataset(root_dir): 13 | """ 14 | root_dir: directory to the mocked dataset 15 | """ 16 | base_dir = os.path.join(root_dir, "AG_NEWS") 17 | os.makedirs(base_dir, exist_ok=True) 18 | 19 | seed = 1 20 | mocked_data = defaultdict(list) 21 | for file_name in ("train.csv", "test.csv"): 22 | txt_file = os.path.join(base_dir, file_name) 23 | with open(txt_file, "w", encoding="utf-8") as f: 24 | for i in range(5): 25 | label = seed % 4 + 1 26 | rand_string = get_random_unicode(seed) 27 | dataset_line = (label, f"{rand_string} {rand_string}") 28 | # append line to correct dataset split 29 | mocked_data[os.path.splitext(file_name)[0]].append(dataset_line) 30 | f.write(f'"{label}","{rand_string}","{rand_string}"\n') 31 | seed += 1 32 | 33 | return mocked_data 34 | 35 | 36 | class TestAGNews(TempDirMixin, TorchtextTestCase): 37 | root_dir = None 38 | samples = [] 39 | patcher = None 40 | 41 | @classmethod 42 | def setUpClass(cls): 43 | super().setUpClass() 44 | cls.root_dir = cls.get_base_temp_dir() 45 | cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) 46 | cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) 47 | cls.patcher.start() 48 | 49 | @classmethod 50 | def tearDownClass(cls): 51 | cls.patcher.stop() 52 | super().tearDownClass() 53 | 54 | @parameterized.expand(["train", "test"]) 55 | def test_agnews(self, split): 56 | dataset = AG_NEWS(root=self.root_dir, split=split) 57 | 58 | samples = list(dataset) 59 | expected_samples = self.samples[split] 60 | for sample, expected_sample in zip_equal(samples, expected_samples): 61 | self.assertEqual(sample, expected_sample) 62 | 63 | @parameterized.expand(["train", "test"]) 64 | def test_agnews_split_argument(self, split): 65 | dataset1 = AG_NEWS(root=self.root_dir, split=split) 66 | (dataset2,) = AG_NEWS(root=self.root_dir, split=(split,)) 67 | 68 | for d1, d2 in zip_equal(dataset1, dataset2): 69 | self.assertEqual(d1, d2) 70 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/test_cc100.py: -------------------------------------------------------------------------------- 1 | import lzma 2 | import os 3 | from collections import defaultdict 4 | from unittest.mock import patch 5 | 6 | from parameterized import parameterized 7 | from torchtext.datasets import CC100 8 | from torchtext.datasets.cc100 import VALID_CODES 9 | 10 | from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode 11 | from ..common.torchtext_test_case import TorchtextTestCase 12 | 13 | 14 | def _get_mock_dataset(root_dir): 15 | """ 16 | root_dir: directory to the mocked dataset 17 | """ 18 | base_dir = os.path.join(root_dir, "CC100") 19 | os.makedirs(base_dir, exist_ok=True) 20 | 21 | seed = 1 22 | mocked_data = defaultdict(list) 23 | 24 | for language_code in VALID_CODES: 25 | file_name = f"{language_code}.txt.xz" 26 | compressed_file = os.path.join(base_dir, file_name) 27 | with lzma.open(compressed_file, "wt", encoding="utf-8") as f: 28 | for i in range(5): 29 | rand_string = get_random_unicode(seed) 30 | content = f"{rand_string}\n" 31 | f.write(content) 32 | mocked_data[language_code].append((language_code, rand_string)) 33 | seed += 1 34 | 35 | return mocked_data 36 | 37 | 38 | class TestCC100(TempDirMixin, TorchtextTestCase): 39 | @classmethod 40 | def setUpClass(cls): 41 | super().setUpClass() 42 | cls.root_dir = cls.get_base_temp_dir() 43 | cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) 44 | cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) 45 | cls.patcher.start() 46 | 47 | @classmethod 48 | def tearDownClass(cls): 49 | cls.patcher.stop() 50 | super().tearDownClass() 51 | 52 | @parameterized.expand(VALID_CODES) 53 | def test_cc100(self, language_code): 54 | dataset = CC100(root=self.root_dir, language_code=language_code) 55 | 56 | samples = list(dataset) 57 | expected_samples = self.samples[language_code] 58 | for sample, expected_sample in zip_equal(samples, expected_samples): 59 | self.assertEqual(sample, expected_sample) 60 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/test_enwik9.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | from unittest.mock import patch 4 | 5 | from torchtext.datasets.enwik9 import EnWik9 6 | 7 | from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode 8 | from ..common.torchtext_test_case import TorchtextTestCase 9 | 10 | 11 | def _get_mock_dataset(root_dir): 12 | """ 13 | root_dir: directory to the mocked dataset 14 | """ 15 | base_dir = os.path.join(root_dir, "EnWik9") 16 | temp_dataset_dir = os.path.join(base_dir, "temp_dataset_dir") 17 | os.makedirs(temp_dataset_dir, exist_ok=True) 18 | 19 | seed = 1 20 | file_name = "enwik9" 21 | txt_file = os.path.join(temp_dataset_dir, file_name) 22 | mocked_data = [] 23 | with open(txt_file, "w", encoding="utf-8") as f: 24 | for i in range(5): 25 | rand_string = "<" + get_random_unicode(seed) + ">" 26 | dataset_line = f"'{rand_string}'" 27 | f.write(f"'{rand_string}'\n") 28 | 29 | # append line to correct dataset split 30 | mocked_data.append(dataset_line) 31 | seed += 1 32 | 33 | compressed_dataset_path = os.path.join(base_dir, "enwik9.zip") 34 | # create zip file from dataset folder 35 | with zipfile.ZipFile(compressed_dataset_path, "w") as zip_file: 36 | txt_file = os.path.join(temp_dataset_dir, file_name) 37 | zip_file.write(txt_file, arcname=file_name) 38 | 39 | return mocked_data 40 | 41 | 42 | class TestEnWik9(TempDirMixin, TorchtextTestCase): 43 | root_dir = None 44 | samples = [] 45 | 46 | @classmethod 47 | def setUpClass(cls): 48 | super().setUpClass() 49 | cls.root_dir = cls.get_base_temp_dir() 50 | cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) 51 | cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) 52 | cls.patcher.start() 53 | 54 | @classmethod 55 | def tearDownClass(cls): 56 | cls.patcher.stop() 57 | super().tearDownClass() 58 | 59 | def test_enwik9(self) -> None: 60 | dataset = EnWik9(root=self.root_dir) 61 | 62 | samples = list(dataset) 63 | expected_samples = self.samples 64 | for sample, expected_sample in zip_equal(samples, expected_samples): 65 | self.assertEqual(sample, expected_sample) 66 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/test_mrpc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from unittest.mock import patch 4 | 5 | from parameterized import parameterized 6 | from torchtext.datasets.mrpc import MRPC 7 | 8 | from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode 9 | from ..common.torchtext_test_case import TorchtextTestCase 10 | 11 | 12 | def _get_mock_dataset(root_dir): 13 | """ 14 | root_dir: directory to the mocked dataset 15 | """ 16 | base_dir = os.path.join(root_dir, "MRPC") 17 | os.makedirs(base_dir, exist_ok=True) 18 | 19 | seed = 1 20 | mocked_data = defaultdict(list) 21 | for file_name, file_type in [("msr_paraphrase_train.txt", "train"), ("msr_paraphrase_test.txt", "test")]: 22 | txt_file = os.path.join(base_dir, file_name) 23 | with open(txt_file, "w", encoding="utf-8") as f: 24 | f.write("Quality\t#1 ID\t#2 ID\t#1 String\t#2 String\n") 25 | for i in range(5): 26 | label = seed % 2 27 | rand_string_1 = get_random_unicode(seed) 28 | rand_string_2 = get_random_unicode(seed + 1) 29 | dataset_line = (label, rand_string_1, rand_string_2) 30 | f.write(f"{label}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\n") 31 | 32 | # append line to correct dataset split 33 | mocked_data[file_type].append(dataset_line) 34 | seed += 1 35 | 36 | return mocked_data 37 | 38 | 39 | class TestMRPC(TempDirMixin, TorchtextTestCase): 40 | root_dir = None 41 | samples = [] 42 | 43 | @classmethod 44 | def setUpClass(cls): 45 | super().setUpClass() 46 | cls.root_dir = cls.get_base_temp_dir() 47 | cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) 48 | cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) 49 | cls.patcher.start() 50 | 51 | @classmethod 52 | def tearDownClass(cls): 53 | cls.patcher.stop() 54 | super().tearDownClass() 55 | 56 | @parameterized.expand(["train", "test"]) 57 | def test_mrpc(self, split): 58 | dataset = MRPC(root=self.root_dir, split=split) 59 | 60 | samples = list(dataset) 61 | expected_samples = self.samples[split] 62 | for sample, expected_sample in zip_equal(samples, expected_samples): 63 | self.assertEqual(sample, expected_sample) 64 | 65 | @parameterized.expand(["train", "test"]) 66 | def test_sst2_split_argument(self, split): 67 | dataset1 = MRPC(root=self.root_dir, split=split) 68 | (dataset2,) = MRPC(root=self.root_dir, split=(split,)) 69 | 70 | for d1, d2 in zip_equal(dataset1, dataset2): 71 | self.assertEqual(d1, d2) 72 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/test_penntreebank.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from unittest.mock import patch 4 | 5 | from parameterized import parameterized 6 | from torchtext.datasets.penntreebank import PennTreebank 7 | 8 | from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode 9 | from ..common.torchtext_test_case import TorchtextTestCase 10 | 11 | 12 | def _get_mock_dataset(root_dir): 13 | """ 14 | root_dir: directory to the mocked dataset 15 | """ 16 | base_dir = os.path.join(root_dir, "PennTreebank") 17 | os.makedirs(base_dir, exist_ok=True) 18 | 19 | seed = 1 20 | mocked_data = defaultdict(list) 21 | for file_name in ("ptb.train.txt", "ptb.valid.txt", "ptb.test.txt"): 22 | txt_file = os.path.join(base_dir, file_name) 23 | with open(txt_file, "w", encoding="utf-8") as f: 24 | for i in range(5): 25 | rand_string = get_random_unicode(seed) 26 | dataset_line = f"{rand_string}" 27 | # append line to correct dataset split 28 | split = file_name.replace("ptb.", "").replace(".txt", "") 29 | mocked_data[split].append(dataset_line) 30 | f.write(f"{rand_string}\n") 31 | seed += 1 32 | 33 | return mocked_data 34 | 35 | 36 | class TestPennTreebank(TempDirMixin, TorchtextTestCase): 37 | root_dir = None 38 | samples = [] 39 | 40 | @classmethod 41 | def setUpClass(cls): 42 | super().setUpClass() 43 | cls.root_dir = cls.get_base_temp_dir() 44 | cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) 45 | cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) 46 | cls.patcher.start() 47 | 48 | @classmethod 49 | def tearDownClass(cls): 50 | cls.patcher.stop() 51 | super().tearDownClass() 52 | 53 | @parameterized.expand(["train", "valid", "test"]) 54 | def test_penn_treebank_polarity(self, split): 55 | dataset = PennTreebank(root=self.root_dir, split=split) 56 | 57 | samples = list(dataset) 58 | expected_samples = self.samples[split] 59 | for sample, expected_sample in zip_equal(samples, expected_samples): 60 | self.assertEqual(sample, expected_sample) 61 | 62 | @parameterized.expand(["train", "valid", "test"]) 63 | def test_penn_treebank_split_argument(self, split): 64 | dataset1 = PennTreebank(root=self.root_dir, split=split) 65 | (dataset2,) = PennTreebank(root=self.root_dir, split=(split,)) 66 | 67 | for d1, d2 in zip_equal(dataset1, dataset2): 68 | self.assertEqual(d1, d2) 69 | -------------------------------------------------------------------------------- /test/torchtext_unittest/datasets/test_qqp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest.mock import patch 3 | 4 | from torchtext.datasets.qqp import QQP 5 | 6 | from ..common.case_utils import TempDirMixin, zip_equal, get_random_unicode 7 | from ..common.torchtext_test_case import TorchtextTestCase 8 | 9 | 10 | def _get_mock_dataset(root_dir): 11 | """ 12 | root_dir: directory to the mocked dataset 13 | """ 14 | base_dir = os.path.join(root_dir, "QQP") 15 | os.makedirs(base_dir, exist_ok=True) 16 | 17 | seed = 1 18 | file_name = "quora_duplicate_questions.tsv" 19 | txt_file = os.path.join(base_dir, file_name) 20 | mocked_data = [] 21 | with open(txt_file, "w", encoding="utf-8") as f: 22 | f.write("id\tqid1\tqid2\tquestion1\tquestion2\tis_duplicate\n") 23 | for i in range(5): 24 | label = seed % 2 25 | rand_string_1 = get_random_unicode(seed) 26 | rand_string_2 = get_random_unicode(seed + 1) 27 | dataset_line = (label, rand_string_1, rand_string_2) 28 | # append line to correct dataset split 29 | mocked_data.append(dataset_line) 30 | f.write(f"{i}\t{i}\t{i}\t{rand_string_1}\t{rand_string_2}\t{label}\n") 31 | seed += 1 32 | 33 | return mocked_data 34 | 35 | 36 | class TestQQP(TempDirMixin, TorchtextTestCase): 37 | root_dir = None 38 | samples = [] 39 | 40 | @classmethod 41 | def setUpClass(cls): 42 | super().setUpClass() 43 | cls.root_dir = cls.get_base_temp_dir() 44 | cls.samples = _get_mock_dataset(os.path.join(cls.root_dir, "datasets")) 45 | cls.patcher = patch("torchdata.datapipes.iter.util.cacheholder._hash_check", return_value=True) 46 | cls.patcher.start() 47 | 48 | @classmethod 49 | def tearDownClass(cls): 50 | cls.patcher.stop() 51 | super().tearDownClass() 52 | 53 | def test_qqp(self) -> None: 54 | dataset = QQP(root=self.root_dir) 55 | 56 | samples = list(dataset) 57 | expected_samples = self.samples 58 | for sample, expected_sample in zip_equal(samples, expected_samples): 59 | self.assertEqual(sample, expected_sample) 60 | -------------------------------------------------------------------------------- /test/torchtext_unittest/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/models/__init__.py -------------------------------------------------------------------------------- /test/torchtext_unittest/models/gpu_tests/models_gpu_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pytest 4 | import torch 5 | from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase 6 | from torchtext_unittest.models.roberta_models_test_impl import RobertaBaseTestModels 7 | from torchtext_unittest.models.t5_models_test_impl import T5BaseTestModels 8 | 9 | 10 | @pytest.mark.gpu_test 11 | @unittest.skipIf(not torch.cuda.is_available(), reason="CUDA is not available") 12 | class TestModels32GPU(RobertaBaseTestModels, T5BaseTestModels, TorchtextTestCase): 13 | dtype = torch.float32 14 | device = torch.device("cuda") 15 | -------------------------------------------------------------------------------- /test/torchtext_unittest/models/models_cpu_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..common.torchtext_test_case import TorchtextTestCase 4 | from .roberta_models_test_impl import RobertaBaseTestModels 5 | from .t5_models_test_impl import T5BaseTestModels 6 | 7 | 8 | class TestModels32CPU(RobertaBaseTestModels, T5BaseTestModels, TorchtextTestCase): 9 | dtype = torch.float32 10 | device = torch.device("cpu") 11 | -------------------------------------------------------------------------------- /test/torchtext_unittest/models/t5_test_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchtext.models import T5Transform 3 | from torchtext_unittest.common.assets import get_asset_path 4 | from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase 5 | 6 | 7 | class TestTransforms(TorchtextTestCase): 8 | def _t5tokenizer(self, test_scripting): 9 | asset_name = "t5_tokenizer_base.model" 10 | asset_path = get_asset_path(asset_name) 11 | transform = T5Transform(asset_path, max_seq_len=512, eos_idx=1, padding_idx=0) 12 | if test_scripting: 13 | transform = torch.jit.script(transform) 14 | 15 | # test encode; input is a single string 16 | encode_seq = "Hello World!, how are you?" 17 | actual = transform(encode_seq) 18 | expected = torch.tensor([8774, 1150, 55, 6, 149, 33, 25, 58, 1]) 19 | self.assertEqual(actual, expected) 20 | 21 | # test encode; input is a batched string 22 | encode_seq = ["Hello World!, how are you?"] 23 | actual = transform(encode_seq) 24 | expected = torch.tensor([[8774, 1150, 55, 6, 149, 33, 25, 58, 1]]) 25 | self.assertEqual(actual, expected) 26 | 27 | # test decode; input is a list of token ids 28 | decode_seq = [8774, 1150, 55, 6, 149, 33, 25, 58, 1] 29 | actual = transform.decode(decode_seq) 30 | expected = "Hello World!, how are you?" 31 | self.assertEqual(actual, expected) 32 | 33 | # test decode; input is a batched list of token ids 34 | decode_seq = [[8774, 1150, 55, 6, 149, 33, 25, 58, 1]] 35 | actual = transform.decode(decode_seq) 36 | expected = ["Hello World!, how are you?"] 37 | self.assertEqual(actual, expected) 38 | 39 | def test_t5tokenizer(self) -> None: 40 | """test tokenization on string input (encode) and translation from token ids to strings (decode)""" 41 | self._t5tokenizer(test_scripting=False) 42 | 43 | def test_t5tokenizer_jit(self) -> None: 44 | """test tokenization on string input (encode) and translation from token ids to strings (decode) with scripting""" 45 | self._t5tokenizer(test_scripting=True) 46 | -------------------------------------------------------------------------------- /test/torchtext_unittest/models/test_transformers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..common.parameterized_utils import nested_params 4 | from ..common.torchtext_test_case import TorchtextTestCase 5 | 6 | 7 | class TestTransformers(TorchtextTestCase): 8 | @nested_params( 9 | [True, False], 10 | [True, False], 11 | ) 12 | def test_padded_input_inference(self, with_no_grad, return_all_layers): 13 | """test transformerencoder inference same with and without padding""" 14 | from torchtext.models import RobertaEncoderConf, RobertaModel 15 | 16 | def encoder_inference(encoder, input_lst, with_no_grad): 17 | if with_no_grad: 18 | with torch.no_grad(): 19 | res = [encoder(eval_input) for eval_input in input_lst] 20 | else: 21 | res = [encoder(eval_input) for eval_input in input_lst] 22 | return res 23 | 24 | # Roberta config except for less layers (2 instead of 12) 25 | pad_idx = 1 26 | encoder_conf = RobertaEncoderConf( 27 | vocab_size=250002, 28 | embedding_dim=768, 29 | ffn_dimension=3072, 30 | padding_idx=pad_idx, 31 | max_seq_len=514, 32 | num_attention_heads=12, 33 | num_encoder_layers=2, 34 | dropout=0.1, 35 | scaling=None, 36 | normalize_before=False, 37 | ) 38 | model = RobertaModel(encoder_conf) 39 | model = model.eval() 40 | # TODO: make return_all_layers a property of RobertaEncoderConf so it can be passed as arg 41 | model.encoder.transformer.return_all_layers = return_all_layers 42 | 43 | # result from converting string "some text" to tensor using xlmr_base embeddings 44 | input_no_pad = torch.Tensor([[0, 3060, 7986, 2]]).to(torch.int) 45 | data_len = input_no_pad.shape[1] # sequence length of non-pad data 46 | # add two padding tokens to input_no_pad 47 | input_pad = torch.Tensor([[0, 3060, 7986, 2, pad_idx, pad_idx]]).to(torch.int) 48 | input_lst = [input_no_pad, input_pad] 49 | 50 | output_no_pad, output_pad = encoder_inference(model, input_lst, with_no_grad) 51 | torch.testing.assert_close(output_no_pad, output_pad[:, :data_len, :]) 52 | -------------------------------------------------------------------------------- /test/torchtext_unittest/prototype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/test/torchtext_unittest/prototype/__init__.py -------------------------------------------------------------------------------- /third_party/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") 2 | 3 | 4 | if(POLICY CMP0091) 5 | cmake_policy(SET CMP0091 NEW) 6 | endif() 7 | 8 | add_subdirectory(re2 EXCLUDE_FROM_ALL) 9 | add_subdirectory(double-conversion EXCLUDE_FROM_ALL) 10 | add_subdirectory(sentencepiece EXCLUDE_FROM_ALL) 11 | add_subdirectory(utf8proc EXCLUDE_FROM_ALL) 12 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/tools/__init__.py -------------------------------------------------------------------------------- /tools/conda/torchtext/meta.yaml: -------------------------------------------------------------------------------- 1 | package: 2 | name: torchtext 3 | version: 0.4.0 4 | 5 | source: 6 | url: https://github.com/pytorch/text/archive/0.4.0.zip 7 | 8 | requirements: 9 | build: 10 | - python 11 | - setuptools 12 | 13 | run: 14 | - python 15 | - tqdm 16 | - numpy >=1.11 17 | - pytorch >=1.2 18 | - requests 19 | 20 | build: 21 | number: 1 22 | noarch: python 23 | script: python setup.py install --single-version-externally-managed --record=record.txt 24 | 25 | test: 26 | imports: 27 | - torchtext 28 | - torchtext.data 29 | 30 | about: 31 | home: https://github.com/pytorch/text 32 | license: BSD 33 | license_file: LICENSE 34 | summary: "PyTorch Data loaders and abstractions for text and NLP" 35 | -------------------------------------------------------------------------------- /tools/setup_helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .extension import * # noqa 2 | -------------------------------------------------------------------------------- /torchtext/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torch.hub import _get_torch_home 4 | 5 | _WARN = True 6 | _TORCHTEXT_DEPRECATION_MSG = ( 7 | "\n/!\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\ \n" 8 | "Torchtext is deprecated and the last released version will be 0.18 (this one). " 9 | "You can silence this warning by calling the following at the beginnign of your scripts: " 10 | "`import torchtext; torchtext.disable_torchtext_deprecation_warning()`" 11 | ) 12 | 13 | def disable_torchtext_deprecation_warning(): 14 | global _WARN 15 | _WARN = False 16 | 17 | # the following import has to happen first in order to load the torchtext C++ library 18 | from torchtext import _extension # noqa: F401 19 | 20 | _TEXT_BUCKET = "https://download.pytorch.org/models/text/" 21 | 22 | _CACHE_DIR = os.path.expanduser(os.path.join(_get_torch_home(), "text")) 23 | 24 | try: 25 | from .version import __version__, git_version # noqa: F401 26 | except ImportError: 27 | pass 28 | 29 | __all__ = [ 30 | "data", 31 | "nn", 32 | "datasets", 33 | "utils", 34 | "vocab", 35 | "transforms", 36 | "functional", 37 | "models", 38 | "prototype", 39 | "experimental", 40 | ] 41 | -------------------------------------------------------------------------------- /torchtext/_download_hooks.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import requests 4 | 5 | # This is to allow monkey-patching in fbcode 6 | from torch.hub import load_state_dict_from_url # noqa 7 | from tqdm import tqdm 8 | 9 | 10 | def _stream_response(r, chunk_size=16 * 1024): 11 | total_size = int(r.headers.get("Content-length", 0)) 12 | with tqdm(total=total_size, unit="B", unit_scale=1) as t: 13 | for chunk in r.iter_content(chunk_size): 14 | if chunk: 15 | t.update(len(chunk)) 16 | yield chunk 17 | 18 | 19 | def _get_response_from_google_drive(url): 20 | confirm_token = None 21 | session = requests.Session() 22 | response = session.get(url, stream=True) 23 | for k, v in response.cookies.items(): 24 | if k.startswith("download_warning"): 25 | confirm_token = v 26 | if confirm_token is None: 27 | if "Quota exceeded" in str(response.content): 28 | raise RuntimeError( 29 | "Google drive link {} is currently unavailable, because the quota was exceeded.".format(url) 30 | ) 31 | else: 32 | raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") 33 | 34 | url = url + "&confirm=" + confirm_token 35 | response = session.get(url, stream=True) 36 | 37 | if "content-disposition" not in response.headers: 38 | raise RuntimeError("Internal error: headers don't contain content-disposition.") 39 | 40 | filename = re.findall('filename="(.+)"', response.headers["content-disposition"]) 41 | if filename is None: 42 | raise RuntimeError("Filename could not be autodetected") 43 | filename = filename[0] 44 | 45 | return response, filename 46 | 47 | 48 | class DownloadManager: 49 | def get_local_path(self, url, destination): 50 | if "drive.google.com" not in url: 51 | response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"}, stream=True) 52 | else: 53 | response, filename = _get_response_from_google_drive(url) 54 | 55 | with open(destination, "wb") as f: 56 | for chunk in _stream_response(response): 57 | f.write(chunk) 58 | 59 | 60 | _DATASET_DOWNLOAD_MANAGER = DownloadManager() 61 | _TEST_DOWNLOAD_MANAGER = DownloadManager() 62 | -------------------------------------------------------------------------------- /torchtext/_extension.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import torch 5 | from torchtext._internal import module_utils as _mod_utils 6 | 7 | _LIB_DIR = Path(__file__).parent / "lib" 8 | 9 | 10 | def _get_lib_path(lib: str): 11 | suffix = "pyd" if os.name == "nt" else "so" 12 | path = _LIB_DIR / f"{lib}.{suffix}" 13 | return path 14 | 15 | 16 | def _load_lib(lib: str) -> bool: 17 | """Load extension module 18 | 19 | Note: 20 | In case `torchtext` is deployed with `pex` format, the library file 21 | is not in a standard location. 22 | In this case, we expect that `libtorchtext` is available somewhere 23 | in the search path of dynamic loading mechanism, so that importing 24 | `_torchtext` will have library loader find and load `libtorchtext`. 25 | This is the reason why the function should not raising an error when the library 26 | file is not found. 27 | 28 | Returns: 29 | bool: 30 | True if the library file is found AND the library loaded without failure. 31 | False if the library file is not found (like in the case where torchtext 32 | is deployed with pex format, thus the shared library file is 33 | in a non-standard location.). 34 | If the library file is found but there is an issue loading the library, 35 | (such as missing dependency) then this function raises the exception as-is. 36 | 37 | Raises: 38 | Exception: 39 | If the library file is found, but there is an issue loading the library file, 40 | (when underlying `ctype.DLL` throws an exception), this function will pass 41 | the exception as-is, instead of catching it and returning bool. 42 | The expected case is `OSError` thrown by `ctype.DLL` when a dynamic dependency 43 | is not found. 44 | This behavior was chosen because the expected failure case is not recoverable. 45 | If a dependency is missing, then users have to install it. 46 | """ 47 | path = _get_lib_path(lib) 48 | if not path.exists(): 49 | return False 50 | torch.ops.load_library(path) 51 | return True 52 | 53 | 54 | def _init_extension(): 55 | if not _mod_utils.is_module_available("torchtext._torchtext"): 56 | raise ImportError("torchtext C++ Extension is not found.") 57 | 58 | _load_lib("libtorchtext") 59 | # This import is for initializing the methods registered via PyBind11 60 | # This has to happen after the base library is loaded 61 | from torchtext import _torchtext # noqa 62 | 63 | 64 | _init_extension() 65 | -------------------------------------------------------------------------------- /torchtext/_internal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pytorch/text/bde7ecdb6ba9179ccd30cde60a6550478d0a359f/torchtext/_internal/__init__.py -------------------------------------------------------------------------------- /torchtext/_internal/module_utils.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | 4 | def is_module_available(*modules: str) -> bool: 5 | r"""Returns if a top-level module with :attr:`name` exists *without** 6 | importing it. This is generally safer than try-catch block around a 7 | `import X`. It avoids third party libraries breaking assumptions of some of 8 | our tests, e.g., setting multiprocessing start method when imported 9 | (see librosa/#747, torchvision/#544). 10 | """ 11 | return all(importlib.util.find_spec(m) is not None for m in modules) 12 | -------------------------------------------------------------------------------- /torchtext/csrc/bert_tokenizer.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace torchtext { 7 | 8 | typedef std::basic_string UString; 9 | typedef ska_ordered::order_preserving_flat_hash_map 10 | IndexDict; 11 | 12 | // stores (do_lower_case, strip_accents, never_split, list of tokens in 13 | // vocabulary) 14 | typedef std::tuple< 15 | bool, 16 | c10::optional, 17 | std::vector, 18 | std::vector> 19 | BERTEncoderStates; 20 | 21 | struct BERTEncoder : torch::CustomClassHolder { 22 | TORCHTEXT_API BERTEncoder( 23 | const std::string& vocab_file, 24 | bool do_lower_case, 25 | c10::optional strip_accents, 26 | std::vector never_split); 27 | BERTEncoder( 28 | Vocab vocab, 29 | bool do_lower_case, 30 | c10::optional strip_accents, 31 | std::vector never_split); 32 | TORCHTEXT_API std::vector Tokenize(std::string text); 33 | TORCHTEXT_API std::vector Encode(std::string text); 34 | TORCHTEXT_API std::vector> BatchTokenize( 35 | std::vector text); 36 | TORCHTEXT_API std::vector> BatchEncode( 37 | std::vector text); 38 | 39 | Vocab vocab_; 40 | bool do_lower_case_; 41 | c10::optional strip_accents_ = {}; 42 | std::vector never_split_; 43 | std::set never_split_set_; 44 | 45 | protected: 46 | UString _clean( 47 | const UString& text, 48 | bool strip_accents, 49 | bool is_never_split_token); 50 | void _max_seg(const std::string& s, std::vector& results); 51 | UString _basic_tokenize(const UString& token, bool is_never_split_token); 52 | void split_( 53 | const std::string& str, 54 | std::vector& tokens, 55 | const char& delimiter = ' '); 56 | static std::string kUnkToken; 57 | }; 58 | 59 | TORCHTEXT_API BERTEncoderStates 60 | _serialize_bert_encoder(const c10::intrusive_ptr& self); 61 | TORCHTEXT_API c10::intrusive_ptr _deserialize_bert_encoder( 62 | BERTEncoderStates states); 63 | } // namespace torchtext 64 | -------------------------------------------------------------------------------- /torchtext/csrc/clip_tokenizer.h: -------------------------------------------------------------------------------- 1 | #ifndef CLIP_TOKENIZER_H_ 2 | #define CLIP_TOKENIZER_H_ 3 | 4 | #include 5 | #include 6 | 7 | namespace torchtext { 8 | 9 | typedef std::tuple< 10 | std::unordered_map, 11 | std::unordered_map, 12 | std::string, 13 | std::unordered_map, 14 | bool> 15 | CLIPEncoderStatesPybind; 16 | 17 | typedef std::tuple< 18 | c10::Dict, 19 | c10::Dict, 20 | std::string, 21 | c10::Dict, 22 | bool> 23 | CLIPEncoderStatesTorchbind; 24 | 25 | struct CLIPEncoder : GPT2BPEEncoder { 26 | public: 27 | using GPT2BPEEncoder::GPT2BPEEncoder; 28 | 29 | TORCHTEXT_API std::vector Encode(const std::string& text); 30 | TORCHTEXT_API std::vector Tokenize(const std::string& text); 31 | 32 | protected: 33 | TORCHTEXT_API std::vector BPE_( 34 | const std::vector& token_list) override; 35 | 36 | TORCHTEXT_API std::vector PreTokenize_( 37 | std::string input) override; 38 | }; 39 | 40 | TORCHTEXT_API CLIPEncoderStatesPybind 41 | _serialize_clip_encoder_pybind(const c10::intrusive_ptr& self); 42 | CLIPEncoderStatesTorchbind _serialize_clip_encoder_torchbind( 43 | const c10::intrusive_ptr& self); 44 | TORCHTEXT_API c10::intrusive_ptr _deserialize_clip_encoder_pybind( 45 | CLIPEncoderStatesPybind states); 46 | c10::intrusive_ptr _deserialize_clip_encoder_torchbind( 47 | CLIPEncoderStatesTorchbind states); 48 | 49 | } // namespace torchtext 50 | 51 | #endif // CLIP_TOKENIZER_H_ 52 | -------------------------------------------------------------------------------- /torchtext/csrc/common.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace torchtext { 8 | namespace impl { 9 | 10 | int64_t divup(int64_t x, int64_t y) { 11 | return (x + y - 1) / y; 12 | } 13 | 14 | void infer_offsets( 15 | const std::string& file_path, 16 | int64_t num_lines, 17 | int64_t chunk_size, 18 | std::vector& offsets, 19 | int64_t num_header_lines) { 20 | std::ifstream fin; 21 | fin.open(file_path, std::ios::in); 22 | 23 | while (num_header_lines > 0) { 24 | fin.ignore(std::numeric_limits::max(), '\n'); 25 | num_header_lines--; 26 | } 27 | offsets.push_back(fin.tellg()); 28 | size_t offset = 0; 29 | while (fin.ignore(std::numeric_limits::max(), '\n')) { 30 | offset++; 31 | if (offset % chunk_size == 0) { 32 | offsets.push_back(fin.tellg()); 33 | } 34 | } 35 | } 36 | 37 | } // namespace impl 38 | } // namespace torchtext 39 | -------------------------------------------------------------------------------- /torchtext/csrc/common.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace torchtext { 8 | 9 | namespace impl { 10 | TORCHTEXT_API int64_t divup(int64_t x, int64_t y); 11 | TORCHTEXT_API void infer_offsets( 12 | const std::string& file_path, 13 | int64_t num_lines, 14 | int64_t chunk_size, 15 | std::vector& offsets, 16 | int64_t num_header_lines = 0); 17 | } // namespace impl 18 | } // namespace torchtext 19 | -------------------------------------------------------------------------------- /torchtext/csrc/export.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Define the visibility of symbols. 4 | // The original logic and background can be found here. 5 | // https://github.com/pytorch/pytorch/blob/bcc02769bef1d7b89bec724223284958b7c5b564/c10/macros/Export.h#L49-L55 6 | // 7 | // In the context of torchtext, the logic is simpler at the moment. 8 | // 9 | // The torchtext custom operations are implemented in 10 | // `torchtext/lib/libtorchtext.[so|pyd]`. Some symbols are referred from 11 | // `torchtext._torchtext`. 12 | // 13 | // In Windows, default visibility of dynamically library are hidden, while in 14 | // Linux/macOS, they are visible. 15 | // 16 | // At the moment we do not expect torchtext libraries to be built/linked 17 | // statically. We assume they are always shared. 18 | 19 | #ifdef _WIN32 20 | #define TORCHTEXT_EXPORT __declspec(dllexport) 21 | #define TORCHTEXT_IMPORT __declspec(dllimport) 22 | #else // _WIN32 23 | #if defined(__GNUC__) 24 | #define TORCHTEXT_EXPORT __attribute__((__visibility__("default"))) 25 | #else // defined(__GNUC__) 26 | #define TORCHTEXT_EXPORT 27 | #endif // defined(__GNUC__) 28 | #define TORCHTEXT_IMPORT TORCHTEXT_EXPORT 29 | #endif // _WIN32 30 | 31 | #ifdef TORCHTEXT_BUILD_MAIN_LIB 32 | #define TORCHTEXT_API TORCHTEXT_EXPORT 33 | #else 34 | #define TORCHTEXT_API TORCHTEXT_IMPORT 35 | #endif 36 | -------------------------------------------------------------------------------- /torchtext/csrc/regex.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | namespace torchtext { 4 | 5 | Regex::Regex(const std::string& re_str) : re_str_(re_str) { 6 | compiled_pattern_ = new RE2(re_str_); 7 | } 8 | 9 | Regex::~Regex() { 10 | delete compiled_pattern_; 11 | } 12 | 13 | std::string Regex::Sub(std::string str, const std::string& repl) const { 14 | RE2::GlobalReplace(&str, *compiled_pattern_, repl); 15 | return str; 16 | } 17 | 18 | bool Regex::FindAndConsume(re2::StringPiece* input, std::string* text) const { 19 | return RE2::FindAndConsume(input, *compiled_pattern_, text); 20 | } 21 | 22 | std::string _serialize_regex(const c10::intrusive_ptr& self) { 23 | return self->re_str_; 24 | } 25 | 26 | c10::intrusive_ptr _deserialize_regex(std::string&& state) { 27 | return c10::make_intrusive(std::move(state)); 28 | } 29 | 30 | } // namespace torchtext 31 | -------------------------------------------------------------------------------- /torchtext/csrc/regex.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace torchtext { 8 | struct Regex : torch::CustomClassHolder { 9 | private: 10 | RE2* compiled_pattern_; 11 | 12 | public: 13 | std::string re_str_; 14 | 15 | TORCHTEXT_API Regex(const std::string& re_str); 16 | TORCHTEXT_API ~Regex(); 17 | TORCHTEXT_API std::string Sub(std::string str, const std::string& repl) const; 18 | TORCHTEXT_API bool FindAndConsume(re2::StringPiece* input, std::string* text) 19 | const; 20 | }; 21 | 22 | TORCHTEXT_API std::string _serialize_regex( 23 | const c10::intrusive_ptr& self); 24 | TORCHTEXT_API c10::intrusive_ptr _deserialize_regex(std::string&& state); 25 | 26 | } // namespace torchtext 27 | -------------------------------------------------------------------------------- /torchtext/csrc/regex_tokenizer.cpp: -------------------------------------------------------------------------------- 1 | #include // @manual 2 | #include 3 | 4 | namespace torchtext { 5 | 6 | RegexTokenizer::RegexTokenizer( 7 | const std::vector& patterns, 8 | const std::vector& replacements, 9 | const bool to_lower = false) 10 | : patterns_(std::move(patterns)), 11 | replacements_(std::move(replacements)), 12 | to_lower_(to_lower) { 13 | TORCH_CHECK( 14 | patterns.size() == replacements.size(), 15 | "Expected `patterns` and `replacements` to have same size!"); 16 | 17 | for (const auto& pattern : patterns_) { 18 | compiled_patterns_.push_back(new RE2(pattern)); 19 | } 20 | } 21 | 22 | std::vector RegexTokenizer::forward(std::string str) const { 23 | // str tolower 24 | if (to_lower_) { 25 | std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { 26 | return std::tolower(c); 27 | }); 28 | } 29 | 30 | for (size_t i = 0; i < compiled_patterns_.size(); i++) { 31 | RE2::GlobalReplace(&str, *compiled_patterns_[i], replacements_[i]); 32 | } 33 | 34 | std::vector tokens; 35 | split_(str, tokens); 36 | return tokens; 37 | } 38 | 39 | void RegexTokenizer::split_( 40 | std::string& str, 41 | std::vector& tokens, 42 | const char& delimiter) const { 43 | std::stringstream ss(str); 44 | std::string token; 45 | 46 | while (std::getline(ss, token, delimiter)) { 47 | if (!token.empty()) { 48 | tokens.push_back(token); 49 | } 50 | } 51 | } 52 | 53 | RegexTokenizerStates _serialize_regex_tokenizer( 54 | const c10::intrusive_ptr& self) { 55 | return std::make_tuple(self->patterns_, self->replacements_, self->to_lower_); 56 | } 57 | 58 | c10::intrusive_ptr _deserialize_regex_tokenizer( 59 | RegexTokenizerStates&& states) { 60 | return c10::make_intrusive( 61 | std::move(std::get<0>(states)), 62 | std::move(std::get<1>(states)), 63 | std::get<2>(states)); 64 | } 65 | 66 | } // namespace torchtext 67 | -------------------------------------------------------------------------------- /torchtext/csrc/regex_tokenizer.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | namespace torchtext { 6 | 7 | typedef std::tuple, std::vector, bool> 8 | RegexTokenizerStates; 9 | 10 | struct RegexTokenizer : torch::CustomClassHolder { 11 | private: 12 | std::vector compiled_patterns_; 13 | void split_( 14 | std::string& str, 15 | std::vector& tokens, 16 | const char& delimiter = ' ') const; 17 | 18 | public: 19 | std::vector patterns_; 20 | std::vector replacements_; 21 | bool to_lower_; 22 | 23 | TORCHTEXT_API explicit RegexTokenizer( 24 | const std::vector& patterns, 25 | const std::vector& replacements, 26 | const bool to_lower); 27 | TORCHTEXT_API std::vector forward(std::string str) const; 28 | }; 29 | 30 | TORCHTEXT_API RegexTokenizerStates 31 | _serialize_regex_tokenizer(const c10::intrusive_ptr& self); 32 | TORCHTEXT_API c10::intrusive_ptr _deserialize_regex_tokenizer( 33 | RegexTokenizerStates&& states); 34 | 35 | } // namespace torchtext 36 | -------------------------------------------------------------------------------- /torchtext/csrc/sentencepiece.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | namespace torchtext { 7 | 8 | struct SentencePiece : torch::CustomClassHolder { 9 | private: 10 | sentencepiece::SentencePieceProcessor processor_; 11 | 12 | public: 13 | // content_ holds the serialized model data passed at the initialization. 14 | // We need this because the underlying SentencePieceProcessor class does not 15 | // provide serialization mechanism, yet we still need to be able to serialize 16 | // the model so that we can save the scripted object. pickle will get the 17 | // serialized model from this content_ member, thus it needs to be public. 18 | std::string content_; 19 | 20 | TORCHTEXT_API explicit SentencePiece(const std::string& content); 21 | TORCHTEXT_API std::vector Encode(const std::string& input) const; 22 | TORCHTEXT_API std::vector EncodeAsIds( 23 | const std::string& input) const; 24 | TORCHTEXT_API std::string DecodeIds(const std::vector& ids) const; 25 | TORCHTEXT_API std::vector EncodeAsPieces( 26 | const std::string& input) const; 27 | TORCHTEXT_API std::string DecodePieces( 28 | const std::vector& pieces) const; 29 | TORCHTEXT_API int64_t GetPieceSize() const; 30 | TORCHTEXT_API int64_t unk_id() const; 31 | TORCHTEXT_API int64_t PieceToId(const std::string& piece) const; 32 | TORCHTEXT_API std::string IdToPiece(const int64_t id) const; 33 | }; 34 | 35 | void generate_sp_model( 36 | const std::string& filename, 37 | const int64_t& vocab_size, 38 | const std::string& model_type, 39 | const std::string& model_prefix); 40 | c10::intrusive_ptr load_sp_model(const std::string& path); 41 | c10::intrusive_ptr load_sp_model_string(std::string content); 42 | 43 | } // namespace torchtext 44 | -------------------------------------------------------------------------------- /torchtext/csrc/vectors.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace torchtext { 5 | 6 | typedef std::vector StringList; 7 | typedef ska_ordered::order_preserving_flat_hash_map 8 | VectorsMap; 9 | typedef ska_ordered::order_preserving_flat_hash_map 10 | IndexMap; 11 | typedef std::tuple< 12 | std::string, 13 | std::vector, 14 | std::vector, 15 | std::vector> 16 | VectorsStates; 17 | 18 | struct Vectors : torch::CustomClassHolder { 19 | public: 20 | const std::string version_str_ = "0.0.1"; 21 | IndexMap stoi_; 22 | VectorsMap stovec_; 23 | torch::Tensor vectors_; 24 | torch::Tensor unk_tensor_; 25 | 26 | explicit Vectors( 27 | const IndexMap& stoi, 28 | torch::Tensor vectors, 29 | torch::Tensor unk_tensor); 30 | TORCHTEXT_API explicit Vectors( 31 | const std::vector& tokens, 32 | const std::vector& indices, 33 | torch::Tensor vectors, 34 | torch::Tensor unk_tensor); 35 | TORCHTEXT_API std::unordered_map get_stoi(); 36 | TORCHTEXT_API torch::Tensor __getitem__(const std::string& token); 37 | TORCHTEXT_API torch::Tensor lookup_vectors( 38 | const std::vector& tokens); 39 | TORCHTEXT_API void __setitem__( 40 | const std::string& token, 41 | const torch::Tensor& vector); 42 | TORCHTEXT_API int64_t __len__(); 43 | }; 44 | 45 | TORCHTEXT_API VectorsStates 46 | _serialize_vectors(const c10::intrusive_ptr& self); 47 | TORCHTEXT_API c10::intrusive_ptr _deserialize_vectors( 48 | VectorsStates states); 49 | 50 | TORCHTEXT_API std::tuple> 51 | _load_token_and_vectors_from_file( 52 | const std::string& file_path, 53 | const std::string& delimiter_str, 54 | const int64_t num_cpus, 55 | c10::optional opt_unk_tensor); 56 | 57 | } // namespace torchtext 58 | -------------------------------------------------------------------------------- /torchtext/csrc/vocab_factory.h: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include // @manual 4 | 5 | namespace py = pybind11; 6 | 7 | namespace torchtext { 8 | 9 | TORCHTEXT_API Vocab _build_vocab_from_text_file_using_python_tokenizer( 10 | const std::string& file_path, 11 | const int64_t min_freq, 12 | py::object tokenizer); 13 | 14 | TORCHTEXT_API Vocab _load_vocab_from_file( 15 | const std::string& file_path, 16 | const int64_t min_freq, 17 | const int64_t num_cpus); 18 | 19 | TORCHTEXT_API Vocab _build_vocab_from_text_file( 20 | const std::string& file_path, 21 | const int64_t min_freq, 22 | const int64_t num_cpus, 23 | torch::jit::script::Module tokenizer); 24 | } // namespace torchtext 25 | -------------------------------------------------------------------------------- /torchtext/data/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | 6 | 7 | from .functional import ( 8 | custom_replace, 9 | filter_wikipedia_xml, 10 | generate_sp_model, 11 | load_sp_model, 12 | numericalize_tokens_from_iterator, 13 | sentencepiece_numericalizer, 14 | sentencepiece_tokenizer, 15 | simple_space_split, 16 | to_map_style_dataset, 17 | ) 18 | from .metrics import bleu_score 19 | from .utils import get_tokenizer, interleave_keys 20 | 21 | __all__ = [ 22 | "bleu_score", 23 | "get_tokenizer", 24 | "interleave_keys", 25 | "generate_sp_model", 26 | "load_sp_model", 27 | "sentencepiece_numericalizer", 28 | "sentencepiece_tokenizer", 29 | "custom_replace", 30 | "simple_space_split", 31 | "numericalize_tokens_from_iterator", 32 | "filter_wikipedia_xml", 33 | "to_map_style_dataset", 34 | ] 35 | -------------------------------------------------------------------------------- /torchtext/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | 6 | 7 | import importlib 8 | 9 | from .ag_news import AG_NEWS 10 | from .amazonreviewfull import AmazonReviewFull 11 | from .amazonreviewpolarity import AmazonReviewPolarity 12 | from .cc100 import CC100 13 | from .cnndm import CNNDM 14 | from .cola import CoLA 15 | from .conll2000chunking import CoNLL2000Chunking 16 | from .dbpedia import DBpedia 17 | from .enwik9 import EnWik9 18 | from .imdb import IMDB 19 | from .iwslt2016 import IWSLT2016 20 | from .iwslt2017 import IWSLT2017 21 | from .mnli import MNLI 22 | from .mrpc import MRPC 23 | from .multi30k import Multi30k 24 | from .penntreebank import PennTreebank 25 | from .qnli import QNLI 26 | from .qqp import QQP 27 | from .rte import RTE 28 | from .sogounews import SogouNews 29 | from .squad1 import SQuAD1 30 | from .squad2 import SQuAD2 31 | from .sst2 import SST2 32 | from .stsb import STSB 33 | from .udpos import UDPOS 34 | from .wikitext103 import WikiText103 35 | from .wikitext2 import WikiText2 36 | from .wnli import WNLI 37 | from .yahooanswers import YahooAnswers 38 | from .yelpreviewfull import YelpReviewFull 39 | from .yelpreviewpolarity import YelpReviewPolarity 40 | 41 | DATASETS = { 42 | "AG_NEWS": AG_NEWS, 43 | "AmazonReviewFull": AmazonReviewFull, 44 | "AmazonReviewPolarity": AmazonReviewPolarity, 45 | "CC100": CC100, 46 | "CoLA": CoLA, 47 | "CoNLL2000Chunking": CoNLL2000Chunking, 48 | "DBpedia": DBpedia, 49 | "EnWik9": EnWik9, 50 | "IMDB": IMDB, 51 | "IWSLT2016": IWSLT2016, 52 | "IWSLT2017": IWSLT2017, 53 | "MNLI": MNLI, 54 | "MRPC": MRPC, 55 | "Multi30k": Multi30k, 56 | "PennTreebank": PennTreebank, 57 | "QNLI": QNLI, 58 | "QQP": QQP, 59 | "RTE": RTE, 60 | "SQuAD1": SQuAD1, 61 | "SQuAD2": SQuAD2, 62 | "SogouNews": SogouNews, 63 | "SST2": SST2, 64 | "STSB": STSB, 65 | "UDPOS": UDPOS, 66 | "WikiText103": WikiText103, 67 | "WikiText2": WikiText2, 68 | "WNLI": WNLI, 69 | "YahooAnswers": YahooAnswers, 70 | "YelpReviewFull": YelpReviewFull, 71 | "YelpReviewPolarity": YelpReviewPolarity, 72 | "CNNDM": CNNDM, 73 | } 74 | 75 | URLS = {} 76 | NUM_LINES = {} 77 | MD5 = {} 78 | for dataset in DATASETS: 79 | dataset_module_path = "torchtext.datasets." + dataset.lower() 80 | dataset_module = importlib.import_module(dataset_module_path) 81 | URLS[dataset] = dataset_module.URL 82 | NUM_LINES[dataset] = dataset_module.NUM_LINES 83 | MD5[dataset] = dataset_module.MD5 84 | 85 | __all__ = sorted(list(map(str, DATASETS.keys()))) 86 | -------------------------------------------------------------------------------- /torchtext/datasets/ag_news.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | from typing import Union, Tuple 4 | 5 | from torchtext._internal.module_utils import is_module_available 6 | from torchtext.data.datasets_utils import ( 7 | _wrap_split_argument, 8 | _create_dataset_directory, 9 | ) 10 | 11 | URL = { 12 | "train": "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv", 13 | "test": "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv", 14 | } 15 | 16 | MD5 = { 17 | "train": "b1a00f826fdfbd249f79597b59e1dc12", 18 | "test": "d52ea96a97a2d943681189a97654912d", 19 | } 20 | 21 | NUM_LINES = { 22 | "train": 120000, 23 | "test": 7600, 24 | } 25 | 26 | DATASET_NAME = "AG_NEWS" 27 | 28 | 29 | def _filepath_fn(root, split, _=None): 30 | return os.path.join(root, split + ".csv") 31 | 32 | 33 | def _modify_res(t): 34 | return int(t[0]), " ".join(t[1:]) 35 | 36 | 37 | @_create_dataset_directory(dataset_name=DATASET_NAME) 38 | @_wrap_split_argument(("train", "test")) 39 | def AG_NEWS(root: str, split: Union[Tuple[str], str]): 40 | """AG_NEWS Dataset 41 | 42 | .. warning:: 43 | 44 | Using datapipes is still currently subject to a few caveats. If you wish 45 | to use this dataset with shuffling, multi-processing, or distributed 46 | learning, please see :ref:`this note ` for further 47 | instructions. 48 | 49 | For additional details refer to https://paperswithcode.com/dataset/ag-news 50 | 51 | Number of lines per split: 52 | - train: 120000 53 | - test: 7600 54 | 55 | Args: 56 | root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') 57 | split: split or splits to be returned. Can be a string or tuple of strings. Default: (`train`, `test`) 58 | 59 | :returns: DataPipe that yields tuple of label (1 to 4) and text 60 | :rtype: (int, str) 61 | """ 62 | if not is_module_available("torchdata"): 63 | raise ModuleNotFoundError( 64 | "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" 65 | ) 66 | from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa 67 | 68 | url_dp = IterableWrapper([URL[split]]) 69 | cache_dp = url_dp.on_disk_cache( 70 | filepath_fn=partial(_filepath_fn, root, split), 71 | hash_dict={_filepath_fn(root, split): MD5[split]}, 72 | hash_type="md5", 73 | ) 74 | cache_dp = HttpReader(cache_dp) 75 | cache_dp = cache_dp.end_caching(mode="wb", same_filepath_fn=True) 76 | 77 | data_dp = FileOpener(cache_dp, encoding="utf-8") 78 | return data_dp.parse_csv().map(fn=_modify_res).shuffle().set_shuffle(False).sharding_filter() 79 | -------------------------------------------------------------------------------- /torchtext/datasets/enwik9.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from torchtext._internal.module_utils import is_module_available 5 | from torchtext.data.datasets_utils import _create_dataset_directory 6 | 7 | URL = "http://mattmahoney.net/dc/enwik9.zip" 8 | 9 | MD5 = "3e773f8a1577fda2e27f871ca17f31fd" 10 | 11 | _PATH = "enwik9.zip" 12 | 13 | NUM_LINES = {"train": 13147026} 14 | 15 | DATASET_NAME = "EnWik9" 16 | 17 | 18 | def _filepath_fn(root, _=None): 19 | return os.path.join(root, _PATH) 20 | 21 | 22 | def _extracted_filepath_fn(root, _=None): 23 | return os.path.join(root, os.path.splitext(_PATH)[0]) 24 | 25 | 26 | @_create_dataset_directory(dataset_name=DATASET_NAME) 27 | def EnWik9(root: str): 28 | """EnWik9 dataset 29 | 30 | .. warning:: 31 | 32 | using datapipes is still currently subject to a few caveats. if you wish 33 | to use this dataset with shuffling, multi-processing, or distributed 34 | learning, please see :ref:`this note ` for further 35 | instructions. 36 | 37 | For additional details refer to http://mattmahoney.net/dc/textdata.html 38 | 39 | Number of lines in dataset: 13147026 40 | 41 | Args: 42 | root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') 43 | 44 | :returns: DataPipe that yields raw text rows from WnWik9 dataset 45 | :rtype: str 46 | """ 47 | if not is_module_available("torchdata"): 48 | raise ModuleNotFoundError( 49 | "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" 50 | ) 51 | from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa 52 | 53 | url_dp = IterableWrapper([URL]) 54 | cache_compressed_dp = url_dp.on_disk_cache( 55 | filepath_fn=partial(_filepath_fn, root), 56 | hash_dict={_filepath_fn(root): MD5}, 57 | hash_type="md5", 58 | ) 59 | cache_compressed_dp = HttpReader(cache_compressed_dp).end_caching(mode="wb", same_filepath_fn=True) 60 | 61 | cache_decompressed_dp = cache_compressed_dp.on_disk_cache(filepath_fn=partial(_extracted_filepath_fn, root)) 62 | cache_decompressed_dp = FileOpener(cache_decompressed_dp, mode="b").load_from_zip() 63 | cache_decompressed_dp = cache_decompressed_dp.end_caching(mode="wb", same_filepath_fn=True) 64 | 65 | data_dp = FileOpener(cache_decompressed_dp, encoding="utf-8") 66 | return data_dp.readlines(return_path=False).shuffle().set_shuffle(False).sharding_filter() 67 | -------------------------------------------------------------------------------- /torchtext/datasets/qqp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from functools import partial 3 | 4 | from torchtext._internal.module_utils import is_module_available 5 | from torchtext.data.datasets_utils import _create_dataset_directory 6 | 7 | URL = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" 8 | 9 | MD5 = "b6d5672bd9dc1e66ab2bb020ebeafb8d" 10 | 11 | _PATH = "quora_duplicate_questions.tsv" 12 | 13 | NUM_LINES = {"train": 404290} 14 | 15 | DATASET_NAME = "QQP" 16 | 17 | 18 | def _filepath_fn(root, _=None): 19 | return os.path.join(root, _PATH) 20 | 21 | 22 | def _modify_res(x): 23 | return (int(x[-1]), x[3], x[4]) 24 | 25 | 26 | @_create_dataset_directory(dataset_name=DATASET_NAME) 27 | def QQP(root: str): 28 | """QQP dataset 29 | 30 | .. warning:: 31 | 32 | using datapipes is still currently subject to a few caveats. if you wish 33 | to use this dataset with shuffling, multi-processing, or distributed 34 | learning, please see :ref:`this note ` for further 35 | instructions. 36 | 37 | For additional details refer to https://quoradata.quora.com/First-Quora-Dataset-Release-Question-Pairs 38 | 39 | Args: 40 | root: Directory where the datasets are saved. Default: os.path.expanduser('~/.torchtext/cache') 41 | 42 | :returns: DataPipe that yields rows from QQP dataset (label (int), question1 (str), question2 (str)) 43 | :rtype: (int, str, str) 44 | """ 45 | if not is_module_available("torchdata"): 46 | raise ModuleNotFoundError( 47 | "Package `torchdata` not found. Please install following instructions at https://github.com/pytorch/data" 48 | ) 49 | from torchdata.datapipes.iter import FileOpener, GDriveReader, HttpReader, IterableWrapper # noqa 50 | 51 | url_dp = IterableWrapper([URL]) 52 | cache_dp = url_dp.on_disk_cache( 53 | filepath_fn=partial(_filepath_fn, root), 54 | hash_dict={_filepath_fn(root): MD5}, 55 | hash_type="md5", 56 | ) 57 | cache_dp = HttpReader(cache_dp).end_caching(mode="wb", same_filepath_fn=True) 58 | cache_dp = FileOpener(cache_dp, encoding="utf-8") 59 | # some context stored at top of the file needs to be removed 60 | parsed_data = cache_dp.parse_csv(skip_lines=1, delimiter="\t").map(_modify_res) 61 | return parsed_data.shuffle().set_shuffle(False).sharding_filter() 62 | -------------------------------------------------------------------------------- /torchtext/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | -------------------------------------------------------------------------------- /torchtext/experimental/transforms.py: -------------------------------------------------------------------------------- 1 | def __getattr__(name): 2 | 3 | moved_apis = [ 4 | "basic_english_normalize", 5 | "BasicEnglishNormalize", 6 | "PRETRAINED_SP_MODEL", 7 | "load_sp_model", 8 | "sentencepiece_tokenizer", 9 | "SentencePieceTokenizer", 10 | "sentencepiece_processor", 11 | "SentencePieceProcessor", 12 | "VocabTransform", 13 | "VectorTransform", 14 | ] 15 | 16 | if name in moved_apis: 17 | import warnings 18 | 19 | warnings.warn( 20 | "experimental package has been moved to prototype. You may change all imports from `torchtext.experimental` to `torchtext.prototype`", 21 | UserWarning, 22 | ) 23 | 24 | if name == "basic_english_normalize": 25 | from torchtext.prototype.transforms import basic_english_normalize 26 | 27 | return basic_english_normalize 28 | elif name == "BasicEnglishNormalize": 29 | from torchtext.prototype.transforms import BasicEnglishNormalize 30 | 31 | return BasicEnglishNormalize 32 | elif name == "PRETRAINED_SP_MODEL": 33 | from torchtext.prototype.transforms import PRETRAINED_SP_MODEL 34 | 35 | return PRETRAINED_SP_MODEL 36 | elif name == "load_sp_model": 37 | from torchtext.prototype.transforms import load_sp_model 38 | 39 | return load_sp_model 40 | elif name == "sentencepiece_tokenizer": 41 | from torchtext.prototype.transforms import sentencepiece_tokenizer 42 | 43 | return sentencepiece_tokenizer 44 | elif name == "SentencePieceTokenizer": 45 | from torchtext.prototype.transforms import SentencePieceTokenizer 46 | 47 | return SentencePieceTokenizer 48 | elif name == "sentencepiece_processor": 49 | from torchtext.prototype.transforms import sentencepiece_processor 50 | 51 | return sentencepiece_processor 52 | elif name == "VocabTransform": 53 | from torchtext.prototype.transforms import VocabTransform 54 | 55 | return VocabTransform 56 | else: 57 | from torchtext.prototype.transforms import VectorTransform 58 | 59 | return VectorTransform 60 | 61 | raise AttributeError(f"module {__name__} has no attribute {name}") 62 | -------------------------------------------------------------------------------- /torchtext/experimental/vectors.py: -------------------------------------------------------------------------------- 1 | def __getattr__(name): 2 | 3 | moved_apis = ["FastText", "GloVe", "load_vectors_from_file_path", "build_vectors", "Vectors"] 4 | 5 | if name in moved_apis: 6 | import warnings 7 | 8 | warnings.warn( 9 | "experimental package has been moved to prototype. You may change all imports from `torchtext.experimental` to `torchtext.prototype`", 10 | UserWarning, 11 | ) 12 | 13 | if name == "FastText": 14 | from torchtext.prototype.vectors import FastText 15 | 16 | return FastText 17 | elif name == "GloVe": 18 | from torchtext.prototype.vectors import GloVe 19 | 20 | return GloVe 21 | elif name == "load_vectors_from_file_path": 22 | from torchtext.prototype.vectors import load_vectors_from_file_path 23 | 24 | return load_vectors_from_file_path 25 | elif name == "build_vectors": 26 | from torchtext.prototype.vectors import build_vectors 27 | 28 | return build_vectors 29 | else: 30 | from torchtext.prototype.vectors import Vectors 31 | 32 | return Vectors 33 | 34 | raise AttributeError(f"module {__name__} has no attribute {name}") 35 | -------------------------------------------------------------------------------- /torchtext/experimental/vocab_factory.py: -------------------------------------------------------------------------------- 1 | def __getattr__(name): 2 | moved_apis = ["build_vocab_from_text_file", "load_vocab_from_file"] 3 | if name in moved_apis: 4 | import warnings 5 | 6 | warnings.warn( 7 | "experimental package has been moved to prototype. You may change all imports from `torchtext.experimental` to `torchtext.prototype`", 8 | UserWarning, 9 | ) 10 | 11 | if name == "build_vocab_from_text_file": 12 | from torchtext.prototype.vocab_factory import build_vocab_from_text_file 13 | 14 | return build_vocab_from_text_file 15 | else: 16 | from torchtext.prototype.vocab_factory import load_vocab_from_file 17 | 18 | return load_vocab_from_file 19 | 20 | raise AttributeError(f"module {__name__} has no attribute {name}") 21 | -------------------------------------------------------------------------------- /torchtext/lib/.gitignore: -------------------------------------------------------------------------------- 1 | # For smooth build process, this `lib` directory has to exist. 2 | # git won't allow to have empty directory, so adding .gitignore 3 | # https://stackoverflow.com/a/932982 4 | * 5 | !.gitignore 6 | -------------------------------------------------------------------------------- /torchtext/models/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | 6 | from .roberta import * # noqa: F401, F403 7 | from .t5 import * # noqa: F401, F403 8 | -------------------------------------------------------------------------------- /torchtext/models/roberta/__init__.py: -------------------------------------------------------------------------------- 1 | from .bundler import ( 2 | ROBERTA_BASE_ENCODER, 3 | ROBERTA_LARGE_ENCODER, 4 | ROBERTA_DISTILLED_ENCODER, 5 | RobertaBundle, 6 | XLMR_BASE_ENCODER, 7 | XLMR_LARGE_ENCODER, 8 | ) 9 | from .model import RobertaClassificationHead, RobertaEncoderConf, RobertaModel 10 | 11 | __all__ = [ 12 | "RobertaEncoderConf", 13 | "RobertaClassificationHead", 14 | "RobertaModel", 15 | "RobertaBundle", 16 | "XLMR_BASE_ENCODER", 17 | "XLMR_LARGE_ENCODER", 18 | "ROBERTA_BASE_ENCODER", 19 | "ROBERTA_LARGE_ENCODER", 20 | "ROBERTA_DISTILLED_ENCODER", 21 | ] 22 | -------------------------------------------------------------------------------- /torchtext/models/t5/__init__.py: -------------------------------------------------------------------------------- 1 | from .bundler import ( 2 | FLAN_T5_BASE_ENCODER, 3 | FLAN_T5_BASE, 4 | FLAN_T5_BASE_GENERATION, 5 | FLAN_T5_LARGE_ENCODER, 6 | FLAN_T5_LARGE, 7 | FLAN_T5_LARGE_GENERATION, 8 | FLAN_T5_XL_ENCODER, 9 | FLAN_T5_XL, 10 | FLAN_T5_XL_GENERATION, 11 | FLAN_T5_XXL_ENCODER, 12 | FLAN_T5_XXL, 13 | FLAN_T5_XXL_GENERATION, 14 | T5_11B, 15 | T5_11B_ENCODER, 16 | T5_11B_GENERATION, 17 | T5_3B, 18 | T5_3B_ENCODER, 19 | T5_3B_GENERATION, 20 | T5_BASE, 21 | T5_BASE_ENCODER, 22 | T5_BASE_GENERATION, 23 | T5_LARGE, 24 | T5_LARGE_ENCODER, 25 | T5_LARGE_GENERATION, 26 | T5_SMALL, 27 | T5_SMALL_ENCODER, 28 | T5_SMALL_GENERATION, 29 | T5Bundle, 30 | ) 31 | from .model import T5Conf, T5Model 32 | from .t5_transform import T5Transform 33 | 34 | __all__ = [ 35 | "T5Conf", 36 | "T5Model", 37 | "T5Bundle", 38 | "T5_BASE_ENCODER", 39 | "T5_BASE", 40 | "T5_BASE_GENERATION", 41 | "T5_SMALL_ENCODER", 42 | "T5_SMALL", 43 | "T5_SMALL_GENERATION", 44 | "T5_LARGE_ENCODER", 45 | "T5_LARGE", 46 | "T5_LARGE_GENERATION", 47 | "T5_3B_ENCODER", 48 | "T5_3B", 49 | "T5_3B_GENERATION", 50 | "T5_11B_ENCODER", 51 | "T5_11B", 52 | "T5_11B_GENERATION", 53 | "FLAN_T5_BASE_ENCODER", 54 | "FLAN_T5_BASE", 55 | "FLAN_T5_BASE_GENERATION", 56 | "FLAN_T5_LARGE_ENCODER", 57 | "FLAN_T5_LARGE", 58 | "FLAN_T5_LARGE_GENERATION", 59 | "FLAN_T5_XL_ENCODER", 60 | "FLAN_T5_XL", 61 | "FLAN_T5_XL_GENERATION", 62 | "FLAN_T5_XXL_ENCODER", 63 | "FLAN_T5_XXL", 64 | "FLAN_T5_XXL_GENERATION", 65 | "T5Transform", 66 | ] 67 | -------------------------------------------------------------------------------- /torchtext/nn/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | 6 | from .modules import * # noqa: F401,F403 7 | -------------------------------------------------------------------------------- /torchtext/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .multiheadattention import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct 2 | 3 | __all__ = ["InProjContainer", "MultiheadAttentionContainer", "ScaledDotProduct"] 4 | -------------------------------------------------------------------------------- /torchtext/prototype/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | 6 | from . import transforms 7 | 8 | __all__ = ["transforms"] 9 | -------------------------------------------------------------------------------- /torchtext/prototype/asset/get_checksum.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | FILENAME=$1 3 | URL=$2 4 | 5 | wget -O - -o /dev/null $URL | sha256sum | head -c 64 > $FILENAME 6 | -------------------------------------------------------------------------------- /torchtext/prototype/asset/get_checksums_fast_text.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import subprocess 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def output_checksums_to_files(dir=".checksums"): 10 | if not os.path.exists(dir): 11 | os.makedirs(dir) 12 | 13 | processes = [] 14 | with open("languages_fast_text.txt", "r") as f: 15 | num_languages = 0 16 | for line in f: 17 | num_languages += 1 18 | language = line.strip() 19 | filepath = "{}/{}.txt".format(dir, language) 20 | url = "https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec".format(language) 21 | processes.append(subprocess.Popen(["./get_checksum.sh", filepath, url])) 22 | 23 | print("Computing checksums") 24 | with tqdm(unit_scale=0, unit="files", total=num_languages) as t: 25 | for p in processes: 26 | p.wait() 27 | t.update(1) 28 | 29 | 30 | def process_checksums_to_json_file(dir=".checksums"): 31 | if not os.path.exists(dir): 32 | os.makedirs(dir) 33 | os.chdir(dir) 34 | 35 | checksums = {} 36 | for file_name in glob.glob("*.txt"): 37 | file_base_name = os.path.splitext(file_name)[0] 38 | url = "https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec".format(file_base_name) 39 | 40 | with open(file_name, "r") as f: 41 | sha256hash = f.readline() 42 | checksums[url] = sha256hash 43 | checksums_json = json.dumps(checksums) 44 | 45 | with open("checksums_fast_text.json", "w") as f: 46 | f.write(checksums_json) 47 | 48 | 49 | def main(): 50 | dir = ".checksums" 51 | json_file_path = os.path.join(os.getcwd(), dir, "checksums_fast_text.json") 52 | output_checksums_to_files(dir=dir) 53 | process_checksums_to_json_file(dir=dir) 54 | 55 | print("Path to FastTest checksum file: {}".format(json_file_path)) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /torchtext/vocab/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | import torchtext 3 | if torchtext._WARN: 4 | warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) 5 | 6 | from .vectors import CharNGram, FastText, GloVe, pretrained_aliases, Vectors 7 | from .vocab import Vocab 8 | from .vocab_factory import build_vocab_from_iterator, vocab 9 | 10 | __all__ = [ 11 | "Vocab", 12 | "vocab", 13 | "build_vocab_from_iterator", 14 | "GloVe", 15 | "FastText", 16 | "CharNGram", 17 | "pretrained_aliases", 18 | "Vectors", 19 | ] 20 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.17.0a0 2 | --------------------------------------------------------------------------------