├── .travis.yml ├── CODEOWNERS ├── CONTRIBUTING-ARCHIVED.md ├── LICENSE ├── README.md ├── arguments.py ├── convert_to_logical_forms.py ├── decaNLP_logo.png ├── dockerfiles ├── cuda_torch03 ├── cuda_torch04 ├── torch03 └── torch04 ├── local_data ├── dev_fine_sent.csv ├── schema.txt ├── test_fine_sent.csv └── train_fine_sent.csv ├── metrics.py ├── models ├── __init__.py ├── coattentive_pointer_generator.py ├── common.py ├── multitask_question_answering_network.py ├── pointer_generator.py └── self_attentive_pointer_generator.py ├── multiprocess ├── __init__.py ├── distributed_data_parallel.py └── multiprocess.py ├── predict.py ├── text ├── .flake8 ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── __init__.py ├── build_tools │ └── travis │ │ ├── after_success.sh │ │ ├── install.sh │ │ └── test_script.sh ├── codecov.yml ├── docs │ ├── Makefile │ ├── make.bat │ └── source │ │ ├── _static │ │ ├── css │ │ │ └── pytorch_theme.css │ │ └── img │ │ │ ├── pytorch-logo-dark.png │ │ │ ├── pytorch-logo-dark.svg │ │ │ ├── pytorch-logo-flame.png │ │ │ └── pytorch-logo-flame.svg │ │ ├── conf.py │ │ ├── data.rst │ │ ├── datasets.rst │ │ └── index.rst ├── pytest.ini ├── setup.py ├── test │ ├── .gitignore │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── test_markers.py │ │ └── torchtext_test_case.py │ ├── conftest.py │ ├── data.py │ ├── data │ │ ├── __init__.py │ │ ├── test_dataset.py │ │ ├── test_field.py │ │ ├── test_pipeline.py │ │ ├── test_subword.py │ │ └── test_utils.py │ ├── imdb.py │ ├── language_modeling.py │ ├── sequence_tagging.py │ ├── snli.py │ ├── sst.py │ ├── test_vocab.py │ ├── translation.py │ └── trec.py └── torchtext │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── batch.py │ ├── dataset.py │ ├── example.py │ ├── field.py │ ├── iterator.py │ ├── pipeline.py │ └── utils.py │ ├── datasets │ ├── __init__.py │ ├── generic.py │ ├── imdb.py │ ├── language_modeling.py │ ├── sequence_tagging.py │ ├── snli.py │ ├── sst.py │ ├── translation.py │ └── trec.py │ ├── utils.py │ └── vocab.py ├── train.py ├── util.py └── validate.py /.travis.yml: -------------------------------------------------------------------------------- 1 | group: travis_latest 2 | language: python 3 | cache: pip 4 | python: 5 | #- 2.7 6 | - 3.6 7 | #- nightly 8 | #- pypy 9 | #- pypy3 10 | matrix: 11 | allow_failures: 12 | - python: nightly 13 | - python: pypy 14 | - python: pypy3 15 | install: 16 | #- pip install -r requirements.txt 17 | - pip install flake8 # pytest # add another testing frameworks later 18 | before_script: 19 | # stop the build if there are Python syntax errors or undefined names 20 | - flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics 21 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 22 | - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 23 | script: 24 | - true # pytest --capture=sys # add other tests here 25 | notifications: 26 | on_success: change 27 | on_failure: change # `always` will be the setting once code changes slow down 28 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related gus information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CONTRIBUTING-ARCHIVED.md: -------------------------------------------------------------------------------- 1 | # ARCHIVED 2 | 3 | This project is `Archived` and is no longer actively maintained; 4 | We are not accepting contributions or Pull Requests. 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018, Salesforce 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | import types 4 | import sys 5 | from argparse import ArgumentParser 6 | import subprocess 7 | import json 8 | import datetime 9 | from dateutil import tz 10 | 11 | 12 | def get_commit(): 13 | directory = os.path.dirname(sys.argv[0]) 14 | return subprocess.Popen("cd {} && git log | head -n 1".format(directory), shell=True, stdout=subprocess.PIPE).stdout.read().split()[1].decode() 15 | 16 | 17 | def save_args(args): 18 | os.makedirs(args.log_dir, exist_ok=args.exist_ok) 19 | with open(os.path.join(args.log_dir, 'config.json'), 'wt') as f: 20 | json.dump(vars(args), f, indent=2) 21 | 22 | 23 | def parse(): 24 | """ 25 | Returns the arguments from the command line. 26 | """ 27 | parser = ArgumentParser() 28 | parser.add_argument('--root', default='/decaNLP', type=str, help='root directory for data, results, embeddings, code, etc.') 29 | parser.add_argument('--data', default='.data/', type=str, help='where to load data from.') 30 | parser.add_argument('--save', default='results', type=str, help='where to save results.') 31 | parser.add_argument('--embeddings', default='.embeddings', type=str, help='where to save embeddings.') 32 | parser.add_argument('--name', default='', type=str, help='name of the experiment; if blank, a name is automatically generated from the arguments') 33 | 34 | parser.add_argument('--train_tasks', nargs='+', type=str, help='tasks to use for training', required=True) 35 | parser.add_argument('--train_iterations', nargs='+', type=int, help='number of iterations to focus on each task') 36 | parser.add_argument('--train_batch_tokens', nargs='+', default=[9000], type=int, help='Number of tokens to use for dynamic batching, corresponging to tasks in train tasks') 37 | parser.add_argument('--jump_start', default=0, type=int, help='number of iterations to give jump started tasks') 38 | parser.add_argument('--n_jump_start', default=0, type=int, help='how many tasks to jump start (presented in order)') 39 | parser.add_argument('--num_print', default=15, type=int, help='how many validation examples with greedy output to print to std out') 40 | 41 | parser.add_argument('--no_tensorboard', action='store_false', dest='tensorboard', help='Turn of tensorboard logging') 42 | parser.add_argument('--log_every', default=int(1e2), type=int, help='how often to log results in # of iterations') 43 | parser.add_argument('--save_every', default=int(1e3), type=int, help='how often to save a checkpoint in # of iterations') 44 | 45 | parser.add_argument('--val_tasks', nargs='+', type=str, help='tasks to collect evaluation metrics for') 46 | parser.add_argument('--val_every', default=int(1e3), type=int, help='how often to run validation in # of iterations') 47 | parser.add_argument('--val_no_filter', action='store_false', dest='val_filter', help='whether to allow filtering on the validation sets') 48 | parser.add_argument('--val_batch_size', nargs='+', default=[256], type=int, help='Batch size for validation corresponding to tasks in val tasks') 49 | 50 | parser.add_argument('--vocab_tasks', nargs='+', type=str, help='tasks to use in the construction of the vocabulary') 51 | parser.add_argument('--max_output_length', default=100, type=int, help='maximum output length for generation') 52 | parser.add_argument('--max_effective_vocab', default=int(1e6), type=int, help='max effective vocabulary size for pretrained embeddings') 53 | parser.add_argument('--max_generative_vocab', default=50000, type=int, help='max vocabulary for the generative softmax') 54 | parser.add_argument('--max_train_context_length', default=400, type=int, help='maximum length of the contexts during training') 55 | parser.add_argument('--max_val_context_length', default=400, type=int, help='maximum length of the contexts during validation') 56 | parser.add_argument('--max_answer_length', default=50, type=int, help='maximum length of answers during training and validation') 57 | parser.add_argument('--subsample', default=20000000, type=int, help='subsample the datasets') 58 | parser.add_argument('--preserve_case', action='store_false', dest='lower', help='whether to preserve casing for all text') 59 | 60 | parser.add_argument('--model', type=str, default='MultitaskQuestionAnsweringNetwork', help='which model to import') 61 | parser.add_argument('--dimension', default=200, type=int, help='output dimensions for all layers') 62 | parser.add_argument('--rnn_layers', default=1, type=int, help='number of layers for RNN modules') 63 | parser.add_argument('--transformer_layers', default=2, type=int, help='number of layers for transformer modules') 64 | parser.add_argument('--transformer_hidden', default=150, type=int, help='hidden size of the transformer modules') 65 | parser.add_argument('--transformer_heads', default=3, type=int, help='number of heads for transformer modules') 66 | parser.add_argument('--dropout_ratio', default=0.2, type=float, help='dropout for the model') 67 | parser.add_argument('--cove', action='store_true', help='whether to use contextualized word vectors (McCann et al. 2017)') 68 | parser.add_argument('--intermediate_cove', action='store_true', help='whether to use the intermediate layers of contextualized word vectors (McCann et al. 2017)') 69 | parser.add_argument('--elmo', default=[-1], nargs='+', type=int, help='which layer(s) (0, 1, or 2) of ELMo (Peters et al. 2018) to use; -1 for none ') 70 | parser.add_argument('--no_glove_and_char', action='store_false', dest='glove_and_char', help='turn off GloVe and CharNGram embeddings') 71 | 72 | parser.add_argument('--warmup', default=800, type=int, help='warmup for learning rate') 73 | parser.add_argument('--grad_clip', default=1.0, type=float, help='gradient clipping') 74 | parser.add_argument('--beta0', default=0.9, type=float, help='alternative momentum for Adam (only when not using transformer_lr)') 75 | parser.add_argument('--optimizer', default='adam', type=str, help='Adam or SGD') 76 | parser.add_argument('--no_transformer_lr', action='store_false', dest='transformer_lr', help='turns off the transformer learning rate strategy') 77 | parser.add_argument('--sgd_lr', default=1.0, type=float, help='learning rate for SGD (if not using Adam)') 78 | 79 | parser.add_argument('--load', default=None, type=str, help='path to checkpoint to load model from inside args.save') 80 | parser.add_argument('--resume', action='store_true', help='whether to resume training with past optimizers') 81 | 82 | parser.add_argument('--seed', default=123, type=int, help='Random seed.') 83 | parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used for training (multi-gpu currently WIP)') 84 | parser.add_argument('--backend', default='gloo', type=str, help='backend for distributed training') 85 | 86 | parser.add_argument('--no_commit', action='store_false', dest='commit', help='do not track the git commit associated with this training run') 87 | parser.add_argument('--exist_ok', action='store_true', help='Ok if the save directory already exists, i.e. overwrite is ok') 88 | parser.add_argument('--token_testing', action='store_true', help='if true, sorts all iterators') 89 | parser.add_argument('--reverse', action='store_true', help='if token_testing and true, sorts all iterators in reverse') 90 | 91 | args = parser.parse_args() 92 | if args.model is None: 93 | args.model = 'mcqa' 94 | if args.val_tasks is None: 95 | args.val_tasks = [] 96 | for t in args.train_tasks: 97 | if t not in args.val_tasks: 98 | args.val_tasks.append(t) 99 | 100 | if 'imdb' in args.val_tasks: 101 | args.val_tasks.remove('imdb') 102 | args.world_size = len(args.devices) if args.devices[0] > -1 else -1 103 | if args.world_size > 1: 104 | print('multi-gpu training is currently a work in progress') 105 | return 106 | args.timestamp = '-'.join(datetime.datetime.now(tz=tz.tzoffset(None, -8*60*60)).strftime("%y/%m/%d/%H/%M/%S.%f").split()) 107 | 108 | if len(args.train_tasks) > 1: 109 | if args.train_iterations is None: 110 | args.train_iterations = [1] 111 | if len(args.train_iterations) < len(args.train_tasks): 112 | args.train_iterations = len(args.train_tasks) * args.train_iterations 113 | if len(args.train_batch_tokens) < len(args.train_tasks): 114 | args.train_batch_tokens = len(args.train_tasks) * args.train_batch_tokens 115 | if len(args.val_batch_size) < len(args.val_tasks): 116 | args.val_batch_size = len(args.val_tasks) * args.val_batch_size 117 | 118 | # postprocess arguments 119 | if args.commit: 120 | args.commit = get_commit() 121 | else: 122 | args.commit = '' 123 | train_out = f'{",".join(args.train_tasks)}' 124 | if len(args.train_tasks) > 1: 125 | train_out += f'{"-".join([str(x) for x in args.train_iterations])}' 126 | args.log_dir = os.path.join(args.save, args.timestamp, 127 | f'{train_out}{(",val=" + ",".join(args.val_tasks)) if args.val_tasks != args.train_tasks else ""},{args.model},' \ 128 | f'{args.world_size}g', 129 | args.commit[:7]) 130 | if len(args.name) > 0: 131 | args.log_dir = os.path.join(args.save, args.name) 132 | args.dist_sync_file = os.path.join(args.log_dir, 'distributed_sync_file') 133 | 134 | for x in ['data', 'save', 'embeddings', 'log_dir', 'dist_sync_file']: 135 | setattr(args, x, os.path.join(args.root, getattr(args, x))) 136 | save_args(args) 137 | 138 | return args 139 | -------------------------------------------------------------------------------- /convert_to_logical_forms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from text.torchtext.datasets.generic import Query 3 | from argparse import ArgumentParser 4 | import os 5 | import re 6 | import ujson as json 7 | from metrics import to_lf 8 | 9 | 10 | def correct_format(x): 11 | if len(x.keys()) == 0: 12 | x = {'query': None, 'error': 'Invalid'} 13 | else: 14 | c = x['conds'] 15 | proper = True 16 | for cc in c: 17 | if len(cc) < 3: 18 | proper = False 19 | if proper: 20 | x = {'query': x, 'error': ''} 21 | else: 22 | x = {'query': None, 'error': 'Invalid'} 23 | return x 24 | 25 | 26 | def write_logical_forms(greedy, args): 27 | data_dir = os.path.join(args.data, 'wikisql', 'data') 28 | path = os.path.join(data_dir, 'dev.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.jsonl') 29 | table_path = os.path.join(data_dir, 'dev.tables.jsonl') if 'valid' in args.evaluate else os.path.join(data_dir, 'test.tables.jsonl') 30 | with open(table_path) as tables_file: 31 | tables = [json.loads(line) for line in tables_file] 32 | id_to_tables = {x['id']: x for x in tables} 33 | 34 | examples = [] 35 | with open(path) as example_file: 36 | for line in example_file: 37 | entry = json.loads(line) 38 | table = id_to_tables[entry['table_id']] 39 | sql = entry['sql'] 40 | header = table['header'] 41 | a = repr(Query.from_dict(entry['sql'], table['header'])) 42 | ex = {'sql': sql, 'header': header, 'answer': a, 'table': table} 43 | examples.append(ex) 44 | 45 | with open(args.output, 'a') as f: 46 | count = 0 47 | correct = 0 48 | text_answers = [] 49 | for idx, (g, ex) in enumerate(zip(greedy, examples)): 50 | count += 1 51 | text_answers.append([ex['answer'].lower()]) 52 | try: 53 | lf = to_lf(g, ex['table']) 54 | f.write(json.dumps(correct_format(lf)) + '\n') 55 | gt = ex['sql'] 56 | conds = gt['conds'] 57 | lower_conds = [] 58 | for c in conds: 59 | lc = c 60 | lc[2] = str(lc[2]).lower() 61 | lower_conds.append(lc) 62 | gt['conds'] = lower_conds 63 | correct += lf == gt 64 | except Exception as e: 65 | f.write(json.dumps(correct_format({})) + '\n') 66 | 67 | if __name__ == '__main__': 68 | parser = ArgumentParser() 69 | parser.add_argument('data', help='path to the directory containing data for WikiSQL') 70 | parser.add_argument('predictions', help='path to prediction file, containing one prediction per line') 71 | parser.add_argument('ids', help='path to file for indices, a list of integers indicating the index into the dev/test set of the predictions on the corresponding line in \'predicitons\'') 72 | parser.add_argument('output', help='path for logical forms output line by line') 73 | parser.add_argument('evaluate', help='running on the \'validation\' or \'test\' set') 74 | args = parser.parse_args() 75 | with open(args.predictions) as f: 76 | greedy = [l for l in f] 77 | if args.ids is not None: 78 | with open(args.ids) as f: 79 | ids = [int(l.strip()) for l in f] 80 | greedy = [x[1] for x in sorted([(i, g) for i, g in zip(ids, greedy)])] 81 | write_logical_forms(greedy, args) 82 | -------------------------------------------------------------------------------- /decaNLP_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/decaNLP_logo.png -------------------------------------------------------------------------------- /dockerfiles/cuda_torch03: -------------------------------------------------------------------------------- 1 | # docker build --no-cache multitasking . 2 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 3 | 4 | RUN apt-get update \ 5 | && apt-get install -y --no-install-recommends \ 6 | git \ 7 | ssh \ 8 | build-essential \ 9 | locales \ 10 | ca-certificates \ 11 | curl \ 12 | unzip 13 | 14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 15 | chmod +x ~/miniconda.sh && \ 16 | ~/miniconda.sh -b -p /opt/conda && \ 17 | rm ~/miniconda.sh && \ 18 | /opt/conda/bin/conda install conda-build python=3.6.3 numpy pyyaml mkl&& \ 19 | /opt/conda/bin/conda clean -ya 20 | ENV PATH /opt/conda/bin:$PATH 21 | 22 | # Default to utf-8 encodings in python 23 | # Can verify in container with: 24 | # python -c 'import locale; print(locale.getpreferredencoding(False))' 25 | RUN locale-gen en_US.UTF-8 26 | ENV LANG en_US.UTF-8 27 | ENV LANGUAGE en_US:en 28 | ENV LC_ALL en_US.UTF-8 29 | 30 | RUN conda install -c pytorch pytorch=0.3 cuda90 31 | 32 | # Revtok 33 | RUN pip install -e git+https://github.com/jekbradbury/revtok.git#egg=revtok 34 | 35 | # torchtext requirements 36 | RUN pip install tqdm 37 | RUN pip install nltk==3.2.5 38 | 39 | # tensorboard 40 | RUN pip install tensorboardX 41 | RUN pip install tensorboard 42 | RUN pip install tensorflow 43 | RUN pip install python-dateutil 44 | 45 | # additional python packages 46 | RUN pip install ujson 47 | RUN pip install -e git+git://github.com/andersjo/pyrouge.git#egg=pyrouge 48 | RUN cd /src/pyrouge/pyrouge/../tools/ROUGE-1.5.5/data/ && rm WordNet-2.0.exc.db && ./WordNet-2.0-Exceptions/buildExeptionDB.pl ./WordNet-2.0-Exceptions ./smart_common_words.txt ./WordNet-2.0.exc.db && chmod 777 WordNet-2.0.exc.db 49 | #RUN pip install lxml 50 | RUN pip install sacrebleu 51 | 52 | # Install packages for XML processing 53 | RUN apt-get install --yes \ 54 | expat \ 55 | libexpat-dev \ 56 | libxml2-dev \ 57 | libxslt1-dev \ 58 | libgdbm-dev \ 59 | libxml-libxslt-perl \ 60 | libxml-libxml-perl \ 61 | python-lxml 62 | 63 | # WikISQL evaluation 64 | RUN pip install records 65 | RUN pip install babel 66 | RUN pip install tabulate 67 | 68 | RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove 69 | 70 | CMD bash 71 | -------------------------------------------------------------------------------- /dockerfiles/cuda_torch04: -------------------------------------------------------------------------------- 1 | # docker build --no-cache multitasking . 2 | FROM nvidia/cuda:9.0-cudnn7-devel-ubuntu16.04 3 | 4 | RUN apt-get update \ 5 | && apt-get install -y --no-install-recommends \ 6 | git \ 7 | ssh \ 8 | build-essential \ 9 | locales \ 10 | ca-certificates \ 11 | curl \ 12 | unzip 13 | 14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 15 | chmod +x ~/miniconda.sh && \ 16 | ~/miniconda.sh -b -p /opt/conda && \ 17 | rm ~/miniconda.sh && \ 18 | /opt/conda/bin/conda install conda-build python=3.6.3 numpy pyyaml mkl&& \ 19 | /opt/conda/bin/conda clean -ya 20 | ENV PATH /opt/conda/bin:$PATH 21 | 22 | # Default to utf-8 encodings in python 23 | # Can verify in container with: 24 | # python -c 'import locale; print(locale.getpreferredencoding(False))' 25 | RUN locale-gen en_US.UTF-8 26 | ENV LANG en_US.UTF-8 27 | ENV LANGUAGE en_US:en 28 | ENV LC_ALL en_US.UTF-8 29 | 30 | RUN conda install -c pytorch pytorch=0.4.1 cuda90 31 | 32 | # Revtok 33 | RUN pip install -e git+https://github.com/jekbradbury/revtok.git#egg=revtok 34 | 35 | # torchtext requirements 36 | RUN pip install tqdm 37 | RUN pip install nltk==3.2.5 38 | 39 | # tensorboard 40 | RUN pip install tensorboardX 41 | RUN pip install tensorboard 42 | RUN pip install tensorflow 43 | RUN pip install python-dateutil 44 | 45 | # additional python packages 46 | RUN pip install ujson 47 | RUN pip install -e git+git://github.com/andersjo/pyrouge.git#egg=pyrouge 48 | RUN cd /src/pyrouge/pyrouge/../tools/ROUGE-1.5.5/data/ && rm WordNet-2.0.exc.db && ./WordNet-2.0-Exceptions/buildExeptionDB.pl ./WordNet-2.0-Exceptions ./smart_common_words.txt ./WordNet-2.0.exc.db && chmod 777 WordNet-2.0.exc.db 49 | #RUN pip install lxml 50 | RUN pip install sacrebleu 51 | 52 | # Install packages for XML processing 53 | RUN apt-get install --yes \ 54 | expat \ 55 | libexpat-dev \ 56 | libxml2-dev \ 57 | libxslt1-dev \ 58 | libgdbm-dev \ 59 | libxml-libxslt-perl \ 60 | libxml-libxml-perl \ 61 | python-lxml 62 | 63 | # WikISQL evaluation 64 | RUN pip install records 65 | RUN pip install babel 66 | RUN pip install tabulate 67 | 68 | RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove 69 | RUN pip install allennlp 70 | 71 | CMD bash 72 | -------------------------------------------------------------------------------- /dockerfiles/torch03: -------------------------------------------------------------------------------- 1 | # docker build --no-cache multitasking . 2 | FROM ubuntu:16.04 3 | 4 | RUN apt-get update \ 5 | && apt-get install -y --no-install-recommends \ 6 | git \ 7 | ssh \ 8 | build-essential \ 9 | locales \ 10 | ca-certificates \ 11 | curl \ 12 | unzip 13 | 14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 15 | chmod +x ~/miniconda.sh && \ 16 | ~/miniconda.sh -b -p /opt/conda && \ 17 | rm ~/miniconda.sh && \ 18 | /opt/conda/bin/conda install conda-build python=3.6.3 numpy pyyaml mkl&& \ 19 | /opt/conda/bin/conda clean -ya 20 | ENV PATH /opt/conda/bin:$PATH 21 | 22 | # Default to utf-8 encodings in python 23 | # Can verify in container with: 24 | # python -c 'import locale; print(locale.getpreferredencoding(False))' 25 | RUN locale-gen en_US.UTF-8 26 | ENV LANG en_US.UTF-8 27 | ENV LANGUAGE en_US:en 28 | ENV LC_ALL en_US.UTF-8 29 | 30 | RUN conda install -c pytorch pytorch=0.3 31 | 32 | # Revtok 33 | RUN pip install -e git+https://github.com/jekbradbury/revtok.git#egg=revtok 34 | 35 | # torchtext requirements 36 | RUN pip install tqdm 37 | RUN pip install nltk==3.2.5 38 | 39 | # tensorboard 40 | RUN pip install tensorboardX 41 | RUN pip install tensorboard 42 | RUN pip install tensorflow 43 | RUN pip install python-dateutil 44 | 45 | # additional python packages 46 | RUN pip install ujson 47 | RUN pip install -e git+git://github.com/andersjo/pyrouge.git#egg=pyrouge 48 | RUN cd /src/pyrouge/pyrouge/../tools/ROUGE-1.5.5/data/ && rm WordNet-2.0.exc.db && ./WordNet-2.0-Exceptions/buildExeptionDB.pl ./WordNet-2.0-Exceptions ./smart_common_words.txt ./WordNet-2.0.exc.db && chmod 777 WordNet-2.0.exc.db 49 | #RUN pip install lxml 50 | RUN pip install sacrebleu 51 | 52 | # Install packages for XML processing 53 | RUN apt-get install --yes \ 54 | expat \ 55 | libexpat-dev \ 56 | libxml2-dev \ 57 | libxslt1-dev \ 58 | libgdbm-dev \ 59 | libxml-libxslt-perl \ 60 | libxml-libxml-perl \ 61 | python-lxml 62 | 63 | # WikISQL evaluation 64 | RUN pip install records 65 | RUN pip install babel 66 | RUN pip install tabulate 67 | RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove 68 | 69 | CMD bash 70 | -------------------------------------------------------------------------------- /dockerfiles/torch04: -------------------------------------------------------------------------------- 1 | # docker build --no-cache multitasking . 2 | FROM ubuntu:16.04 3 | 4 | RUN apt-get update \ 5 | && apt-get install -y --no-install-recommends \ 6 | git \ 7 | ssh \ 8 | build-essential \ 9 | locales \ 10 | ca-certificates \ 11 | curl \ 12 | unzip 13 | 14 | RUN curl -o ~/miniconda.sh -O https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 15 | chmod +x ~/miniconda.sh && \ 16 | ~/miniconda.sh -b -p /opt/conda && \ 17 | rm ~/miniconda.sh && \ 18 | /opt/conda/bin/conda install conda-build python=3.6.3 numpy pyyaml mkl&& \ 19 | /opt/conda/bin/conda clean -ya 20 | ENV PATH /opt/conda/bin:$PATH 21 | 22 | # Default to utf-8 encodings in python 23 | # Can verify in container with: 24 | # python -c 'import locale; print(locale.getpreferredencoding(False))' 25 | RUN locale-gen en_US.UTF-8 26 | ENV LANG en_US.UTF-8 27 | ENV LANGUAGE en_US:en 28 | ENV LC_ALL en_US.UTF-8 29 | 30 | RUN conda install -c pytorch pytorch=0.4.1 31 | 32 | # Revtok 33 | RUN pip install -e git+https://github.com/jekbradbury/revtok.git#egg=revtok 34 | 35 | # torchtext requirements 36 | RUN pip install tqdm 37 | RUN pip install nltk==3.2.5 38 | 39 | # tensorboard 40 | RUN pip install tensorboardX 41 | RUN pip install tensorboard 42 | RUN pip install tensorflow 43 | RUN pip install python-dateutil 44 | 45 | # additional python packages 46 | RUN pip install ujson 47 | RUN pip install -e git+git://github.com/andersjo/pyrouge.git#egg=pyrouge 48 | RUN cd /src/pyrouge/pyrouge/../tools/ROUGE-1.5.5/data/ && rm WordNet-2.0.exc.db && ./WordNet-2.0-Exceptions/buildExeptionDB.pl ./WordNet-2.0-Exceptions ./smart_common_words.txt ./WordNet-2.0.exc.db && chmod 777 WordNet-2.0.exc.db 49 | #RUN pip install lxml 50 | RUN pip install sacrebleu 51 | 52 | # Install packages for XML processing 53 | RUN apt-get install --yes \ 54 | expat \ 55 | libexpat-dev \ 56 | libxml2-dev \ 57 | libxslt1-dev \ 58 | libgdbm-dev \ 59 | libxml-libxslt-perl \ 60 | libxml-libxml-perl \ 61 | python-lxml 62 | 63 | # WikISQL evaluation 64 | RUN pip install records 65 | RUN pip install babel 66 | RUN pip install tabulate 67 | 68 | RUN pip install -e git+git://github.com/salesforce/cove.git#egg=cove 69 | RUN pip install allennlp 70 | 71 | CMD bash 72 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .multitask_question_answering_network import MultitaskQuestionAnsweringNetwork 2 | from .coattentive_pointer_generator import CoattentivePointerGenerator 3 | from .self_attentive_pointer_generator import SelfAttentivePointerGenerator 4 | from .pointer_generator import PointerGenerator 5 | -------------------------------------------------------------------------------- /models/coattentive_pointer_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask, CoattentiveLayer 10 | 11 | 12 | class CoattentivePointerGenerator(nn.Module): 13 | 14 | def __init__(self, field, args): 15 | super().__init__() 16 | self.field = field 17 | self.args = args 18 | self.pad_idx = self.field.vocab.stoi[self.field.pad_token] 19 | 20 | self.encoder_embeddings = Embedding(field, args.dimension, 21 | dropout=args.dropout_ratio) 22 | self.decoder_embeddings = Embedding(field, args.dimension, 23 | dropout=args.dropout_ratio) 24 | 25 | 26 | self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, 27 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, num_layers=1) 28 | self.coattention = CoattentiveLayer(args.dimension, dropout=0.3) 29 | dim = 2*args.dimension + args.dimension + args.dimension 30 | 31 | self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension, 32 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, 33 | num_layers=args.rnn_layers) 34 | self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) 35 | self.bilstm_context = PackedLSTM(args.dimension, args.dimension, 36 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, 37 | num_layers=args.rnn_layers) 38 | 39 | self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) 40 | self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension, 41 | dropout=args.dropout_ratio, num_layers=args.rnn_layers) 42 | 43 | self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab) 44 | self.out = nn.Linear(args.dimension, self.generative_vocab_size) 45 | 46 | self.dropout = nn.Dropout(0.4) 47 | 48 | def set_embeddings(self, embeddings): 49 | self.encoder_embeddings.set_embeddings(embeddings) 50 | self.decoder_embeddings.set_embeddings(embeddings) 51 | 52 | 53 | def forward(self, batch): 54 | context, context_lengths, context_limited = batch.context, batch.context_lengths, batch.context_limited 55 | question, question_lengths, question_limited = batch.question, batch.question_lengths, batch.question_limited 56 | answer, answer_lengths, answer_limited = batch.answer, batch.answer_lengths, batch.answer_limited 57 | oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx 58 | 59 | def map_to_full(x): 60 | return limited_idx_to_full_idx[x] 61 | self.map_to_full = map_to_full 62 | 63 | context_embedded = self.encoder_embeddings(context) 64 | question_embedded = self.encoder_embeddings(question) 65 | 66 | context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] 67 | question_encoded = self.bilstm_before_coattention(question_embedded, question_lengths)[0] 68 | 69 | context_padding = context.data == self.pad_idx 70 | question_padding = question.data == self.pad_idx 71 | 72 | coattended_context = self.coattention(context_encoded, question_encoded, context_padding, question_padding) 73 | 74 | context_summary = torch.cat([coattended_context, context_encoded, context_embedded], -1) 75 | condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths) 76 | self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding) 77 | final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], context_lengths) 78 | context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] 79 | 80 | context_indices = context_limited if context_limited is not None else context 81 | answer_indices = answer_limited if answer_limited is not None else answer 82 | 83 | pad_idx = self.field.decoder_stoi[self.field.pad_token] 84 | context_padding = context_indices.data == pad_idx 85 | 86 | self.dual_ptr_rnn_decoder.applyMasks(context_padding) 87 | 88 | if self.training: 89 | answer_padding = answer_indices.data == pad_idx 90 | answer_embedded = self.decoder_embeddings(answer) 91 | self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), self_attended_context, context_padding=context_padding, answer_padding=answer_padding[:, :-1], positional_encodings=True) 92 | decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded, 93 | final_context, hidden=context_rnn_state) 94 | rnn_output, context_attention, context_alignment, vocab_pointer_switch, rnn_state = decoder_outputs 95 | 96 | probs = self.probs(self.out, rnn_output, vocab_pointer_switch, 97 | context_attention, 98 | context_indices, 99 | oov_to_limited_idx) 100 | 101 | probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx) 102 | loss = F.nll_loss(probs.log(), targets) 103 | return loss, None 104 | else: 105 | return None, self.greedy(self_attended_context, final_context, 106 | context_indices, 107 | oov_to_limited_idx, rnn_state=context_rnn_state).data 108 | 109 | def reshape_rnn_state(self, h): 110 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ 111 | .transpose(1, 2).contiguous() \ 112 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() 113 | 114 | def probs(self, generator, outputs, vocab_pointer_switches, 115 | context_attention, 116 | context_indices, 117 | oov_to_limited_idx): 118 | 119 | 120 | size = list(outputs.size()) 121 | 122 | size[-1] = self.generative_vocab_size 123 | scores = generator(outputs.view(-1, outputs.size(-1))).view(size) 124 | p_vocab = F.softmax(scores, dim=scores.dim()-1) 125 | scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab 126 | 127 | effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx) 128 | if self.generative_vocab_size < effective_vocab_size: 129 | size[-1] = effective_vocab_size - self.generative_vocab_size 130 | buff = scaled_p_vocab.new_full(size, EPSILON) 131 | scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1) 132 | 133 | p_context_ptr = scaled_p_vocab.new_full(scaled_p_vocab.size(), EPSILON) 134 | p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention) 135 | scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr 136 | 137 | probs = scaled_p_vocab + scaled_p_context_ptr 138 | return probs 139 | 140 | 141 | def greedy(self, self_attended_context, context, context_indices, oov_to_limited_idx, rnn_state=None): 142 | B, TC, C = context.size() 143 | T = self.args.max_output_length 144 | outs = context.new_full((B, T), self.field.decoder_stoi[''], dtype=torch.long) 145 | hiddens = [self_attended_context[0].new_zeros((B, T, C)) 146 | for l in range(len(self.self_attentive_decoder.layers) + 1)] 147 | 148 | hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) 149 | eos_yet = context.data.new(B).byte().zero_() 150 | 151 | rnn_output, context_alignment = None, None 152 | for t in range(T): 153 | if t == 0: 154 | embedding = self.decoder_embeddings( 155 | self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi[''], dtype=torch.long), [1]*B) 156 | else: 157 | embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B) 158 | hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1) 159 | for l in range(len(self.self_attentive_decoder.layers)): 160 | hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward( 161 | self.self_attentive_decoder.layers[l].attention( 162 | self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1]) 163 | , self_attended_context[l], self_attended_context[l])) 164 | decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1), 165 | context, 166 | context_alignment=context_alignment, 167 | hidden=rnn_state, output=rnn_output) 168 | 169 | rnn_output, context_attention, context_alignment, vocab_pointer_switch, rnn_state = decoder_outputs 170 | probs = self.probs(self.out, rnn_output, vocab_pointer_switch, 171 | context_attention, 172 | context_indices, 173 | oov_to_limited_idx) 174 | pred_probs, preds = probs.max(-1) 175 | preds = preds.squeeze(1) 176 | eos_yet = eos_yet | (preds == self.field.decoder_stoi['']) 177 | outs[:, t] = preds.cpu().apply_(self.map_to_full) 178 | if eos_yet.all(): 179 | break 180 | return outs 181 | 182 | 183 | class DualPtrRNNDecoder(nn.Module): 184 | 185 | def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1): 186 | super().__init__() 187 | self.d_hid = d_hid 188 | self.d_in = d_in 189 | self.num_layers = num_layers 190 | self.dropout = nn.Dropout(dropout) 191 | 192 | self.input_feed = True 193 | if self.input_feed: 194 | d_in += 1 * d_hid 195 | 196 | self.rnn = LSTMDecoder(self.num_layers, d_in, d_hid, dropout) 197 | self.context_attn = LSTMDecoderAttention(d_hid, dot=True) 198 | 199 | self.vocab_pointer_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid()) 200 | 201 | def forward(self, input, context, output=None, hidden=None, context_alignment=None): 202 | context_output = output.squeeze(1) if output is not None else self.make_init_output(context) 203 | context_alignment = context_alignment if context_alignment is not None else self.make_init_output(context) 204 | 205 | context_outputs, vocab_pointer_switches, context_attentions, context_alignments = [], [], [], [] 206 | for emb_t in input.split(1, dim=1): 207 | emb_t = emb_t.squeeze(1) 208 | context_output = self.dropout(context_output) 209 | if self.input_feed: 210 | emb_t = torch.cat([emb_t, context_output], 1) 211 | dec_state, hidden = self.rnn(emb_t, hidden) 212 | context_output, context_attention, context_alignment = self.context_attn(dec_state, context) 213 | vocab_pointer_switch = self.vocab_pointer_switch(torch.cat([dec_state, context_output, emb_t], -1)) 214 | context_output = self.dropout(context_output) 215 | context_outputs.append(context_output) 216 | vocab_pointer_switches.append(vocab_pointer_switch) 217 | context_attentions.append(context_attention) 218 | context_alignments.append(context_alignment) 219 | context_outputs, vocab_pointer_switches, context_attention = [self.package_outputs(x) for x in [context_outputs, vocab_pointer_switches, context_attentions]] 220 | return context_outputs, context_attention, context_alignment, vocab_pointer_switches, hidden 221 | 222 | def applyMasks(self, context_mask): 223 | self.context_attn.applyMasks(context_mask) 224 | 225 | def make_init_output(self, context): 226 | batch_size = context.size(0) 227 | h_size = (batch_size, self.d_hid) 228 | return context.new_zeros(h_size) 229 | 230 | def package_outputs(self, outputs): 231 | outputs = torch.stack(outputs, dim=1) 232 | return outputs 233 | -------------------------------------------------------------------------------- /models/pointer_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask 10 | 11 | 12 | class PointerGenerator(nn.Module): 13 | 14 | def __init__(self, field, args): 15 | super().__init__() 16 | self.field = field 17 | self.args = args 18 | self.pad_idx = self.field.vocab.stoi[self.field.pad_token] 19 | 20 | self.encoder_embeddings = Embedding(field, args.dimension, 21 | dropout=args.dropout_ratio) 22 | self.decoder_embeddings = Embedding(field, args.dimension, 23 | dropout=args.dropout_ratio) 24 | 25 | 26 | self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, 27 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, num_layers=args.rnn_layers) 28 | self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension, 29 | dropout=args.dropout_ratio, num_layers=args.rnn_layers) 30 | 31 | self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab) 32 | self.out = nn.Linear(args.dimension, self.generative_vocab_size) 33 | 34 | self.dropout = nn.Dropout(0.4) 35 | 36 | def set_embeddings(self, embeddings): 37 | self.encoder_embeddings.set_embeddings(embeddings) 38 | self.decoder_embeddings.set_embeddings(embeddings) 39 | 40 | 41 | def forward(self, batch): 42 | context, context_lengths, context_limited = batch.context_question, batch.context_question_lengths, batch.context_question_limited 43 | answer, answer_lengths, answer_limited = batch.answer, batch.answer_lengths, batch.answer_limited 44 | oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx 45 | 46 | def map_to_full(x): 47 | return limited_idx_to_full_idx[x] 48 | self.map_to_full = map_to_full 49 | 50 | context_embedded = self.encoder_embeddings(context) 51 | 52 | context_encoded, (context_rnn_h, context_rnn_c) = self.bilstm_before_coattention(context_embedded, context_lengths) 53 | context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] 54 | 55 | context_padding = context.data == self.pad_idx 56 | context_indices = context_limited if context_limited is not None else context 57 | answer_indices = answer_limited if answer_limited is not None else answer 58 | 59 | pad_idx = self.field.decoder_stoi[self.field.pad_token] 60 | context_padding = context_indices.data == pad_idx 61 | self.dual_ptr_rnn_decoder.applyMasks(context_padding) 62 | 63 | if self.training: 64 | answer_embedded = self.decoder_embeddings(answer) 65 | decoder_outputs = self.dual_ptr_rnn_decoder(answer_embedded[:, :-1].contiguous(), 66 | context_encoded, hidden=context_rnn_state) 67 | rnn_output, context_attention, context_alignment, vocab_pointer_switch, rnn_state = decoder_outputs 68 | 69 | probs = self.probs(self.out, rnn_output, vocab_pointer_switch, 70 | context_attention, 71 | context_indices, 72 | oov_to_limited_idx) 73 | 74 | probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx) 75 | loss = F.nll_loss(probs.log(), targets) 76 | return loss, None 77 | else: 78 | return None, self.greedy(context_encoded, 79 | context_indices, 80 | oov_to_limited_idx, rnn_state=context_rnn_state).data 81 | 82 | def reshape_rnn_state(self, h): 83 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ 84 | .transpose(1, 2).contiguous() \ 85 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() 86 | 87 | def probs(self, generator, outputs, vocab_pointer_switches, 88 | context_attention, 89 | context_indices, 90 | oov_to_limited_idx): 91 | 92 | 93 | size = list(outputs.size()) 94 | 95 | size[-1] = self.generative_vocab_size 96 | scores = generator(outputs.view(-1, outputs.size(-1))).view(size) 97 | p_vocab = F.softmax(scores, dim=scores.dim()-1) 98 | scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab 99 | 100 | effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx) 101 | if self.generative_vocab_size < effective_vocab_size: 102 | size[-1] = effective_vocab_size - self.generative_vocab_size 103 | buff = scaled_p_vocab.new_full(size, EPSILON) 104 | scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1) 105 | 106 | p_context_ptr = scaled_p_vocab.new_full(scaled_p_vocab.size(), EPSILON) 107 | p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention) 108 | scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr 109 | probs = scaled_p_vocab + scaled_p_context_ptr 110 | return probs 111 | 112 | 113 | def greedy(self, context, context_indices, oov_to_limited_idx, rnn_state=None): 114 | B, TC, C = context.size() 115 | T = self.args.max_output_length 116 | outs = context.new_full((B, T), self.field.decoder_stoi[''], dtype=torch.long) 117 | eos_yet = context.data.new(B).byte().zero_() 118 | 119 | rnn_output, context_alignment = None, None 120 | for t in range(T): 121 | if t == 0: 122 | embedding = self.decoder_embeddings( 123 | context[-1].new_full((B, 1), self.field.vocab.stoi[''], dtype=torch.long), [1]*B) 124 | 125 | else: 126 | embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B) 127 | decoder_outputs = self.dual_ptr_rnn_decoder(embedding, #hiddens[-1][:, t].unsqueeze(1), 128 | context, 129 | context_alignment=context_alignment, 130 | hidden=rnn_state, output=rnn_output) 131 | 132 | rnn_output, context_attention, context_alignment, vocab_pointer_switch, rnn_state = decoder_outputs 133 | probs = self.probs(self.out, rnn_output, vocab_pointer_switch, 134 | context_attention, 135 | context_indices, 136 | oov_to_limited_idx) 137 | pred_probs, preds = probs.max(-1) 138 | preds = preds.squeeze(1) 139 | eos_yet = eos_yet | (preds == self.field.decoder_stoi['']) 140 | outs[:, t] = preds.cpu().apply_(self.map_to_full) 141 | if eos_yet.all(): 142 | break 143 | return outs 144 | 145 | 146 | class DualPtrRNNDecoder(nn.Module): 147 | 148 | def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1): 149 | super().__init__() 150 | self.d_hid = d_hid 151 | self.d_in = d_in 152 | self.num_layers = num_layers 153 | self.dropout = nn.Dropout(dropout) 154 | 155 | self.input_feed = True 156 | if self.input_feed: 157 | d_in += 1 * d_hid 158 | 159 | self.rnn = LSTMDecoder(self.num_layers, d_in, d_hid, dropout) 160 | self.context_attn = LSTMDecoderAttention(d_hid, dot=True) 161 | 162 | self.vocab_pointer_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid()) 163 | 164 | def forward(self, input, context, output=None, hidden=None, context_alignment=None): 165 | context_output = output.squeeze(1) if output is not None else self.make_init_output(context) 166 | context_alignment = context_alignment if context_alignment is not None else self.make_init_output(context) 167 | 168 | context_outputs, vocab_pointer_switches, context_attentions, context_alignments = [], [], [], [] 169 | for emb_t in input.split(1, dim=1): 170 | emb_t = emb_t.squeeze(1) 171 | context_output = self.dropout(context_output) 172 | if self.input_feed: 173 | emb_t = torch.cat([emb_t, context_output], 1) 174 | dec_state, hidden = self.rnn(emb_t, hidden) 175 | context_output, context_attention, context_alignment = self.context_attn(dec_state, context) 176 | vocab_pointer_switch = self.vocab_pointer_switch(torch.cat([dec_state, context_output, emb_t], -1)) 177 | context_output = self.dropout(context_output) 178 | context_outputs.append(context_output) 179 | vocab_pointer_switches.append(vocab_pointer_switch) 180 | context_attentions.append(context_attention) 181 | context_alignments.append(context_alignment) 182 | context_outputs, vocab_pointer_switches, context_attention = [self.package_outputs(x) for x in [context_outputs, vocab_pointer_switches, context_attentions]] 183 | return context_outputs, context_attention, context_alignment, vocab_pointer_switches, hidden 184 | 185 | 186 | def applyMasks(self, context_mask): 187 | self.context_attn.applyMasks(context_mask) 188 | 189 | def make_init_output(self, context): 190 | batch_size = context.size(0) 191 | h_size = (batch_size, self.d_hid) 192 | return context.new_zeros(h_size) 193 | 194 | def package_outputs(self, outputs): 195 | outputs = torch.stack(outputs, dim=1) 196 | return outputs 197 | -------------------------------------------------------------------------------- /models/self_attentive_pointer_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | from .common import positional_encodings_like, INF, EPSILON, TransformerEncoder, TransformerDecoder, PackedLSTM, LSTMDecoderAttention, LSTMDecoder, Embedding, Feedforward, mask 10 | 11 | 12 | class SelfAttentivePointerGenerator(nn.Module): 13 | 14 | def __init__(self, field, args): 15 | super().__init__() 16 | self.field = field 17 | self.args = args 18 | self.pad_idx = self.field.vocab.stoi[self.field.pad_token] 19 | 20 | self.encoder_embeddings = Embedding(field, args.dimension, 21 | dropout=args.dropout_ratio) 22 | self.decoder_embeddings = Embedding(field, args.dimension, 23 | dropout=args.dropout_ratio) 24 | 25 | 26 | self.bilstm_before_coattention = PackedLSTM(args.dimension, args.dimension, 27 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, num_layers=1) 28 | dim = args.dimension + args.dimension 29 | 30 | self.context_bilstm_after_coattention = PackedLSTM(dim, args.dimension, 31 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, 32 | num_layers=args.rnn_layers) 33 | self.self_attentive_encoder_context = TransformerEncoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) 34 | self.bilstm_context = PackedLSTM(args.dimension, args.dimension, 35 | batch_first=True, dropout=args.dropout_ratio, bidirectional=True, 36 | num_layers=args.rnn_layers) 37 | 38 | self.self_attentive_decoder = TransformerDecoder(args.dimension, args.transformer_heads, args.transformer_hidden, args.transformer_layers, args.dropout_ratio) 39 | self.dual_ptr_rnn_decoder = DualPtrRNNDecoder(args.dimension, args.dimension, 40 | dropout=args.dropout_ratio, num_layers=args.rnn_layers) 41 | 42 | self.generative_vocab_size = min(len(field.vocab), args.max_generative_vocab) 43 | self.out = nn.Linear(args.dimension, self.generative_vocab_size) 44 | 45 | self.dropout = nn.Dropout(0.4) 46 | 47 | def set_embeddings(self, embeddings): 48 | self.encoder_embeddings.set_embeddings(embeddings) 49 | self.decoder_embeddings.set_embeddings(embeddings) 50 | 51 | 52 | def forward(self, batch): 53 | context, context_lengths, context_limited = batch.context_question, batch.context_question_lengths, batch.context_question_limited 54 | answer, answer_lengths, answer_limited = batch.answer, batch.answer_lengths, batch.answer_limited 55 | oov_to_limited_idx, limited_idx_to_full_idx = batch.oov_to_limited_idx, batch.limited_idx_to_full_idx 56 | 57 | def map_to_full(x): 58 | return limited_idx_to_full_idx[x] 59 | self.map_to_full = map_to_full 60 | 61 | context_embedded = self.encoder_embeddings(context) 62 | 63 | context_encoded = self.bilstm_before_coattention(context_embedded, context_lengths)[0] 64 | 65 | context_padding = context.data == self.pad_idx 66 | 67 | context_summary = torch.cat([context_encoded, context_embedded], -1) 68 | condensed_context, _ = self.context_bilstm_after_coattention(context_summary, context_lengths) 69 | self_attended_context = self.self_attentive_encoder_context(condensed_context, padding=context_padding) 70 | final_context, (context_rnn_h, context_rnn_c) = self.bilstm_context(self_attended_context[-1], context_lengths) 71 | context_rnn_state = [self.reshape_rnn_state(x) for x in (context_rnn_h, context_rnn_c)] 72 | 73 | context_indices = context_limited if context_limited is not None else context 74 | answer_indices = answer_limited if answer_limited is not None else answer 75 | 76 | pad_idx = self.field.decoder_stoi[self.field.pad_token] 77 | context_padding = context_indices.data == pad_idx 78 | 79 | self.dual_ptr_rnn_decoder.applyMasks(context_padding) 80 | 81 | if self.training: 82 | answer_padding = answer_indices.data == pad_idx 83 | answer_embedded = self.decoder_embeddings(answer) 84 | self_attended_decoded = self.self_attentive_decoder(answer_embedded[:, :-1].contiguous(), self_attended_context, context_padding=context_padding, answer_padding=answer_padding[:, :-1], positional_encodings=True) 85 | decoder_outputs = self.dual_ptr_rnn_decoder(self_attended_decoded, 86 | final_context, hidden=context_rnn_state) 87 | rnn_output, context_attention, context_alignment, vocab_pointer_switch, rnn_state = decoder_outputs 88 | 89 | probs = self.probs(self.out, rnn_output, vocab_pointer_switch, 90 | context_attention, 91 | context_indices, 92 | oov_to_limited_idx) 93 | 94 | probs, targets = mask(answer_indices[:, 1:].contiguous(), probs.contiguous(), pad_idx=pad_idx) 95 | loss = F.nll_loss(probs.log(), targets) 96 | return loss, None 97 | else: 98 | return None, self.greedy(self_attended_context, final_context, 99 | context_indices, 100 | oov_to_limited_idx, rnn_state=context_rnn_state).data 101 | 102 | def reshape_rnn_state(self, h): 103 | return h.view(h.size(0) // 2, 2, h.size(1), h.size(2)) \ 104 | .transpose(1, 2).contiguous() \ 105 | .view(h.size(0) // 2, h.size(1), h.size(2) * 2).contiguous() 106 | 107 | def probs(self, generator, outputs, vocab_pointer_switches, 108 | context_attention, 109 | context_indices, 110 | oov_to_limited_idx): 111 | 112 | 113 | size = list(outputs.size()) 114 | 115 | size[-1] = self.generative_vocab_size 116 | scores = generator(outputs.view(-1, outputs.size(-1))).view(size) 117 | p_vocab = F.softmax(scores, dim=scores.dim()-1) 118 | scaled_p_vocab = vocab_pointer_switches.expand_as(p_vocab) * p_vocab 119 | 120 | effective_vocab_size = self.generative_vocab_size + len(oov_to_limited_idx) 121 | if self.generative_vocab_size < effective_vocab_size: 122 | size[-1] = effective_vocab_size - self.generative_vocab_size 123 | buff = scaled_p_vocab.new_full(size, EPSILON) 124 | scaled_p_vocab = torch.cat([scaled_p_vocab, buff], dim=buff.dim()-1) 125 | 126 | p_context_ptr = scaled_p_vocab.new_full(scaled_p_vocab.size(), EPSILON) 127 | p_context_ptr.scatter_add_(p_context_ptr.dim()-1, context_indices.unsqueeze(1).expand_as(context_attention), context_attention) 128 | scaled_p_context_ptr = (1 - vocab_pointer_switches).expand_as(p_context_ptr) * p_context_ptr 129 | 130 | probs = scaled_p_vocab + scaled_p_context_ptr 131 | return probs 132 | 133 | 134 | def greedy(self, self_attended_context, context, context_indices, oov_to_limited_idx, rnn_state=None): 135 | B, TC, C = context.size() 136 | T = self.args.max_output_length 137 | outs = context.new_full((B, T), self.field.decoder_stoi[''], dtype=torch.long) 138 | hiddens = [self_attended_context[0].new_zeros((B, T, C)) 139 | for l in range(len(self.self_attentive_decoder.layers) + 1)] 140 | hiddens[0] = hiddens[0] + positional_encodings_like(hiddens[0]) 141 | eos_yet = context.data.new(B).byte().zero_() 142 | 143 | rnn_output, context_alignment = None, None 144 | for t in range(T): 145 | if t == 0: 146 | embedding = self.decoder_embeddings( 147 | self_attended_context[-1].new_full((B, 1), self.field.vocab.stoi[''], dtype=torch.long), [1]*B) 148 | else: 149 | embedding = self.decoder_embeddings(outs[:, t - 1].unsqueeze(1), [1]*B) 150 | hiddens[0][:, t] = hiddens[0][:, t] + (math.sqrt(self.self_attentive_decoder.d_model) * embedding).squeeze(1) 151 | for l in range(len(self.self_attentive_decoder.layers)): 152 | hiddens[l + 1][:, t] = self.self_attentive_decoder.layers[l].feedforward( 153 | self.self_attentive_decoder.layers[l].attention( 154 | self.self_attentive_decoder.layers[l].selfattn(hiddens[l][:, t], hiddens[l][:, :t + 1], hiddens[l][:, :t + 1]) 155 | , self_attended_context[l], self_attended_context[l])) 156 | decoder_outputs = self.dual_ptr_rnn_decoder(hiddens[-1][:, t].unsqueeze(1), 157 | context, 158 | context_alignment=context_alignment, 159 | hidden=rnn_state, output=rnn_output) 160 | 161 | rnn_output, context_attention, context_alignment, vocab_pointer_switch, rnn_state = decoder_outputs 162 | probs = self.probs(self.out, rnn_output, vocab_pointer_switch, 163 | context_attention, 164 | context_indices, 165 | oov_to_limited_idx) 166 | pred_probs, preds = probs.max(-1) 167 | preds = preds.squeeze(1) 168 | eos_yet = eos_yet | (preds == self.field.decoder_stoi['']) 169 | outs[:, t] = preds.cpu().apply_(self.map_to_full) 170 | if eos_yet.all(): 171 | break 172 | return outs 173 | 174 | class DualPtrRNNDecoder(nn.Module): 175 | 176 | def __init__(self, d_in, d_hid, dropout=0.0, num_layers=1): 177 | super().__init__() 178 | self.d_hid = d_hid 179 | self.d_in = d_in 180 | self.num_layers = num_layers 181 | self.dropout = nn.Dropout(dropout) 182 | 183 | self.input_feed = True 184 | if self.input_feed: 185 | d_in += 1 * d_hid 186 | 187 | self.rnn = LSTMDecoder(self.num_layers, d_in, d_hid, dropout) 188 | self.context_attn = LSTMDecoderAttention(d_hid, dot=True) 189 | 190 | self.vocab_pointer_switch = nn.Sequential(Feedforward(2 * self.d_hid + d_in, 1), nn.Sigmoid()) 191 | 192 | def forward(self, input, context, output=None, hidden=None, context_alignment=None): 193 | context_output = output.squeeze(1) if output is not None else self.make_init_output(context) 194 | context_alignment = context_alignment if context_alignment is not None else self.make_init_output(context) 195 | 196 | context_outputs, vocab_pointer_switches, context_attentions, context_alignments = [], [], [], [] 197 | for emb_t in input.split(1, dim=1): 198 | emb_t = emb_t.squeeze(1) 199 | context_output = self.dropout(context_output) 200 | if self.input_feed: 201 | emb_t = torch.cat([emb_t, context_output], 1) 202 | dec_state, hidden = self.rnn(emb_t, hidden) 203 | context_output, context_attention, context_alignment = self.context_attn(dec_state, context) 204 | vocab_pointer_switch = self.vocab_pointer_switch(torch.cat([dec_state, context_output, emb_t], -1)) 205 | context_output = self.dropout(context_output) 206 | context_outputs.append(context_output) 207 | vocab_pointer_switches.append(vocab_pointer_switch) 208 | context_attentions.append(context_attention) 209 | context_alignments.append(context_alignment) 210 | context_outputs, vocab_pointer_switches, context_attention = [self.package_outputs(x) for x in [context_outputs, vocab_pointer_switches, context_attentions]] 211 | return context_outputs, context_attention, context_alignment, vocab_pointer_switches, hidden 212 | 213 | 214 | def applyMasks(self, context_mask): 215 | self.context_attn.applyMasks(context_mask) 216 | 217 | def make_init_output(self, context): 218 | batch_size = context.size(0) 219 | h_size = (batch_size, self.d_hid) 220 | return context.new_zeros(h_size) 221 | 222 | def package_outputs(self, outputs): 223 | outputs = torch.stack(outputs, dim=1) 224 | return outputs 225 | -------------------------------------------------------------------------------- /multiprocess/__init__.py: -------------------------------------------------------------------------------- 1 | from .multiprocess import Multiprocess 2 | from .distributed_data_parallel import DistributedDataParallel 3 | -------------------------------------------------------------------------------- /multiprocess/distributed_data_parallel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 3 | import torch.distributed as dist 4 | from torch.nn.modules import Module 5 | 6 | 7 | 8 | class DistributedDataParallel(Module): 9 | 10 | def __init__(self, module): 11 | super(DistributedDataParallel, self).__init__() 12 | self.warn_on_half = True#$ True if dist._backend == dist.dist_backend.GLOO else False 13 | 14 | self.module = module 15 | 16 | for p in self.module.state_dict().values(): 17 | if torch.is_tensor(p): 18 | dist.broadcast(p, 0) 19 | 20 | def allreduce_params(): 21 | if(self.needs_reduction): 22 | self.needs_reduction = False 23 | buckets = {} 24 | for param in self.module.parameters(): 25 | if param.requires_grad and param.grad is not None: 26 | tp = type(param.data) 27 | if tp not in buckets: 28 | buckets[tp] = [] 29 | buckets[tp].append(param) 30 | if self.warn_on_half: 31 | if torch.cuda.HalfTensor in buckets: 32 | print("WARNING: gloo dist backend for half parameters may be extremely slow." + 33 | " It is recommended to use the NCCL backend in this case.") 34 | self.warn_on_half = False 35 | 36 | for tp in buckets: 37 | bucket = buckets[tp] 38 | grads = [param.grad.data for param in bucket] 39 | coalesced = _flatten_dense_tensors(grads) 40 | dist.all_reduce(coalesced) 41 | coalesced /= dist.get_world_size() 42 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 43 | buf.copy_(synced) 44 | 45 | for param in list(self.module.parameters()): 46 | if param.requires_grad: 47 | def allreduce_hook(*unused): 48 | param._execution_engine.queue_callback(allreduce_params) 49 | param.register_hook(allreduce_hook) 50 | 51 | def forward(self, *inputs, **kwargs): 52 | self.needs_reduction = True 53 | return self.module(*inputs, **kwargs) 54 | -------------------------------------------------------------------------------- /multiprocess/multiprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from torch.multiprocessing import Process 5 | 6 | 7 | class Multiprocess(): 8 | 9 | def __init__(self, fn, args): 10 | self.fn = fn 11 | self.args = args 12 | self.world_size = args.world_size 13 | 14 | if os.path.isfile(args.dist_sync_file): 15 | os.remove(args.dist_sync_file) 16 | 17 | def run(self, runtime_args): 18 | self.start(runtime_args) 19 | self.join() 20 | 21 | def start(self, runtime_args): 22 | self.processes = [] 23 | for rank in range(self.world_size): 24 | self.processes.append(Process(target=self.init_process, args=(rank, self.fn, self.args, runtime_args))) 25 | self.processes[-1].start() 26 | 27 | def init_process(self, rank, fn, args, runtime_args): 28 | torch.distributed.init_process_group(world_size=self.world_size, 29 | init_method='file://'+args.dist_sync_file, 30 | backend=args.backend, 31 | rank=rank) 32 | fn(args, runtime_args, rank, self.world_size) 33 | 34 | def join(self): 35 | for p in self.processes: 36 | p.join() 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | from text.torchtext.datasets.generic import Query 4 | from text import torchtext 5 | from argparse import ArgumentParser 6 | import ujson as json 7 | import torch 8 | import numpy as np 9 | import random 10 | from pprint import pformat 11 | 12 | from util import get_splits, set_seed, preprocess_examples 13 | from metrics import compute_metrics 14 | import models 15 | 16 | 17 | def get_all_splits(args, new_vocab): 18 | splits = [] 19 | for task in args.tasks: 20 | print(f'Loading {task}') 21 | kwargs = {} 22 | if not 'train' in args.evaluate: 23 | kwargs['train'] = None 24 | if not 'valid' in args.evaluate: 25 | kwargs['validation'] = None 26 | if not 'test' in args.evaluate: 27 | kwargs['test'] = None 28 | s = get_splits(args, task, new_vocab, **kwargs)[0] 29 | preprocess_examples(args, [task], [s], new_vocab, train=False) 30 | splits.append(s) 31 | return splits 32 | 33 | 34 | def prepare_data(args, FIELD): 35 | new_vocab = torchtext.data.ReversibleField(batch_first=True, init_token='', eos_token='', lower=args.lower, include_lengths=True) 36 | splits = get_all_splits(args, new_vocab) 37 | new_vocab.build_vocab(*splits) 38 | print(f'Vocabulary has {len(FIELD.vocab)} tokens from training') 39 | args.max_generative_vocab = min(len(FIELD.vocab), args.max_generative_vocab) 40 | FIELD.append_vocab(new_vocab) 41 | print(f'Vocabulary has expanded to {len(FIELD.vocab)} tokens') 42 | 43 | char_vectors = torchtext.vocab.CharNGram(cache=args.embeddings) 44 | glove_vectors = torchtext.vocab.GloVe(cache=args.embeddings) 45 | vectors = [char_vectors, glove_vectors] 46 | FIELD.vocab.load_vectors(vectors, True) 47 | FIELD.decoder_to_vocab = {idx: FIELD.vocab.stoi[word] for idx, word in enumerate(FIELD.decoder_itos)} 48 | FIELD.vocab_to_decoder = {idx: FIELD.decoder_stoi[word] for idx, word in enumerate(FIELD.vocab.itos) if word in FIELD.decoder_stoi} 49 | splits = get_all_splits(args, FIELD) 50 | 51 | return FIELD, splits 52 | 53 | 54 | def to_iter(data, bs, device): 55 | Iterator = torchtext.data.Iterator 56 | it = Iterator(data, batch_size=bs, 57 | device=device, batch_size_fn=None, 58 | train=False, repeat=False, sort=None, 59 | shuffle=None, reverse=False) 60 | 61 | return it 62 | 63 | 64 | def run(args, field, val_sets, model): 65 | device = set_seed(args) 66 | print(f'Preparing iterators') 67 | if len(args.val_batch_size) == 1 and len(val_sets) > 1: 68 | args.val_batch_size *= len(val_sets) 69 | iters = [(name, to_iter(x, bs, device)) for name, x, bs in zip(args.tasks, val_sets, args.val_batch_size)] 70 | 71 | def mult(ps): 72 | r = 0 73 | for p in ps: 74 | this_r = 1 75 | for s in p.size(): 76 | this_r *= s 77 | r += this_r 78 | return r 79 | params = list(filter(lambda p: p.requires_grad, model.parameters())) 80 | num_param = mult(params) 81 | print(f'{args.model} has {num_param:,} parameters') 82 | model.to(device) 83 | 84 | decaScore = [] 85 | model.eval() 86 | with torch.no_grad(): 87 | for task, it in iters: 88 | print(task) 89 | prediction_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.txt') 90 | answer_file_name = os.path.join(os.path.splitext(args.best_checkpoint)[0], args.evaluate, task + '.gold.txt') 91 | results_file_name = answer_file_name.replace('gold', 'results') 92 | if 'sql' in task or 'squad' in task: 93 | ids_file_name = answer_file_name.replace('gold', 'ids') 94 | if os.path.exists(prediction_file_name): 95 | print('** ', prediction_file_name, ' already exists -- this is where predictions are stored **') 96 | if args.overwrite: 97 | print('**** overwriting ', prediction_file_name, ' ****') 98 | if os.path.exists(answer_file_name): 99 | print('** ', answer_file_name, ' already exists -- this is where ground truth answers are stored **') 100 | if args.overwrite: 101 | print('**** overwriting ', answer_file_name, ' ****') 102 | if os.path.exists(results_file_name): 103 | print('** ', results_file_name, ' already exists -- this is where metrics are stored **') 104 | if args.overwrite: 105 | print('**** overwriting ', results_file_name, ' ****') 106 | else: 107 | with open(results_file_name) as results_file: 108 | if not args.silent: 109 | for l in results_file: 110 | print(l) 111 | metrics = json.loads(results_file.readlines()[0]) 112 | decaScore.append(metrics[args.task_to_metric[task]]) 113 | continue 114 | 115 | for x in [prediction_file_name, answer_file_name, results_file_name]: 116 | os.makedirs(os.path.dirname(x), exist_ok=True) 117 | 118 | if not os.path.exists(prediction_file_name) or args.overwrite: 119 | with open(prediction_file_name, 'w') as prediction_file: 120 | predictions = [] 121 | ids = [] 122 | for batch_idx, batch in enumerate(it): 123 | _, p = model(batch) 124 | p = field.reverse(p) 125 | for i, pp in enumerate(p): 126 | if 'sql' in task: 127 | ids.append(int(batch.wikisql_id[i])) 128 | if 'squad' in task: 129 | ids.append(it.dataset.q_ids[int(batch.squad_id[i])]) 130 | prediction_file.write(pp + '\n') 131 | predictions.append(pp) 132 | if 'sql' in task: 133 | with open(ids_file_name, 'w') as id_file: 134 | for i in ids: 135 | id_file.write(json.dumps(i) + '\n') 136 | if 'squad' in task: 137 | with open(ids_file_name, 'w') as id_file: 138 | for i in ids: 139 | id_file.write(i + '\n') 140 | else: 141 | with open(prediction_file_name) as prediction_file: 142 | predictions = [x.strip() for x in prediction_file.readlines()] 143 | if 'sql' in task or 'squad' in task: 144 | with open(ids_file_name) as id_file: 145 | ids = [int(x.strip()) for x in id_file.readlines()] 146 | 147 | def from_all_answers(an): 148 | return [it.dataset.all_answers[sid] for sid in an.tolist()] 149 | 150 | if not os.path.exists(answer_file_name) or args.overwrite: 151 | with open(answer_file_name, 'w') as answer_file: 152 | answers = [] 153 | for batch_idx, batch in enumerate(it): 154 | if hasattr(batch, 'wikisql_id'): 155 | a = from_all_answers(batch.wikisql_id.data.cpu()) 156 | elif hasattr(batch, 'squad_id'): 157 | a = from_all_answers(batch.squad_id.data.cpu()) 158 | elif hasattr(batch, 'woz_id'): 159 | a = from_all_answers(batch.woz_id.data.cpu()) 160 | else: 161 | a = field.reverse(batch.answer.data) 162 | for aa in a: 163 | answers.append(aa) 164 | answer_file.write(json.dumps(aa) + '\n') 165 | else: 166 | with open(answer_file_name) as answer_file: 167 | answers = [json.loads(x.strip()) for x in answer_file.readlines()] 168 | 169 | if len(answers) > 0: 170 | if not os.path.exists(results_file_name) or args.overwrite: 171 | metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task or args.bleu, dialogue='woz' in task, 172 | rouge='cnn' in task or 'dailymail' in task or args.rouge, logical_form='sql' in task, corpus_f1='zre' in task, args=args) 173 | with open(results_file_name, 'w') as results_file: 174 | results_file.write(json.dumps(metrics) + '\n') 175 | else: 176 | with open(results_file_name) as results_file: 177 | metrics = json.loads(results_file.readlines()[0]) 178 | 179 | if not args.silent: 180 | for i, (p, a) in enumerate(zip(predictions, answers)): 181 | print(f'Prediction {i+1}: {p}\nAnswer {i+1}: {a}\n') 182 | print(metrics) 183 | decaScore.append(metrics[args.task_to_metric[task]]) 184 | 185 | print(f'Evaluated Tasks:\n') 186 | for i, (task, _) in enumerate(iters): 187 | print(f'{task}: {decaScore[i]}') 188 | print(f'-------------------') 189 | print(f'DecaScore: {sum(decaScore)}\n') 190 | print(f'\nSummary: | {sum(decaScore)} | {" | ".join([str(x) for x in decaScore])} |\n') 191 | 192 | 193 | def get_args(): 194 | parser = ArgumentParser() 195 | parser.add_argument('--path', required=True) 196 | parser.add_argument('--evaluate', type=str, required=True) 197 | parser.add_argument('--tasks', default=['squad', 'iwslt.en.de', 'cnn_dailymail', 'multinli.in.out', 'sst', 'srl', 'zre', 'woz.en', 'wikisql', 'schema'], nargs='+') 198 | parser.add_argument('--devices', default=[0], nargs='+', type=int, help='a list of devices that can be used (multi-gpu currently WIP)') 199 | parser.add_argument('--seed', default=123, type=int, help='Random seed.') 200 | parser.add_argument('--data', default='/decaNLP/.data/', type=str, help='where to load data from.') 201 | parser.add_argument('--embeddings', default='/decaNLP/.embeddings', type=str, help='where to save embeddings.') 202 | parser.add_argument('--checkpoint_name') 203 | parser.add_argument('--bleu', action='store_true', help='whether to use the bleu metric (always on for iwslt)') 204 | parser.add_argument('--rouge', action='store_true', help='whether to use the bleu metric (always on for cnn, dailymail, and cnn_dailymail)') 205 | parser.add_argument('--overwrite', action='store_true', help='whether to overwrite previously written predictions') 206 | parser.add_argument('--silent', action='store_true', help='whether to print predictions to stdout') 207 | 208 | args = parser.parse_args() 209 | 210 | with open(os.path.join(args.path, 'config.json')) as config_file: 211 | config = json.load(config_file) 212 | retrieve = ['model', 213 | 'transformer_layers', 'rnn_layers', 'transformer_hidden', 214 | 'dimension', 'load', 'max_val_context_length', 'val_batch_size', 215 | 'transformer_heads', 'max_output_length', 'max_generative_vocab', 216 | 'lower', 'cove', 'intermediate_cove', 'elmo', 'glove_and_char'] 217 | for r in retrieve: 218 | if r in config: 219 | setattr(args, r, config[r]) 220 | elif 'cove' in r: 221 | setattr(args, r, False) 222 | elif 'elmo' in r: 223 | setattr(args, r, [-1]) 224 | elif 'glove_and_char' in r: 225 | setattr(args, r, True) 226 | else: 227 | setattr(args, r, None) 228 | args.dropout_ratio = 0.0 229 | 230 | args.task_to_metric = {'cnn_dailymail': 'avg_rouge', 231 | 'iwslt.en.de': 'bleu', 232 | 'multinli.in.out': 'em', 233 | 'squad': 'nf1', 234 | 'srl': 'nf1', 235 | 'sst': 'em', 236 | 'wikisql': 'lfem', 237 | 'woz.en': 'joint_goal_em', 238 | 'zre': 'corpus_f1', 239 | 'schema': 'em'} 240 | 241 | if not args.checkpoint_name is None: 242 | args.best_checkpoint = os.path.join(args.path, args.checkpoint_name) 243 | else: 244 | assert os.path.exists(os.path.join(args.path, 'process_0.log')) 245 | args.best_checkpoint = get_best(args) 246 | 247 | return args 248 | 249 | 250 | def get_best(args): 251 | with open(os.path.join(args.path, 'config.json')) as f: 252 | save_every = json.load(f)['save_every'] 253 | 254 | with open(os.path.join(args.path, 'process_0.log')) as f: 255 | lines = f.readlines() 256 | 257 | best_score = 0 258 | best_it = 0 259 | deca_scores = {} 260 | for l in lines: 261 | if 'val' in l: 262 | try: 263 | task = l.split('val_')[1].split(':')[0] 264 | except Exception as e: 265 | print(e) 266 | continue 267 | it = int(l.split('iteration_')[1].split(':')[0]) 268 | metric = args.task_to_metric[task] 269 | score = float(l.split(metric+'_')[1].split(':')[0]) 270 | if it in deca_scores: 271 | deca_scores[it]['deca'] += score 272 | deca_scores[it][metric] = score 273 | else: 274 | deca_scores[it] = {'deca': score, metric: score} 275 | if deca_scores[it]['deca'] > best_score: 276 | best_score = deca_scores[it]['deca'] 277 | best_it = it 278 | print(best_it) 279 | print(best_score) 280 | return os.path.join(args.path, f'iteration_{int(best_it)}.pth') 281 | 282 | 283 | if __name__ == '__main__': 284 | args = get_args() 285 | print(f'Arguments:\n{pformat(vars(args))}') 286 | 287 | np.random.seed(args.seed) 288 | random.seed(args.seed) 289 | torch.manual_seed(args.seed) 290 | torch.cuda.manual_seed(args.seed) 291 | 292 | print(f'Loading from {args.best_checkpoint}') 293 | save_dict = torch.load(args.best_checkpoint) 294 | field = save_dict['field'] 295 | print(f'Initializing Model') 296 | Model = getattr(models, args.model) 297 | model = Model(field, args) 298 | model_dict = save_dict['model_state_dict'] 299 | backwards_compatible_cove_dict = {} 300 | for k, v in model_dict.items(): 301 | if 'cove.rnn.' in k: 302 | k = k.replace('cove.rnn.', 'cove.rnn1.') 303 | backwards_compatible_cove_dict[k] = v 304 | model_dict = backwards_compatible_cove_dict 305 | model.load_state_dict(model_dict) 306 | field, splits = prepare_data(args, field) 307 | model.set_embeddings(field.vocab.vectors) 308 | 309 | run(args, field, splits, model) 310 | -------------------------------------------------------------------------------- /text/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E402, E722 3 | max-line-length = 90 4 | -------------------------------------------------------------------------------- /text/.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | dist/ 3 | torchtext.egg-info/ 4 | *.txt 5 | *.zip 6 | */**/__pycache__ 7 | */**/*.pyc 8 | */**/*~ 9 | *~ 10 | .cache 11 | .vector_cache 12 | 13 | # Documentation 14 | docs/build 15 | 16 | # Download folder 17 | .data 18 | -------------------------------------------------------------------------------- /text/.travis.yml: -------------------------------------------------------------------------------- 1 | dist: trusty 2 | 3 | language: python 4 | 5 | cache: 6 | directories: 7 | - /home/travis/download 8 | - /home/travis/.cache/pip 9 | 10 | # This matrix tests that the code works on Python 2.7, 11 | # 2.7.9, 3.5, 3.6 (same versions as PyTorch CI), and passes lint. 12 | matrix: 13 | fast_finish: true 14 | include: 15 | - env: PYTHON_VERSION="2.7" COVERAGE="true" 16 | - env: PYTHON_VERSION="2.7.9" COVERAGE="true" 17 | - env: PYTHON_VERSION="3.5" COVERAGE="true" 18 | - env: PYTHON_VERSION="3.6" COVERAGE="true" 19 | - env: PYTHON_VERSION="2.7" RUN_FLAKE8="true" SKIP_TESTS="true" 20 | - env: PYTHON_VERSION="3.6" RUN_FLAKE8="true" SKIP_TESTS="true" 21 | - env: PYTHON_VERSION="2.7.9" RUN_SLOW="true" COVERAGE="true" 22 | sudo: required 23 | - env: PYTHON_VERSION="3.6" RUN_SLOW="true" COVERAGE="true" 24 | sudo: required 25 | allow_failures: 26 | - env: PYTHON_VERSION="2.7.9" RUN_SLOW="true" COVERAGE="true" 27 | - env: PYTHON_VERSION="3.6" RUN_SLOW="true" COVERAGE="true" 28 | 29 | notifications: 30 | email: false 31 | 32 | install: source build_tools/travis/install.sh 33 | script: bash build_tools/travis/test_script.sh 34 | after_success: source build_tools/travis/after_success.sh 35 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) James Bradbury and Soumith Chintala 2016, 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /text/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/pytorch/text.svg?branch=master)](https://travis-ci.org/pytorch/text) 2 | [![codecov](https://codecov.io/gh/pytorch/text/branch/master/graph/badge.svg)](https://codecov.io/gh/pytorch/text) 3 | 4 | # torchtext 5 | 6 | This repository consists of: 7 | 8 | - [torchtext.data](#data) : Generic data loaders, abstractions, and iterators for text (including vocabulary and word vectors) 9 | - [torchtext.datasets](#datasets) : Pre-built loaders for common NLP datasets 10 | 11 | # Data 12 | 13 | The data module provides the following: 14 | 15 | - Ability to describe declaratively how to load a custom NLP dataset that's in a "normal" format: 16 | ```python 17 | pos = data.TabularDataset( 18 | path='data/pos/pos_wsj_train.tsv', format='tsv', 19 | fields=[('text', data.Field()), 20 | ('labels', data.Field())]) 21 | 22 | sentiment = data.TabularDataset( 23 | path='data/sentiment/train.json', format='json', 24 | fields={'sentence_tokenized': ('text', data.Field(sequential=True)), 25 | 'sentiment_gold': ('labels', data.Field(sequential=False))}) 26 | ``` 27 | - Ability to define a preprocessing pipeline: 28 | ```python 29 | src = data.Field(tokenize=my_custom_tokenizer) 30 | trg = data.Field(tokenize=my_custom_tokenizer) 31 | mt_train = datasets.TranslationDataset( 32 | path='data/mt/wmt16-ende.train', exts=('.en', '.de'), 33 | fields=(src, trg)) 34 | ``` 35 | - Batching, padding, and numericalizing (including building a vocabulary object): 36 | ```python 37 | # continuing from above 38 | mt_dev = data.TranslationDataset( 39 | path='data/mt/newstest2014', exts=('.en', '.de'), 40 | fields=(src, trg)) 41 | src.build_vocab(mt_train, max_size=80000) 42 | trg.build_vocab(mt_train, max_size=40000) 43 | # mt_dev shares the fields, so it shares their vocab objects 44 | 45 | train_iter = data.BucketIterator( 46 | dataset=mt_train, batch_size=32, 47 | sort_key=lambda x: data.interleave_keys(len(x.src), len(x.trg))) 48 | # usage 49 | >>>next(iter(train_iter)) 50 | 51 | ``` 52 | - Wrapper for dataset splits (train, validation, test): 53 | ```python 54 | TEXT = data.Field() 55 | LABELS = data.Field() 56 | 57 | train, val, test = data.TabularDataset.splits( 58 | path='/data/pos_wsj/pos_wsj', train='_train.tsv', 59 | validation='_dev.tsv', test='_test.tsv', format='tsv', 60 | fields=[('text', TEXT), ('labels', LABELS)]) 61 | 62 | train_iter, val_iter, test_iter = data.BucketIterator.splits( 63 | (train, val, test), batch_sizes=(16, 256, 256), 64 | sort_key=lambda x: len(x.text), device=0) 65 | 66 | TEXT.build_vocab(train) 67 | LABELS.build_vocab(train) 68 | ``` 69 | 70 | # Datasets 71 | 72 | The datasets module currently contains: 73 | 74 | - Sentiment analysis: SST and IMDb 75 | - Question classification: TREC 76 | - Entailment: SNLI 77 | - Language modeling: abstract class + WikiText-2 78 | - Machine translation: abstract class + Multi30k, IWSLT, WMT14 79 | - Sequence tagging (e.g. POS/NER): abstract class + UDPOS 80 | 81 | Others are planned or a work in progress: 82 | 83 | - Question answering: SQuAD 84 | 85 | See the "test" directory for examples of dataset usage. 86 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from . import torchtext 2 | 3 | 4 | __all__ = ['torchtext'] 5 | -------------------------------------------------------------------------------- /text/build_tools/travis/after_success.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is meant to be called by the "after_success" step defined in 3 | # .travis.yml. See http://docs.travis-ci.com/ for more details. 4 | 5 | set -e 6 | 7 | if [[ "$COVERAGE" == "true" ]]; then 8 | # Ignore codecov failures as the codecov server is not 9 | # very reliable but we don't want travis to report a failure 10 | # in the github UI just because the coverage report failed to 11 | # be published. 12 | codecov || echo "codecov upload failed" 13 | fi 14 | -------------------------------------------------------------------------------- /text/build_tools/travis/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is meant to be called by the "install" step defined in 3 | # .travis.yml. See http://docs.travis-ci.com/ for more details. 4 | # The behavior of the script is controlled by environment variabled defined 5 | # in the .travis.yml in the top level folder of the project. 6 | 7 | set -e 8 | 9 | echo 'List files from cached directories' 10 | if [ -d $HOME/download ]; then 11 | echo 'download:' 12 | ls $HOME/download 13 | fi 14 | if [ -d $HOME/.cache/pip ]; then 15 | echo 'pip:' 16 | ls $HOME/.cache/pip 17 | fi 18 | 19 | # Deactivate the travis-provided virtual environment and setup a 20 | # conda-based environment instead 21 | deactivate 22 | 23 | # Add the miniconda bin directory to $PATH 24 | export PATH=/home/travis/miniconda3/bin:$PATH 25 | echo $PATH 26 | 27 | # Use the miniconda installer for setup of conda itself 28 | pushd . 29 | cd 30 | mkdir -p download 31 | cd download 32 | if [[ ! -f /home/travis/miniconda3/bin/activate ]] 33 | then 34 | if [[ ! -f miniconda.sh ]] 35 | then 36 | wget http://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 37 | -O miniconda.sh 38 | fi 39 | chmod +x miniconda.sh && ./miniconda.sh -b -f 40 | conda update --yes conda 41 | echo "Creating environment to run tests in." 42 | conda create -n testenv --yes python="$PYTHON_VERSION" 43 | fi 44 | cd .. 45 | popd 46 | 47 | # Activate the python environment we created. 48 | source activate testenv 49 | 50 | # Install requirements via pip in our conda environment 51 | pip install -r requirements.txt 52 | 53 | # Install the following only if running tests 54 | if [[ "$SKIP_TESTS" != "true" ]]; then 55 | # SpaCy English models 56 | python -m spacy download en 57 | 58 | # NLTK data needed for Moses tokenizer 59 | python -m nltk.downloader perluniprops nonbreaking_prefixes 60 | 61 | # PyTorch 62 | conda install --yes pytorch torchvision -c soumith 63 | fi 64 | -------------------------------------------------------------------------------- /text/build_tools/travis/test_script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is meant to be called by the "script" step defined in 3 | # .travis.yml. See http://docs.travis-ci.com/ for more details. 4 | # The behavior of the script is controlled by environment variabled defined 5 | # in the .travis.yml in the top level folder of the project. 6 | 7 | set -e 8 | 9 | python --version 10 | 11 | run_tests() { 12 | if [[ "$RUN_SLOW" == "true" ]]; then 13 | TEST_CMD="py.test --runslow -s -v --cov=torchtext --durations=20" 14 | else 15 | TEST_CMD="py.test -v --cov=torchtext --durations=20" 16 | fi 17 | $TEST_CMD 18 | } 19 | 20 | if [[ "$RUN_FLAKE8" == "true" ]]; then 21 | flake8 22 | fi 23 | 24 | if [[ "$SKIP_TESTS" != "true" ]]; then 25 | run_tests 26 | fi 27 | -------------------------------------------------------------------------------- /text/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 0 3 | round: down 4 | status: 5 | patch: 6 | default: 7 | target: 90 8 | project: 9 | default: 10 | threshold: 1% 11 | changes: false 12 | comment: false 13 | ignore: 14 | - "test/" 15 | -------------------------------------------------------------------------------- /text/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = torchtext 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | docset: html 16 | doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/pytorch-logo-flame.png --enable-js --online-redirect-url http://pytorch.org/text/ --force $(BUILDDIR)/html/ 17 | 18 | # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. 19 | cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png 20 | convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png 21 | 22 | .PHONY: help Makefile 23 | 24 | # Catch-all target: route all unknown targets to Sphinx using the new 25 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 26 | %: Makefile 27 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 28 | -------------------------------------------------------------------------------- /text/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=torchtext 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /text/docs/source/_static/css/pytorch_theme.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 3 | } 4 | 5 | /* Default header fonts are ugly */ 6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 8 | } 9 | 10 | /* Use white for docs background */ 11 | .wy-side-nav-search { 12 | background-color: #fff; 13 | } 14 | 15 | .wy-nav-content-wrap, .wy-menu li.current > a { 16 | background-color: #fff; 17 | } 18 | 19 | @media screen and (min-width: 1400px) { 20 | .wy-nav-content-wrap { 21 | background-color: rgba(0, 0, 0, 0.0470588); 22 | } 23 | 24 | .wy-nav-content { 25 | background-color: #fff; 26 | } 27 | } 28 | 29 | /* Fixes for mobile */ 30 | .wy-nav-top { 31 | background-color: #fff; 32 | background-image: url('../img/pytorch-logo-dark.svg'); 33 | background-repeat: no-repeat; 34 | background-position: center; 35 | padding: 0; 36 | margin: 0.4045em 0.809em; 37 | color: #333; 38 | } 39 | 40 | .wy-nav-top > a { 41 | display: none; 42 | } 43 | 44 | @media screen and (max-width: 768px) { 45 | .wy-side-nav-search>a img.logo { 46 | height: 60px; 47 | } 48 | } 49 | 50 | /* This is needed to ensure that logo above search scales properly */ 51 | .wy-side-nav-search a { 52 | display: block; 53 | } 54 | 55 | /* This ensures that multiple constructors will remain in separate lines. */ 56 | .rst-content dl:not(.docutils) dt { 57 | display: table; 58 | } 59 | 60 | /* Use our red for literals (it's very similar to the original color) */ 61 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 62 | color: #F05732; 63 | } 64 | 65 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, 66 | .rst-content code.xref, a .rst-content tt, a .rst-content code { 67 | color: #404040; 68 | } 69 | 70 | /* Change link colors (except for the menu) */ 71 | 72 | a { 73 | color: #F05732; 74 | } 75 | 76 | a:hover { 77 | color: #F05732; 78 | } 79 | 80 | 81 | a:visited { 82 | color: #D44D2C; 83 | } 84 | 85 | .wy-menu a { 86 | color: #b3b3b3; 87 | } 88 | 89 | .wy-menu a:hover { 90 | color: #b3b3b3; 91 | } 92 | 93 | /* Default footer text is quite big */ 94 | footer { 95 | font-size: 80%; 96 | } 97 | 98 | footer .rst-footer-buttons { 99 | font-size: 125%; /* revert footer settings - 1/80% = 125% */ 100 | } 101 | 102 | footer p { 103 | font-size: 100%; 104 | } 105 | 106 | /* For hidden headers that appear in TOC tree */ 107 | /* see http://stackoverflow.com/a/32363545/3343043 */ 108 | .rst-content .hidden-section { 109 | display: none; 110 | } 111 | 112 | nav .hidden-section { 113 | display: inherit; 114 | } 115 | 116 | .wy-side-nav-search>div.version { 117 | color: #000; 118 | } 119 | -------------------------------------------------------------------------------- /text/docs/source/_static/img/pytorch-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/text/docs/source/_static/img/pytorch-logo-dark.png -------------------------------------------------------------------------------- /text/docs/source/_static/img/pytorch-logo-dark.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 10 | 13 | 14 | 16 | 17 | 18 | 20 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /text/docs/source/_static/img/pytorch-logo-flame.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/text/docs/source/_static/img/pytorch-logo-flame.png -------------------------------------------------------------------------------- /text/docs/source/_static/img/pytorch-logo-flame.svg: -------------------------------------------------------------------------------- 1 | 2 | image/svg+xml -------------------------------------------------------------------------------- /text/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # torchtext documentation build configuration file, created by 5 | # sphinx-quickstart on Thu Nov 16 01:05:05 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | # import os 21 | # import sys 22 | # sys.path.insert(0, os.path.abspath('.')) 23 | import torchtext 24 | import sphinx_rtd_theme 25 | 26 | # -- General configuration ------------------------------------------------ 27 | 28 | # If your documentation needs a minimal Sphinx version, state it here. 29 | # 30 | # needs_sphinx = '1.0' 31 | 32 | # Add any Sphinx extension module names here, as strings. They can be 33 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 34 | # ones. 35 | extensions = [ 36 | 'sphinx.ext.autodoc', 37 | 'sphinx.ext.autosummary', 38 | 'sphinx.ext.doctest', 39 | 'sphinx.ext.intersphinx', 40 | 'sphinx.ext.todo', 41 | 'sphinx.ext.coverage', 42 | 'sphinx.ext.mathjax', 43 | 'sphinx.ext.napoleon', 44 | 'sphinx.ext.viewcode', 45 | ] 46 | 47 | napoleon_use_ivar = True 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['_templates'] 51 | 52 | # The suffix(es) of source filenames. 53 | # You can specify multiple suffix as a list of string: 54 | # 55 | # source_suffix = ['.rst', '.md'] 56 | source_suffix = '.rst' 57 | 58 | # The master toctree document. 59 | master_doc = 'index' 60 | 61 | # General information about the project. 62 | project = 'torchtext' 63 | copyright = '2017, Torch Contributors' 64 | author = 'Torch Contributors' 65 | 66 | # The version info for the project you're documenting, acts as replacement for 67 | # |version| and |release|, also used in various other places throughout the 68 | # built documents. 69 | # 70 | # The short X.Y version. 71 | version = torchtext.__version__ 72 | # The full version, including alpha/beta/rc tags. 73 | release = torchtext.__version__ 74 | 75 | # The language for content autogenerated by Sphinx. Refer to documentation 76 | # for a list of supported languages. 77 | # 78 | # This is also used if you do content translation via gettext catalogs. 79 | # Usually you set "language" from the command line for these cases. 80 | language = None 81 | 82 | # List of patterns, relative to source directory, that match files and 83 | # directories to ignore when looking for source files. 84 | # This patterns also effect to html_static_path and html_extra_path 85 | exclude_patterns = [] 86 | 87 | # The name of the Pygments (syntax highlighting) style to use. 88 | pygments_style = 'sphinx' 89 | 90 | # If true, `todo` and `todoList` produce output, else they produce nothing. 91 | todo_include_todos = True 92 | 93 | 94 | # -- Options for HTML output ---------------------------------------------- 95 | 96 | # The theme to use for HTML and HTML Help pages. See the documentation for 97 | # a list of builtin themes. 98 | # 99 | html_theme = 'sphinx_rtd_theme' 100 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 101 | 102 | # Theme options are theme-specific and customize the look and feel of a theme 103 | # further. For a list of options available for each theme, see the 104 | # documentation. 105 | # 106 | html_theme_options = { 107 | 'collapse_navigation': False, 108 | 'display_version': True, 109 | 'logo_only': True, 110 | } 111 | 112 | html_logo = '_static/img/pytorch-logo-dark.svg' 113 | 114 | # Add any paths that contain custom static files (such as style sheets) here, 115 | # relative to this directory. They are copied after the builtin static files, 116 | # so a file named "default.css" will overwrite the builtin "default.css". 117 | html_static_path = ['_static'] 118 | 119 | html_context = { 120 | 'css_files': [ 121 | 'https://fonts.googleapis.com/css?family=Lato', 122 | '_static/css/pytorch_theme.css' 123 | ], 124 | } 125 | 126 | 127 | # -- Options for HTMLHelp output ------------------------------------------ 128 | 129 | # Output file base name for HTML help builder. 130 | htmlhelp_basename = 'PyTorchdoc' 131 | 132 | 133 | # -- Options for LaTeX output --------------------------------------------- 134 | 135 | latex_elements = { 136 | # The paper size ('letterpaper' or 'a4paper'). 137 | # 138 | # 'papersize': 'letterpaper', 139 | 140 | # The font size ('10pt', '11pt' or '12pt'). 141 | # 142 | # 'pointsize': '10pt', 143 | 144 | # Additional stuff for the LaTeX preamble. 145 | # 146 | # 'preamble': '', 147 | 148 | # Latex figure (float) alignment 149 | # 150 | # 'figure_align': 'htbp', 151 | } 152 | 153 | # Grouping the document tree into LaTeX files. List of tuples 154 | # (source start file, target name, title, 155 | # author, documentclass [howto, manual, or own class]). 156 | latex_documents = [ 157 | (master_doc, 'pytorch.tex', 'torchtext Documentation', 158 | 'Torch Contributors', 'manual'), 159 | ] 160 | 161 | 162 | # -- Options for manual page output --------------------------------------- 163 | 164 | # One entry per manual page. List of tuples 165 | # (source start file, name, description, authors, manual section). 166 | man_pages = [ 167 | (master_doc, 'torchtext', 'torchtext Documentation', 168 | [author], 1) 169 | ] 170 | 171 | 172 | # -- Options for Texinfo output ------------------------------------------- 173 | 174 | # Grouping the document tree into Texinfo files. List of tuples 175 | # (source start file, target name, title, author, 176 | # dir menu entry, description, category) 177 | texinfo_documents = [ 178 | (master_doc, 'torchtext', 'torchtext Documentation', 179 | author, 'torchtext', 'One line description of project.', 180 | 'Miscellaneous'), 181 | ] 182 | 183 | 184 | # Example configuration for intersphinx: refer to the Python standard library. 185 | intersphinx_mapping = { 186 | 'python': ('https://docs.python.org/', None), 187 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 188 | } 189 | 190 | # -- A patch that prevents Sphinx from cross-referencing ivar tags ------- 191 | # See http://stackoverflow.com/a/41184353/3343043 192 | 193 | from docutils import nodes 194 | from sphinx.util.docfields import TypedField 195 | from sphinx import addnodes 196 | 197 | 198 | def patched_make_field(self, types, domain, items, **kw): 199 | # `kw` catches `env=None` needed for newer sphinx while maintaining 200 | # backwards compatibility when passed along further down! 201 | 202 | # type: (List, unicode, Tuple) -> nodes.field 203 | def handle_item(fieldarg, content): 204 | par = nodes.paragraph() 205 | par += addnodes.literal_strong('', fieldarg) # Patch: this line added 206 | # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, 207 | # addnodes.literal_strong)) 208 | if fieldarg in types: 209 | par += nodes.Text(' (') 210 | # NOTE: using .pop() here to prevent a single type node to be 211 | # inserted twice into the doctree, which leads to 212 | # inconsistencies later when references are resolved 213 | fieldtype = types.pop(fieldarg) 214 | if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): 215 | typename = u''.join(n.astext() for n in fieldtype) 216 | typename = typename.replace('int', 'python:int') 217 | typename = typename.replace('long', 'python:long') 218 | typename = typename.replace('float', 'python:float') 219 | typename = typename.replace('type', 'python:type') 220 | par.extend(self.make_xrefs(self.typerolename, domain, typename, 221 | addnodes.literal_emphasis, **kw)) 222 | else: 223 | par += fieldtype 224 | par += nodes.Text(')') 225 | par += nodes.Text(' -- ') 226 | par += content 227 | return par 228 | 229 | fieldname = nodes.field_name('', self.label) 230 | if len(items) == 1 and self.can_collapse: 231 | fieldarg, content = items[0] 232 | bodynode = handle_item(fieldarg, content) 233 | else: 234 | bodynode = self.list_type() 235 | for fieldarg, content in items: 236 | bodynode += nodes.list_item('', handle_item(fieldarg, content)) 237 | fieldbody = nodes.field_body('', bodynode) 238 | return nodes.field('', fieldname, fieldbody) 239 | 240 | 241 | TypedField.make_field = patched_make_field 242 | -------------------------------------------------------------------------------- /text/docs/source/data.rst: -------------------------------------------------------------------------------- 1 | torchtext.data 2 | ================= 3 | 4 | .. currentmodule:: torchtext.data 5 | 6 | The data module provides the following: 7 | 8 | - Ability to define a preprocessing pipeline 9 | - Batching, padding, and numericalizing (including building a vocabulary object) 10 | - Wrapper for dataset splits (train, validation, test) 11 | - Loader a custom NLP dataset 12 | 13 | .. contents:: Data 14 | :local: 15 | 16 | 17 | 18 | Batch 19 | ~~~~~ 20 | 21 | .. autoclass:: Batch 22 | 23 | Dataset 24 | ~~~~~~~ 25 | 26 | .. autoclass:: Dataset 27 | 28 | Example 29 | ~~~~~~~ 30 | 31 | .. autoclass:: Example 32 | 33 | Field 34 | ~~~~~ 35 | 36 | .. autoclass:: Field 37 | 38 | Iterator 39 | ~~~~~~~~ 40 | 41 | .. autoclass:: Iterator 42 | 43 | Pipeline 44 | ~~~~~~~~ 45 | 46 | .. autoclass:: Pipeline 47 | -------------------------------------------------------------------------------- /text/docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | torchtext.datasets 2 | ==================== 3 | 4 | .. currentmodule:: torchtext.datasets 5 | 6 | All datasets are subclasses of :class:`torchtext.data.Dataset`, which 7 | inherits from :class:`torch.utils.data.Dataset` i.e, they have ``split`` and 8 | ``iters`` methods implemented. 9 | 10 | General use cases are as follows: 11 | 12 | Approach 1, ``splits``: :: 13 | 14 | # set up fields 15 | TEXT = data.Field(lower=True, include_lengths=True, batch_first=True) 16 | LABEL = data.Field(sequential=False) 17 | 18 | # make splits for data 19 | train, test = datasets.IMDB.splits(TEXT, LABEL) 20 | 21 | # build the vocabulary 22 | TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) 23 | LABEL.build_vocab(train) 24 | 25 | # make iterator for splits 26 | train_iter, test_iter = data.BucketIterator.splits( 27 | (train, test), batch_size=3, device=0) 28 | 29 | Approach 2, ``iters``: :: 30 | 31 | # use default configurations 32 | train_iter, test_iter = datasets.IMDB.iters(batch_size=4) 33 | 34 | The following datasets are available: 35 | 36 | .. contents:: Datasets 37 | :local: 38 | 39 | 40 | Sentiment Analysis 41 | ^^^^^^^^^^^^^^^^^^ 42 | 43 | SST 44 | ~~~ 45 | 46 | .. autoclass:: SST 47 | :members: splits, iters 48 | 49 | IMDb 50 | ~~~~ 51 | 52 | .. autoclass:: IMDB 53 | :members: splits, iters 54 | 55 | 56 | 57 | Question Classification 58 | ^^^^^^^^^^^^^^^^^^^^^^^ 59 | 60 | TREC 61 | ~~~~ 62 | 63 | .. autoclass:: TREC 64 | :members: splits, iters 65 | 66 | Entailment 67 | ^^^^^^^^^^ 68 | 69 | SNLI 70 | ~~~~ 71 | 72 | .. autoclass:: SNLI 73 | :members: splits, iters 74 | 75 | 76 | 77 | Language Modeling 78 | ^^^^^^^^^^^^^^^^^ 79 | 80 | Language modeling datasets are subclasses of ``LanguageModelingDataset`` class. 81 | 82 | .. autoclass:: LanguageModelingDataset 83 | :members: __init__ 84 | 85 | 86 | WikiText-2 87 | ~~~~~~~~~~ 88 | 89 | .. autoclass:: WikiText2 90 | :members: splits, iters 91 | 92 | 93 | 94 | Machine Translation 95 | ^^^^^^^^^^^^^^^^^^^ 96 | 97 | Machine translation datasets are subclasses of ``TranslationDataset`` class. 98 | 99 | .. autoclass:: TranslationDataset 100 | :members: __init__ 101 | 102 | 103 | Multi30k 104 | ~~~~~~~~ 105 | 106 | .. autoclass:: Multi30k 107 | :members: splits 108 | 109 | IWSLT 110 | ~~~~~ 111 | 112 | .. autoclass:: IWSLT 113 | :members: splits 114 | 115 | WMT14 116 | ~~~~~ 117 | 118 | .. autoclass:: WMT14 119 | :members: splits 120 | -------------------------------------------------------------------------------- /text/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | torchtext 2 | =========== 3 | 4 | The :mod:`torchtext` package consists of data processing utilities and 5 | popular datasets for natural language. 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | :caption: Package Reference 10 | 11 | data 12 | datasets 13 | 14 | .. automodule:: torchtext 15 | :members: 16 | -------------------------------------------------------------------------------- /text/pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = test/ 3 | python_paths = ./ -------------------------------------------------------------------------------- /text/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import io 4 | import re 5 | from setuptools import setup, find_packages 6 | 7 | 8 | def read(*names, **kwargs): 9 | with io.open( 10 | os.path.join(os.path.dirname(__file__), *names), 11 | encoding=kwargs.get("encoding", "utf8") 12 | ) as fp: 13 | return fp.read() 14 | 15 | 16 | def find_version(*file_paths): 17 | version_file = read(*file_paths) 18 | version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", 19 | version_file, re.M) 20 | if version_match: 21 | return version_match.group(1) 22 | raise RuntimeError("Unable to find version string.") 23 | 24 | 25 | VERSION = find_version('torchtext', '__init__.py') 26 | 27 | long_description = '''torch-text provides text and NLP data utilities 28 | and datasets for torch''' 29 | 30 | setup_info = dict( 31 | # Metadata 32 | name='torchtext', 33 | version=VERSION, 34 | author='PyTorch core devs and James Bradbury', 35 | author_email='jekbradbury@gmail.com', 36 | url='https://github.com/pytorch/text', 37 | description='text utilities and datasets for torch deep learning', 38 | long_description=long_description, 39 | license='BSD', 40 | 41 | install_requires=[ 42 | 'tqdm', 'requests' 43 | ], 44 | 45 | # Package info 46 | packages=find_packages(exclude=('test',)), 47 | 48 | zip_safe=True, 49 | ) 50 | 51 | setup(**setup_info) 52 | -------------------------------------------------------------------------------- /text/test/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/text/test/.gitignore -------------------------------------------------------------------------------- /text/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/text/test/__init__.py -------------------------------------------------------------------------------- /text/test/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/text/test/common/__init__.py -------------------------------------------------------------------------------- /text/test/common/test_markers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | slow = pytest.mark.skipif( 4 | not pytest.config.getoption("--runslow"), 5 | reason="This test is slow. Set --runslow flag to run." 6 | ) 7 | -------------------------------------------------------------------------------- /text/test/common/torchtext_test_case.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from unittest import TestCase 3 | import json 4 | import logging 5 | import os 6 | import shutil 7 | import subprocess 8 | import tempfile 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class TorchtextTestCase(TestCase): 14 | def setUp(self): 15 | logging.basicConfig(format=('%(asctime)s - %(levelname)s - ' 16 | '%(name)s - %(message)s'), 17 | level=logging.INFO) 18 | # Directory where everything temporary and test-related is written 19 | self.project_root = os.path.abspath(os.path.realpath(os.path.join( 20 | os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir))) 21 | self.test_dir = tempfile.mkdtemp() 22 | self.test_ppid_dataset_path = os.path.join(self.test_dir, "test_ppid_dataset") 23 | self.test_numerical_features_dataset_path = os.path.join( 24 | self.test_dir, "test_numerical_features_dataset") 25 | 26 | def tearDown(self): 27 | try: 28 | shutil.rmtree(self.test_dir) 29 | except: 30 | subprocess.call(["rm", "-rf", self.test_dir]) 31 | 32 | def write_test_ppid_dataset(self, data_format="csv"): 33 | data_format = data_format.lower() 34 | if data_format == "csv": 35 | delim = "," 36 | elif data_format == "tsv": 37 | delim = "\t" 38 | dict_dataset = [ 39 | {"id": "0", "question1": "When do you use シ instead of し?", 40 | "question2": "When do you use \"&\" instead of \"and\"?", 41 | "label": "0"}, 42 | {"id": "1", "question1": "Where was Lincoln born?", 43 | "question2": "Which location was Abraham Lincoln born?", 44 | "label": "1"}, 45 | {"id": "2", "question1": "What is 2+2", 46 | "question2": "2+2=?", 47 | "label": "1"}, 48 | ] 49 | with open(self.test_ppid_dataset_path, "w") as test_ppid_dataset_file: 50 | for example in dict_dataset: 51 | if data_format == "json": 52 | test_ppid_dataset_file.write(json.dumps(example) + "\n") 53 | elif data_format == "csv" or data_format == "tsv": 54 | test_ppid_dataset_file.write("{}\n".format( 55 | delim.join([example["id"], example["question1"], 56 | example["question2"], example["label"]]))) 57 | else: 58 | raise ValueError("Invalid format {}".format(data_format)) 59 | 60 | def write_test_numerical_features_dataset(self): 61 | with open(self.test_numerical_features_dataset_path, 62 | "w") as test_numerical_features_dataset_file: 63 | test_numerical_features_dataset_file.write("0.1\t1\tteststring1\n") 64 | test_numerical_features_dataset_file.write("0.5\t12\tteststring2\n") 65 | test_numerical_features_dataset_file.write("0.2\t0\tteststring3\n") 66 | test_numerical_features_dataset_file.write("0.4\t12\tteststring4\n") 67 | test_numerical_features_dataset_file.write("0.9\t9\tteststring5\n") 68 | 69 | 70 | def verify_numericalized_example(field, test_example_data, 71 | test_example_numericalized, 72 | test_example_lengths=None, 73 | batch_first=False, train=True): 74 | """ 75 | Function to verify that numericalized example is correct 76 | with respect to the Field's Vocab. 77 | """ 78 | if isinstance(test_example_numericalized, tuple): 79 | test_example_numericalized, lengths = test_example_numericalized 80 | assert test_example_lengths == lengths.tolist() 81 | if batch_first: 82 | test_example_numericalized.data.t_() 83 | # Transpose numericalized example so we can compare over batches 84 | for example_idx, numericalized_single_example in enumerate( 85 | test_example_numericalized.t()): 86 | assert len(test_example_data[example_idx]) == len(numericalized_single_example) 87 | assert numericalized_single_example.volatile is not train 88 | for token_idx, numericalized_token in enumerate( 89 | numericalized_single_example): 90 | # Convert from Variable to int 91 | numericalized_token = numericalized_token.data[0] 92 | test_example_token = test_example_data[example_idx][token_idx] 93 | # Check if the numericalized example is correct, taking into 94 | # account unknown tokens. 95 | if field.vocab.stoi[test_example_token] != 0: 96 | # token is in-vocabulary 97 | assert (field.vocab.itos[numericalized_token] == 98 | test_example_token) 99 | else: 100 | # token is OOV and always has an index of 0 101 | assert numericalized_token == 0 102 | -------------------------------------------------------------------------------- /text/test/conftest.py: -------------------------------------------------------------------------------- 1 | def pytest_addoption(parser): 2 | parser.addoption("--runslow", action="store_true", 3 | help="Run slow tests") 4 | -------------------------------------------------------------------------------- /text/test/data.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | 3 | 4 | TEXT = data.Field() 5 | LABELS = data.Field() 6 | 7 | train, val, test = data.TabularDataset.splits( 8 | path='~/chainer-research/jmt-data/pos_wsj/pos_wsj', train='.train', 9 | validation='.dev', test='.test', format='tsv', 10 | fields=[('text', TEXT), ('labels', LABELS)]) 11 | 12 | print(train.fields) 13 | print(len(train)) 14 | print(vars(train[0])) 15 | 16 | train_iter, val_iter, test_iter = data.BucketIterator.splits( 17 | (train, val, test), batch_size=3, sort_key=lambda x: len(x.text), device=0) 18 | 19 | LABELS.build_vocab(train.labels) 20 | TEXT.build_vocab(train.text) 21 | 22 | print(TEXT.vocab.freqs.most_common(10)) 23 | print(LABELS.vocab.itos) 24 | 25 | batch = next(iter(train_iter)) 26 | print(batch.text) 27 | print(batch.labels) 28 | -------------------------------------------------------------------------------- /text/test/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/decaNLP/f1d474b0ff7c7c45a325177401d46b8d0dd16b38/text/test/data/__init__.py -------------------------------------------------------------------------------- /text/test/data/test_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | import torchtext.data as data 4 | 5 | from ..common.torchtext_test_case import TorchtextTestCase 6 | 7 | 8 | class TestDataset(TorchtextTestCase): 9 | def test_tabular_simple_data(self): 10 | for data_format in ["csv", "tsv", "json"]: 11 | self.write_test_ppid_dataset(data_format=data_format) 12 | 13 | if data_format == "json": 14 | question_field = data.Field(sequential=True) 15 | label_field = data.Field(sequential=False) 16 | fields = {"question1": ("q1", question_field), 17 | "question2": ("q2", question_field), 18 | "label": ("label", label_field)} 19 | else: 20 | question_field = data.Field(sequential=True) 21 | label_field = data.Field(sequential=False) 22 | fields = [("id", None), ("q1", question_field), 23 | ("q2", question_field), ("label", label_field)] 24 | 25 | dataset = data.TabularDataset( 26 | path=self.test_ppid_dataset_path, format=data_format, fields=fields) 27 | 28 | assert len(dataset) == 3 29 | 30 | expected_examples = [ 31 | (["When", "do", "you", "use", "シ", "instead", "of", "し?"], 32 | ["When", "do", "you", "use", "\"&\"", 33 | "instead", "of", "\"and\"?"], "0"), 34 | (["Where", "was", "Lincoln", "born?"], 35 | ["Which", "location", "was", "Abraham", "Lincoln", "born?"], "1"), 36 | (["What", "is", "2+2"], ["2+2=?"], "1")] 37 | 38 | # Ensure examples have correct contents / test __getitem__ 39 | for i in range(len(dataset)): 40 | self.assertEqual(dataset[i].q1, expected_examples[i][0]) 41 | self.assertEqual(dataset[i].q2, expected_examples[i][1]) 42 | self.assertEqual(dataset[i].label, expected_examples[i][2]) 43 | 44 | # Test __getattr__ 45 | for i, (q1, q2, label) in enumerate(zip(dataset.q1, dataset.q2, 46 | dataset.label)): 47 | self.assertEqual(q1, expected_examples[i][0]) 48 | self.assertEqual(q2, expected_examples[i][1]) 49 | self.assertEqual(label, expected_examples[i][2]) 50 | 51 | # Test __iter__ 52 | for i, example in enumerate(dataset): 53 | self.assertEqual(example.q1, expected_examples[i][0]) 54 | self.assertEqual(example.q2, expected_examples[i][1]) 55 | self.assertEqual(example.label, expected_examples[i][2]) 56 | 57 | def test_json_dataset_one_key_multiple_fields(self): 58 | self.write_test_ppid_dataset(data_format="json") 59 | 60 | question_field = data.Field(sequential=True) 61 | spacy_tok_question_field = data.Field(sequential=True, tokenize="spacy") 62 | label_field = data.Field(sequential=False) 63 | fields = {"question1": [("q1", question_field), 64 | ("q1_spacy", spacy_tok_question_field)], 65 | "question2": [("q2", question_field), 66 | ("q2_spacy", spacy_tok_question_field)], 67 | "label": ("label", label_field)} 68 | dataset = data.TabularDataset( 69 | path=self.test_ppid_dataset_path, format="json", fields=fields) 70 | expected_examples = [ 71 | (["When", "do", "you", "use", "シ", "instead", "of", "し?"], 72 | ["When", "do", "you", "use", "シ", "instead", "of", "し", "?"], 73 | ["When", "do", "you", "use", "\"&\"", 74 | "instead", "of", "\"and\"?"], 75 | ["When", "do", "you", "use", "\"", "&", "\"", 76 | "instead", "of", "\"", "and", "\"", "?"], "0"), 77 | (["Where", "was", "Lincoln", "born?"], 78 | ["Where", "was", "Lincoln", "born", "?"], 79 | ["Which", "location", "was", "Abraham", "Lincoln", "born?"], 80 | ["Which", "location", "was", "Abraham", "Lincoln", "born", "?"], 81 | "1"), 82 | (["What", "is", "2+2"], ["What", "is", "2", "+", "2"], 83 | ["2+2=?"], ["2", "+", "2=", "?"], "1")] 84 | for i, example in enumerate(dataset): 85 | self.assertEqual(example.q1, expected_examples[i][0]) 86 | self.assertEqual(example.q1_spacy, expected_examples[i][1]) 87 | self.assertEqual(example.q2, expected_examples[i][2]) 88 | self.assertEqual(example.q2_spacy, expected_examples[i][3]) 89 | self.assertEqual(example.label, expected_examples[i][4]) 90 | 91 | def test_errors(self): 92 | # Ensure that trying to retrieve a key not in JSON data errors 93 | self.write_test_ppid_dataset(data_format="json") 94 | 95 | question_field = data.Field(sequential=True) 96 | label_field = data.Field(sequential=False) 97 | fields = {"qeustion1": ("q1", question_field), 98 | "question2": ("q2", question_field), 99 | "label": ("label", label_field)} 100 | 101 | with self.assertRaises(ValueError): 102 | data.TabularDataset( 103 | path=self.test_ppid_dataset_path, format="json", fields=fields) 104 | -------------------------------------------------------------------------------- /text/test/data/test_pipeline.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | import six 4 | import torchtext.data as data 5 | 6 | from ..common.torchtext_test_case import TorchtextTestCase 7 | 8 | 9 | class TestPipeline(TorchtextTestCase): 10 | @staticmethod 11 | def repeat_n(x, n=3): 12 | """ 13 | Given a sequence, repeat it n times. 14 | """ 15 | return x * n 16 | 17 | def test_pipeline(self): 18 | id_pipeline = data.Pipeline() 19 | assert id_pipeline("Test STring") == "Test STring" 20 | assert id_pipeline("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T" 21 | assert id_pipeline(["1241", "Some String"]) == ["1241", "Some String"] 22 | 23 | pipeline = data.Pipeline(six.text_type.lower) 24 | assert pipeline("Test STring") == "test string" 25 | assert pipeline("ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T") == "ᑌᑎiᑕoᗪᕮ_tᕮ᙭t" 26 | assert pipeline(["1241", "Some String"]) == ["1241", "some string"] 27 | 28 | args_pipeline = data.Pipeline(TestPipeline.repeat_n) 29 | assert args_pipeline("test", 5) == "testtesttesttesttest" 30 | assert args_pipeline(["ele1", "ele2"], 2) == ["ele1ele1", "ele2ele2"] 31 | 32 | def test_composition(self): 33 | id_pipeline = data.Pipeline() 34 | pipeline = data.Pipeline(TestPipeline.repeat_n) 35 | pipeline.add_before(id_pipeline) 36 | pipeline.add_after(id_pipeline) 37 | pipeline.add_before(six.text_type.lower) 38 | pipeline.add_after(six.text_type.capitalize) 39 | 40 | other_pipeline = data.Pipeline(six.text_type.swapcase) 41 | other_pipeline.add_before(pipeline) 42 | 43 | # Assert pipeline gives proper results after composition 44 | # (test that we aren't modfifying pipes member) 45 | assert pipeline("teST") == "Testtesttest" 46 | assert pipeline(["ElE1", "eLe2"]) == ["Ele1ele1ele1", "Ele2ele2ele2"] 47 | 48 | # Assert pipeline that we added to gives proper results 49 | assert other_pipeline("teST") == "tESTTESTTEST" 50 | assert other_pipeline(["ElE1", "eLe2"]) == ["eLE1ELE1ELE1", "eLE2ELE2ELE2"] 51 | 52 | def test_exceptions(self): 53 | with self.assertRaises(ValueError): 54 | data.Pipeline("Not Callable") 55 | -------------------------------------------------------------------------------- /text/test/data/test_subword.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from torchtext import data 4 | from torchtext.datasets import TREC 5 | 6 | 7 | class TestSubword(unittest.TestCase): 8 | def test_subword_trec(self): 9 | TEXT = data.SubwordField() 10 | LABEL = data.Field(sequential=False) 11 | RAW = data.Field(sequential=False, use_vocab=False) 12 | raw, = TREC.splits(RAW, LABEL, train=None) 13 | cooked, = TREC.splits(TEXT, LABEL, train=None) 14 | LABEL.build_vocab(cooked) 15 | TEXT.build_vocab(cooked, max_size=100) 16 | TEXT.segment(cooked) 17 | print(cooked[0].text) 18 | batch = next(iter(data.Iterator(cooked, 1, shuffle=False, device=-1))) 19 | self.assertEqual(TEXT.reverse(batch.text.data)[0], raw[0].text) 20 | 21 | 22 | if __name__ == '__main__': 23 | unittest.main() 24 | -------------------------------------------------------------------------------- /text/test/data/test_utils.py: -------------------------------------------------------------------------------- 1 | import six 2 | import torchtext.data as data 3 | 4 | from ..common.torchtext_test_case import TorchtextTestCase 5 | 6 | 7 | class TestUtils(TorchtextTestCase): 8 | def test_get_tokenizer(self): 9 | # Test the default case with str.split 10 | assert data.get_tokenizer(str.split) == str.split 11 | test_str = "A string, particularly one with slightly complex punctuation." 12 | assert data.get_tokenizer(str.split)(test_str) == str.split(test_str) 13 | 14 | # Test SpaCy option, and verify it properly handles punctuation. 15 | assert data.get_tokenizer("spacy")(six.text_type(test_str)) == [ 16 | "A", "string", ",", "particularly", "one", "with", "slightly", 17 | "complex", "punctuation", "."] 18 | 19 | # Test Moses option. Test strings taken from NLTK doctests. 20 | # Note that internally, MosesTokenizer converts to unicode if applicable 21 | moses_tokenizer = data.get_tokenizer("moses") 22 | assert moses_tokenizer(test_str) == [ 23 | "A", "string", ",", "particularly", "one", "with", "slightly", 24 | "complex", "punctuation", "."] 25 | 26 | # Nonbreaking prefixes should tokenize the final period. 27 | assert moses_tokenizer(six.text_type("abc def.")) == ["abc", "def", "."] 28 | 29 | # Test that errors are raised for invalid input arguments. 30 | with self.assertRaises(ValueError): 31 | data.get_tokenizer(1) 32 | with self.assertRaises(ValueError): 33 | data.get_tokenizer("some other string") 34 | -------------------------------------------------------------------------------- /text/test/imdb.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | from torchtext.vocab import GloVe 4 | 5 | 6 | # Approach 1: 7 | # set up fields 8 | TEXT = data.Field(lower=True, include_lengths=True, batch_first=True) 9 | LABEL = data.Field(sequential=False) 10 | 11 | 12 | # make splits for data 13 | train, test = datasets.IMDB.splits(TEXT, LABEL) 14 | 15 | # print information about the data 16 | print('train.fields', train.fields) 17 | print('len(train)', len(train)) 18 | print('vars(train[0])', vars(train[0])) 19 | 20 | # build the vocabulary 21 | TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) 22 | LABEL.build_vocab(train) 23 | 24 | # print vocab information 25 | print('len(TEXT.vocab)', len(TEXT.vocab)) 26 | print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) 27 | 28 | # make iterator for splits 29 | train_iter, test_iter = data.BucketIterator.splits( 30 | (train, test), batch_size=3, device=0) 31 | 32 | # print batch information 33 | batch = next(iter(train_iter)) 34 | print(batch.text) 35 | print(batch.label) 36 | 37 | # Approach 2: 38 | train_iter, test_iter = datasets.IMDB.iters(batch_size=4) 39 | 40 | # print batch information 41 | batch = next(iter(train_iter)) 42 | print(batch.text) 43 | print(batch.label) 44 | -------------------------------------------------------------------------------- /text/test/language_modeling.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | from torchtext.vocab import GloVe 4 | 5 | # Approach 1: 6 | # set up fields 7 | TEXT = data.Field(lower=True, batch_first=True) 8 | 9 | # make splits for data 10 | train, valid, test = datasets.WikiText2.splits(TEXT) 11 | 12 | # print information about the data 13 | print('train.fields', train.fields) 14 | print('len(train)', len(train)) 15 | print('vars(train[0])', vars(train[0])['text'][0:10]) 16 | 17 | # build the vocabulary 18 | TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) 19 | 20 | # print vocab information 21 | print('len(TEXT.vocab)', len(TEXT.vocab)) 22 | 23 | # make iterator for splits 24 | train_iter, valid_iter, test_iter = data.BPTTIterator.splits( 25 | (train, valid, test), batch_size=3, bptt_len=30, device=0) 26 | 27 | # print batch information 28 | batch = next(iter(train_iter)) 29 | print(batch.text) 30 | print(batch.target) 31 | 32 | # Approach 2: 33 | train_iter, valid_iter, test_iter = datasets.WikiText2.iters(batch_size=4, bptt_len=30) 34 | 35 | # print batch information 36 | batch = next(iter(train_iter)) 37 | print(batch.text) 38 | print(batch.target) 39 | -------------------------------------------------------------------------------- /text/test/sequence_tagging.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | 4 | # Define the fields associated with the sequences. 5 | WORD = data.Field(init_token="", eos_token="") 6 | UD_TAG = data.Field(init_token="", eos_token="") 7 | 8 | # Download and the load default data. 9 | train, val, test = datasets.UDPOS.splits( 10 | fields=(('word', WORD), ('udtag', UD_TAG), (None, None))) 11 | 12 | print(train.fields) 13 | print(len(train)) 14 | print(vars(train[0])) 15 | 16 | # We can also define more than two columns. 17 | WORD = data.Field(init_token="", eos_token="") 18 | UD_TAG = data.Field(init_token="", eos_token="") 19 | PTB_TAG = data.Field(init_token="", eos_token="") 20 | 21 | # Load the specified data. 22 | train, val, test = datasets.UDPOS.splits( 23 | fields=(('word', WORD), ('udtag', UD_TAG), ('ptbtag', PTB_TAG)), 24 | path=".data/sequence-labeling/en-ud-v2", 25 | train="en-ud-tag.v2.train.txt", 26 | validation="en-ud-tag.v2.dev.txt", 27 | test="en-ud-tag.v2.test.txt") 28 | 29 | print(train.fields) 30 | print(len(train)) 31 | print(vars(train[0])) 32 | 33 | WORD.build_vocab(train.word, min_freq=3) 34 | UD_TAG.build_vocab(train.udtag) 35 | PTB_TAG.build_vocab(train.ptbtag) 36 | 37 | print(UD_TAG.vocab.freqs) 38 | print(PTB_TAG.vocab.freqs) 39 | 40 | train_iter, val_iter = data.BucketIterator.splits( 41 | (train, val), batch_size=3, device=0) 42 | 43 | batch = next(iter(train_iter)) 44 | 45 | print("words", batch.word) 46 | print("udtags", batch.udtag) 47 | print("ptbtags", batch.ptbtag) 48 | -------------------------------------------------------------------------------- /text/test/snli.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | 4 | TEXT = data.Field() 5 | LABEL = data.Field(sequential=False) 6 | 7 | train, val, test = datasets.SNLI.splits(TEXT, LABEL) 8 | 9 | print(train.fields) 10 | print(len(train)) 11 | print(vars(train[0])) 12 | 13 | TEXT.build_vocab(train) 14 | LABEL.build_vocab(train) 15 | 16 | train_iter, val_iter, test_iter = data.BucketIterator.splits( 17 | (train, val, test), batch_size=3, device=0) 18 | 19 | batch = next(iter(train_iter)) 20 | print(batch.premise) 21 | print(batch.hypothesis) 22 | print(batch.label) 23 | 24 | train_iter, val_iter, test_iter = datasets.SNLI.iters(batch_size=4) 25 | 26 | batch = next(iter(train_iter)) 27 | print(batch.premise) 28 | print(batch.hypothesis) 29 | print(batch.label) 30 | -------------------------------------------------------------------------------- /text/test/sst.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | from torchtext.vocab import Vectors, GloVe, CharNGram, FastText 4 | 5 | 6 | # Approach 1: 7 | # set up fields 8 | TEXT = data.Field() 9 | LABEL = data.Field(sequential=False) 10 | 11 | # make splits for data 12 | train, val, test = datasets.SST.splits( 13 | TEXT, LABEL, fine_grained=True, train_subtrees=True, 14 | filter_pred=lambda ex: ex.label != 'neutral') 15 | 16 | # print information about the data 17 | print('train.fields', train.fields) 18 | print('len(train)', len(train)) 19 | print('vars(train[0])', vars(train[0])) 20 | 21 | # build the vocabulary 22 | url = 'https://s3-us-west-1.amazonaws.com/fasttext-vectors/wiki.simple.vec' 23 | TEXT.build_vocab(train, vectors=Vectors('wiki.simple.vec', url=url)) 24 | LABEL.build_vocab(train) 25 | 26 | # print vocab information 27 | print('len(TEXT.vocab)', len(TEXT.vocab)) 28 | print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) 29 | 30 | # make iterator for splits 31 | train_iter, val_iter, test_iter = data.BucketIterator.splits( 32 | (train, val, test), batch_size=3, device=0) 33 | 34 | # print batch information 35 | batch = next(iter(train_iter)) 36 | print(batch.text) 37 | print(batch.label) 38 | 39 | # Approach 2: 40 | TEXT.build_vocab(train, vectors=[GloVe(name='840B', dim='300'), CharNGram(), FastText()]) 41 | LABEL.build_vocab(train) 42 | 43 | # print vocab information 44 | print('len(TEXT.vocab)', len(TEXT.vocab)) 45 | print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) 46 | 47 | train_iter, val_iter, test_iter = datasets.SST.iters(batch_size=4) 48 | 49 | # print batch information 50 | batch = next(iter(train_iter)) 51 | print(batch.text) 52 | print(batch.label) 53 | 54 | # Approach 3: 55 | f = FastText() 56 | TEXT.build_vocab(train, vectors=f) 57 | TEXT.vocab.extend(f) 58 | LABEL.build_vocab(train) 59 | 60 | # print vocab information 61 | print('len(TEXT.vocab)', len(TEXT.vocab)) 62 | print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) 63 | 64 | train_iter, val_iter, test_iter = datasets.SST.iters(batch_size=4) 65 | 66 | # print batch information 67 | batch = next(iter(train_iter)) 68 | print(batch.text) 69 | print(batch.label) 70 | -------------------------------------------------------------------------------- /text/test/test_vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | from collections import Counter 4 | import os 5 | import pickle 6 | 7 | 8 | import numpy as np 9 | from numpy.testing import assert_allclose 10 | import torch 11 | from torchtext import vocab 12 | from torchtext.vocab import Vectors, FastText, GloVe, CharNGram 13 | 14 | from .common.test_markers import slow 15 | from .common.torchtext_test_case import TorchtextTestCase 16 | 17 | 18 | def conditional_remove(f): 19 | if os.path.isfile(f): 20 | os.remove(f) 21 | 22 | 23 | class TestVocab(TorchtextTestCase): 24 | 25 | def test_vocab_basic(self): 26 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 27 | v = vocab.Vocab(c, min_freq=3, specials=['', '', '']) 28 | 29 | expected_itos = ['', '', '', 30 | 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] 31 | expected_stoi = {x: index for index, x in enumerate(expected_itos)} 32 | self.assertEqual(v.itos, expected_itos) 33 | self.assertEqual(dict(v.stoi), expected_stoi) 34 | 35 | def test_vocab_set_vectors(self): 36 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 37 | 'test': 4, 'freq_too_low': 2}) 38 | v = vocab.Vocab(c, min_freq=3, specials=['', '', '']) 39 | stoi = {"hello": 0, "world": 1, "test": 2} 40 | vectors = torch.FloatTensor([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]) 41 | dim = 2 42 | v.set_vectors(stoi, vectors, dim) 43 | expected_vectors = np.array([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], 44 | [0.0, 0.0], [0.1, 0.2], [0.5, 0.6], 45 | [0.3, 0.4]]) 46 | assert_allclose(v.vectors.numpy(), expected_vectors) 47 | 48 | def test_vocab_download_fasttext_vectors(self): 49 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 50 | # Build a vocab and get vectors twice to test caching, then once more 51 | # to test string aliases. 52 | for i in range(3): 53 | if i == 2: 54 | vectors = str("fasttext.simple.300d") # must handle str on Py2 55 | else: 56 | vectors = FastText(language='simple') 57 | 58 | v = vocab.Vocab(c, min_freq=3, specials=['', '', ''], 59 | vectors=vectors) 60 | 61 | expected_itos = ['', '', '', 62 | 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] 63 | expected_stoi = {x: index for index, x in enumerate(expected_itos)} 64 | self.assertEqual(v.itos, expected_itos) 65 | self.assertEqual(dict(v.stoi), expected_stoi) 66 | vectors = v.vectors.numpy() 67 | 68 | # The first 5 entries in each vector. 69 | expected_fasttext_simple_en = { 70 | 'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645], 71 | 'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165], 72 | } 73 | 74 | for word in expected_fasttext_simple_en: 75 | assert_allclose(vectors[v.stoi[word], :5], 76 | expected_fasttext_simple_en[word]) 77 | 78 | assert_allclose(vectors[v.stoi['']], np.zeros(300)) 79 | assert_allclose(vectors[v.stoi['OOV token']], np.zeros(300)) 80 | # Delete the vectors after we're done to save disk space on CI 81 | if os.environ.get("TRAVIS") == "true": 82 | vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec") 83 | conditional_remove(vec_file) 84 | 85 | def test_vocab_extend(self): 86 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 87 | # Build a vocab and get vectors twice to test caching. 88 | for i in range(2): 89 | f = FastText(language='simple') 90 | v = vocab.Vocab(c, min_freq=3, specials=['', '', ''], 91 | vectors=f) 92 | n_vocab = len(v) 93 | v.extend(f) # extend the vocab with the words contained in f.itos 94 | self.assertGreater(len(v), n_vocab) 95 | 96 | self.assertEqual(v.itos[:6], ['', '', '', 97 | 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) 98 | vectors = v.vectors.numpy() 99 | 100 | # The first 5 entries in each vector. 101 | expected_fasttext_simple_en = { 102 | 'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645], 103 | 'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165], 104 | } 105 | 106 | for word in expected_fasttext_simple_en: 107 | assert_allclose(vectors[v.stoi[word], :5], 108 | expected_fasttext_simple_en[word]) 109 | 110 | assert_allclose(vectors[v.stoi['']], np.zeros(300)) 111 | # Delete the vectors after we're done to save disk space on CI 112 | if os.environ.get("TRAVIS") == "true": 113 | vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec") 114 | conditional_remove(vec_file) 115 | 116 | def test_vocab_download_custom_vectors(self): 117 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 118 | # Build a vocab and get vectors twice to test caching. 119 | for i in range(2): 120 | v = vocab.Vocab(c, min_freq=3, specials=['', '', ''], 121 | vectors=Vectors('wiki.simple.vec', 122 | url=FastText.url_base.format('simple'))) 123 | 124 | self.assertEqual(v.itos, ['', '', '', 125 | 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world']) 126 | vectors = v.vectors.numpy() 127 | 128 | # The first 5 entries in each vector. 129 | expected_fasttext_simple_en = { 130 | 'hello': [0.39567, 0.21454, -0.035389, -0.24299, -0.095645], 131 | 'world': [0.10444, -0.10858, 0.27212, 0.13299, -0.33165], 132 | } 133 | 134 | for word in expected_fasttext_simple_en: 135 | assert_allclose(vectors[v.stoi[word], :5], 136 | expected_fasttext_simple_en[word]) 137 | 138 | assert_allclose(vectors[v.stoi['']], np.zeros(300)) 139 | # Delete the vectors after we're done to save disk space on CI 140 | if os.environ.get("TRAVIS") == "true": 141 | vec_file = os.path.join(self.project_root, ".vector_cache", "wiki.simple.vec") 142 | conditional_remove(vec_file) 143 | 144 | @slow 145 | def test_vocab_download_glove_vectors(self): 146 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 147 | 148 | # Build a vocab and get vectors twice to test caching, then once more 149 | # to test string aliases. 150 | for i in range(3): 151 | if i == 2: 152 | vectors = "glove.twitter.27B.25d" 153 | else: 154 | vectors = GloVe(name='twitter.27B', dim='25') 155 | v = vocab.Vocab(c, min_freq=3, specials=['', '', ''], 156 | vectors=vectors) 157 | 158 | expected_itos = ['', '', '', 159 | 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] 160 | expected_stoi = {x: index for index, x in enumerate(expected_itos)} 161 | self.assertEqual(v.itos, expected_itos) 162 | self.assertEqual(dict(v.stoi), expected_stoi) 163 | 164 | vectors = v.vectors.numpy() 165 | 166 | # The first 5 entries in each vector. 167 | expected_twitter = { 168 | 'hello': [-0.77069, 0.12827, 0.33137, 0.0050893, -0.47605], 169 | 'world': [0.10301, 0.095666, -0.14789, -0.22383, -0.14775], 170 | } 171 | 172 | for word in expected_twitter: 173 | assert_allclose(vectors[v.stoi[word], :5], 174 | expected_twitter[word]) 175 | 176 | assert_allclose(vectors[v.stoi['']], np.zeros(25)) 177 | assert_allclose(vectors[v.stoi['OOV token']], np.zeros(25)) 178 | # Delete the vectors after we're done to save disk space on CI 179 | if os.environ.get("TRAVIS") == "true": 180 | zip_file = os.path.join(self.project_root, ".vector_cache", 181 | "glove.twitter.27B.zip") 182 | conditional_remove(zip_file) 183 | for dim in ["25", "50", "100", "200"]: 184 | conditional_remove(os.path.join(self.project_root, ".vector_cache", 185 | "glove.twitter.27B.{}d.txt".format(dim))) 186 | 187 | @slow 188 | def test_vocab_download_charngram_vectors(self): 189 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 190 | # Build a vocab and get vectors twice to test caching, then once more 191 | # to test string aliases. 192 | for i in range(3): 193 | if i == 2: 194 | vectors = "charngram.100d" 195 | else: 196 | vectors = CharNGram() 197 | v = vocab.Vocab(c, min_freq=3, specials=['', '', ''], 198 | vectors=vectors) 199 | expected_itos = ['', '', '', 200 | 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T', 'hello', 'world'] 201 | expected_stoi = {x: index for index, x in enumerate(expected_itos)} 202 | self.assertEqual(v.itos, expected_itos) 203 | self.assertEqual(dict(v.stoi), expected_stoi) 204 | vectors = v.vectors.numpy() 205 | 206 | # The first 5 entries in each vector. 207 | expected_charngram = { 208 | 'hello': [-0.44782442, -0.08937783, -0.34227219, 209 | -0.16233221, -0.39343098], 210 | 'world': [-0.29590717, -0.05275926, -0.37334684, 0.27117205, -0.3868292], 211 | } 212 | 213 | for word in expected_charngram: 214 | assert_allclose(vectors[v.stoi[word], :5], 215 | expected_charngram[word]) 216 | 217 | assert_allclose(vectors[v.stoi['']], np.zeros(100)) 218 | assert_allclose(vectors[v.stoi['OOV token']], np.zeros(100)) 219 | # Delete the vectors after we're done to save disk space on CI 220 | if os.environ.get("TRAVIS") == "true": 221 | conditional_remove( 222 | os.path.join(self.project_root, ".vector_cache", "charNgram.txt")) 223 | conditional_remove( 224 | os.path.join(self.project_root, ".vector_cache", 225 | "jmt_pre-trained_embeddings.tar.gz")) 226 | 227 | def test_errors(self): 228 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 229 | with self.assertRaises(ValueError): 230 | # Test proper error raised when using unknown string alias 231 | vocab.Vocab(c, min_freq=3, specials=['', '', ''], 232 | vectors=["fasttext.english.300d"]) 233 | vocab.Vocab(c, min_freq=3, specials=['', '', ''], 234 | vectors="fasttext.english.300d") 235 | with self.assertRaises(ValueError): 236 | # Test proper error is raised when vectors argument is 237 | # non-string or non-Vectors 238 | vocab.Vocab(c, min_freq=3, specials=['', '', ''], 239 | vectors={"word": [1, 2, 3]}) 240 | 241 | def test_serialization(self): 242 | c = Counter({'hello': 4, 'world': 3, 'ᑌᑎIᑕOᗪᕮ_Tᕮ᙭T': 5, 'freq_too_low': 2}) 243 | v = vocab.Vocab(c, min_freq=3, specials=['', '', '']) 244 | pickle_path = os.path.join(self.test_dir, "vocab.pkl") 245 | pickle.dump(v, open(pickle_path, "wb")) 246 | v_loaded = pickle.load(open(pickle_path, "rb")) 247 | assert v == v_loaded 248 | -------------------------------------------------------------------------------- /text/test/translation.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | 4 | import re 5 | import spacy 6 | 7 | spacy_de = spacy.load('de') 8 | spacy_en = spacy.load('en') 9 | 10 | url = re.compile('(.*)') 11 | 12 | 13 | def tokenize_de(text): 14 | return [tok.text for tok in spacy_de.tokenizer(url.sub('@URL@', text))] 15 | 16 | 17 | def tokenize_en(text): 18 | return [tok.text for tok in spacy_en.tokenizer(url.sub('@URL@', text))] 19 | 20 | 21 | # Testing IWSLT 22 | DE = data.Field(tokenize=tokenize_de) 23 | EN = data.Field(tokenize=tokenize_en) 24 | 25 | train, val, test = datasets.IWSLT.splits(exts=('.de', '.en'), fields=(DE, EN)) 26 | 27 | print(train.fields) 28 | print(len(train)) 29 | print(vars(train[0])) 30 | print(vars(train[100])) 31 | 32 | DE.build_vocab(train.src, min_freq=3) 33 | EN.build_vocab(train.trg, max_size=50000) 34 | 35 | train_iter, val_iter = data.BucketIterator.splits( 36 | (train, val), batch_size=3, device=0) 37 | 38 | print(DE.vocab.freqs.most_common(10)) 39 | print(len(DE.vocab)) 40 | print(EN.vocab.freqs.most_common(10)) 41 | print(len(EN.vocab)) 42 | 43 | batch = next(iter(train_iter)) 44 | print(batch.src) 45 | print(batch.trg) 46 | 47 | 48 | # Testing Multi30k 49 | DE = data.Field(tokenize=tokenize_de) 50 | EN = data.Field(tokenize=tokenize_en) 51 | 52 | train, val, test = datasets.Multi30k.splits(exts=('.de', '.en'), fields=(DE, EN)) 53 | 54 | print(train.fields) 55 | print(len(train)) 56 | print(vars(train[0])) 57 | print(vars(train[100])) 58 | 59 | DE.build_vocab(train.src, min_freq=3) 60 | EN.build_vocab(train.trg, max_size=50000) 61 | 62 | train_iter, val_iter = data.BucketIterator.splits( 63 | (train, val), batch_size=3, device=0) 64 | 65 | print(DE.vocab.freqs.most_common(10)) 66 | print(len(DE.vocab)) 67 | print(EN.vocab.freqs.most_common(10)) 68 | print(len(EN.vocab)) 69 | 70 | batch = next(iter(train_iter)) 71 | print(batch.src) 72 | print(batch.trg) 73 | 74 | 75 | # Testing custom paths 76 | DE = data.Field(tokenize=tokenize_de) 77 | EN = data.Field(tokenize=tokenize_en) 78 | 79 | train, val = datasets.TranslationDataset.splits( 80 | path='.data/multi30k/', train='train', 81 | validation='val', exts=('.de', '.en'), 82 | fields=(DE, EN)) 83 | 84 | print(train.fields) 85 | print(len(train)) 86 | print(vars(train[0])) 87 | print(vars(train[100])) 88 | 89 | DE.build_vocab(train.src, min_freq=3) 90 | EN.build_vocab(train.trg, max_size=50000) 91 | 92 | train_iter, val_iter = data.BucketIterator.splits( 93 | (train, val), batch_size=3, device=0) 94 | 95 | print(DE.vocab.freqs.most_common(10)) 96 | print(len(DE.vocab)) 97 | print(EN.vocab.freqs.most_common(10)) 98 | print(len(EN.vocab)) 99 | 100 | batch = next(iter(train_iter)) 101 | print(batch.src) 102 | print(batch.trg) 103 | -------------------------------------------------------------------------------- /text/test/trec.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | from torchtext import datasets 3 | from torchtext.vocab import GloVe, CharNGram 4 | 5 | 6 | # Approach 1: 7 | # set up fields 8 | TEXT = data.Field(lower=True, include_lengths=True, batch_first=True) 9 | LABEL = data.Field(sequential=False) 10 | 11 | 12 | # make splits for data 13 | train, test = datasets.TREC.splits(TEXT, LABEL, fine_grained=True) 14 | 15 | # print information about the data 16 | print('train.fields', train.fields) 17 | print('len(train)', len(train)) 18 | print('vars(train[0])', vars(train[0])) 19 | 20 | # build the vocabulary 21 | TEXT.build_vocab(train, vectors=GloVe(name='6B', dim=300)) 22 | LABEL.build_vocab(train) 23 | 24 | # print vocab information 25 | print('len(TEXT.vocab)', len(TEXT.vocab)) 26 | print('TEXT.vocab.vectors.size()', TEXT.vocab.vectors.size()) 27 | 28 | # make iterator for splits 29 | train_iter, test_iter = data.BucketIterator.splits( 30 | (train, test), batch_size=3, device=0) 31 | 32 | # print batch information 33 | batch = next(iter(train_iter)) 34 | print(batch.text) 35 | print(batch.label) 36 | 37 | # Approach 2: 38 | TEXT.build_vocab(train, vectors=[GloVe(name='840B', dim='300'), CharNGram()]) 39 | LABEL.build_vocab(train) 40 | 41 | train_iter, test_iter = datasets.TREC.iters(batch_size=4) 42 | 43 | # print batch information 44 | batch = next(iter(train_iter)) 45 | print(batch.text) 46 | print(batch.label) 47 | -------------------------------------------------------------------------------- /text/torchtext/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import datasets 3 | from . import utils 4 | 5 | __version__ = '0.2.1' 6 | 7 | __all__ = ['data', 8 | 'datasets', 9 | 'utils'] 10 | -------------------------------------------------------------------------------- /text/torchtext/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .batch import Batch 2 | from .dataset import Dataset, TabularDataset 3 | from .example import Example 4 | from .field import RawField, Field, ReversibleField, SubwordField 5 | from .iterator import (batch, BucketIterator, Iterator, BPTTIterator, 6 | pool) 7 | from .pipeline import Pipeline 8 | from .utils import get_tokenizer, interleave_keys 9 | 10 | __all__ = ["Batch", 11 | "Dataset", "TabularDataset", "ZipDataset", 12 | "Example", 13 | "RawField", "Field", "ReversibleField", "SubwordField", 14 | "batch", "BucketIterator", "Iterator", "BPTTIterator", 15 | "pool", 16 | "Pipeline", 17 | "get_tokenizer", "interleave_keys"] 18 | -------------------------------------------------------------------------------- /text/torchtext/data/batch.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from torch.autograd import Variable 3 | from copy import deepcopy 4 | 5 | 6 | 7 | class Batch(object): 8 | """Defines a batch of examples along with its Fields. 9 | 10 | Attributes: 11 | batch_size: Number of examples in the batch. 12 | dataset: A reference to the dataset object the examples come from 13 | (which itself contains the dataset's Field objects). 14 | train: Whether the batch is from a training set. 15 | 16 | Also stores the Variable for each column in the batch as an attribute. 17 | """ 18 | 19 | def __init__(self, data=None, dataset=None, device=None, train=True): 20 | """Create a Batch from a list of examples.""" 21 | if data is not None: 22 | self.batch_size = len(data) 23 | self.dataset = dataset 24 | self.train = train 25 | field = list(dataset.fields.values())[0] 26 | limited_idx_to_full_idx = deepcopy(field.decoder_to_vocab) # should avoid this with a conditional in map to full 27 | oov_to_limited_idx = {} 28 | for (name, field) in dataset.fields.items(): 29 | if field is not None: 30 | batch = [x.__dict__[name] for x in data] 31 | if not field.include_lengths: 32 | setattr(self, name, field.process(batch, device=device, train=train)) 33 | else: 34 | entry, lengths, limited_entry, raw = field.process(batch, device=device, train=train, 35 | limited=field.decoder_stoi, l2f=limited_idx_to_full_idx, oov2l=oov_to_limited_idx) 36 | setattr(self, name, entry) 37 | setattr(self, f'{name}_lengths', lengths) 38 | setattr(self, f'{name}_limited', limited_entry) 39 | setattr(self, f'{name}_elmo', [[s.strip() for s in l] for l in raw]) 40 | setattr(self, f'limited_idx_to_full_idx', limited_idx_to_full_idx) 41 | setattr(self, f'oov_to_limited_idx', oov_to_limited_idx) 42 | 43 | 44 | @classmethod 45 | def fromvars(cls, dataset, batch_size, train=True, **kwargs): 46 | """Create a Batch directly from a number of Variables.""" 47 | batch = cls() 48 | batch.batch_size = batch_size 49 | batch.dataset = dataset 50 | batch.train = train 51 | for k, v in kwargs.items(): 52 | setattr(batch, k, v) 53 | return batch 54 | -------------------------------------------------------------------------------- /text/torchtext/data/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import zipfile 4 | import tarfile 5 | 6 | import torch.utils.data 7 | 8 | from .example import Example 9 | from ..utils import download_from_url 10 | 11 | 12 | class Dataset(torch.utils.data.Dataset): 13 | """Defines a dataset composed of Examples along with its Fields. 14 | 15 | Attributes: 16 | sort_key (callable): A key to use for sorting dataset examples for batching 17 | together examples with similar lengths to minimize padding. 18 | examples (list(Example)): The examples in this dataset. 19 | fields: A dictionary containing the name of each column together with 20 | its corresponding Field object. Two columns with the same Field 21 | object will share a vocabulary. 22 | fields (dict[str, Field]): Contains the name of each column or field, together 23 | with the corresponding Field object. Two fields with the same Field object 24 | will have a shared vocabulary. 25 | """ 26 | sort_key = None 27 | 28 | def __init__(self, examples, fields, filter_pred=None): 29 | """Create a dataset from a list of Examples and Fields. 30 | 31 | Arguments: 32 | examples: List of Examples. 33 | fields (List(tuple(str, Field))): The Fields to use in this tuple. The 34 | string is a field name, and the Field is the associated field. 35 | filter_pred (callable or None): Use only examples for which 36 | filter_pred(example) is True, or use all examples if None. 37 | Default is None. 38 | """ 39 | if filter_pred is not None: 40 | make_list = isinstance(examples, list) 41 | examples = filter(filter_pred, examples) 42 | if make_list: 43 | examples = list(examples) 44 | self.examples = examples 45 | self.fields = dict(fields) 46 | 47 | @classmethod 48 | def splits(cls, path=None, root='.data', train=None, validation=None, 49 | test=None, **kwargs): 50 | """Create Dataset objects for multiple splits of a dataset. 51 | 52 | Arguments: 53 | path (str): Common prefix of the splits' file paths, or None to use 54 | the result of cls.download(root). 55 | root (str): Root dataset storage directory. Default is '.data'. 56 | train (str): Suffix to add to path for the train set, or None for no 57 | train set. Default is None. 58 | validation (str): Suffix to add to path for the validation set, or None 59 | for no validation set. Default is None. 60 | test (str): Suffix to add to path for the test set, or None for no test 61 | set. Default is None. 62 | Remaining keyword arguments: Passed to the constructor of the 63 | Dataset (sub)class being used. 64 | 65 | Returns: 66 | split_datasets (tuple(Dataset)): Datasets for train, validation, and 67 | test splits in that order, if provided. 68 | """ 69 | if path is None: 70 | path = cls.download(root) 71 | train_data = None if train is None else cls( 72 | os.path.join(path, train), **kwargs) 73 | val_data = None if validation is None else cls( 74 | os.path.join(path, validation), **kwargs) 75 | test_data = None if test is None else cls( 76 | os.path.join(path, test), **kwargs) 77 | return tuple(d for d in (train_data, val_data, test_data) 78 | if d is not None) 79 | 80 | def __getitem__(self, i): 81 | return self.examples[i] 82 | 83 | def __len__(self): 84 | try: 85 | return len(self.examples) 86 | except TypeError: 87 | return 2**32 88 | 89 | def __iter__(self): 90 | for x in self.examples: 91 | yield x 92 | 93 | def __getattr__(self, attr): 94 | if attr in self.fields: 95 | for x in self.examples: 96 | yield getattr(x, attr) 97 | 98 | @classmethod 99 | def download(cls, root, check=None): 100 | """Download and unzip an online archive (.zip, .gz, or .tgz). 101 | 102 | Arguments: 103 | root (str): Folder to download data to. 104 | check (str or None): Folder whose existence indicates 105 | that the dataset has already been downloaded, or 106 | None to check the existence of root/{cls.name}. 107 | 108 | Returns: 109 | dataset_path (str): Path to extracted dataset. 110 | """ 111 | path = os.path.join(root, cls.name) 112 | check = path if check is None else check 113 | if not os.path.isdir(check): 114 | for url in cls.urls: 115 | if isinstance(url, tuple): 116 | url, filename = url 117 | else: 118 | filename = os.path.basename(url) 119 | zpath = os.path.join(path, filename) 120 | if not os.path.isfile(zpath): 121 | if not os.path.exists(os.path.dirname(zpath)): 122 | os.makedirs(os.path.dirname(zpath)) 123 | print('downloading {}'.format(filename)) 124 | download_from_url(url, zpath) 125 | ext = os.path.splitext(filename)[-1] 126 | if ext == '.zip': 127 | with zipfile.ZipFile(zpath, 'r') as zfile: 128 | print('extracting') 129 | zfile.extractall(path) 130 | elif ext in ['.gz', '.tgz']: 131 | with tarfile.open(zpath, 'r:gz') as tar: 132 | dirs = [member for member in tar.getmembers()] 133 | tar.extractall(path=path, members=dirs) 134 | elif ext in ['.bz2', '.tar']: 135 | with tarfile.open(zpath) as tar: 136 | dirs = [member for member in tar.getmembers()] 137 | tar.extractall(path=path, members=dirs) 138 | 139 | return os.path.join(path, cls.dirname) 140 | 141 | 142 | class TabularDataset(Dataset): 143 | """Defines a Dataset of columns stored in CSV, TSV, or JSON format.""" 144 | 145 | def __init__(self, path, format, fields, skip_header=False, subsample=False, **kwargs): 146 | """Create a TabularDataset given a path, file format, and field list. 147 | 148 | Arguments: 149 | path (str): Path to the data file. 150 | format (str): The format of the data file. One of "CSV", "TSV", or 151 | "JSON" (case-insensitive). 152 | fields (list(tuple(str, Field)) or dict[str: tuple(str, Field)]: For CSV and 153 | TSV formats, list of tuples of (name, field). The list should be in 154 | the same order as the columns in the CSV or TSV file, while tuples of 155 | (name, None) represent columns that will be ignored. For JSON format, 156 | dictionary whose keys are the JSON keys and whose values are tuples of 157 | (name, field). This allows the user to rename columns from their JSON key 158 | names and also enables selecting a subset of columns to load 159 | (since JSON keys not present in the input dictionary are ignored). 160 | skip_header (bool): Whether to skip the first line of the input file. 161 | """ 162 | make_example = { 163 | 'json': Example.fromJSON, 'dict': Example.fromdict, 164 | 'tsv': Example.fromTSV, 'csv': Example.fromCSV}[format.lower()] 165 | 166 | examples = [] 167 | with io.open(os.path.expanduser(path), encoding="utf8") as f: 168 | if skip_header: 169 | next(f) 170 | for line in f: 171 | examples.append(make_example(line, fields)) 172 | 173 | if make_example in (Example.fromdict, Example.fromJSON): 174 | fields, field_dict = [], fields 175 | for field in field_dict.values(): 176 | if isinstance(field, list): 177 | fields.extend(field) 178 | else: 179 | fields.append(field) 180 | 181 | super(TabularDataset, self).__init__(examples, fields, **kwargs) 182 | -------------------------------------------------------------------------------- /text/torchtext/data/example.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import sys 3 | import json 4 | 5 | import six 6 | 7 | 8 | def intern_strings(x): 9 | if isinstance(x, (list, tuple)): 10 | r = [] 11 | for y in x: 12 | if isinstance(y, str): 13 | r.append(sys.intern(y)) 14 | else: 15 | r.append(y) 16 | return r 17 | return x 18 | 19 | 20 | class Example(object): 21 | """Defines a single training or test example. 22 | 23 | Stores each column of the example as an attribute. 24 | """ 25 | 26 | @classmethod 27 | def fromJSON(cls, data, fields): 28 | return cls.fromdict(json.loads(data), fields) 29 | 30 | @classmethod 31 | def fromdict(cls, data, fields): 32 | ex = cls() 33 | for key, vals in fields.items(): 34 | if key not in data: 35 | raise ValueError("Specified key {} was not found in " 36 | "the input data".format(key)) 37 | if vals is not None: 38 | if not isinstance(vals, list): 39 | vals = [vals] 40 | for val in vals: 41 | name, field = val 42 | setattr(ex, name, intern_strings(field.preprocess(data[key]))) 43 | return ex 44 | 45 | @classmethod 46 | def fromTSV(cls, data, fields): 47 | return cls.fromlist(data.split('\t'), fields) 48 | 49 | @classmethod 50 | def fromCSV(cls, data, fields): 51 | data = data.rstrip("\n") 52 | # If Python 2, encode to utf-8 since CSV doesn't take unicode input 53 | if six.PY2: 54 | data = data.encode('utf-8') 55 | # Use Python CSV module to parse the CSV line 56 | parsed_csv_lines = csv.reader([data]) 57 | 58 | # If Python 2, decode back to unicode (the original input format). 59 | if six.PY2: 60 | for line in parsed_csv_lines: 61 | parsed_csv_line = [six.text_type(col, 'utf-8') for col in line] 62 | break 63 | else: 64 | parsed_csv_line = list(parsed_csv_lines)[0] 65 | return cls.fromlist(parsed_csv_line, fields) 66 | 67 | @classmethod 68 | def fromlist(cls, data, fields): 69 | ex = cls() 70 | for (name, field), val in zip(fields, data): 71 | if field is not None: 72 | if isinstance(val, six.string_types): 73 | val = val.rstrip('\n') 74 | setattr(ex, name, intern_strings(field.preprocess(val))) 75 | return ex 76 | 77 | @classmethod 78 | def fromtree(cls, data, fields, subtrees=False): 79 | try: 80 | from nltk.tree import Tree 81 | except ImportError: 82 | print("Please install NLTK. " 83 | "See the docs at http://nltk.org for more information.") 84 | raise 85 | tree = Tree.fromstring(data) 86 | if subtrees: 87 | return [cls.fromlist( 88 | [' '.join(t.leaves()), t.label()], fields) for t in tree.subtrees()] 89 | return cls.fromlist([' '.join(tree.leaves()), tree.label()], fields) 90 | -------------------------------------------------------------------------------- /text/torchtext/data/pipeline.py: -------------------------------------------------------------------------------- 1 | class Pipeline(object): 2 | """Defines a pipeline for transforming sequence data. 3 | 4 | The input is assumed to be utf-8 encoded `str` (Python 3) or 5 | `unicode` (Python 2). 6 | 7 | Attributes: 8 | convert_token: The function to apply to input sequence data. 9 | pipes: The Pipelines that will be applid to input sequence 10 | data in order. 11 | """ 12 | def __init__(self, convert_token=None): 13 | """Create a pipeline. 14 | 15 | Arguments: 16 | convert_token: The function to apply to input sequence data. 17 | If None, the identity function is used. Default: None 18 | """ 19 | if convert_token is None: 20 | self.convert_token = Pipeline.identity 21 | elif callable(convert_token): 22 | self.convert_token = convert_token 23 | else: 24 | raise ValueError("Pipeline input convert_token {} is not None " 25 | "or callable".format(convert_token)) 26 | self.pipes = [self] 27 | 28 | def __call__(self, x, *args): 29 | """Apply the the current Pipeline(s) to an input. 30 | 31 | Arguments: 32 | x: The input to process with the Pipeline(s). 33 | Positional arguments: Forwarded to the `call` function 34 | of the Pipeline(s). 35 | """ 36 | for pipe in self.pipes: 37 | x = pipe.call(x, *args) 38 | return x 39 | 40 | def call(self, x, *args): 41 | """Apply _only_ the convert_token function of the current pipeline 42 | to the input. If the input is a list, a list with the results of 43 | applying the `convert_token` function to all input elements is 44 | returned. 45 | 46 | Arguments: 47 | x: The input to apply the convert_token function to. 48 | Positional arguments: Forwarded to the `convert_token` function 49 | of the current Pipeline. 50 | """ 51 | if isinstance(x, list): 52 | return [self.convert_token(tok, *args) for tok in x] 53 | return self.convert_token(x, *args) 54 | 55 | def add_before(self, pipeline): 56 | """Add a Pipeline to be applied before this processing pipeline. 57 | 58 | Arguments: 59 | pipeline: The Pipeline or callable to apply before this 60 | Pipeline. 61 | """ 62 | if not isinstance(pipeline, Pipeline): 63 | pipeline = Pipeline(pipeline) 64 | self.pipes = pipeline.pipes[:] + self.pipes[:] 65 | return self 66 | 67 | def add_after(self, pipeline): 68 | """Add a Pipeline to be applied after this processing pipeline. 69 | 70 | Arguments: 71 | pipeline: The Pipeline or callable to apply after this 72 | Pipeline. 73 | """ 74 | if not isinstance(pipeline, Pipeline): 75 | pipeline = Pipeline(pipeline) 76 | self.pipes = self.pipes[:] + pipeline.pipes[:] 77 | return self 78 | 79 | @staticmethod 80 | def identity(x): 81 | """Return a copy of the input. 82 | 83 | This is here for serialization compatibility with pickle. 84 | """ 85 | return x 86 | -------------------------------------------------------------------------------- /text/torchtext/data/utils.py: -------------------------------------------------------------------------------- 1 | def get_tokenizer(tokenizer, decap=False): 2 | if callable(tokenizer): 3 | return tokenizer 4 | if tokenizer == "spacy": 5 | try: 6 | import spacy 7 | spacy_en = spacy.load('en') 8 | return lambda s: [tok.text for tok in spacy_en.tokenizer(s)] 9 | except ImportError: 10 | print("Please install SpaCy and the SpaCy English tokenizer. " 11 | "See the docs at https://spacy.io for more information.") 12 | raise 13 | except AttributeError: 14 | print("Please install SpaCy and the SpaCy English tokenizer. " 15 | "See the docs at https://spacy.io for more information.") 16 | raise 17 | elif tokenizer == "moses": 18 | try: 19 | from nltk.tokenize.moses import MosesTokenizer 20 | moses_tokenizer = MosesTokenizer() 21 | return moses_tokenizer.tokenize 22 | except ImportError: 23 | print("Please install NLTK. " 24 | "See the docs at http://nltk.org for more information.") 25 | raise 26 | except LookupError: 27 | print("Please install the necessary NLTK corpora. " 28 | "See the docs at http://nltk.org for more information.") 29 | raise 30 | elif tokenizer == 'revtok': 31 | try: 32 | import revtok 33 | return revtok.tokenize 34 | except ImportError: 35 | print("Please install revtok.") 36 | raise 37 | elif tokenizer == 'subword': 38 | try: 39 | import revtok 40 | return revtok.tokenize 41 | except ImportError: 42 | print("Please install revtok.") 43 | raise 44 | raise ValueError("Requested tokenizer {}, valid choices are a " 45 | "callable that takes a single string as input, " 46 | "\"revtok\" for the revtok reversible tokenizer, " 47 | "\"subword\" for the revtok caps-aware tokenizer, " 48 | "\"spacy\" for the SpaCy English tokenizer, or " 49 | "\"moses\" for the NLTK port of the Moses tokenization " 50 | "script.".format(tokenizer)) 51 | 52 | 53 | def interleave_keys(a, b): 54 | """Interleave bits from two sort keys to form a joint sort key. 55 | 56 | Examples that are similar in both of the provided keys will have similar 57 | values for the key defined by this function. Useful for tasks with two 58 | text fields like machine translation or natural language inference. 59 | """ 60 | def interleave(args): 61 | return ''.join([x for t in zip(*args) for x in t]) 62 | return int(''.join(interleave(format(x, '016b') for x in (a, b))), base=2) 63 | -------------------------------------------------------------------------------- /text/torchtext/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_modeling import LanguageModelingDataset, WikiText2 # NOQA 2 | from .snli import SNLI 3 | from .sst import SST 4 | from .translation import TranslationDataset, Multi30k, IWSLT, WMT14 # NOQA 5 | from .sequence_tagging import SequenceTaggingDataset, UDPOS # NOQA 6 | from .trec import TREC 7 | from .imdb import IMDb 8 | from . import generic 9 | 10 | 11 | __all__ = ['LanguageModelingDataset', 12 | 'SNLI', 13 | 'SST', 14 | 'TranslationDataset', 15 | 'Multi30k', 16 | 'IWSLT', 17 | 'WMT14' 18 | 'WikiText2', 19 | 'TREC', 20 | 'IMDb', 21 | 'SequenceTaggingDataset', 22 | 'UDPOS', 23 | ] 24 | -------------------------------------------------------------------------------- /text/torchtext/datasets/imdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | from .. import data 5 | 6 | 7 | class IMDb(data.Dataset): 8 | 9 | urls = ['http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz'] 10 | name = 'imdb' 11 | dirname = 'aclImdb' 12 | 13 | @staticmethod 14 | def sort_key(ex): 15 | return len(ex.text) 16 | 17 | def __init__(self, path, text_field, label_field, **kwargs): 18 | """Create an IMDB dataset instance given a path and fields. 19 | 20 | Arguments: 21 | path: Path to the dataset's highest level directory 22 | text_field: The field that will be used for text data. 23 | label_field: The field that will be used for label data. 24 | Remaining keyword arguments: Passed to the constructor of 25 | data.Dataset. 26 | """ 27 | fields = [('text', text_field), ('label', label_field)] 28 | examples = [] 29 | 30 | for label in ['pos', 'neg']: 31 | for fname in glob.iglob(os.path.join(path, label, '*.txt')): 32 | with open(fname, 'r') as f: 33 | text = f.readline() 34 | examples.append(data.Example.fromlist([text, label], fields)) 35 | 36 | super(IMDb, self).__init__(examples, fields, **kwargs) 37 | 38 | @classmethod 39 | def splits(cls, text_field, label_field, root='.data', 40 | train='train', test='test', **kwargs): 41 | """Create dataset objects for splits of the IMDB dataset. 42 | 43 | Arguments: 44 | text_field: The field that will be used for the sentence. 45 | label_field: The field that will be used for label data. 46 | root: Root dataset storage directory. Default is '.data'. 47 | train: The directory that contains the training examples 48 | test: The directory that contains the test examples 49 | Remaining keyword arguments: Passed to the splits method of 50 | Dataset. 51 | """ 52 | return super(IMDb, cls).splits( 53 | root=root, text_field=text_field, label_field=label_field, 54 | train=train, validation=None, test=test, **kwargs) 55 | 56 | @classmethod 57 | def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs): 58 | """Creater iterator objects for splits of the IMDB dataset. 59 | 60 | Arguments: 61 | batch_size: Batch_size 62 | device: Device to create batches on. Use - 1 for CPU and None for 63 | the currently active GPU device. 64 | root: The root directory that contains the imdb dataset subdirectory 65 | vectors: one of the available pretrained vectors or a list with each 66 | element one of the available pretrained vectors (see Vocab.load_vectors) 67 | 68 | Remaining keyword arguments: Passed to the splits method. 69 | """ 70 | TEXT = data.Field() 71 | LABEL = data.Field(sequential=False) 72 | 73 | train, test = cls.splits(TEXT, LABEL, root=root, **kwargs) 74 | 75 | TEXT.build_vocab(train, vectors=vectors) 76 | LABEL.build_vocab(train) 77 | 78 | return data.BucketIterator.splits( 79 | (train, test), batch_size=batch_size, device=device) 80 | -------------------------------------------------------------------------------- /text/torchtext/datasets/language_modeling.py: -------------------------------------------------------------------------------- 1 | from .. import data 2 | 3 | 4 | class LanguageModelingDataset(data.Dataset): 5 | """Defines a dataset for language modeling.""" 6 | 7 | def __init__(self, path, text_field, newline_eos=True, **kwargs): 8 | """Create a LanguageModelingDataset given a path and a field. 9 | 10 | Arguments: 11 | path: Path to the data file. 12 | text_field: The field that will be used for text data. 13 | newline_eos: Whether to add an token for every newline in the 14 | data file. Default: True. 15 | Remaining keyword arguments: Passed to the constructor of 16 | data.Dataset. 17 | """ 18 | fields = [('text', text_field)] 19 | text = [] 20 | with open(path) as f: 21 | for line in f: 22 | text += text_field.preprocess(line) 23 | if newline_eos: 24 | text.append('') 25 | 26 | examples = [data.Example.fromlist([text], fields)] 27 | super(LanguageModelingDataset, self).__init__( 28 | examples, fields, **kwargs) 29 | 30 | 31 | class WikiText2(LanguageModelingDataset): 32 | 33 | urls = ['https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'] 34 | name = 'wikitext-2' 35 | dirname = 'wikitext-2' 36 | 37 | @classmethod 38 | def splits(cls, text_field, root='.data', train='wiki.train.tokens', 39 | validation='wiki.valid.tokens', test='wiki.test.tokens', 40 | **kwargs): 41 | """Create dataset objects for splits of the WikiText-2 dataset. 42 | 43 | This is the most flexible way to use the dataset. 44 | 45 | Arguments: 46 | text_field: The field that will be used for text data. 47 | root: The root directory that the dataset's zip archive will be 48 | expanded into; therefore the directory in whose wikitext-2 49 | subdirectory the data files will be stored. 50 | train: The filename of the train data. Default: 'wiki.train.tokens'. 51 | validation: The filename of the validation data, or None to not 52 | load the validation set. Default: 'wiki.valid.tokens'. 53 | test: The filename of the test data, or None to not load the test 54 | set. Default: 'wiki.test.tokens'. 55 | """ 56 | return super(WikiText2, cls).splits( 57 | root=root, train=train, validation=validation, test=test, 58 | text_field=text_field, **kwargs) 59 | 60 | @classmethod 61 | def iters(cls, batch_size=32, bptt_len=35, device=0, root='.data', 62 | vectors=None, **kwargs): 63 | """Create iterator objects for splits of the WikiText-2 dataset. 64 | 65 | This is the simplest way to use the dataset, and assumes common 66 | defaults for field, vocabulary, and iterator parameters. 67 | 68 | Arguments: 69 | batch_size: Batch size. 70 | bptt_len: Length of sequences for backpropagation through time. 71 | device: Device to create batches on. Use -1 for CPU and None for 72 | the currently active GPU device. 73 | root: The root directory that the dataset's zip archive will be 74 | expanded into; therefore the directory in whose wikitext-2 75 | subdirectory the data files will be stored. 76 | wv_dir, wv_type, wv_dim: Passed to the Vocab constructor for the 77 | text field. The word vectors are accessible as 78 | train.dataset.fields['text'].vocab.vectors. 79 | Remaining keyword arguments: Passed to the splits method. 80 | """ 81 | TEXT = data.Field() 82 | 83 | train, val, test = cls.splits(TEXT, root=root, **kwargs) 84 | 85 | TEXT.build_vocab(train, vectors=vectors) 86 | 87 | return data.BPTTIterator.splits( 88 | (train, val, test), batch_size=batch_size, bptt_len=bptt_len, 89 | device=device) 90 | -------------------------------------------------------------------------------- /text/torchtext/datasets/sequence_tagging.py: -------------------------------------------------------------------------------- 1 | from .. import data 2 | 3 | 4 | class SequenceTaggingDataset(data.Dataset): 5 | """Defines a dataset for sequence tagging. Examples in this dataset 6 | contain paired lists -- paired list of words and tags. 7 | 8 | For example, in the case of part-of-speech tagging, an example is of the 9 | form 10 | [I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT] 11 | 12 | See torchtext/test/sequence_tagging.py on how to use this class. 13 | """ 14 | 15 | @staticmethod 16 | def sort_key(example): 17 | for attr in dir(example): 18 | if not callable(getattr(example, attr)) and \ 19 | not attr.startswith("__"): 20 | return len(getattr(example, attr)) 21 | return 0 22 | 23 | def __init__(self, path, fields, **kwargs): 24 | examples = [] 25 | columns = [] 26 | 27 | with open(path) as input_file: 28 | for line in input_file: 29 | line = line.strip() 30 | if line == "": 31 | if columns: 32 | examples.append(data.Example.fromlist(columns, fields)) 33 | columns = [] 34 | else: 35 | for i, column in enumerate(line.split("\t")): 36 | if len(columns) < i + 1: 37 | columns.append([]) 38 | columns[i].append(column) 39 | 40 | if columns: 41 | examples.append(data.Example.fromlist(columns, fields)) 42 | super(SequenceTaggingDataset, self).__init__(examples, fields, 43 | **kwargs) 44 | 45 | 46 | class UDPOS(SequenceTaggingDataset): 47 | 48 | # Universal Dependencies English Web Treebank. 49 | # Download original at http://universaldependencies.org/ 50 | # License: http://creativecommons.org/licenses/by-sa/4.0/ 51 | urls = ['https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip'] 52 | dirname = 'en-ud-v2' 53 | name = 'udpos' 54 | 55 | @classmethod 56 | def splits(cls, fields, root=".data", train="en-ud-tag.v2.train.txt", 57 | validation="en-ud-tag.v2.dev.txt", 58 | test="en-ud-tag.v2.test.txt", **kwargs): 59 | """Downloads and loads the Universal Dependencies Version 2 POS Tagged 60 | data. 61 | """ 62 | 63 | return super(UDPOS, cls).splits( 64 | fields=fields, root=root, train=train, validation=validation, 65 | test=test, **kwargs) 66 | -------------------------------------------------------------------------------- /text/torchtext/datasets/snli.py: -------------------------------------------------------------------------------- 1 | from .. import data 2 | 3 | 4 | class ShiftReduceField(data.Field): 5 | 6 | def __init__(self): 7 | 8 | super(ShiftReduceField, self).__init__(preprocessing=lambda parse: [ 9 | 'reduce' if t == ')' else 'shift' for t in parse if t != '(']) 10 | 11 | self.build_vocab([['reduce'], ['shift']]) 12 | 13 | 14 | class ParsedTextField(data.Field): 15 | 16 | def __init__(self, eos_token='', lower=False): 17 | 18 | super(ParsedTextField, self).__init__( 19 | eos_token=eos_token, lower=lower, preprocessing=lambda parse: [ 20 | t for t in parse if t not in ('(', ')')], 21 | postprocessing=lambda parse, _, __: [ 22 | list(reversed(p)) for p in parse]) 23 | 24 | 25 | class SNLI(data.TabularDataset): 26 | 27 | urls = ['http://nlp.stanford.edu/projects/snli/snli_1.0.zip'] 28 | dirname = 'snli_1.0' 29 | name = 'snli' 30 | 31 | @staticmethod 32 | def sort_key(ex): 33 | return data.interleave_keys( 34 | len(ex.premise), len(ex.hypothesis)) 35 | 36 | @classmethod 37 | def splits(cls, text_field, label_field, parse_field=None, root='.data', 38 | train='snli_1.0_train.jsonl', validation='snli_1.0_dev.jsonl', 39 | test='snli_1.0_test.jsonl'): 40 | """Create dataset objects for splits of the SNLI dataset. 41 | 42 | This is the most flexible way to use the dataset. 43 | 44 | Arguments: 45 | text_field: The field that will be used for premise and hypothesis 46 | data. 47 | label_field: The field that will be used for label data. 48 | parse_field: The field that will be used for shift-reduce parser 49 | transitions, or None to not include them. 50 | root: The root directory that the dataset's zip archive will be 51 | expanded into; therefore the directory in whose snli_1.0 52 | subdirectory the data files will be stored. 53 | train: The filename of the train data. Default: 'train.jsonl'. 54 | validation: The filename of the validation data, or None to not 55 | load the validation set. Default: 'dev.jsonl'. 56 | test: The filename of the test data, or None to not load the test 57 | set. Default: 'test.jsonl'. 58 | """ 59 | path = cls.download(root) 60 | 61 | if parse_field is None: 62 | return super(SNLI, cls).splits( 63 | path, root, train, validation, test, 64 | format='json', fields={'sentence1': ('premise', text_field), 65 | 'sentence2': ('hypothesis', text_field), 66 | 'gold_label': ('label', label_field)}, 67 | filter_pred=lambda ex: ex.label != '-') 68 | return super(SNLI, cls).splits( 69 | path, root, train, validation, test, 70 | format='json', fields={'sentence1_binary_parse': 71 | [('premise', text_field), 72 | ('premise_transitions', parse_field)], 73 | 'sentence2_binary_parse': 74 | [('hypothesis', text_field), 75 | ('hypothesis_transitions', parse_field)], 76 | 'gold_label': ('label', label_field)}, 77 | filter_pred=lambda ex: ex.label != '-') 78 | 79 | @classmethod 80 | def iters(cls, batch_size=32, device=0, root='.data', 81 | vectors=None, trees=False, **kwargs): 82 | """Create iterator objects for splits of the SNLI dataset. 83 | 84 | This is the simplest way to use the dataset, and assumes common 85 | defaults for field, vocabulary, and iterator parameters. 86 | 87 | Arguments: 88 | batch_size: Batch size. 89 | device: Device to create batches on. Use -1 for CPU and None for 90 | the currently active GPU device. 91 | root: The root directory that the dataset's zip archive will be 92 | expanded into; therefore the directory in whose wikitext-2 93 | subdirectory the data files will be stored. 94 | vectors: one of the available pretrained vectors or a list with each 95 | element one of the available pretrained vectors (see Vocab.load_vectors) 96 | trees: Whether to include shift-reduce parser transitions. 97 | Default: False. 98 | Remaining keyword arguments: Passed to the splits method. 99 | """ 100 | if trees: 101 | TEXT = ParsedTextField() 102 | TRANSITIONS = ShiftReduceField() 103 | else: 104 | TEXT = data.Field(tokenize='spacy') 105 | TRANSITIONS = None 106 | LABEL = data.Field(sequential=False) 107 | 108 | train, val, test = cls.splits( 109 | TEXT, LABEL, TRANSITIONS, root=root, **kwargs) 110 | 111 | TEXT.build_vocab(train, vectors=vectors) 112 | LABEL.build_vocab(train) 113 | 114 | return data.BucketIterator.splits( 115 | (train, val, test), batch_size=batch_size, device=device) 116 | -------------------------------------------------------------------------------- /text/torchtext/datasets/sst.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from .. import data 4 | 5 | 6 | class SST(data.Dataset): 7 | 8 | urls = ['http://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip'] 9 | dirname = 'trees' 10 | name = 'sst' 11 | 12 | @staticmethod 13 | def sort_key(ex): 14 | return len(ex.text) 15 | 16 | def __init__(self, path, text_field, label_field, subtrees=False, 17 | fine_grained=False, **kwargs): 18 | """Create an SST dataset instance given a path and fields. 19 | 20 | Arguments: 21 | path: Path to the data file 22 | text_field: The field that will be used for text data. 23 | label_field: The field that will be used for label data. 24 | subtrees: Whether to include sentiment-tagged subphrases 25 | in addition to complete examples. Default: False. 26 | fine_grained: Whether to use 5-class instead of 3-class 27 | labeling. Default: False. 28 | Remaining keyword arguments: Passed to the constructor of 29 | data.Dataset. 30 | """ 31 | fields = [('text', text_field), ('label', label_field)] 32 | 33 | def get_label_str(label): 34 | pre = 'very ' if fine_grained else '' 35 | return {'0': pre + 'negative', '1': 'negative', '2': 'neutral', 36 | '3': 'positive', '4': pre + 'positive', None: None}[label] 37 | label_field.preprocessing = data.Pipeline(get_label_str) 38 | with open(os.path.expanduser(path)) as f: 39 | if subtrees: 40 | examples = [ex for line in f for ex in 41 | data.Example.fromtree(line, fields, True)] 42 | else: 43 | examples = [data.Example.fromtree(line, fields) for line in f] 44 | super(SST, self).__init__(examples, fields, **kwargs) 45 | 46 | @classmethod 47 | def splits(cls, text_field, label_field, root='.data', 48 | train='train.txt', validation='dev.txt', test='test.txt', 49 | train_subtrees=False, **kwargs): 50 | """Create dataset objects for splits of the SST dataset. 51 | 52 | Arguments: 53 | text_field: The field that will be used for the sentence. 54 | label_field: The field that will be used for label data. 55 | root: The root directory that the dataset's zip archive will be 56 | expanded into; therefore the directory in whose trees 57 | subdirectory the data files will be stored. 58 | train: The filename of the train data. Default: 'train.txt'. 59 | validation: The filename of the validation data, or None to not 60 | load the validation set. Default: 'dev.txt'. 61 | test: The filename of the test data, or None to not load the test 62 | set. Default: 'test.txt'. 63 | train_subtrees: Whether to use all subtrees in the training set. 64 | Default: False. 65 | Remaining keyword arguments: Passed to the splits method of 66 | Dataset. 67 | """ 68 | path = cls.download(root) 69 | 70 | train_data = None if train is None else cls( 71 | os.path.join(path, train), text_field, label_field, subtrees=train_subtrees, 72 | **kwargs) 73 | val_data = None if validation is None else cls( 74 | os.path.join(path, validation), text_field, label_field, **kwargs) 75 | test_data = None if test is None else cls( 76 | os.path.join(path, test), text_field, label_field, **kwargs) 77 | return tuple(d for d in (train_data, val_data, test_data) 78 | if d is not None) 79 | 80 | @classmethod 81 | def iters(cls, batch_size=32, device=0, root='.data', vectors=None, **kwargs): 82 | """Creater iterator objects for splits of the SST dataset. 83 | 84 | Arguments: 85 | batch_size: Batch_size 86 | device: Device to create batches on. Use - 1 for CPU and None for 87 | the currently active GPU device. 88 | root: The root directory that the dataset's zip archive will be 89 | expanded into; therefore the directory in whose trees 90 | subdirectory the data files will be stored. 91 | vectors: one of the available pretrained vectors or a list with each 92 | element one of the available pretrained vectors (see Vocab.load_vectors) 93 | Remaining keyword arguments: Passed to the splits method. 94 | """ 95 | TEXT = data.Field() 96 | LABEL = data.Field(sequential=False) 97 | 98 | train, val, test = cls.splits(TEXT, LABEL, root=root, **kwargs) 99 | 100 | TEXT.build_vocab(train, vectors=vectors) 101 | LABEL.build_vocab(train) 102 | 103 | return data.BucketIterator.splits( 104 | (train, val, test), batch_size=batch_size, device=device) 105 | -------------------------------------------------------------------------------- /text/torchtext/datasets/translation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import xml.etree.ElementTree as ET 3 | import glob 4 | import io 5 | 6 | from .. import data 7 | 8 | 9 | class TranslationDataset(data.Dataset): 10 | """Defines a dataset for machine translation.""" 11 | 12 | @staticmethod 13 | def sort_key(ex): 14 | return data.interleave_keys(len(ex.src), len(ex.trg)) 15 | 16 | def __init__(self, path, exts, fields, **kwargs): 17 | """Create a TranslationDataset given paths and fields. 18 | 19 | Arguments: 20 | path: Common prefix of paths to the data files for both languages. 21 | exts: A tuple containing the extension to path for each language. 22 | fields: A tuple containing the fields that will be used for data 23 | in each language. 24 | Remaining keyword arguments: Passed to the constructor of 25 | data.Dataset. 26 | """ 27 | if not isinstance(fields[0], (tuple, list)): 28 | fields = [('src', fields[0]), ('trg', fields[1])] 29 | 30 | src_path, trg_path = tuple(os.path.expanduser(path + x) for x in exts) 31 | 32 | examples = [] 33 | with open(src_path) as src_file, open(trg_path) as trg_file: 34 | for src_line, trg_line in zip(src_file, trg_file): 35 | src_line, trg_line = src_line.strip(), trg_line.strip() 36 | if src_line != '' and trg_line != '': 37 | examples.append(data.Example.fromlist( 38 | [src_line, trg_line], fields)) 39 | 40 | super(TranslationDataset, self).__init__(examples, fields, **kwargs) 41 | 42 | @classmethod 43 | def splits(cls, exts, fields, root='.data', 44 | train='train', validation='val', test='test', **kwargs): 45 | """Create dataset objects for splits of a TranslationDataset. 46 | 47 | Arguments: 48 | 49 | root: Root dataset storage directory. Default is '.data'. 50 | exts: A tuple containing the extension to path for each language. 51 | fields: A tuple containing the fields that will be used for data 52 | in each language. 53 | train: The prefix of the train data. Default: 'train'. 54 | validation: The prefix of the validation data. Default: 'val'. 55 | test: The prefix of the test data. Default: 'test'. 56 | Remaining keyword arguments: Passed to the splits method of 57 | Dataset. 58 | """ 59 | path = cls.download(root) 60 | 61 | train_data = None if train is None else cls( 62 | os.path.join(path, train), exts, fields, **kwargs) 63 | val_data = None if validation is None else cls( 64 | os.path.join(path, validation), exts, fields, **kwargs) 65 | test_data = None if test is None else cls( 66 | os.path.join(path, test), exts, fields, **kwargs) 67 | return tuple(d for d in (train_data, val_data, test_data) 68 | if d is not None) 69 | 70 | 71 | class Multi30k(TranslationDataset): 72 | """The small-dataset WMT 2016 multimodal task, also known as Flickr30k""" 73 | 74 | urls = ['http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/training.tar.gz', 75 | 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz', 76 | 'http://www.quest.dcs.shef.ac.uk/' 77 | 'wmt17_files_mmt/mmt_task1_test2016.tar.gz'] 78 | name = 'multi30k' 79 | dirname = '' 80 | 81 | @classmethod 82 | def splits(cls, exts, fields, root='.data', 83 | train='train', validation='val', test='test2016', **kwargs): 84 | """Create dataset objects for splits of the Multi30k dataset. 85 | 86 | Arguments: 87 | 88 | root: Root dataset storage directory. Default is '.data'. 89 | exts: A tuple containing the extension to path for each language. 90 | fields: A tuple containing the fields that will be used for data 91 | in each language. 92 | train: The prefix of the train data. Default: 'train'. 93 | validation: The prefix of the validation data. Default: 'val'. 94 | test: The prefix of the test data. Default: 'test'. 95 | Remaining keyword arguments: Passed to the splits method of 96 | Dataset. 97 | """ 98 | return super(Multi30k, cls).splits( 99 | exts, fields, root, train, validation, test, **kwargs) 100 | 101 | 102 | class IWSLT(TranslationDataset): 103 | """The IWSLT 2016 TED talk translation task""" 104 | 105 | base_url = 'https://wit3.fbk.eu/archive/2016-01//texts/{}/{}/{}.tgz' 106 | name = 'iwslt' 107 | base_dirname = '{}-{}' 108 | 109 | @classmethod 110 | def splits(cls, exts, fields, root='.data', 111 | train='train', validation='IWSLT16.TED.tst2013', 112 | test='IWSLT16.TED.tst2014', **kwargs): 113 | """Create dataset objects for splits of the IWSLT dataset. 114 | 115 | Arguments: 116 | 117 | root: Root dataset storage directory. Default is '.data'. 118 | exts: A tuple containing the extension to path for each language. 119 | fields: A tuple containing the fields that will be used for data 120 | in each language. 121 | train: The prefix of the train data. Default: 'train'. 122 | validation: The prefix of the validation data. Default: 'val'. 123 | test: The prefix of the test data. Default: 'test'. 124 | Remaining keyword arguments: Passed to the splits method of 125 | Dataset. 126 | """ 127 | cls.dirname = cls.base_dirname.format(exts[0][1:], exts[1][1:]) 128 | cls.urls = [cls.base_url.format(exts[0][1:], exts[1][1:], cls.dirname)] 129 | check = os.path.join(root, cls.name, cls.dirname) 130 | path = cls.download(root, check=check) 131 | 132 | if train is not None: 133 | train = '.'.join([train, cls.dirname]) 134 | if validation is not None: 135 | validation = '.'.join([validation, cls.dirname]) 136 | if test is not None: 137 | test = '.'.join([test, cls.dirname]) 138 | 139 | if not os.path.exists(os.path.join(path, '.'.join(['train', cls.dirname])) + exts[0]): 140 | cls.clean(path) 141 | 142 | train_data = None if train is None else cls( 143 | os.path.join(path, train), exts, fields, **kwargs) 144 | val_data = None if validation is None else cls( 145 | os.path.join(path, validation), exts, fields, **kwargs) 146 | test_data = None if test is None else cls( 147 | os.path.join(path, test), exts, fields, **kwargs) 148 | return tuple(d for d in (train_data, val_data, test_data) 149 | if d is not None) 150 | 151 | @staticmethod 152 | def clean(path): 153 | for f_xml in glob.iglob(os.path.join(path, '*.xml')): 154 | print(f_xml) 155 | f_txt = os.path.splitext(f_xml)[0] 156 | with io.open(f_txt, mode='w', encoding='utf-8') as fd_txt: 157 | root = ET.parse(f_xml).getroot()[0] 158 | for doc in root.findall('doc'): 159 | for e in doc.findall('seg'): 160 | fd_txt.write(e.text.strip() + '\n') 161 | 162 | xml_tags = ['args.max_answer_length or 18 | len(ex.context)>max_context_length) 19 | is_too_short = lambda ex: (len(ex.answer) {len(s.examples)}') 35 | 36 | l = len(s.examples) 37 | s.examples = [ex for ex in s.examples if not is_too_short(ex)] 38 | if len(s.examples) < l: 39 | if logger is not None: 40 | logger.info(f'Filtering out short {task} examples: {l} -> {len(s.examples)}') 41 | 42 | l = len(s.examples) 43 | s.examples = [ex for ex in s.examples if 'This page includes the show' not in ex.answer] 44 | if len(s.examples) < l: 45 | if logger is not None: 46 | logger.info(f'Filtering {task} examples with a dummy summary: {l} -> {len(s.examples)} ') 47 | 48 | if logger is not None: 49 | context_lengths = [len(ex.context) for ex in s.examples] 50 | question_lengths = [len(ex.question) for ex in s.examples] 51 | answer_lengths = [len(ex.answer) for ex in s.examples] 52 | 53 | logger.info(f'{task} context lengths (min, mean, max): {np.min(context_lengths)}, {int(np.mean(context_lengths))}, {np.max(context_lengths)}') 54 | logger.info(f'{task} question lengths (min, mean, max): {np.min(question_lengths)}, {int(np.mean(question_lengths))}, {np.max(question_lengths)}') 55 | logger.info(f'{task} answer lengths (min, mean, max): {np.min(answer_lengths)}, {int(np.mean(answer_lengths))}, {np.max(answer_lengths)}') 56 | 57 | for x in s.examples: 58 | x.context_question = get_context_question(x, x.context, x.question, field) 59 | 60 | if logger is not None: 61 | logger.info('Tokenized examples:') 62 | for ex in s.examples[:10]: 63 | logger.info('Context: ' + ' '.join(ex.context)) 64 | logger.info('Question: ' + ' '.join(ex.question)) 65 | logger.info(' '.join(ex.context_question)) 66 | logger.info('Answer: ' + ' '.join(ex.answer)) 67 | 68 | 69 | 70 | def set_seed(args, rank=None): 71 | if rank is None and len(args.devices) > 0: 72 | ordinal = args.devices[0] 73 | else: 74 | ordinal = args.devices[rank] 75 | device = torch.device(f'cuda:{ordinal}' if ordinal > -1 else 'cpu') 76 | print(f'device: {device}') 77 | np.random.seed(args.seed) 78 | random.seed(args.seed) 79 | torch.manual_seed(args.seed) 80 | with torch.cuda.device(ordinal): 81 | torch.cuda.manual_seed(args.seed) 82 | return device 83 | 84 | 85 | def count_params(params): 86 | def mult(ps): 87 | r = 0 88 | for p in ps: 89 | this_r = 1 90 | for s in p.size(): 91 | this_r *= s 92 | r += this_r 93 | return r 94 | return mult(params) 95 | 96 | 97 | def get_trainable_params(model): 98 | return list(filter(lambda p: p.requires_grad, model.parameters())) 99 | 100 | 101 | def elapsed_time(log): 102 | t = time.time() - log.start 103 | day = int(t // (24 * 3600)) 104 | t = t % (24 * 3600) 105 | hour = int(t // 3600) 106 | t %= 3600 107 | minutes = int(t // 60) 108 | t %= 60 109 | seconds = int(t) 110 | return f'{day:02}:{hour:02}:{minutes:02}:{seconds:02}' 111 | 112 | 113 | def get_splits(args, task, FIELD, **kwargs): 114 | if 'multi30k' in task: 115 | src, trg = ['.'+x for x in task.split('.')[1:]] 116 | split = torchtext.datasets.generic.Multi30k.splits(exts=(src, trg), 117 | fields=FIELD, root=args.data, **kwargs) 118 | elif 'iwslt' in task: 119 | src, trg = ['.'+x for x in task.split('.')[1:]] 120 | split = torchtext.datasets.generic.IWSLT.splits(exts=(src, trg), 121 | fields=FIELD, root=args.data, **kwargs) 122 | elif 'squad' in task: 123 | split = torchtext.datasets.generic.SQuAD.splits( 124 | fields=FIELD, root=args.data, description=task, **kwargs) 125 | elif 'wikisql' in task: 126 | split = torchtext.datasets.generic.WikiSQL.splits( 127 | fields=FIELD, root=args.data, query_as_question='query_as_question' in task, **kwargs) 128 | elif 'ontonotes.ner' in task: 129 | split_task = task.split('.') 130 | _, _, subtask, nones, counting = split_task 131 | split = torchtext.datasets.generic.OntoNotesNER.splits( 132 | subtask=subtask, nones=True if nones == 'nones' else False, 133 | fields=FIELD, root=args.data, **kwargs) 134 | elif 'woz' in task: 135 | split = torchtext.datasets.generic.WOZ.splits(description=task, 136 | fields=FIELD, root=args.data, **kwargs) 137 | elif 'multinli' in task: 138 | split = torchtext.datasets.generic.MultiNLI.splits(description=task, 139 | fields=FIELD, root=args.data, **kwargs) 140 | elif 'srl' in task: 141 | split = torchtext.datasets.generic.SRL.splits( 142 | fields=FIELD, root=args.data, **kwargs) 143 | elif 'snli' in task: 144 | split = torchtext.datasets.generic.SNLI.splits( 145 | fields=FIELD, root=args.data, **kwargs) 146 | elif 'schema' in task: 147 | split = torchtext.datasets.generic.WinogradSchema.splits( 148 | fields=FIELD, root=args.data, **kwargs) 149 | elif task == 'cnn': 150 | split = torchtext.datasets.generic.CNN.splits( 151 | fields=FIELD, root=args.data, **kwargs) 152 | elif task == 'dailymail': 153 | split = torchtext.datasets.generic.DailyMail.splits( 154 | fields=FIELD, root=args.data, **kwargs) 155 | elif task == 'cnn_dailymail': 156 | split_cnn = torchtext.datasets.generic.CNN.splits( 157 | fields=FIELD, root=args.data, **kwargs) 158 | split_dm = torchtext.datasets.generic.DailyMail.splits( 159 | fields=FIELD, root=args.data, **kwargs) 160 | for scnn, sdm in zip(split_cnn, split_dm): 161 | scnn.examples.extend(sdm) 162 | split = split_cnn 163 | elif 'sst' in task: 164 | split = torchtext.datasets.generic.SST.splits( 165 | fields=FIELD, root=args.data, **kwargs) 166 | elif 'imdb' in task: 167 | kwargs['validation'] = None 168 | split = torchtext.datasets.generic.IMDb.splits( 169 | fields=FIELD, root=args.data, **kwargs) 170 | elif 'zre' in task: 171 | split = torchtext.datasets.generic.ZeroShotRE.splits( 172 | fields=FIELD, root=args.data, **kwargs) 173 | elif os.path.exists(os.path.join(args.data, task)): 174 | split = torchtext.datasets.generic.JSON.splits( 175 | fields=FIELD, root=args.data, name=task, **kwargs) 176 | return split 177 | 178 | 179 | def batch_fn(new, i, sofar): 180 | prev_max_len = sofar / (i - 1) if i > 1 else 0 181 | return max(len(new.context), 5*len(new.answer), prev_max_len) * i 182 | 183 | 184 | def pad(x, new_channel, dim, val=None): 185 | if x.size(dim) > new_channel: 186 | x = x.narrow(dim, 0, new_channel) 187 | channels = x.size() 188 | assert (new_channel >= channels[dim]) 189 | if new_channel == channels[dim]: 190 | return x 191 | size = list(channels) 192 | size[dim] = new_channel - size[dim] 193 | padding = x.new(*size).fill_(val) 194 | return torch.cat([x, padding], dim) 195 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util import pad 3 | from metrics import compute_metrics 4 | 5 | def compute_validation_outputs(model, val_iter, field, optional_names=[]): 6 | loss, predictions, answers = [], [], [] 7 | outputs = [[] for _ in range(len(optional_names))] 8 | for batch_idx, batch in enumerate(val_iter): 9 | l, p = model(batch) 10 | loss.append(l) 11 | predictions.append(pad(p, 150, dim=-1, val=field.vocab.stoi[''])) 12 | a = None 13 | if hasattr(batch, 'wikisql_id'): 14 | a = batch.wikisql_id.data.cpu() 15 | elif hasattr(batch, 'squad_id'): 16 | a = batch.squad_id.data.cpu() 17 | elif hasattr(batch, 'woz_id'): 18 | a = batch.woz_id.data.cpu() 19 | else: 20 | a = pad(batch.answer.data.cpu(), 150, dim=-1, val=field.vocab.stoi['']) 21 | answers.append(a) 22 | for opt_idx, optional_name in enumerate(optional_names): 23 | outputs[opt_idx].append(getattr(batch, optional_name).data.cpu()) 24 | loss = torch.cat(loss, 0) if loss[0] is not None else None 25 | predictions = torch.cat(predictions, 0) 26 | answers = torch.cat(answers, 0) 27 | return loss, predictions, answers, [torch.cat([pad(x, 150, dim=-1, val=field.vocab.stoi['']) for x in output], 0) for output in outputs] 28 | 29 | 30 | def get_clip(val_iter): 31 | return -val_iter.extra if val_iter.extra > 0 else None 32 | 33 | 34 | def all_reverse(tensor, world_size, field, clip, dim=0): 35 | if world_size > 1: 36 | tensor = tensor.float() # tensors must be on cpu and float for all_gather 37 | all_tensors = [torch.zeros_like(tensor) for _ in range(world_size)] 38 | torch.distributed.barrier() # all_gather is experimental for gloo, found that these barriers were necessary 39 | torch.distributed.all_gather(all_tensors, tensor) 40 | torch.distributed.barrier() 41 | tensor = torch.cat(all_tensors, 0).long() # tensors must be long for reverse 42 | # for distributed training, dev sets are padded with extra examples so that the 43 | # tensors are all of a predictable size for all_gather. This line removes those extra examples 44 | return field.reverse(tensor)[:clip] 45 | 46 | 47 | def gather_results(model, val_iter, field, world_size, optional_names=[]): 48 | loss, predictions, answers, outputs = compute_validation_outputs(model, val_iter, field, optional_names=optional_names) 49 | clip = get_clip(val_iter) 50 | if not hasattr(val_iter.dataset.examples[0], 'squad_id') and not hasattr(val_iter.dataset.examples[0], 'wikisql_id') and not hasattr(val_iter.dataset.examples[0], 'woz_id'): 51 | answers = all_reverse(answers, world_size, field, clip) 52 | return loss, all_reverse(predictions, world_size, field, clip), answers, [all_reverse(x, world_size, field, clip) for x in outputs], 53 | 54 | 55 | def print_results(keys, values, rank=None, num_print=1): 56 | print() 57 | start = rank * num_print if rank is not None else 0 58 | end = start + num_print 59 | values = [val[start:end] for val in values] 60 | for ex_idx in range(len(values[0])): 61 | for key_idx, key in enumerate(keys): 62 | value = values[key_idx][ex_idx] 63 | v = value[0] if isinstance(value, list) else value 64 | print(f'{key}: {repr(v)}') 65 | print() 66 | 67 | 68 | def validate(task, val_iter, model, logger, field, world_size, rank, num_print=10, args=None): 69 | with torch.no_grad(): 70 | model.eval() 71 | required_names = ['greedy', 'answer'] 72 | optional_names = ['context', 'question'] 73 | loss, predictions, answers, results = gather_results(model, val_iter, field, world_size, optional_names=optional_names) 74 | predictions = [p.replace('UNK', 'OOV') for p in predictions] 75 | names = required_names + optional_names 76 | if hasattr(val_iter.dataset.examples[0], 'wikisql_id') or hasattr(val_iter.dataset.examples[0], 'squad_id') or hasattr(val_iter.dataset.examples[0], 'woz_id'): 77 | answers = [val_iter.dataset.all_answers[sid] for sid in answers.tolist()] 78 | metrics, answers = compute_metrics(predictions, answers, bleu='iwslt' in task or 'multi30k' in task, dialogue='woz' in task, 79 | rouge='cnn' in task, logical_form='sql' in task, corpus_f1='zre' in task, args=args) 80 | results = [predictions, answers] + results 81 | print_results(names, results, rank=rank, num_print=num_print) 82 | 83 | return loss, metrics 84 | --------------------------------------------------------------------------------