├── scripts ├── __init__.py ├── __init__.pyc ├── download_ag.sh ├── download_imdb.sh ├── figures │ ├── curve_plot.py │ ├── misc.py │ └── regplot.py ├── verify.py ├── train.py └── preprocess_data.py ├── vampire ├── __init__.py ├── tests │ ├── __init__.py │ ├── fixtures │ │ ├── imdb │ │ │ ├── vocabulary │ │ │ │ ├── non_padded_namespaces.txt │ │ │ │ └── vampire.txt │ │ │ ├── test.npz │ │ │ ├── train.npz │ │ │ ├── vampire.bgfreq │ │ │ ├── train.jsonl │ │ │ └── test.jsonl │ │ ├── vae │ │ │ ├── vocabulary │ │ │ │ ├── non_padded_namespaces.txt │ │ │ │ └── vampire.txt │ │ │ └── model.tar.gz │ │ ├── reference_corpus │ │ │ └── dev.npz │ │ ├── stopwords │ │ │ ├── snowball_stopwords.txt │ │ │ └── mallet_stopwords.txt │ │ ├── unsupervised │ │ │ └── experiment.json │ │ └── classifier │ │ │ ├── experiment_seq2seq.json │ │ │ └── experiment_seq2vec.json │ ├── models │ │ ├── classifier_test.py │ │ └── vampire_test.py │ ├── modules │ │ └── token_embedders │ │ │ └── vampire_token_embedder_test.py │ └── data │ │ └── dataset_readers │ │ └── semisupervised_text_classification_json_test.py ├── common │ ├── testing │ │ └── __init__.py │ ├── __init__.py │ ├── stopwords │ │ ├── snowball_stopwords.txt │ │ └── mallet_stopwords.txt │ ├── allennlp_bridge.py │ └── util.py ├── data │ ├── __init__.py │ └── dataset_readers │ │ ├── __init__.py │ │ ├── vampire_reader.py │ │ └── semisupervised_text_classification_json.py ├── models │ ├── __init__.py │ └── classifier.py └── modules │ ├── token_embedders │ ├── __init__.py │ └── vampire_token_embedder.py │ ├── vae │ ├── __init__.py │ ├── vae.py │ └── logistic_normal.py │ ├── __init__.py │ ├── pretrained_vae.py │ └── encoder.py ├── .DS_Store ├── figures └── bat.png ├── environments ├── __init__.py ├── datasets.py ├── random_search.py └── environments.py ├── .gitignore ├── requirements.txt ├── codecov.yml ├── Dockerfile ├── TROUBLESHOOTING.md ├── search_spaces ├── vampire_imdb_search.json ├── vampire_ag_search.json ├── vampire_yahoo_search.json ├── vampire_hatespeech_search.json ├── pretraining_search.jsonnet ├── nll_classifier_search.json ├── long_classifier_search.json ├── classifier_ag_search.json ├── classifier_hatespeech_search.json ├── classifier_yahoo_search.json └── classifier_imdb_search.json ├── training_config └── vampire.jsonnet ├── colab └── VAMPIRE_AGNews.ipynb ├── README.md └── LICENSE /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vampire/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vampire/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/.DS_Store -------------------------------------------------------------------------------- /figures/bat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/figures/bat.png -------------------------------------------------------------------------------- /scripts/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/scripts/__init__.pyc -------------------------------------------------------------------------------- /vampire/common/testing/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.common.testing.test_case import VAETestCase 2 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | vampire 4 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/vae/vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | vampire 2 | *tags 3 | *labels 4 | -------------------------------------------------------------------------------- /vampire/data/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.data.dataset_readers import SemiSupervisedTextClassificationJsonReader 2 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/test.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/vampire/tests/fixtures/imdb/test.npz -------------------------------------------------------------------------------- /vampire/models/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.models.classifier import Classifier 2 | from vampire.models.vampire import VAMPIRE 3 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/train.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/vampire/tests/fixtures/imdb/train.npz -------------------------------------------------------------------------------- /environments/__init__.py: -------------------------------------------------------------------------------- 1 | from environments.environments import ENVIRONMENTS 2 | from environments.random_search import RandomSearch -------------------------------------------------------------------------------- /vampire/tests/fixtures/vae/model.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/vampire/tests/fixtures/vae/model.tar.gz -------------------------------------------------------------------------------- /vampire/modules/token_embedders/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.modules.token_embedders.vampire_token_embedder import VampireTokenEmbedder 2 | -------------------------------------------------------------------------------- /vampire/modules/vae/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.modules.vae.vae import VAE 2 | from vampire.modules.vae.logistic_normal import LogisticNormal 3 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/reference_corpus/dev.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/vampire/HEAD/vampire/tests/fixtures/reference_corpus/dev.npz -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .mypy_cache/ 2 | .pytest_cache/ 3 | __pycache__/ 4 | .vscode/ 5 | model_logs/ 6 | .coverage* 7 | .ipynb_checkpoints/ 8 | datasets/ 9 | s3/ 10 | .hyperparameter_search_results -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/vampire.bgfreq: -------------------------------------------------------------------------------- 1 | { 2 | "abandon": 0.0055, 3 | "absolutely": 0.0045, 4 | "academy": 0.0038, 5 | "access": 0.0051, 6 | "ache": 0.0075, 7 | "acting": 0.0079 8 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==0.9.0 2 | pandas 3 | pytest 4 | tabulate 5 | regex 6 | # Checks style, syntax, and other useful errors. 7 | pylint==1.8.1 8 | # Static type checking 9 | mypy==0.521 10 | scipy>=1.3.0 -------------------------------------------------------------------------------- /vampire/common/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.common.testing.test_case import VAETestCase 2 | from vampire.common.allennlp_bridge import ExtendedVocabulary, VocabularyWithPretrainedVAE 3 | from vampire.common.util import * 4 | -------------------------------------------------------------------------------- /vampire/data/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.data.dataset_readers.semisupervised_text_classification_json import ( 2 | SemiSupervisedTextClassificationJsonReader) 3 | from vampire.data.dataset_readers.vampire_reader import VampireReader 4 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 0 3 | round: down 4 | status: 5 | patch: 6 | default: 7 | target: 90 8 | project: 9 | default: 10 | threshold: 1% 11 | changes: false 12 | comment: false 13 | ignore: 14 | - "vae/tests" -------------------------------------------------------------------------------- /vampire/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from vampire.modules.encoder import * 2 | from vampire.modules.pretrained_vae import PretrainedVAE 3 | from vampire.modules.token_embedders.vampire_token_embedder import VampireTokenEmbedder 4 | from vampire.modules.vae import LogisticNormal 5 | from vampire.modules.vae import VAE 6 | -------------------------------------------------------------------------------- /scripts/download_ag.sh: -------------------------------------------------------------------------------- 1 | mkdir -p $(pwd)/examples/ag 2 | curl -Lo $(pwd)/examples/ag/train.jsonl https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/train.jsonl 3 | curl -Lo $(pwd)/examples/ag/dev.jsonl https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/dev.jsonl 4 | curl -Lo $(pwd)/examples/ag/test.jsonl https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/test.jsonl -------------------------------------------------------------------------------- /scripts/download_imdb.sh: -------------------------------------------------------------------------------- 1 | mkdir -p $(pwd)/examples/imdb 2 | curl -Lo $(pwd)/examples/imdb/train.jsonl https://s3-us-west-2.amazonaws.com/allennlp/datasets/imdb/train.jsonl 3 | curl -Lo $(pwd)/examples/imdb/dev.jsonl https://s3-us-west-2.amazonaws.com/allennlp/datasets/imdb/dev.jsonl 4 | curl -Lo $(pwd)/examples/imdb/test.jsonl https://s3-us-west-2.amazonaws.com/allennlp/datasets/imdb/test.jsonl -------------------------------------------------------------------------------- /scripts/figures/curve_plot.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | sns.set(font_scale=1.3, style='white') 7 | 8 | 9 | if __name__ == '__main__': 10 | fig, ax = plt.subplots(1, 2, sharex=True, figsize=(6,3)) 11 | df = pd.read_csv("~/Downloads/run-vampire_AG_log_validation-tag-npmi.csv") 12 | df1 = pd.read_csv("~/Downloads/run-vampire_AG_log_validation-tag-nll.csv") 13 | sns.lineplot(df1['Step'], df1['Value'], ax=ax[0]) 14 | sns.lineplot(df['Step'], df['Value'], ax=ax[1]) 15 | ax[1].set_xlabel("Epoch") 16 | ax[0].set_ylabel("NLL") 17 | ax[1].set_ylabel("NPMI") 18 | plt.tight_layout() 19 | plt.savefig("curves.pdf") -------------------------------------------------------------------------------- /environments/datasets.py: -------------------------------------------------------------------------------- 1 | DATASETS = { 2 | "imdb": { 3 | "train": "s3://suching-dev/final-datasets/imdb/train_pretokenized.jsonl", 4 | "dev": "s3://suching-dev/final-datasets/imdb/dev_pretokenized.jsonl", 5 | "test": "s3://suching-dev/final-datasets/imdb/test_pretokenized.jsonl", 6 | "unlabeled": "s3://suching-dev/final-datasets/imdb/unlabeled_pretokenized.jsonl", 7 | "reference_counts": "s3://suching-dev/final-datasets/imdb/valid_npmi_reference/train.npz", 8 | "reference_vocabulary": "s3://suching-dev/final-datasets/imdb/valid_npmi_reference/train.vocab.json", 9 | "stopword_path": "s3://suching-dev/stopwords/snowball_stopwords.txt" 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM allennlp/commit:234fb18fc253d8118308da31c9d3bfaa9e346861 2 | 3 | LABEL maintainer="suching@allenai.org" 4 | 5 | WORKDIR /vampire 6 | 7 | RUN pip install pandas 8 | RUN pip install pytest 9 | RUN pip install torchvision 10 | RUN pip install tabulate 11 | RUN pip install regex 12 | RUN pip install pylint==1.8.1 13 | RUN pip install mypy==0.521 14 | RUN pip install codecov 15 | RUN pip install pytest-cov 16 | 17 | RUN python -m spacy download en 18 | 19 | COPY scripts/ scripts/ 20 | COPY environments/ environments/ 21 | COPY vampire/ vampire/ 22 | COPY training_config/ training_config/ 23 | COPY .pylintrc .pylintrc 24 | 25 | # Optional argument to set an environment variable with the Git SHA 26 | ARG SOURCE_COMMIT 27 | ENV ALLENAI_VAMPIRE_SOURCE_COMMIT $SOURCE_COMMIT 28 | 29 | EXPOSE 8000 30 | 31 | ENTRYPOINT ["/bin/bash"] -------------------------------------------------------------------------------- /TROUBLESHOOTING.md: -------------------------------------------------------------------------------- 1 | ## Troubleshooting 2 | 3 | A few insights have been received after playing around with the model since publication, including some methods to circumvent training instability, especially when training on larger corpora. 4 | 5 | Training instability usually manifests as NaN loss errors. To circumvent this, some easy things to try: 6 | 7 | * Use TFIDF as input instead of raw word frequencies. You can do this by setting `--tfidf` flag in `scripts/prepreprocess_data.py` 8 | 9 | * Increase batch size to at least 256 10 | 11 | * Reduce LR to 1e-4 or 1e-5. If you are training over a very large corpus, shouldn’t affect representation quality much. 12 | 13 | * Use some learning rate scheduler, slanted triangular scheduler has worked well for me. Make sure you tinker with the total number of epochs you train over. 14 | 15 | * Clamp the KLD to some max value (e.g. 1000) so it doesn’t diverge 16 | 17 | * Use a different KLD annealing scheduler (ie sigmoid) 18 | 19 | If you still have issues after trying these modifications, please submit an issue! 20 | -------------------------------------------------------------------------------- /scripts/figures/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | miscellaneous scripts for the paper 3 | """ 4 | 5 | # sample over vampires 6 | import glob 7 | import json 8 | configs = glob.glob("logs/vampire_yahoo_search/*/trial/config*") 9 | configs = [(x, json.load(open(x, 'r'))) for x in configs] 10 | hidden_dims = [(x, y['model']['vae']['encoder']['hidden_dims'][0]) for x,y in configs] 11 | hidden_dims = [(x.replace('/trial/config.json', ''), y) for x,y in hidden_dims] 12 | hidden_dims = [ (x, y) for x,y in hidden_dims if x + "/trial" in glob.glob(x + "/*")] 13 | hidden_dims = [ (x, y) for x,y in hidden_dims if x + "/trial/model.tar.gz" in glob.glob(x + "/trial/*")] 14 | hidden_dims = [" ".join([str(y) for y in x]) for x in hidden_dims] 15 | 16 | 17 | # join on VAMPIRE search 18 | import pandas as pd 19 | df = pd.read_json("/home/suching/vampire/logs/hatespeech_classifier_search/results.jsonl", lines=True) 20 | df1 = pd.read_json("/home/suching/vampire/logs/vampire_hatespeech_search/results.jsonl", lines=True) 21 | df['vampire_directory'] = df['model.input_embedder.token_embedders.vampire_tokens.model_archive'].str.replace('model.tar.gz', '') 22 | master = df.merge(df1, left_on = 'vampire_directory', right_on='directory') 23 | master.to_json("hyperparameter_search_results/hatespeech_vampire_classifier_search.jsonl", lines=True, orient='records') 24 | 25 | -------------------------------------------------------------------------------- /scripts/verify.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | # pylint: disable=invalid-name 3 | 4 | """Script that runs all verification steps. 5 | """ 6 | 7 | import argparse 8 | import sys 9 | from subprocess import CalledProcessError, run 10 | 11 | 12 | def main(arguments): 13 | try: 14 | print("Verifying with " + str(arguments)) 15 | if "pytest" in args: 16 | print("Tests (pytest):", flush=True) 17 | run("pytest -v --cov=vampire --color=yes vampire", shell=True, check=True) 18 | 19 | if "pylint" in arguments: 20 | print("Linter (pylint):", flush=True) 21 | run("pylint -d locally-disabled,locally-enabled -f colorized vampire", shell=True, check=True) 22 | print("pylint checks passed") 23 | 24 | if "mypy" in arguments: 25 | print("Typechecker (mypy):", flush=True) 26 | run("mypy vampire --ignore-missing-imports", shell=True, check=True) 27 | print("mypy checks passed") 28 | 29 | if "check-large-files" in arguments: 30 | print("Checking all added files have size <= 5MB", flush=True) 31 | run("./scripts/check_large_files.sh 5", shell=True, check=True) 32 | print("check large files passed") 33 | 34 | except CalledProcessError: 35 | # squelch the exception stacktrace 36 | sys.exit(1) 37 | 38 | if __name__ == "__main__": 39 | 40 | checks = ['pytest', 'pylint', 'mypy', 'check-large-files'] 41 | 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--checks', type=str, required=False, nargs='+', choices=checks) 44 | 45 | args = parser.parse_args() 46 | 47 | if args.checks: 48 | run_checks = args.checks 49 | else: 50 | run_checks = checks 51 | 52 | main(run_checks) 53 | -------------------------------------------------------------------------------- /search_spaces/vampire_imdb_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "KL_ANNEALING": { 4 | "sampling strategy": "choice", 5 | "choices": ["sigmoid", "linear", "constant"] 6 | }, 7 | "SIGMOID_WEIGHT_1": 0.25, 8 | "SIGMOID_WEIGHT_2": 15, 9 | "LINEAR_SCALING": 1000, 10 | "VAE_HIDDEN_DIM": { 11 | "sampling strategy": "integer", 12 | "bounds": [64, 128] 13 | }, 14 | "TRAIN_PATH": "/home/suching/vampire/data/imdb/train_unlabeled.npz", 15 | "DEV_PATH": "/home/suching/vampire/data/imdb/dev.npz", 16 | "VOCABULARY_DIRECTORY": "/home/suching/vampire/data/imdb/vocab/", 17 | "ADDITIONAL_UNLABELED_DATA_PATH": null, 18 | "REFERENCE_COUNTS": "s3://suching-dev/final-datasets/imdb/valid_npmi_reference/train.npz", 19 | "REFERENCE_VOCAB": "s3://suching-dev/final-datasets/imdb/valid_npmi_reference/train.vocab.json", 20 | "STOPWORDS_PATH": "s3://suching-dev/stopwords/snowball_stopwords.txt", 21 | "TRACK_NPMI": true, 22 | "NUM_ENCODER_LAYERS": 2, 23 | "ENCODER_ACTIVATION": { 24 | "sampling strategy": "choice", 25 | "choices": ["relu", "tanh", "softplus"] 26 | }, 27 | "NUM_MEAN_PROJECTION_LAYERS": 1, 28 | "MEAN_PROJECTION_ACTIVATION": "linear", 29 | "NUM_LOG_VAR_PROJECTION_LAYERS": 1, 30 | "LOG_VAR_PROJECTION_ACTIVATION": "linear", 31 | "SEED": { 32 | "sampling strategy": "integer", 33 | "bounds": [0, 100000] 34 | }, 35 | "Z_DROPOUT": { 36 | "sampling strategy": "uniform", 37 | "bounds": [0, 0.5] 38 | }, 39 | "LEARNING_RATE": { 40 | "sampling strategy": "loguniform", 41 | "bounds": [1e-4, 1e-2] 42 | }, 43 | "CUDA_DEVICE": 0, 44 | "THROTTLE": null, 45 | "ADD_ELMO": 0, 46 | "USE_SPACY_TOKENIZER": 0, 47 | "UPDATE_BACKGROUND_FREQUENCY": 0, 48 | "VOCAB_SIZE": 30000, 49 | "APPLY_BATCHNORM": 1, 50 | "APPLY_BATCHNORM_1": 0, 51 | "SEQUENCE_LENGTH": 400, 52 | "BATCH_SIZE": 64, 53 | "VALIDATION_METRIC": "+npmi" 54 | } -------------------------------------------------------------------------------- /vampire/common/stopwords/snowball_stopwords.txt: -------------------------------------------------------------------------------- 1 | i 2 | me 3 | my 4 | myself 5 | we 6 | our 7 | ours 8 | ourselves 9 | you 10 | your 11 | yours 12 | yourself 13 | yourselves 14 | he 15 | him 16 | his 17 | himself 18 | she 19 | her 20 | hers 21 | herself 22 | it 23 | its 24 | itself 25 | they 26 | them 27 | their 28 | theirs 29 | themselves 30 | what 31 | which 32 | who 33 | whom 34 | this 35 | that 36 | these 37 | those 38 | am 39 | is 40 | are 41 | was 42 | were 43 | be 44 | been 45 | being 46 | have 47 | has 48 | had 49 | having 50 | do 51 | does 52 | did 53 | doing 54 | will 55 | would 56 | shall 57 | should 58 | can 59 | could 60 | may 61 | might 62 | must 63 | ought 64 | a 65 | an 66 | the 67 | and 68 | but 69 | if 70 | or 71 | because 72 | as 73 | until 74 | while 75 | of 76 | at 77 | by 78 | for 79 | with 80 | about 81 | against 82 | between 83 | into 84 | through 85 | during 86 | before 87 | after 88 | above 89 | below 90 | to 91 | from 92 | up 93 | down 94 | in 95 | out 96 | on 97 | off 98 | over 99 | under 100 | again 101 | further 102 | then 103 | once 104 | here 105 | there 106 | when 107 | where 108 | why 109 | how 110 | all 111 | any 112 | both 113 | each 114 | few 115 | more 116 | most 117 | other 118 | some 119 | such 120 | no 121 | nor 122 | not 123 | only 124 | own 125 | same 126 | so 127 | than 128 | too 129 | very 130 | i_m 131 | you_re 132 | he_s 133 | she_s 134 | it_s 135 | we_re 136 | they_re 137 | i_ve 138 | you_ve 139 | we_ve 140 | they_ve 141 | i_d 142 | you_d 143 | he_d 144 | she_d 145 | we_d 146 | they_d 147 | i_ll 148 | you_ll 149 | he_ll 150 | she_ll 151 | we_ll 152 | they_ll 153 | isn_t 154 | aren_t 155 | wasn_t 156 | weren_t 157 | hasn_t 158 | haven_t 159 | hadn_t 160 | doesn_t 161 | don_t 162 | didn_t 163 | won_t 164 | wouldn_t 165 | shan_t 166 | shouldn_t 167 | can_t 168 | cannot 169 | couldn_t 170 | mustn_t 171 | let_s 172 | that_s 173 | who_s 174 | what_s 175 | here_s 176 | there_s 177 | when_s 178 | where_s 179 | why_s 180 | how_s -------------------------------------------------------------------------------- /vampire/tests/fixtures/stopwords/snowball_stopwords.txt: -------------------------------------------------------------------------------- 1 | i 2 | me 3 | my 4 | myself 5 | we 6 | our 7 | ours 8 | ourselves 9 | you 10 | your 11 | yours 12 | yourself 13 | yourselves 14 | he 15 | him 16 | his 17 | himself 18 | she 19 | her 20 | hers 21 | herself 22 | it 23 | its 24 | itself 25 | they 26 | them 27 | their 28 | theirs 29 | themselves 30 | what 31 | which 32 | who 33 | whom 34 | this 35 | that 36 | these 37 | those 38 | am 39 | is 40 | are 41 | was 42 | were 43 | be 44 | been 45 | being 46 | have 47 | has 48 | had 49 | having 50 | do 51 | does 52 | did 53 | doing 54 | will 55 | would 56 | shall 57 | should 58 | can 59 | could 60 | may 61 | might 62 | must 63 | ought 64 | a 65 | an 66 | the 67 | and 68 | but 69 | if 70 | or 71 | because 72 | as 73 | until 74 | while 75 | of 76 | at 77 | by 78 | for 79 | with 80 | about 81 | against 82 | between 83 | into 84 | through 85 | during 86 | before 87 | after 88 | above 89 | below 90 | to 91 | from 92 | up 93 | down 94 | in 95 | out 96 | on 97 | off 98 | over 99 | under 100 | again 101 | further 102 | then 103 | once 104 | here 105 | there 106 | when 107 | where 108 | why 109 | how 110 | all 111 | any 112 | both 113 | each 114 | few 115 | more 116 | most 117 | other 118 | some 119 | such 120 | no 121 | nor 122 | not 123 | only 124 | own 125 | same 126 | so 127 | than 128 | too 129 | very 130 | i_m 131 | you_re 132 | he_s 133 | she_s 134 | it_s 135 | we_re 136 | they_re 137 | i_ve 138 | you_ve 139 | we_ve 140 | they_ve 141 | i_d 142 | you_d 143 | he_d 144 | she_d 145 | we_d 146 | they_d 147 | i_ll 148 | you_ll 149 | he_ll 150 | she_ll 151 | we_ll 152 | they_ll 153 | isn_t 154 | aren_t 155 | wasn_t 156 | weren_t 157 | hasn_t 158 | haven_t 159 | hadn_t 160 | doesn_t 161 | don_t 162 | didn_t 163 | won_t 164 | wouldn_t 165 | shan_t 166 | shouldn_t 167 | can_t 168 | cannot 169 | couldn_t 170 | mustn_t 171 | let_s 172 | that_s 173 | who_s 174 | what_s 175 | here_s 176 | there_s 177 | when_s 178 | where_s 179 | why_s 180 | how_s -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/train.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "train_5011", "orig": "aclImdb/train/neg/5011_1.txt", "rating": 1, "label": "neg", "text": "...And I never thought a movie deserved to be awarded a 1! But this one is honestly the worst movie I've ever watched. My wife picked it up because of the cast, but the storyline right since the DVD box seemed quite predictable. It is not a mystery, nor a juvenile-catching film. It does not include any sensuality, if that's what the title could remotely have suggest any of you. This is just a total no-no. Don't waste your time or money unless you feel like watching a bunch of youngsters in a as-grown-up kind of Gothic setting, where a killer is going after them. Nothing new, nothing interesting, nothing worth watching. Max Makowski makes the worst of Nick Stahl."} 2 | {"id": "train_10433", "orig": "aclImdb/train/neg/10433_4.txt", "rating": 4, "label": "pos", "text": "The fight scenes were great. Loved the old and newer cylons and how they painted the ones on their side. It was the ending that I hated. I was disappointed that it was earth but 150k years back. But to travel all that way just to start over? Are you kidding me? 38k people that fought for their very existence and once they get to paradise, they abandon technology? No way. Sure they were eating paper and rationing food, but that is over. They can live like humans again. They only have one good doctor. What are they going to do when someone has a tooth ache never mind giving birth... yea right. No one would have made that choice."} 3 | {"id": "train_11872", "orig": "aclImdb/train/neg/11872_1.txt", "rating": 1, "label": "neg", "text": "The only way this is a family drama is if parents explain everything wrong with its message.

SPOILER: they feed a deer for a year and then kill it for eating their food after killing its mother and at first pontificating about taking responsibility for their actions. They blame bears and deer for \"misbehaving\" by eating while they take no responsibility to use adequate locks and fences or even learn to shoot instead of twice maiming animals and letting them linger."} -------------------------------------------------------------------------------- /search_spaces/vampire_ag_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "KL_ANNEALING": { 4 | "sampling strategy": "choice", 5 | "choices": ["sigmoid", "linear", "constant"] 6 | }, 7 | "SIGMOID_WEIGHT_1": 0.25, 8 | "SIGMOID_WEIGHT_2": 15, 9 | "LINEAR_SCALING": 1000, 10 | "VAE_HIDDEN_DIM": { 11 | "sampling strategy": "integer", 12 | "bounds": [64, 128] 13 | }, 14 | "TRAIN_PATH": "/home/suching/vampire/data/ag/train.npz", 15 | "DEV_PATH": "/home/suching/vampire/data/ag/dev.npz", 16 | "VOCABULARY_DIRECTORY": "/home/suching/vampire/data/ag/vocab/", 17 | "ADDITIONAL_UNLABELED_DATA_PATH": null, 18 | "REFERENCE_COUNTS": "s3://suching-dev/final-datasets/ag-news/valid_npmi_reference/train.npz", 19 | "REFERENCE_VOCAB": "s3://suching-dev/final-datasets/ag-news/valid_npmi_reference/train.vocab.json", 20 | "STOPWORDS_PATH": "s3://suching-dev/stopwords/snowball_stopwords.txt", 21 | "TRACK_NPMI": true, 22 | "NUM_ENCODER_LAYERS": { 23 | "sampling strategy": "choice", 24 | "choices": [1, 2, 3] 25 | }, 26 | "ENCODER_ACTIVATION": { 27 | "sampling strategy": "choice", 28 | "choices": ["relu", "tanh", "softplus"] 29 | }, 30 | "NUM_MEAN_PROJECTION_LAYERS": 1, 31 | "MEAN_PROJECTION_ACTIVATION": "linear", 32 | "NUM_LOG_VAR_PROJECTION_LAYERS": 1, 33 | "LOG_VAR_PROJECTION_ACTIVATION": "linear", 34 | "SEED": { 35 | "sampling strategy": "integer", 36 | "bounds": [0, 100000] 37 | }, 38 | "Z_DROPOUT": { 39 | "sampling strategy": "uniform", 40 | "bounds": [0, 0.5] 41 | }, 42 | "LEARNING_RATE": { 43 | "sampling strategy": "loguniform", 44 | "bounds": [1e-4, 1e-2] 45 | }, 46 | "CUDA_DEVICE": 0, 47 | "THROTTLE": null, 48 | "ADD_ELMO": 0, 49 | "USE_SPACY_TOKENIZER": 0, 50 | "UPDATE_BACKGROUND_FREQUENCY": 0, 51 | "VOCAB_SIZE": 30000, 52 | "APPLY_BATCHNORM": 1, 53 | "APPLY_BATCHNORM_1": 0, 54 | "SEQUENCE_LENGTH": 400, 55 | "BATCH_SIZE": 64, 56 | "VALIDATION_METRIC": "+npmi" 57 | } -------------------------------------------------------------------------------- /search_spaces/vampire_yahoo_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "KL_ANNEALING": { 4 | "sampling strategy": "choice", 5 | "choices": ["sigmoid", "linear", "constant"] 6 | }, 7 | "SIGMOID_WEIGHT_1": 0.25, 8 | "SIGMOID_WEIGHT_2": 15, 9 | "LINEAR_SCALING": 1000, 10 | "VAE_HIDDEN_DIM": { 11 | "sampling strategy": "integer", 12 | "bounds": [64, 128] 13 | }, 14 | "TRAIN_PATH": "/home/suching/vampire/data/yahoo/train.npz", 15 | "DEV_PATH": "/home/suching/vampire/data/yahoo/dev.npz", 16 | "VOCABULARY_DIRECTORY": "/home/suching/vampire/data/yahoo/vocab/", 17 | "ADDITIONAL_UNLABELED_DATA_PATH": null, 18 | "REFERENCE_COUNTS": "s3://suching-dev/final-datasets/yahoo/valid_npmi_reference/train.npz", 19 | "REFERENCE_VOCAB": "s3://suching-dev/final-datasets/yahoo/valid_npmi_reference/train.vocab.json", 20 | "STOPWORDS_PATH": "s3://suching-dev/stopwords/snowball_stopwords.txt", 21 | "TRACK_NPMI": true, 22 | "NUM_ENCODER_LAYERS": { 23 | "sampling strategy": "choice", 24 | "choices": [1, 2, 3] 25 | }, 26 | "ENCODER_ACTIVATION": { 27 | "sampling strategy": "choice", 28 | "choices": ["relu", "tanh", "softplus"] 29 | }, 30 | "NUM_MEAN_PROJECTION_LAYERS": 1, 31 | "MEAN_PROJECTION_ACTIVATION": "linear", 32 | "NUM_LOG_VAR_PROJECTION_LAYERS": 1, 33 | "LOG_VAR_PROJECTION_ACTIVATION": "linear", 34 | "SEED": { 35 | "sampling strategy": "integer", 36 | "bounds": [0, 100000] 37 | }, 38 | "Z_DROPOUT": { 39 | "sampling strategy": "uniform", 40 | "bounds": [0, 0.5] 41 | }, 42 | "LEARNING_RATE": { 43 | "sampling strategy": "loguniform", 44 | "bounds": [1e-4, 1e-2] 45 | }, 46 | "CUDA_DEVICE": 0, 47 | "THROTTLE": null, 48 | "ADD_ELMO": 0, 49 | "USE_SPACY_TOKENIZER": 0, 50 | "UPDATE_BACKGROUND_FREQUENCY": 0, 51 | "VOCAB_SIZE": 30000, 52 | "APPLY_BATCHNORM": 1, 53 | "APPLY_BATCHNORM_1": 0, 54 | "SEQUENCE_LENGTH": 400, 55 | "BATCH_SIZE": 64, 56 | "VALIDATION_METRIC": "+npmi" 57 | } -------------------------------------------------------------------------------- /search_spaces/vampire_hatespeech_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "KL_ANNEALING": { 4 | "sampling strategy": "choice", 5 | "choices": ["sigmoid", "linear", "constant"] 6 | }, 7 | "SIGMOID_WEIGHT_1": 0.25, 8 | "SIGMOID_WEIGHT_2": 15, 9 | "LINEAR_SCALING": 1000, 10 | "VAE_HIDDEN_DIM": { 11 | "sampling strategy": "integer", 12 | "bounds": [64, 128] 13 | }, 14 | "TRAIN_PATH": "/home/suching/vampire/data/hatespeech/train.npz", 15 | "DEV_PATH": "/home/suching/vampire/data/hatespeech/dev.npz", 16 | "VOCABULARY_DIRECTORY": "/home/suching/vampire/data/hatespeech/vocab/", 17 | "ADDITIONAL_UNLABELED_DATA_PATH": null, 18 | "REFERENCE_COUNTS": "s3://suching-dev/final-datasets/hatespeech/valid_npmi_reference/train.npz", 19 | "REFERENCE_VOCAB": "s3://suching-dev/final-datasets/hatespeech/valid_npmi_reference/train.vocab.json", 20 | "STOPWORDS_PATH": "s3://suching-dev/stopwords/snowball_stopwords.txt", 21 | "TRACK_NPMI": true, 22 | "NUM_ENCODER_LAYERS": { 23 | "sampling strategy": "choice", 24 | "choices": [1, 2, 3] 25 | }, 26 | "ENCODER_ACTIVATION": { 27 | "sampling strategy": "choice", 28 | "choices": ["relu", "tanh", "softplus"] 29 | }, 30 | "NUM_MEAN_PROJECTION_LAYERS": 1, 31 | "MEAN_PROJECTION_ACTIVATION": "linear", 32 | "NUM_LOG_VAR_PROJECTION_LAYERS": 1, 33 | "LOG_VAR_PROJECTION_ACTIVATION": "linear", 34 | "SEED": { 35 | "sampling strategy": "integer", 36 | "bounds": [0, 100000] 37 | }, 38 | "Z_DROPOUT": { 39 | "sampling strategy": "uniform", 40 | "bounds": [0, 0.5] 41 | }, 42 | "LEARNING_RATE": { 43 | "sampling strategy": "loguniform", 44 | "bounds": [1e-4, 1e-2] 45 | }, 46 | "CUDA_DEVICE": 0, 47 | "THROTTLE": null, 48 | "ADD_ELMO": 0, 49 | "USE_SPACY_TOKENIZER": 0, 50 | "UPDATE_BACKGROUND_FREQUENCY": 0, 51 | "VOCAB_SIZE": 30000, 52 | "APPLY_BATCHNORM": 1, 53 | "APPLY_BATCHNORM_1": 0, 54 | "SEQUENCE_LENGTH": 400, 55 | "BATCH_SIZE": 64, 56 | "VALIDATION_METRIC": "+npmi" 57 | } -------------------------------------------------------------------------------- /vampire/tests/fixtures/unsupervised/experiment.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "type": "vampire_reader" 5 | }, 6 | "train_data_path": "vampire/tests/fixtures/imdb/train.npz", 7 | "validation_data_path": "vampire/tests/fixtures/imdb/test.npz", 8 | "vocabulary":{ 9 | "type": "extended_vocabulary", 10 | "directory_path": "vampire/tests/fixtures/imdb/vocabulary/" 11 | }, 12 | "model": { 13 | "type": "vampire", 14 | "update_background_freq": true, 15 | "kl_weight_annealing": "constant", 16 | "reference_counts": "vampire/tests/fixtures/reference_corpus/dev.npz", 17 | "reference_vocabulary": "vampire/tests/fixtures/reference_corpus/dev.vocab.json", 18 | "background_data_path": "vampire/tests/fixtures/imdb/vampire.bgfreq", 19 | "bow_embedder": { 20 | "type": "bag_of_word_counts", 21 | "vocab_namespace": "vampire", 22 | "ignore_oov": true 23 | }, 24 | "vae": { 25 | "type": "logistic_normal", 26 | "encoder": { 27 | "input_dim": 279, 28 | "num_layers": 2, 29 | "hidden_dims": [10, 10], 30 | "activations": ["relu", "relu"] 31 | }, 32 | "mean_projection": { 33 | "input_dim": 10, 34 | "num_layers": 1, 35 | "hidden_dims": [10], 36 | "activations": ["linear"] 37 | }, 38 | "log_variance_projection": { 39 | "input_dim": 10, 40 | "num_layers": 1, 41 | "hidden_dims": [10], 42 | "activations": ["linear"] 43 | }, 44 | "decoder": { 45 | "input_dim": 10, 46 | "num_layers": 1, 47 | "hidden_dims": [279], 48 | "activations": ["tanh"] 49 | }, 50 | "z_dropout": 0.2 51 | } 52 | }, 53 | "iterator": { 54 | "type": "basic", 55 | "batch_size": 100, 56 | "track_epoch": true 57 | }, 58 | "trainer": { 59 | "validation_metric": "-nll", 60 | "num_epochs": 5, 61 | "patience": 5, 62 | "cuda_device": -1, 63 | "optimizer": { 64 | "type": "adam", 65 | "lr": 0.001, 66 | "weight_decay": 0.001 67 | } 68 | } 69 | } 70 | 71 | -------------------------------------------------------------------------------- /search_spaces/pretraining_search.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 1, 3 | "KL_ANNEALING": { 4 | "sampling strategy": "choice", 5 | "choices": ["sigmoid", "linear"] 6 | }, 7 | "SIGMOID_WEIGHT_1": 0.25, 8 | "SIGMOID_WEIGHT_2": 15, 9 | "LINEAR_SCALING": 1000, 10 | "VAE_HIDDEN_DIM": { 11 | "sampling strategy": "integer", 12 | "bounds": [128, 1024] 13 | }, 14 | "TRAIN_PATH": "/home/suching/vampire/examples/tweets_sample/train.npz", 15 | "DEV_PATH": "/home/suching/vampire/examples/tweets_sample/dev.npz", 16 | "BACKGROUND_DATA_PATH": "/home/suching/vampire/examples/tweets_sample/vampire.bgfreq", 17 | "VOCABULARY_DIRECTORY": "/home/suching/vampire/examples/tweets_sample/vocabulary/", 18 | "ADDITIONAL_UNLABELED_DATA_PATH": null, 19 | "REFERENCE_COUNTS": "/home/suching/vampire/examples/tweets_sample/reference/ref.npz", 20 | "REFERENCE_VOCAB": "/home/suching/vampire/examples/tweets_sample/reference/ref.vocab.json", 21 | "STOPWORDS_PATH": "s3://suching-dev/stopwords/snowball_stopwords.txt", 22 | "TRACK_NPMI": true, 23 | "NUM_ENCODER_LAYERS": { 24 | "sampling strategy": "choice", 25 | "choices": [2, 3] 26 | }, 27 | "ENCODER_ACTIVATION": { 28 | "sampling strategy": "choice", 29 | "choices": ["relu", "tanh", "softplus"] 30 | }, 31 | "NUM_MEAN_PROJECTION_LAYERS": 1, 32 | "MEAN_PROJECTION_ACTIVATION": "linear", 33 | "NUM_LOG_VAR_PROJECTION_LAYERS": 1, 34 | "LOG_VAR_PROJECTION_ACTIVATION": "linear", 35 | "SEED": { 36 | "sampling strategy": "integer", 37 | "bounds": [0, 100000] 38 | }, 39 | "Z_DROPOUT": { 40 | "sampling strategy": "uniform", 41 | "bounds": [0, 0.5] 42 | }, 43 | "LEARNING_RATE": { 44 | "sampling strategy": "loguniform", 45 | "bounds": [1e-3, 4e-3] 46 | }, 47 | "CUDA_DEVICE": 0, 48 | "THROTTLE": null, 49 | "ADD_ELMO": 0, 50 | "PATIENCE": 10, 51 | "NUM_EPOCHS": 10, 52 | "USE_SPACY_TOKENIZER": 0, 53 | "UPDATE_BACKGROUND_FREQUENCY": 0, 54 | "VOCAB_SIZE": 30000, 55 | "APPLY_BATCHNORM": 1, 56 | "SEQUENCE_LENGTH": 400, 57 | "BATCH_SIZE": 64, 58 | "KLD_CLAMP": { 59 | "sampling strategy": "uniform", 60 | "bounds": [1, 100000] 61 | }, 62 | "VALIDATION_METRIC": "+npmi" 63 | } -------------------------------------------------------------------------------- /vampire/modules/vae/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from allennlp.models import Model 3 | 4 | 5 | class VAE(Model): 6 | 7 | def __init__(self, vocab): 8 | super(VAE, self).__init__(vocab) 9 | 10 | def estimate_params(self, input_repr): 11 | """ 12 | Estimate the parameters of distribution given an input representation 13 | 14 | Parameters 15 | ---------- 16 | input_repr: ``torch.FloatTensor`` 17 | input representation 18 | 19 | Returns 20 | ------- 21 | params : ``Dict[str, torch.Tensor]`` 22 | estimated parameters after feedforward projection 23 | """ 24 | raise NotImplementedError 25 | 26 | def compute_negative_kld(self, params): 27 | """ 28 | Compute the KL divergence given posteriors. 29 | """ 30 | raise NotImplementedError 31 | 32 | def generate_latent_code(self, input_repr: torch.Tensor): # pylint: disable=W0221 33 | """ 34 | Given an input representation, produces the latent variables from the VAE. 35 | 36 | Parameters 37 | ---------- 38 | input_repr : ``torch.Tensor`` 39 | Input in which the VAE will use to re-create the original text. 40 | This can either be x_enc (the latent representation of x after 41 | being encoded) or x_bow: the Bag-of-Word-Counts representation of x. 42 | 43 | Returns 44 | ------- 45 | A ``Dict[str, torch.Tensor]`` containing 46 | theta: 47 | the latent variable produced by the VAE 48 | parameters: 49 | A dictionary containing the parameters produces by the 50 | distribution 51 | negative_kl_divergence: 52 | The negative KL=divergence specific to the distribution this 53 | VAE implements 54 | """ 55 | raise NotImplementedError 56 | 57 | def get_beta(self): 58 | """ 59 | Returns 60 | ------- 61 | The topics x vocabulary tensor representing word strengths for each topic. 62 | """ 63 | raise NotImplementedError 64 | 65 | def encode(self, input_vector: torch.Tensor): 66 | """ 67 | Encode the input_vector to the VAE's internal representation. 68 | """ 69 | raise NotImplementedError 70 | -------------------------------------------------------------------------------- /search_spaces/nll_classifier_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "CUDA_DEVICE": 0, 4 | "EVALUATE_ON_TEST": 0, 5 | "NUM_EPOCHS": 50, 6 | "SEED": { 7 | "sampling strategy": "integer", 8 | "bounds": [0, 100000] 9 | }, 10 | "TRAIN_PATH": "s3://suching-dev/final-datasets/imdb/train_pretokenized.jsonl", 11 | "DEV_PATH": "s3://suching-dev/final-datasets/imdb/dev_pretokenized.jsonl", 12 | "TEST_PATH": "s3://suching-dev/final-datasets/imdb/test_pretokenized.jsonl", 13 | "THROTTLE": 200, 14 | "USE_SPACY_TOKENIZER": 0, 15 | "FREEZE_EMBEDDINGS": "VAMPIRE", 16 | "EMBEDDINGS": ["RANDOM", "VAMPIRE"], 17 | "VAMPIRE_DIRECTORY": "logs/vampire_search_nll/run_8_2019-05-26_18-02-5081g6qx40 57", 18 | "ENCODER": { 19 | "sampling strategy": "choice", 20 | "choices": ["CNN", "LSTM", "AVERAGE"] 21 | }, 22 | "EMBEDDING_DROPOUT": 0.26941597325945665, 23 | "LEARNING_RATE": 0.004847983603406938, 24 | "DROPOUT": 0.10581295186904283, 25 | "BATCH_SIZE": 16, 26 | "NUM_ENCODER_LAYERS": { 27 | "sampling strategy": "choice", 28 | "choices": [1, 2, 3] 29 | }, 30 | "NUM_OUTPUT_LAYERS": { 31 | "sampling strategy": "choice", 32 | "choices": [1, 2, 3] 33 | }, 34 | "MAX_FILTER_SIZE": { 35 | "sampling strategy": "integer", 36 | "bounds": [3, 6] 37 | }, 38 | "NUM_FILTERS": { 39 | "sampling strategy": "integer", 40 | "bounds": [64, 512] 41 | }, 42 | "HIDDEN_SIZE": { 43 | "sampling strategy": "integer", 44 | "bounds": [64, 512] 45 | }, 46 | "AGGREGATIONS": { 47 | "sampling strategy": "subset", 48 | "choices": ["maxpool", "meanpool", "attention", "final_state"] 49 | }, 50 | "MAX_CHARACTER_FILTER_SIZE": { 51 | "sampling strategy": "integer", 52 | "bounds": [3, 6] 53 | }, 54 | "NUM_CHARACTER_FILTERS": { 55 | "sampling strategy": "integer", 56 | "bounds": [16, 64] 57 | }, 58 | "CHARACTER_HIDDEN_SIZE": { 59 | "sampling strategy": "integer", 60 | "bounds": [16, 128] 61 | }, 62 | "CHARACTER_EMBEDDING_DIM": { 63 | "sampling strategy": "integer", 64 | "bounds": [16, 128] 65 | }, 66 | "CHARACTER_ENCODER": { 67 | "sampling strategy": "choice", 68 | "choices": ["LSTM", "CNN", "AVERAGE"] 69 | }, 70 | "NUM_CHARACTER_ENCODER_LAYERS": { 71 | "sampling strategy": "choice", 72 | "choices": [1, 2] 73 | } 74 | } -------------------------------------------------------------------------------- /environments/random_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict 3 | 4 | import numpy as np 5 | 6 | 7 | class RandomSearch: 8 | 9 | @staticmethod 10 | def random_choice(*args): 11 | choices = [] 12 | for arg in args: 13 | choices.append(arg) 14 | return lambda: np.random.choice(choices) 15 | 16 | @staticmethod 17 | def random_integer(low, high): 18 | return lambda: int(np.random.randint(low, high)) 19 | 20 | @staticmethod 21 | def random_loguniform(low, high): 22 | return lambda: str(np.exp(np.random.uniform(np.log(low), np.log(high)))) 23 | 24 | @staticmethod 25 | def random_subset(*args): 26 | choices = [] 27 | for arg in args: 28 | choices.append(arg) 29 | func = lambda: np.random.choice(choices, np.random.randint(1, len(choices)+1), replace=False) 30 | return func 31 | 32 | @staticmethod 33 | def random_pair(*args): 34 | choices = [] 35 | for arg in args: 36 | choices.append(arg) 37 | func = lambda: np.random.choice(choices, 2, replace=False) 38 | return func 39 | 40 | @staticmethod 41 | def random_uniform(low, high): 42 | return lambda: np.random.uniform(low, high) 43 | 44 | 45 | class HyperparameterSearch: 46 | 47 | def __init__(self, **kwargs): 48 | self.search_space = {} 49 | self.lambda_ = lambda: 0 50 | for key, val in kwargs.items(): 51 | self.search_space[key] = val 52 | 53 | def parse(self, val: Any): 54 | if isinstance(val, type(self.lambda_)) and val.__name__ == self.lambda_.__name__: 55 | val = val() 56 | if isinstance(val, (int, np.int)): 57 | return int(val) 58 | elif isinstance(val, (float, np.float)): 59 | return float(val) 60 | elif isinstance(val, (np.ndarray, list)): 61 | return " ".join(val) 62 | else: 63 | return val 64 | elif isinstance(val, (int, np.int)): 65 | return int(val) 66 | elif isinstance(val, (float, np.float)): 67 | return float(val) 68 | elif isinstance(val, (np.ndarray, list)): 69 | return " ".join(val) 70 | elif val is None: 71 | return None 72 | else: 73 | return val 74 | 75 | 76 | def sample(self) -> Dict: 77 | res = {} 78 | for key, val in self.search_space.items(): 79 | res[key] = self.parse(val) 80 | return res 81 | 82 | def update_environment(self, sample) -> None: 83 | for key, val in sample.items(): 84 | os.environ[key] = str(val) 85 | -------------------------------------------------------------------------------- /vampire/tests/models/classifier_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from allennlp.common import Params 4 | from allennlp.common.testing import ModelTestCase 5 | from allennlp.data.dataset import Batch 6 | 7 | from vampire.common.testing.test_case import VAETestCase 8 | from vampire.models import classifier 9 | from vampire.modules.token_embedders import VampireTokenEmbedder 10 | 11 | 12 | class TestClassifiers(ModelTestCase): 13 | def setUp(self): 14 | super().setUp() 15 | 16 | def test_seq2seq_clf_with_vae_token_embedder_can_train_save_and_load(self): 17 | self.set_up_model(VAETestCase.FIXTURES_ROOT / 'classifier' / 'experiment_seq2seq.json', 18 | VAETestCase.FIXTURES_ROOT / "imdb" / "train.jsonl") 19 | self.ensure_model_can_train_save_and_load(self.param_file) 20 | 21 | def test_seq2seq_clf_with_vae_token_embedder_batch_predictions_are_consistent(self): 22 | self.set_up_model(VAETestCase.FIXTURES_ROOT / 'classifier' / 'experiment_seq2seq.json', 23 | VAETestCase.FIXTURES_ROOT / "imdb" / "train.jsonl") 24 | self.ensure_batch_predictions_are_consistent() 25 | 26 | def test_seq2vec_clf_with_vae_token_embedder_can_train_save_and_load(self): 27 | self.set_up_model(VAETestCase.FIXTURES_ROOT / 'classifier' / 'experiment_seq2vec.json', 28 | VAETestCase.FIXTURES_ROOT / "imdb" / "train.jsonl") 29 | self.ensure_model_can_train_save_and_load(self.param_file) 30 | 31 | 32 | def test_seq2seq_clf_with_vae_token_embedder_forward_pass_runs_correctly(self): 33 | self.set_up_model(VAETestCase.FIXTURES_ROOT / 'classifier' / 'experiment_seq2seq.json', 34 | VAETestCase.FIXTURES_ROOT / "imdb" / "train.jsonl") 35 | dataset = Batch(self.instances) 36 | dataset.index_instances(self.vocab) 37 | training_tensors = dataset.as_tensor_dict() 38 | output_dict = self.model(**training_tensors) 39 | assert output_dict['logits'].shape == (3, 2) 40 | assert output_dict['probs'].shape == (3, 2) 41 | assert output_dict['loss'] 42 | 43 | def test_seq2vec_clf_with_vae_token_embedder_forward_pass_runs_correctly(self): 44 | self.set_up_model(VAETestCase.FIXTURES_ROOT / 'classifier' / 'experiment_seq2vec.json', 45 | VAETestCase.FIXTURES_ROOT / "imdb" / "train.jsonl") 46 | dataset = Batch(self.instances) 47 | dataset.index_instances(self.vocab) 48 | training_tensors = dataset.as_tensor_dict() 49 | output_dict = self.model(**training_tensors) 50 | assert output_dict['logits'].shape == (3, 2) 51 | assert output_dict['probs'].shape == (3, 2) 52 | assert output_dict['loss'] 53 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/test.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "test_596", "orig": "aclImdb/test/pos/596_10.txt", "rating": 10, "label": "pos", "text": "I have not seen this movie since 1979 when I was a teenager. I grew up with the Sesame street muppets and later realized how much effort and time went into bringing these characters to life. Jim Hensen was a genius and master muppeteer. When I watched this movie the other day it took me back in time when I was younger and things seemed so much simpler. For this bit of time travel I rate this movie a 10.

The plot line explores how Kermit goes from the swamp to Hollywood. The laughs and gags are classic muppetism. I am glad these films are still around for the next generation. I hope I never out grow the magic of the muppets."} 2 | {"id": "test_6464", "orig": "aclImdb/test/pos/6464_9.txt", "rating": 9, "label": "pos", "text": "This is an excellent movie that tackles the issue of racism in a delicate and balanced way. Great performances all round but absolutely outstanding acting by Sidney Poitier.

He makes this movie breathe and alive. His portrayal of a guy who struggles against discrimination and violence is simply mind blowing. His acting is forceful and delicate and subtle at the same time. Truly worthy of an Oscar, Poitier had to wait (because of his skin colour) for many more years before the sheer brilliance of his acting was recognised by the Academy.

Cassavetes turns in a great performance too, withdrawn, troubled and realistic as it has become his hallmark. He and Poitier contrast inimitably the forces of cowardice, courage and human transformation through friendship.

The movie is enjoyable and at the same time deeply haunting in its portrayal of racism in the US. The irony is that it somehow mirrors the realities under which Poitier had to work."} 3 | {"id": "test_10401", "orig": "aclImdb/test/neg/10401_1.txt", "rating": 1, "label": "neg", "text": "Kubrick may have been the greatest director of all times. He may have made more classics than anyone else. He may have been a perfectionist. But man, was his first attempt ever bad!

Kubrick had good reason to try to make this film dissappear from the map: it looks like an Ed Wood film. It has strange narration, cheap shots, bad dialogue, ominous music reminiscent of your 50s sci-fi/horror flick, and what looks like relatives of the cast of \"Reefer Madness\" going insane for no reason.

Sure, you can see an undeveloped Kubrick in there. It is a psychological/horror study of war. The characters became dehumanized and insane. There are people playing more than one role. There are constant shots of the faces and particular facial expressions of different people. And there are a few interesting shots around there. But really, this is a mess.

Of course, I am not discouraging you from watching it. If you get a hold of it, you are joining a select group of myself and a few thousand people world wide who have had access to it."} 4 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | import shutil 6 | import subprocess 7 | import tempfile 8 | from typing import Any, Dict 9 | 10 | from allennlp.common.params import Params 11 | 12 | from environments import ENVIRONMENTS 13 | from environments.random_search import HyperparameterSearch 14 | 15 | random_int = random.randint(0, 2**32) 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() # pylint: disable=invalid-name 19 | parser.add_argument('-o', 20 | '--override', 21 | action="store_true", 22 | help='remove the specified serialization dir before training') 23 | parser.add_argument('-c', '--config', type=str, help='training config', required=True) 24 | parser.add_argument('-s', '--serialization-dir', type=str, help='model serialization directory', required=True) 25 | parser.add_argument('-e', '--environment', type=str, help='hyperparameter environment', required=True) 26 | parser.add_argument('-r', '--recover', action='store_true', help = "recover saved model") 27 | parser.add_argument('-d', '--device', type=str, required=False, help = "device to run model on") 28 | parser.add_argument('-x', '--seed', type=str, required=False, help = "seed to run on") 29 | 30 | 31 | args = parser.parse_args() 32 | 33 | env = ENVIRONMENTS[args.environment.upper()] 34 | 35 | 36 | space = HyperparameterSearch(**env) 37 | 38 | sample = space.sample() 39 | 40 | for key, val in sample.items(): 41 | os.environ[key] = str(val) 42 | 43 | if args.device: 44 | os.environ['CUDA_DEVICE'] = args.device 45 | 46 | if args.seed: 47 | os.environ['SEED'] = args.seed 48 | 49 | 50 | allennlp_command = [ 51 | "allennlp", 52 | "train", 53 | "--include-package", 54 | "vampire", 55 | args.config, 56 | "-s", 57 | args.serialization_dir 58 | ] 59 | 60 | if args.seed: 61 | allennlp_command[-1] = allennlp_command[-1] + "_" + args.seed 62 | 63 | if args.recover: 64 | def append_seed_to_config(seed, serialization_dir): 65 | seed = str(seed) 66 | seed_dict = {"pytorch_seed": seed, 67 | "random_seed": seed, 68 | "numpy_seed": seed} 69 | config_path = os.path.join(serialization_dir, 'config.json') 70 | with open(config_path, 'r+') as f: 71 | config_dict = json.load(f) 72 | seed_dict.update(config_dict) 73 | f.seek(0) 74 | json.dump(seed_dict, f, indent=4) 75 | 76 | append_seed_to_config(seed=args.seed, serialization_dir=allennlp_command[-1]) 77 | 78 | allennlp_command.append("--recover") 79 | 80 | if os.path.exists(allennlp_command[-1]) and args.override: 81 | print(f"overriding {allennlp_command[-1]}") 82 | shutil.rmtree(allennlp_command[-1]) 83 | 84 | 85 | subprocess.run(" ".join(allennlp_command), shell=True, check=True) 86 | 87 | 88 | if __name__ == '__main__': 89 | main() 90 | -------------------------------------------------------------------------------- /vampire/data/dataset_readers/vampire_reader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | 4 | import numpy as np 5 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 6 | from allennlp.data.fields import ArrayField, Field 7 | from allennlp.data.instance import Instance 8 | from overrides import overrides 9 | 10 | from vampire.common.util import load_sparse 11 | 12 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 13 | 14 | 15 | @DatasetReader.register("vampire_reader") 16 | class VampireReader(DatasetReader): 17 | """ 18 | Reads bag of word vectors from a sparse matrices representing training and validation data. 19 | 20 | Expects a sparse matrix of size N documents x vocab size, which can be created via 21 | the scripts/preprocess_data.py file. 22 | 23 | The output of ``read`` is a list of ``Instances`` with the field: 24 | vec: ``ArrayField`` 25 | 26 | Parameters 27 | ---------- 28 | lazy : ``bool``, optional, (default = ``False``) 29 | Whether or not instances can be read lazily. 30 | sample : ``int``, optional, (default = ``None``) 31 | If specified, we will randomly sample the provided 32 | number of lines from the dataset. Useful for debugging. 33 | min_sequence_length : ``int`` (default = ``3``) 34 | Only consider examples from data that are greater than 35 | the supplied minimum sequence length. 36 | """ 37 | def __init__(self, 38 | lazy: bool = False, 39 | sample: int = None, 40 | min_sequence_length: int = 0) -> None: 41 | super().__init__(lazy=lazy) 42 | self._sample = sample 43 | self._min_sequence_length = min_sequence_length 44 | 45 | @overrides 46 | def _read(self, file_path): 47 | # load sparse matrix 48 | mat = load_sparse(file_path) 49 | # convert to lil format for row-wise iteration 50 | mat = mat.tolil() 51 | 52 | # optionally sample the matrix 53 | if self._sample: 54 | indices = np.random.choice(range(mat.shape[0]), self._sample) 55 | else: 56 | indices = range(mat.shape[0]) 57 | 58 | for index in indices: 59 | instance = self.text_to_instance(vec=mat[index].toarray().squeeze()) 60 | if instance is not None and mat[index].toarray().sum() > self._min_sequence_length: 61 | yield instance 62 | 63 | @overrides 64 | def text_to_instance(self, vec: str = None) -> Instance: # type: ignore 65 | """ 66 | Parameters 67 | ---------- 68 | text : ``str``, required. 69 | The text to classify 70 | label ``str``, optional, (default = None). 71 | The label for this text. 72 | 73 | Returns 74 | ------- 75 | An ``Instance`` containing the following fields: 76 | tokens : ``TextField`` 77 | The tokens in the sentence or phrase. 78 | label : ``LabelField`` 79 | The label label of the sentence or phrase. 80 | """ 81 | # pylint: disable=arguments-differ 82 | fields: Dict[str, Field] = {} 83 | fields['tokens'] = ArrayField(vec) 84 | return Instance(fields) 85 | -------------------------------------------------------------------------------- /environments/environments.py: -------------------------------------------------------------------------------- 1 | from environments.random_search import RandomSearch 2 | from environments.datasets import DATASETS 3 | import os 4 | 5 | 6 | CLASSIFIER = { 7 | "LAZY_DATASET_READER": 0, 8 | "CUDA_DEVICE": 0, 9 | "EVALUATE_ON_TEST": 0, 10 | "NUM_EPOCHS": 50, 11 | "SEED": RandomSearch.random_integer(0, 10000), 12 | "SEQUENCE_LENGTH": 400, 13 | "TRAIN_PATH": os.environ["DATA_DIR"] + "/train.jsonl", 14 | "DEV_PATH": os.environ["DATA_DIR"] + "/dev.jsonl", 15 | "TEST_PATH": os.environ["DATA_DIR"] + "/test.jsonl", 16 | "THROTTLE": os.environ.get("THROTTLE", None), 17 | "USE_SPACY_TOKENIZER": 1, 18 | "FREEZE_EMBEDDINGS": ["VAMPIRE"], 19 | "EMBEDDINGS": ["VAMPIRE", "RANDOM"], 20 | "ENCODER": "AVERAGE", 21 | "EMBEDDING_DROPOUT": 0.5, 22 | "LEARNING_RATE": 0.001, 23 | "DROPOUT": 0.3, 24 | "VAMPIRE_DIRECTORY": os.environ.get("VAMPIRE_DIR", None), 25 | "VAMPIRE_DIM": os.environ.get("VAMPIRE_DIM", None), 26 | "BATCH_SIZE": 32, 27 | "NUM_ENCODER_LAYERS": 1, 28 | "NUM_OUTPUT_LAYERS": 2, 29 | "MAX_FILTER_SIZE": RandomSearch.random_integer(3, 6), 30 | "NUM_FILTERS": RandomSearch.random_integer(64, 512), 31 | "HIDDEN_SIZE": RandomSearch.random_integer(64, 512), 32 | "AGGREGATIONS": RandomSearch.random_subset("maxpool", "meanpool", "attention", "final_state"), 33 | "MAX_CHARACTER_FILTER_SIZE": RandomSearch.random_integer(3, 6), 34 | "NUM_CHARACTER_FILTERS": RandomSearch.random_integer(16, 64), 35 | "CHARACTER_HIDDEN_SIZE": RandomSearch.random_integer(16, 128), 36 | "CHARACTER_EMBEDDING_DIM": RandomSearch.random_integer(16, 64), 37 | "CHARACTER_ENCODER": RandomSearch.random_choice("LSTM", "CNN", "AVERAGE"), 38 | "NUM_CHARACTER_ENCODER_LAYERS": RandomSearch.random_choice(1, 2), 39 | } 40 | 41 | VAMPIRE = { 42 | "LAZY_DATASET_READER": os.environ.get("LAZY", 0), 43 | "KL_ANNEALING": "linear", 44 | "KLD_CLAMP": None, 45 | "SIGMOID_WEIGHT_1": 0.25, 46 | "SIGMOID_WEIGHT_2": 15, 47 | "LINEAR_SCALING": 1000, 48 | "VAE_HIDDEN_DIM": 81, 49 | "TRAIN_PATH": os.environ["DATA_DIR"] + "/train.npz", 50 | "DEV_PATH": os.environ["DATA_DIR"] + "/dev.npz", 51 | "REFERENCE_COUNTS": os.environ["DATA_DIR"] + "/reference/ref.npz", 52 | "REFERENCE_VOCAB": os.environ["DATA_DIR"] + "/reference/ref.vocab.json", 53 | "VOCABULARY_DIRECTORY": os.environ["DATA_DIR"] + "/vocabulary/", 54 | "BACKGROUND_DATA_PATH": os.environ["DATA_DIR"] + "/vampire.bgfreq", 55 | "NUM_ENCODER_LAYERS": 2, 56 | "ENCODER_ACTIVATION": "relu", 57 | "MEAN_PROJECTION_ACTIVATION": "linear", 58 | "NUM_MEAN_PROJECTION_LAYERS": 1, 59 | "LOG_VAR_PROJECTION_ACTIVATION": "linear", 60 | "NUM_LOG_VAR_PROJECTION_LAYERS": 1, 61 | "SEED": RandomSearch.random_integer(0, 100000), 62 | "Z_DROPOUT": 0.49, 63 | "LEARNING_RATE": 1e-3, 64 | "TRACK_NPMI": True, 65 | "CUDA_DEVICE": 0, 66 | "UPDATE_BACKGROUND_FREQUENCY": 0, 67 | "VOCAB_SIZE": os.environ.get("VOCAB_SIZE", 30000), 68 | "BATCH_SIZE": 64, 69 | "MIN_SEQUENCE_LENGTH": 3, 70 | "NUM_EPOCHS": 50, 71 | "PATIENCE": 5, 72 | "VALIDATION_METRIC": "+npmi" 73 | } 74 | 75 | 76 | 77 | ENVIRONMENTS = { 78 | 'VAMPIRE': VAMPIRE, 79 | "CLASSIFIER": CLASSIFIER, 80 | } 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/imdb/vocabulary/vampire.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | abandon 3 | absolutely 4 | academy 5 | access 6 | ache 7 | acting 8 | actions 9 | adequate 10 | alive 11 | animals 12 | attempt 13 | awarded 14 | bad 15 | balanced 16 | bears 17 | birth 18 | bit 19 | blame 20 | blowing 21 | box 22 | breathe 23 | brilliance 24 | bringing 25 | bunch 26 | cassavetes 27 | cast 28 | catching 29 | characters 30 | cheap 31 | choice 32 | classic 33 | classics 34 | colour 35 | constant 36 | contrast 37 | courage 38 | course 39 | cowardice 40 | cylons 41 | day 42 | deeply 43 | deer 44 | dehumanized 45 | delicate 46 | deserved 47 | dialogue 48 | different 49 | director 50 | disappointed 51 | discouraging 52 | discrimination 53 | dissappear 54 | doctor 55 | does 56 | don 57 | drama 58 | dvd 59 | earth 60 | eating 61 | effort 62 | ending 63 | enjoyable 64 | excellent 65 | existence 66 | explain 67 | explores 68 | expressions 69 | faces 70 | facial 71 | family 72 | feed 73 | feel 74 | fences 75 | fight 76 | film 77 | films 78 | flick 79 | food 80 | forceful 81 | forces 82 | fought 83 | friendship 84 | gags 85 | generation 86 | genius 87 | giving 88 | glad 89 | goes 90 | going 91 | good 92 | gothic 93 | great 94 | greatest 95 | grew 96 | group 97 | grow 98 | grown 99 | guy 100 | hallmark 101 | hated 102 | haunting 103 | hensen 104 | hold 105 | hollywood 106 | honestly 107 | hope 108 | horror 109 | human 110 | humans 111 | include 112 | inimitably 113 | insane 114 | instead 115 | interesting 116 | irony 117 | issue 118 | jim 119 | joining 120 | just 121 | juvenile 122 | kermit 123 | kidding 124 | kill 125 | killer 126 | killing 127 | kind 128 | kubrick 129 | later 130 | laughs 131 | learn 132 | letting 133 | life 134 | like 135 | line 136 | linger 137 | live 138 | locks 139 | looks 140 | loved 141 | madness 142 | magic 143 | maiming 144 | make 145 | makes 146 | makowski 147 | man 148 | map 149 | master 150 | max 151 | mess 152 | message 153 | mind 154 | mirrors 155 | misbehaving 156 | money 157 | mother 158 | movie 159 | muppeteer 160 | muppetism 161 | muppets 162 | music 163 | mystery 164 | narration 165 | new 166 | newer 167 | nick 168 | old 169 | ominous 170 | ones 171 | oscar 172 | outstanding 173 | painted 174 | paper 175 | paradise 176 | parents 177 | particular 178 | people 179 | perfectionist 180 | performance 181 | performances 182 | picked 183 | playing 184 | plot 185 | poitier 186 | pontificating 187 | portrayal 188 | predictable 189 | psychological 190 | quite 191 | racism 192 | rate 193 | rationing 194 | realistic 195 | realities 196 | realized 197 | really 198 | reason 199 | recognised 200 | reefer 201 | relatives 202 | reminiscent 203 | remotely 204 | responsibility 205 | right 206 | role 207 | round 208 | scenes 209 | sci 210 | seen 211 | select 212 | sensuality 213 | sesame 214 | setting 215 | sheer 216 | shoot 217 | shots 218 | sidney 219 | simpler 220 | simply 221 | skin 222 | spoiler 223 | stahl 224 | start 225 | storyline 226 | strange 227 | street 228 | struggles 229 | study 230 | subtle 231 | suggest 232 | sure 233 | swamp 234 | tackles 235 | taking 236 | technology 237 | teenager 238 | things 239 | thought 240 | thousand 241 | time 242 | times 243 | title 244 | took 245 | tooth 246 | total 247 | transformation 248 | travel 249 | troubled 250 | truly 251 | try 252 | turns 253 | twice 254 | undeveloped 255 | unless 256 | use 257 | violence 258 | wait 259 | war 260 | waste 261 | watched 262 | watching 263 | way 264 | went 265 | wide 266 | wife 267 | withdrawn 268 | wood 269 | work 270 | world 271 | worst 272 | worth 273 | worthy 274 | wrong 275 | yea 276 | year 277 | years 278 | younger 279 | youngsters 280 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/vae/vocabulary/vampire.txt: -------------------------------------------------------------------------------- 1 | @@UNKNOWN@@ 2 | abandon 3 | absolutely 4 | academy 5 | access 6 | ache 7 | acting 8 | actions 9 | adequate 10 | alive 11 | animals 12 | attempt 13 | awarded 14 | bad 15 | balanced 16 | bears 17 | birth 18 | bit 19 | blame 20 | blowing 21 | box 22 | breathe 23 | brilliance 24 | bringing 25 | bunch 26 | cassavetes 27 | cast 28 | catching 29 | characters 30 | cheap 31 | choice 32 | classic 33 | classics 34 | colour 35 | constant 36 | contrast 37 | courage 38 | course 39 | cowardice 40 | cylons 41 | day 42 | deeply 43 | deer 44 | dehumanized 45 | delicate 46 | deserved 47 | dialogue 48 | different 49 | director 50 | disappointed 51 | discouraging 52 | discrimination 53 | dissappear 54 | doctor 55 | does 56 | don 57 | drama 58 | dvd 59 | earth 60 | eating 61 | effort 62 | ending 63 | enjoyable 64 | excellent 65 | existence 66 | explain 67 | explores 68 | expressions 69 | faces 70 | facial 71 | family 72 | feed 73 | feel 74 | fences 75 | fight 76 | film 77 | films 78 | flick 79 | food 80 | forceful 81 | forces 82 | fought 83 | friendship 84 | gags 85 | generation 86 | genius 87 | giving 88 | glad 89 | goes 90 | going 91 | good 92 | gothic 93 | great 94 | greatest 95 | grew 96 | group 97 | grow 98 | grown 99 | guy 100 | hallmark 101 | hated 102 | haunting 103 | hensen 104 | hold 105 | hollywood 106 | honestly 107 | hope 108 | horror 109 | human 110 | humans 111 | include 112 | inimitably 113 | insane 114 | instead 115 | interesting 116 | irony 117 | issue 118 | jim 119 | joining 120 | just 121 | juvenile 122 | kermit 123 | kidding 124 | kill 125 | killer 126 | killing 127 | kind 128 | kubrick 129 | later 130 | laughs 131 | learn 132 | letting 133 | life 134 | like 135 | line 136 | linger 137 | live 138 | locks 139 | looks 140 | loved 141 | madness 142 | magic 143 | maiming 144 | make 145 | makes 146 | makowski 147 | man 148 | map 149 | master 150 | max 151 | mess 152 | message 153 | mind 154 | mirrors 155 | misbehaving 156 | money 157 | mother 158 | movie 159 | muppeteer 160 | muppetism 161 | muppets 162 | music 163 | mystery 164 | narration 165 | new 166 | newer 167 | nick 168 | old 169 | ominous 170 | ones 171 | oscar 172 | outstanding 173 | painted 174 | paper 175 | paradise 176 | parents 177 | particular 178 | people 179 | perfectionist 180 | performance 181 | performances 182 | picked 183 | playing 184 | plot 185 | poitier 186 | pontificating 187 | portrayal 188 | predictable 189 | psychological 190 | quite 191 | racism 192 | rate 193 | rationing 194 | realistic 195 | realities 196 | realized 197 | really 198 | reason 199 | recognised 200 | reefer 201 | relatives 202 | reminiscent 203 | remotely 204 | responsibility 205 | right 206 | role 207 | round 208 | scenes 209 | sci 210 | seen 211 | select 212 | sensuality 213 | sesame 214 | setting 215 | sheer 216 | shoot 217 | shots 218 | sidney 219 | simpler 220 | simply 221 | skin 222 | spoiler 223 | stahl 224 | start 225 | storyline 226 | strange 227 | street 228 | struggles 229 | study 230 | subtle 231 | suggest 232 | sure 233 | swamp 234 | tackles 235 | taking 236 | technology 237 | teenager 238 | things 239 | thought 240 | thousand 241 | time 242 | times 243 | title 244 | took 245 | tooth 246 | total 247 | transformation 248 | travel 249 | troubled 250 | truly 251 | try 252 | turns 253 | twice 254 | undeveloped 255 | unless 256 | use 257 | violence 258 | wait 259 | war 260 | waste 261 | watched 262 | watching 263 | way 264 | went 265 | wide 266 | wife 267 | withdrawn 268 | wood 269 | work 270 | world 271 | worst 272 | worth 273 | worthy 274 | wrong 275 | yea 276 | year 277 | years 278 | younger 279 | youngsters 280 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/classifier/experiment_seq2seq.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "type": "semisupervised_text_classification_json", 5 | "tokenizer": { 6 | "word_splitter": "spacy" 7 | }, 8 | "token_indexers": { 9 | "tokens": { 10 | "type": "single_id", 11 | "namespace": "tokens", 12 | "lowercase_tokens": true 13 | }, 14 | "vae_tokens": { 15 | "type": "single_id", 16 | "namespace": "vae", 17 | "lowercase_tokens": true 18 | } 19 | }, 20 | "ignore_labels": false, 21 | "max_sequence_length": 400 22 | }, 23 | "validation_dataset_reader": { 24 | "lazy": false, 25 | "type": "semisupervised_text_classification_json", 26 | "tokenizer": { 27 | "word_splitter": "spacy" 28 | }, 29 | "token_indexers": { 30 | "tokens": { 31 | "type": "single_id", 32 | "namespace": "tokens", 33 | "lowercase_tokens": true 34 | }, 35 | "vae_tokens": { 36 | "type": "single_id", 37 | "namespace": "vae", 38 | "lowercase_tokens": true 39 | } 40 | }, 41 | "ignore_labels": false, 42 | "max_sequence_length": 400 43 | }, 44 | "vocabulary":{ 45 | "type": "vocabulary_with_vampire", 46 | "vampire_vocab_file": "vampire/tests/fixtures/imdb/vocabulary/vampire.txt" 47 | }, 48 | "datasets_for_vocab_creation": ["train"], 49 | "train_data_path": "vampire/tests/fixtures/imdb/train.jsonl", 50 | "validation_data_path": "vampire/tests/fixtures/imdb/test.jsonl", 51 | "model": { 52 | "type": "classifier", 53 | "input_embedder": { 54 | "token_embedders": { 55 | "tokens": { 56 | "type": "embedding", 57 | "embedding_dim": 10, 58 | "trainable": true 59 | }, 60 | "vae_tokens": { 61 | "type": "vampire_token_embedder", 62 | "expand_dim": true, 63 | "device": -1, 64 | "model_archive": "vampire/tests/fixtures/vae/model.tar.gz", 65 | "background_frequency": "vampire/tests/fixtures/imdb/vampire.bgfreq", 66 | "dropout": 0.2 67 | } 68 | } 69 | }, 70 | "encoder": { 71 | "type": "seq2seq", 72 | "architecture": { 73 | "type": "lstm", 74 | "num_layers": 1, 75 | "bidirectional": false, 76 | "input_size": 20, 77 | "hidden_size": 128 78 | }, 79 | "aggregations": ["maxpool" , "attention"] 80 | } 81 | }, 82 | "iterator": { 83 | "type": "bucket", 84 | "sorting_keys": [["tokens", "num_tokens"]], 85 | "batch_size": 32 86 | }, 87 | "trainer": { 88 | "optimizer": { 89 | "type": "adam", 90 | "lr": 0.0004 91 | }, 92 | "validation_metric": "+accuracy", 93 | "num_serialized_models_to_keep": 0, 94 | "num_epochs": 75, 95 | "grad_norm": 10.0, 96 | "patience": 5, 97 | "cuda_device": -1, 98 | "learning_rate_scheduler": { 99 | "type": "reduce_on_plateau", 100 | "factor": 0.5, 101 | "mode": "max", 102 | "patience": 0 103 | } 104 | } 105 | } 106 | 107 | -------------------------------------------------------------------------------- /vampire/tests/models/vampire_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name,unused-import 2 | import numpy as np 3 | from allennlp.commands.train import train_model_from_file 4 | from allennlp.common.testing import ModelTestCase 5 | 6 | from vampire.common.allennlp_bridge import ExtendedVocabulary 7 | from vampire.common.testing.test_case import VAETestCase 8 | from vampire.data.dataset_readers import VampireReader 9 | from vampire.models import VAMPIRE 10 | 11 | 12 | class TestVampire(ModelTestCase): 13 | def setUp(self): 14 | super(TestVampire, self).setUp() 15 | self.set_up_model(VAETestCase.FIXTURES_ROOT / 'unsupervised' / 'experiment.json', 16 | VAETestCase.FIXTURES_ROOT / "imdb" / "train.npz") 17 | 18 | def test_model_can_train_save_and_load_unsupervised(self): 19 | self.ensure_model_can_train_save_and_load(self.param_file) 20 | 21 | def test_npmi_computed_correctly(self): 22 | save_dir = self.TEST_DIR / "save_and_load_test" 23 | model = train_model_from_file(self.param_file, save_dir, overrides="") 24 | 25 | topics = [(1, ["great", "movie", "film", "amazing", "wow", "best", "ridiculous", "ever", "good", "incredible", "positive"]), 26 | (2, ["bad", "film", "worst", "negative", "movie", "ever", "not", "any", "gross", "boring"])] 27 | npmi = model.compute_npmi(topics, num_words=10) 28 | 29 | ref_vocab = model._ref_vocab 30 | ref_counts = model._ref_count_mat 31 | 32 | vocab_index = dict(zip(ref_vocab, range(len(ref_vocab)))) 33 | n_docs, _ = ref_counts.shape 34 | 35 | npmi_means = [] 36 | for topic in topics: 37 | words = topic[1] 38 | npmi_vals = [] 39 | for word_i, word1 in enumerate(words[:10]): 40 | if word1 in vocab_index: 41 | index1 = vocab_index[word1] 42 | else: 43 | index1 = None 44 | for word2 in words[word_i+1:10]: 45 | if word2 in vocab_index: 46 | index2 = vocab_index[word2] 47 | else: 48 | index2 = None 49 | if index1 is None or index2 is None: 50 | _npmi = 0.0 51 | else: 52 | col1 = np.array(ref_counts[:, index1].todense() > 0, dtype=int) 53 | col2 = np.array(ref_counts[:, index2].todense() > 0, dtype=int) 54 | sum1 = col1.sum() 55 | sum2 = col2.sum() 56 | interaction = np.sum(col1 * col2) 57 | if interaction == 0: 58 | assert model._npmi_numerator[index1, index2] == 0.0 and model._npmi_denominator[index1, index2] == 0.0 59 | _npmi = 0.0 60 | else: 61 | assert model._ref_interaction[index1, index2] == np.log10(interaction) 62 | assert model._ref_doc_sum[index1] == sum1 63 | assert model._ref_doc_sum[index2] == sum2 64 | expected_numerator = np.log10(n_docs) + np.log10(interaction) - np.log10(sum1) - np.log10(sum2) 65 | numerator = np.log10(model.n_docs) + model._npmi_numerator[index1, index2] 66 | assert np.isclose(expected_numerator, numerator) 67 | expected_denominator = np.log10(n_docs) - np.log10(interaction) 68 | denominator = np.log10(model.n_docs) - model._npmi_denominator[index1, index2] 69 | assert np.isclose(expected_denominator, denominator) 70 | _npmi = expected_numerator / expected_denominator 71 | npmi_vals.append(_npmi) 72 | npmi_means.append(np.mean(npmi_vals)) 73 | assert np.isclose(npmi, np.mean(npmi_means)) 74 | -------------------------------------------------------------------------------- /training_config/vampire.jsonnet: -------------------------------------------------------------------------------- 1 | local CUDA_DEVICE = std.parseInt(std.extVar("CUDA_DEVICE")); 2 | 3 | local BASE_READER(LAZY, SAMPLE, MIN_SEQUENCE_LENGTH) = { 4 | "lazy": LAZY == 1, 5 | "sample": SAMPLE, 6 | "type": "vampire_reader", 7 | "min_sequence_length": MIN_SEQUENCE_LENGTH 8 | }; 9 | 10 | { 11 | "numpy_seed": std.extVar("SEED"), 12 | "pytorch_seed": std.extVar("SEED"), 13 | "random_seed": std.extVar("SEED"), 14 | "dataset_reader": BASE_READER(std.parseInt(std.extVar("LAZY_DATASET_READER")), null, std.parseInt(std.extVar("MIN_SEQUENCE_LENGTH"))), 15 | "validation_dataset_reader": BASE_READER(std.parseInt(std.extVar("LAZY_DATASET_READER")), null,std.parseInt(std.extVar("MIN_SEQUENCE_LENGTH"))), 16 | "train_data_path": std.extVar("TRAIN_PATH"), 17 | "validation_data_path": std.extVar("DEV_PATH"), 18 | "vocabulary": { 19 | "type": "extended_vocabulary", 20 | "directory_path": std.extVar("VOCABULARY_DIRECTORY") 21 | }, 22 | "model": { 23 | "type": "vampire", 24 | "bow_embedder": { 25 | "type": "bag_of_word_counts", 26 | "vocab_namespace": "vampire", 27 | "ignore_oov": true 28 | }, 29 | "kl_weight_annealing": std.extVar("KL_ANNEALING"), 30 | "sigmoid_weight_1": std.extVar("SIGMOID_WEIGHT_1"), 31 | "sigmoid_weight_2": std.extVar("SIGMOID_WEIGHT_2"), 32 | "linear_scaling": std.extVar("LINEAR_SCALING"), 33 | "reference_counts": std.extVar("REFERENCE_COUNTS"), 34 | "reference_vocabulary": std.extVar("REFERENCE_VOCAB"), 35 | "update_background_freq": std.parseInt(std.extVar("UPDATE_BACKGROUND_FREQUENCY")) == 1, 36 | "track_npmi": std.parseInt(std.extVar("TRACK_NPMI")) == 1, 37 | "background_data_path": std.extVar("BACKGROUND_DATA_PATH"), 38 | "vae": { 39 | "z_dropout": std.extVar("Z_DROPOUT"), 40 | "kld_clamp": std.extVar("KLD_CLAMP"), 41 | "encoder": { 42 | "activations": std.makeArray(std.parseInt(std.extVar("NUM_ENCODER_LAYERS")), function(i) std.extVar("ENCODER_ACTIVATION")), 43 | "hidden_dims": std.makeArray(std.parseInt(std.extVar("NUM_ENCODER_LAYERS")), function(i) std.parseInt(std.extVar("VAE_HIDDEN_DIM"))), 44 | "input_dim": std.parseInt(std.extVar("VOCAB_SIZE")) + 1, 45 | "num_layers": std.parseInt(std.extVar("NUM_ENCODER_LAYERS")) 46 | }, 47 | "mean_projection": { 48 | "activations": std.extVar("MEAN_PROJECTION_ACTIVATION"), 49 | "hidden_dims": std.makeArray(std.parseInt(std.extVar("NUM_MEAN_PROJECTION_LAYERS")), function(i) std.parseInt(std.extVar("VAE_HIDDEN_DIM"))), 50 | "input_dim": std.extVar("VAE_HIDDEN_DIM"), 51 | "num_layers": std.parseInt(std.extVar("NUM_MEAN_PROJECTION_LAYERS")) 52 | }, 53 | "log_variance_projection": { 54 | "activations": std.extVar("LOG_VAR_PROJECTION_ACTIVATION"), 55 | "hidden_dims": std.makeArray(std.parseInt(std.extVar("NUM_LOG_VAR_PROJECTION_LAYERS")), function(i) std.parseInt(std.extVar("VAE_HIDDEN_DIM"))), 56 | "input_dim": std.parseInt(std.extVar("VAE_HIDDEN_DIM")), 57 | "num_layers": std.parseInt(std.extVar("NUM_LOG_VAR_PROJECTION_LAYERS")) 58 | }, 59 | "decoder": { 60 | "activations": "linear", 61 | "hidden_dims": [std.parseInt(std.extVar("VOCAB_SIZE")) + 1], 62 | "input_dim": std.parseInt(std.extVar("VAE_HIDDEN_DIM")), 63 | "num_layers": 1 64 | }, 65 | "type": "logistic_normal" 66 | } 67 | }, 68 | "iterator": { 69 | "batch_size": std.parseInt(std.extVar("BATCH_SIZE")), 70 | "track_epoch": true, 71 | "type": "basic" 72 | }, 73 | "trainer": { 74 | "cuda_device": CUDA_DEVICE, 75 | "num_serialized_models_to_keep": 1, 76 | "num_epochs": std.parseInt(std.extVar("NUM_EPOCHS")), 77 | "patience": std.parseInt(std.extVar("PATIENCE")), 78 | "optimizer": { 79 | "lr": std.extVar("LEARNING_RATE"), 80 | "type": "adam" 81 | }, 82 | "validation_metric": std.extVar("VALIDATION_METRIC") 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /vampire/tests/fixtures/classifier/experiment_seq2vec.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "type": "semisupervised_text_classification_json", 5 | "tokenizer": { 6 | "word_splitter": "spacy" 7 | }, 8 | "token_indexers": { 9 | "tokens": { 10 | "type": "single_id", 11 | "namespace": "tokens", 12 | "lowercase_tokens": true, 13 | "end_tokens": ["@@PADDING@@", "@@PADDING@@"], 14 | "start_tokens": ["@@PADDING@@", "@@PADDING@@"] 15 | }, 16 | "vae_tokens": { 17 | "type": "single_id", 18 | "namespace": "vae", 19 | "lowercase_tokens": true, 20 | "end_tokens": ["@@PADDING@@", "@@PADDING@@"], 21 | "start_tokens": ["@@PADDING@@", "@@PADDING@@"] 22 | } 23 | }, 24 | "ignore_labels": false, 25 | "max_sequence_length": 400 26 | }, 27 | "validation_dataset_reader": { 28 | "lazy": false, 29 | "type": "semisupervised_text_classification_json", 30 | "tokenizer": { 31 | "word_splitter": "spacy" 32 | }, 33 | "token_indexers": { 34 | "tokens": { 35 | "type": "single_id", 36 | "namespace": "tokens", 37 | "lowercase_tokens": true, 38 | "end_tokens": ["@@PADDING@@", "@@PADDING@@"], 39 | "start_tokens": ["@@PADDING@@", "@@PADDING@@"] 40 | }, 41 | "vae_tokens": { 42 | "type": "single_id", 43 | "namespace": "vae", 44 | "lowercase_tokens": true, 45 | "end_tokens": ["@@PADDING@@", "@@PADDING@@"], 46 | "start_tokens": ["@@PADDING@@", "@@PADDING@@"] 47 | } 48 | }, 49 | "ignore_labels": false, 50 | "max_sequence_length": 400 51 | }, 52 | "vocabulary":{ 53 | "type": "vocabulary_with_vampire", 54 | "vampire_vocab_file": "vampire/tests/fixtures/imdb/vocabulary/vampire.txt" 55 | }, 56 | "datasets_for_vocab_creation": ["train"], 57 | "train_data_path": "vampire/tests/fixtures/imdb/train.jsonl", 58 | "validation_data_path": "vampire/tests/fixtures/imdb/test.jsonl", 59 | "model": { 60 | "type": "classifier", 61 | "input_embedder": { 62 | "token_embedders": { 63 | "tokens": { 64 | "type": "embedding", 65 | "embedding_dim": 300, 66 | "trainable": true 67 | }, 68 | "vae_tokens": { 69 | "type": "vampire_token_embedder", 70 | "expand_dim": true, 71 | "device": -1, 72 | "model_archive": "vampire/tests/fixtures/vae/model.tar.gz", 73 | "background_frequency": "vampire/tests/fixtures/imdb/vampire.bgfreq", 74 | "dropout": 0.2 75 | } 76 | 77 | } 78 | }, 79 | "encoder": { 80 | "type": "seq2vec", 81 | "architecture":{ 82 | "type": "cnn", 83 | "num_filters": 100, 84 | "embedding_dim": 310, 85 | "output_dim": 512 86 | } 87 | } 88 | }, 89 | "iterator": { 90 | "type": "bucket", 91 | "sorting_keys": [["tokens", "num_tokens"]], 92 | "batch_size": 32 93 | }, 94 | "trainer": { 95 | "optimizer": { 96 | "type": "adam", 97 | "lr": 0.0004 98 | }, 99 | "validation_metric": "+accuracy", 100 | "num_serialized_models_to_keep": 0, 101 | "num_epochs": 75, 102 | "grad_norm": 10.0, 103 | "patience": 5, 104 | "cuda_device": -1, 105 | "learning_rate_scheduler": { 106 | "type": "reduce_on_plateau", 107 | "factor": 0.5, 108 | "mode": "max", 109 | "patience": 0 110 | } 111 | } 112 | } 113 | 114 | -------------------------------------------------------------------------------- /search_spaces/long_classifier_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "CUDA_DEVICE": 0, 4 | "EVALUATE_ON_TEST": 0, 5 | "NUM_EPOCHS": 50, 6 | "SEED": 93408, 7 | "TRAIN_PATH": "s3://suching-dev/final-datasets/imdb/train_pretokenized.jsonl", 8 | "DEV_PATH": "s3://suching-dev/final-datasets/imdb/dev_pretokenized.jsonl", 9 | "TEST_PATH": "s3://suching-dev/final-datasets/imdb/test_pretokenized.jsonl", 10 | "THROTTLE": 200, 11 | "USE_SPACY_TOKENIZER": 0, 12 | "FREEZE_EMBEDDINGS": "VAMPIRE", 13 | "EMBEDDINGS": ["RANDOM", "VAMPIRE"], 14 | "VAMPIRE_DIRECTORY": { 15 | "sampling strategy": "choice", 16 | "choices": ["logs/vampire_search_long/run_29_2019-05-28_21-41-39d62snk66 32", "logs/vampire_search_long/run_15_2019-05-28_09-16-14so3s8wwc 117", "logs/vampire_search_long/run_14_2019-05-28_08-46-20q9w0xf_6 121", "logs/vampire_search_long/run_27_2019-05-28_20-29-160qqbwwlu 87", "logs/vampire_search_long/run_12_2019-05-28_04-56-42gvcdrud1 68", "logs/vampire_search_long/run_3_2019-05-27_23-31-17uuvsjrxl 44", "logs/vampire_search_long/run_5_2019-05-27_23-31-18jgagzjpg 127", "logs/vampire_search_long/run_19_2019-05-28_13-18-134ux0uqs7 66", "logs/vampire_search_long/run_35_2019-05-29_04-25-37tdw9fby_ 122", "logs/vampire_search_long/run_21_2019-05-28_15-10-30tanv07k2 84", "logs/vampire_search_long/run_4_2019-05-27_23-31-18fl0jodph 37", "logs/vampire_search_long/run_11_2019-05-28_04-27-430ckqrt6v 66", "logs/vampire_search_long/run_30_2019-05-29_00-06-46ewzmttzl 44", "logs/vampire_search_long/run_22_2019-05-28_15-16-08aemlqxro 119", "logs/vampire_search_long/run_8_2019-05-28_02-36-52t6s382e9 119", "logs/vampire_search_long/run_26_2019-05-28_16-20-08blyxnip1 32", "logs/vampire_search_long/run_28_2019-05-28_20-53-44a3qt5tva 35", "logs/vampire_search_long/run_24_2019-05-28_15-45-57kn4quf2j 127", "logs/vampire_search_long/run_9_2019-05-28_03-51-03d4xh6ski 59", "logs/vampire_search_long/run_13_2019-05-28_06-55-17n38yt5od 111", "logs/vampire_search_long/run_0_2019-05-27_23-31-174uu6yvzu 32", "logs/vampire_search_long/run_1_2019-05-27_23-31-17jxqhgjse 64", "logs/vampire_search_long/run_25_2019-05-28_15-54-25r_sg9hnu 38", "logs/vampire_search_long/run_23_2019-05-28_15-40-27j8gt3hh4 41", "logs/vampire_search_long/run_6_2019-05-27_23-31-18jlytp_il 61", "logs/vampire_search_long/run_20_2019-05-28_14-59-29wvt6p71u 58", "logs/vampire_search_long/run_7_2019-05-27_23-31-185mvrmabt 83", "logs/vampire_search_long/run_10_2019-05-28_04-27-36inyo33mz 104", "logs/vampire_search_long/run_17_2019-05-28_10-50-56m_gzto1c 59", "logs/vampire_search_long/run_31_2019-05-29_02-27-39zrfoqyct 62", "logs/vampire_search_long/run_18_2019-05-28_10-51-08b3hxpe75 121"] 17 | }, 18 | "ENCODER": { 19 | "sampling strategy": "choice", 20 | "choices": ["AVERAGE"] 21 | }, 22 | "EMBEDDING_DROPOUT": 0.26941597325945665, 23 | "LEARNING_RATE": 0.004847983603406938, 24 | "DROPOUT": 0.10581295186904283, 25 | "BATCH_SIZE": 16, 26 | "NUM_ENCODER_LAYERS": { 27 | "sampling strategy": "choice", 28 | "choices": [1, 2, 3] 29 | }, 30 | "NUM_OUTPUT_LAYERS": { 31 | "sampling strategy": "choice", 32 | "choices": [1, 2, 3] 33 | }, 34 | "MAX_FILTER_SIZE": { 35 | "sampling strategy": "integer", 36 | "bounds": [3, 6] 37 | }, 38 | "NUM_FILTERS": { 39 | "sampling strategy": "integer", 40 | "bounds": [64, 512] 41 | }, 42 | "HIDDEN_SIZE": { 43 | "sampling strategy": "integer", 44 | "bounds": [64, 512] 45 | }, 46 | "AGGREGATIONS": { 47 | "sampling strategy": "subset", 48 | "choices": ["maxpool", "meanpool", "attention", "final_state"] 49 | }, 50 | "MAX_CHARACTER_FILTER_SIZE": { 51 | "sampling strategy": "integer", 52 | "bounds": [3, 6] 53 | }, 54 | "NUM_CHARACTER_FILTERS": { 55 | "sampling strategy": "integer", 56 | "bounds": [16, 64] 57 | }, 58 | "CHARACTER_HIDDEN_SIZE": { 59 | "sampling strategy": "integer", 60 | "bounds": [16, 128] 61 | }, 62 | "CHARACTER_EMBEDDING_DIM": { 63 | "sampling strategy": "integer", 64 | "bounds": [16, 128] 65 | }, 66 | "CHARACTER_ENCODER": { 67 | "sampling strategy": "choice", 68 | "choices": ["LSTM", "CNN", "AVERAGE"] 69 | }, 70 | "NUM_CHARACTER_ENCODER_LAYERS": { 71 | "sampling strategy": "choice", 72 | "choices": [1, 2] 73 | } 74 | } -------------------------------------------------------------------------------- /vampire/models/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from allennlp.data import Vocabulary 5 | from allennlp.models.model import Model 6 | from allennlp.modules import TextFieldEmbedder 7 | from allennlp.nn import InitializerApplicator 8 | from allennlp.nn.util import get_text_field_mask 9 | from allennlp.training.metrics import CategoricalAccuracy 10 | 11 | from vampire.modules.encoder import Encoder 12 | 13 | 14 | @Model.register("classifier") 15 | class Classifier(Model): 16 | """ 17 | Generic classifier model. Differs from allennlp's basic_classifier 18 | in the fact that it uses a custom Encoder, which wraps all seq2vec 19 | and seq2seq encoders to easily switch between them during 20 | experimentation. 21 | """ 22 | def __init__(self, 23 | vocab: Vocabulary, 24 | input_embedder: TextFieldEmbedder, 25 | encoder: Encoder = None, 26 | dropout: float = None, 27 | initializer: InitializerApplicator = InitializerApplicator() 28 | ) -> None: 29 | """ 30 | Parameters 31 | ---------- 32 | vocab: `Vocabulary` 33 | vocab to use 34 | input_embedder: `TextFieldEmbedder` 35 | generic embedder of tokens 36 | encoder: `Encoder`, optional (default = None) 37 | Seq2Vec or Seq2Seq Encoder wrapper. If no encoder is provided, 38 | assume that the input is a bag of word counts, for linear classification. 39 | dropout: `float`, optional (default = None) 40 | if set, will apply dropout to output of encoder. 41 | initializer: `InitializerApplicator` 42 | generic initializer 43 | """ 44 | super().__init__(vocab) 45 | self._input_embedder = input_embedder 46 | if dropout: 47 | self._dropout = torch.nn.Dropout(dropout) 48 | else: 49 | self._dropout = None 50 | self._encoder = encoder 51 | self._num_labels = vocab.get_vocab_size(namespace="labels") 52 | if self._encoder: 53 | self._clf_input_dim = self._encoder.get_output_dim() 54 | else: 55 | self._clf_input_dim = self._input_embedder.get_output_dim() 56 | self._classification_layer = torch.nn.Linear(self._clf_input_dim, 57 | self._num_labels) 58 | self._accuracy = CategoricalAccuracy() 59 | self._loss = torch.nn.CrossEntropyLoss() 60 | initializer(self) 61 | 62 | def forward(self, # type: ignore 63 | tokens: Dict[str, torch.LongTensor], 64 | label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: 65 | # pylint: disable=arguments-differ 66 | """ 67 | Parameters 68 | ---------- 69 | tokens : Dict[str, torch.LongTensor] 70 | From a ``TextField`` 71 | label : torch.IntTensor, optional (default = None) 72 | From a ``LabelField`` 73 | Returns 74 | ------- 75 | An output dictionary consisting of: 76 | 77 | logits : torch.FloatTensor 78 | A tensor of shape ``(batch_size, num_labels)`` representing 79 | unnormalized log probabilities of the label. 80 | probs : torch.FloatTensor 81 | A tensor of shape ``(batch_size, num_labels)`` representing 82 | probabilities of the label. 83 | loss : torch.FloatTensor, optional 84 | A scalar loss to be optimised. 85 | """ 86 | embedded_text = self._input_embedder(tokens) 87 | mask = get_text_field_mask(tokens).float() 88 | 89 | if self._encoder: 90 | embedded_text = self._encoder(embedded_text=embedded_text, 91 | mask=mask) 92 | 93 | if self._dropout: 94 | embedded_text = self._dropout(embedded_text) 95 | 96 | logits = self._classification_layer(embedded_text) 97 | probs = torch.nn.functional.softmax(logits, dim=-1) 98 | 99 | output_dict = {"logits": logits, "probs": probs} 100 | 101 | if label is not None: 102 | loss = self._loss(logits, label.long().view(-1)) 103 | output_dict["loss"] = loss 104 | self._accuracy(logits, label) 105 | 106 | return output_dict 107 | 108 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 109 | metrics = {'accuracy': self._accuracy.get_metric(reset)} 110 | return metrics 111 | -------------------------------------------------------------------------------- /vampire/common/allennlp_bridge.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import logging 3 | import os 4 | from typing import Iterable 5 | 6 | from allennlp.common.file_utils import cached_path 7 | from allennlp.common.params import Params 8 | from allennlp.common.util import namespace_match 9 | from allennlp.data import instance as adi # pylint: disable=unused-import 10 | from allennlp.data.vocabulary import Vocabulary 11 | from overrides import overrides 12 | 13 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 14 | 15 | DEFAULT_NON_PADDED_NAMESPACES = ("*tags", "*labels") 16 | DEFAULT_PADDING_TOKEN = "@@PADDING@@" 17 | DEFAULT_OOV_TOKEN = "@@UNKNOWN@@" 18 | NAMESPACE_PADDING_FILE = 'non_padded_namespaces.txt' 19 | 20 | 21 | @Vocabulary.register("extended_vocabulary") 22 | class ExtendedVocabulary(Vocabulary): 23 | """ 24 | Augment the allennlp Vocabulary with ability to dump background 25 | frequencies. 26 | """ 27 | 28 | @classmethod 29 | def from_files(cls, directory: str) -> 'Vocabulary': 30 | """ 31 | Loads a ``Vocabulary`` that was serialized using ``save_to_files``. 32 | Parameters 33 | ---------- 34 | directory : ``str`` 35 | The directory containing the serialized vocabulary. 36 | """ 37 | 38 | logger.info("Loading token dictionary from %s.", directory) 39 | with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'r', 'utf-8') as namespace_file: 40 | non_padded_namespaces = [namespace_str.strip() for namespace_str in namespace_file] 41 | 42 | vocab = cls(non_padded_namespaces=non_padded_namespaces) 43 | vocab.serialization_dir = directory # pylint: disable=W0201 44 | # Check every file in the directory. 45 | for namespace_filename in os.listdir(directory): 46 | if namespace_filename == NAMESPACE_PADDING_FILE: 47 | continue 48 | if namespace_filename.startswith("."): 49 | continue 50 | namespace = namespace_filename.replace('.txt', '') 51 | if any(namespace_match(pattern, namespace) for pattern in non_padded_namespaces): 52 | is_padded = False 53 | else: 54 | is_padded = True 55 | filename = os.path.join(directory, namespace_filename) 56 | vocab.set_from_file(filename, is_padded, namespace=namespace) 57 | 58 | return vocab 59 | 60 | @overrides 61 | def save_to_files(self, directory: str) -> None: 62 | """ 63 | Persist this Vocabulary to files so it can be reloaded later. 64 | Each namespace corresponds to one file. 65 | Parameters 66 | ---------- 67 | directory : ``str`` 68 | The directory where we save the serialized vocabulary. 69 | """ 70 | self.serialization_dir = directory # pylint: disable=W0201 71 | os.makedirs(directory, exist_ok=True) 72 | if os.listdir(directory): 73 | logging.warning("vocabulary serialization directory %s is not empty", directory) 74 | 75 | with codecs.open(os.path.join(directory, NAMESPACE_PADDING_FILE), 'w', 'utf-8') as namespace_file: 76 | for namespace_str in self._non_padded_namespaces: 77 | print(namespace_str, file=namespace_file) 78 | 79 | for namespace, mapping in self._index_to_token.items(): 80 | # Each namespace gets written to its own file, in index order. 81 | with codecs.open(os.path.join(directory, namespace + '.txt'), 'w', 'utf-8') as token_file: 82 | num_tokens = len(mapping) 83 | start_index = 1 if mapping[0] == self._padding_token else 0 84 | for i in range(start_index, num_tokens): 85 | print(mapping[i].replace('\n', '@@NEWLINE@@'), file=token_file) 86 | 87 | @Vocabulary.register("vocabulary_with_vampire") 88 | class VocabularyWithPretrainedVAE(Vocabulary): 89 | """ 90 | Augment the allennlp Vocabulary with filtered vocabulary 91 | Idea: override from_params to "set" the vocab from a file before 92 | constructing in a normal fashion. 93 | """ 94 | 95 | @classmethod 96 | def from_params(cls, params: Params, instances: Iterable['adi.Instance'] = None): 97 | vampire_vocab_file = params.pop('vampire_vocab_file') 98 | vocab = cls() 99 | vocab = vocab.from_instances(instances=instances, 100 | tokens_to_add={"classifier": ["@@UNKNOWN@@"]}) 101 | vampire_vocab_file = cached_path(vampire_vocab_file) 102 | vocab.set_from_file(filename=vampire_vocab_file, 103 | namespace="vampire", 104 | oov_token="@@UNKNOWN@@", 105 | is_padded=False) 106 | return vocab 107 | -------------------------------------------------------------------------------- /vampire/modules/pretrained_vae.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Union, List, Dict 3 | import torch 4 | from overrides import overrides 5 | from allennlp.common import Params 6 | from allennlp.models.archival import load_archive 7 | from allennlp.common.file_utils import cached_path 8 | from allennlp.modules.scalar_mix import ScalarMix 9 | 10 | 11 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 12 | 13 | 14 | class _PretrainedVAE: 15 | def __init__(self, 16 | model_archive: str, 17 | device: int, 18 | background_frequency: str, 19 | requires_grad: bool = False) -> None: 20 | 21 | super(_PretrainedVAE, self).__init__() 22 | logger.info("Initializing pretrained VAMPIRE") 23 | self.cuda_device = device if torch.cuda.is_available() else -1 24 | archive = load_archive(cached_path(model_archive), cuda_device=self.cuda_device) 25 | self.vae = archive.model 26 | if not requires_grad: 27 | self.vae.eval() 28 | self.vae.freeze_weights() 29 | self.vae.initialize_bg_from_file(cached_path(background_frequency)) 30 | self._requires_grad = requires_grad 31 | 32 | 33 | class PretrainedVAE(torch.nn.Module): 34 | """ 35 | Core Pretrained VAMPIRE module 36 | """ 37 | def __init__(self, 38 | model_archive: str, 39 | device: int, 40 | background_frequency: str, 41 | requires_grad: bool = False, 42 | scalar_mix: List[int] = None, 43 | dropout: float = None) -> None: 44 | 45 | super(PretrainedVAE, self).__init__() 46 | logger.info("Initializing pretrained VAMPIRE") 47 | self._pretrained_model = _PretrainedVAE(model_archive=model_archive, 48 | device=device, 49 | background_frequency=background_frequency, 50 | requires_grad=requires_grad) 51 | self._requires_grad = requires_grad 52 | if dropout: 53 | self._dropout = torch.nn.Dropout(dropout) 54 | else: 55 | self._dropout = None 56 | num_layers = len(self._pretrained_model.vae.vae.encoder._linear_layers) + 1 # pylint: disable=protected-access 57 | if not scalar_mix: 58 | initial_params = [1] + [-20] * (num_layers - 2) + [1] 59 | else: 60 | initial_params = scalar_mix 61 | self.scalar_mix = ScalarMix( 62 | num_layers, 63 | do_layer_norm=False, 64 | initial_scalar_parameters=initial_params, 65 | trainable=not scalar_mix) 66 | self.add_module('scalar_mix', self.scalar_mix) 67 | 68 | def get_output_dim(self) -> int: 69 | output_dim = self._pretrained_model.vae.vae.encoder.get_output_dim() 70 | return output_dim 71 | 72 | @overrides 73 | def forward(self, # pylint: disable=arguments-differ 74 | inputs: torch.Tensor) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]: 75 | """ 76 | Parameters 77 | ---------- 78 | inputs: ``torch.Tensor``, required. 79 | Shape ``(batch_size, timesteps)`` of word ids representing the current batch. 80 | Returns 81 | ------- 82 | Dict with keys: 83 | ``'vae_representations'``: ``List[torch.Tensor]`` 84 | A ``num_output_representations`` list of VAE representations for the input sequence. 85 | Each representation is shape ``(batch_size, timesteps, embedding_dim)`` 86 | or ``(batch_size, embedding_dim)`` depending on the VAE representation being used. 87 | ``'mask'``: ``torch.Tensor`` 88 | Shape ``(batch_size, timesteps)`` long tensor with sequence mask. 89 | """ 90 | vae_output = self._pretrained_model.vae(tokens={'tokens': inputs}) 91 | 92 | layers, layer_activations = zip(*vae_output['activations']) 93 | 94 | scalar_mix = getattr(self, 'scalar_mix') 95 | representation = scalar_mix(layer_activations) 96 | 97 | if self._dropout: 98 | representation = self._dropout(representation) 99 | 100 | return {'vae_representation': representation, 'layers': layers} 101 | 102 | @classmethod 103 | def from_params(cls, params: Params) -> 'PretrainedVAE': 104 | # Add files to archive 105 | params.add_file_to_archive('model_archive') 106 | model_archive = params.pop('model_archive') 107 | device = params.pop('device') 108 | background_frequency = params.pop('background_frequency') 109 | requires_grad = params.pop('requires_grad', False) 110 | dropout = params.pop_float('dropout', None) 111 | scalar_mix = params.pop('scalar_mix', None) 112 | params.assert_empty(cls.__name__) 113 | return cls(model_archive=model_archive, 114 | device=device, 115 | background_frequency=background_frequency, 116 | requires_grad=requires_grad, 117 | scalar_mix=scalar_mix, 118 | dropout=dropout) 119 | -------------------------------------------------------------------------------- /vampire/modules/encoder.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=arguments-differ 2 | 3 | import torch 4 | from overrides import overrides 5 | from allennlp.common import Registrable 6 | from allennlp.modules import FeedForward, Seq2SeqEncoder, Seq2VecEncoder 7 | from allennlp.nn.util import (get_final_encoder_states, masked_max, masked_mean, masked_log_softmax) 8 | from allennlp.common.checks import ConfigurationError 9 | 10 | class Encoder(Registrable, torch.nn.Module): 11 | """ 12 | This module is a wrapper over AllenNLP encoders, to make it easy to switch 13 | between them in the training config when doing things like hyperparameter search. 14 | 15 | It's the same interface as the AllenNLP encoders, except the encoder architecture is 16 | nested one-level deep (under the field ``architecture``). 17 | """ 18 | default_implementation = 'feedforward' 19 | 20 | def __init__(self, architecture: torch.nn.Module) -> None: 21 | super(Encoder, self).__init__() 22 | self._architecture = architecture 23 | 24 | def get_output_dim(self) -> int: 25 | return self._architecture.get_output_dim() 26 | 27 | def forward(self, **kwargs) -> torch.FloatTensor: 28 | raise NotImplementedError 29 | 30 | @Encoder.register("feedforward") 31 | class MLP(Encoder): 32 | 33 | def __init__(self, architecture: FeedForward) -> None: 34 | super(MLP, self).__init__(architecture) 35 | self._architecture = architecture 36 | 37 | @overrides 38 | def forward(self, **kwargs) -> torch.FloatTensor: 39 | return self._architecture(kwargs['embedded_text']) 40 | 41 | @Seq2VecEncoder.register("maxpool") 42 | class MaxPoolEncoder(Seq2VecEncoder): 43 | def __init__(self, 44 | embedding_dim: int) -> None: 45 | super(MaxPoolEncoder, self).__init__() 46 | self._embedding_dim = embedding_dim 47 | 48 | def get_input_dim(self) -> int: 49 | return self._embedding_dim 50 | 51 | def get_output_dim(self) -> int: 52 | return self._embedding_dim 53 | 54 | def forward(self, tokens: torch.Tensor, mask: torch.Tensor): #pylint: disable=arguments-differ 55 | broadcast_mask = mask.unsqueeze(-1).float() 56 | one_minus_mask = (1.0 - broadcast_mask).byte() 57 | replaced = tokens.masked_fill(one_minus_mask, -1e-7) 58 | max_value, _ = replaced.max(dim=1, keepdim=False) 59 | return max_value 60 | 61 | @Encoder.register("seq2vec") 62 | class Seq2Vec(Encoder): 63 | 64 | def __init__(self, architecture: Seq2VecEncoder) -> None: 65 | super(Seq2Vec, self).__init__(architecture) 66 | self._architecture = architecture 67 | 68 | @overrides 69 | def forward(self, **kwargs) -> torch.FloatTensor: 70 | return self._architecture(kwargs['embedded_text'], kwargs['mask']) 71 | 72 | 73 | @Encoder.register("seq2seq") 74 | class Seq2Seq(Encoder): 75 | 76 | def __init__(self, architecture: Seq2SeqEncoder, aggregations: str) -> None: 77 | super(Seq2Seq, self).__init__(architecture) 78 | self._architecture = architecture 79 | self._aggregations = aggregations 80 | if "attention" in self._aggregations: 81 | self._attention_layer = torch.nn.Linear(self._architecture.get_output_dim(), 82 | 1) 83 | 84 | @overrides 85 | def get_output_dim(self): 86 | return self._architecture.get_output_dim() * len(self._aggregations) 87 | 88 | @overrides 89 | def forward(self, **kwargs) -> torch.FloatTensor: 90 | mask = kwargs['mask'] 91 | embedded_text = kwargs['embedded_text'] 92 | encoded_output = self._architecture(embedded_text, mask) 93 | encoded_repr = [] 94 | for aggregation in self._aggregations: 95 | if aggregation == "meanpool": 96 | broadcast_mask = mask.unsqueeze(-1).float() 97 | context_vectors = encoded_output * broadcast_mask 98 | encoded_text = masked_mean(context_vectors, 99 | broadcast_mask, 100 | dim=1, 101 | keepdim=False) 102 | elif aggregation == 'maxpool': 103 | broadcast_mask = mask.unsqueeze(-1).float() 104 | context_vectors = encoded_output * broadcast_mask 105 | encoded_text = masked_max(context_vectors, 106 | broadcast_mask, 107 | dim=1) 108 | elif aggregation == 'final_state': 109 | is_bi = self._architecture.is_bidirectional() 110 | encoded_text = get_final_encoder_states(encoded_output, 111 | mask, 112 | is_bi) 113 | elif aggregation == 'attention': 114 | alpha = self._attention_layer(encoded_output) 115 | alpha = masked_log_softmax(alpha, mask.unsqueeze(-1), dim=1).exp() 116 | encoded_text = alpha * encoded_output 117 | encoded_text = encoded_text.sum(dim=1) 118 | else: 119 | raise ConfigurationError(f"{aggregation} aggregation not available.") 120 | encoded_repr.append(encoded_text) 121 | 122 | encoded_repr = torch.cat(encoded_repr, 1) 123 | return encoded_repr 124 | -------------------------------------------------------------------------------- /colab/VAMPIRE_AGNews.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "### VAMPIRE Example: AG News Corpus" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": {}, 23 | "source": [ 24 | "In this notebook, we run through the example in the README. Since VAMPIRE is a low resource method, it can be run on the CPU or Colab GPU. Before starting, open this notebook in a Colab environment, and connect to a GPU instance." 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "Clone the repository and cd into working directory:" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 0, 37 | "metadata": { 38 | "colab": {}, 39 | "colab_type": "code", 40 | "id": "wEFivVCCpEUn" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "!git clone https://github.com/allenai/vampire\n" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 0, 50 | "metadata": { 51 | "colab": {}, 52 | "colab_type": "code", 53 | "id": "rMDq7NeCrHB3" 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "%cd vampire" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "Install requirements:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 0, 70 | "metadata": { 71 | "colab": {}, 72 | "colab_type": "code", 73 | "id": "f52p11vJpJW8" 74 | }, 75 | "outputs": [], 76 | "source": [ 77 | "!pip install -r requirements.txt" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 0, 83 | "metadata": { 84 | "colab": {}, 85 | "colab_type": "code", 86 | "id": "JXC-bTeIptAF" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "!python -m spacy download en" 91 | ] 92 | }, 93 | { 94 | "cell_type": "markdown", 95 | "metadata": {}, 96 | "source": [ 97 | "All tests should pass:" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 0, 103 | "metadata": { 104 | "colab": {}, 105 | "colab_type": "code", 106 | "id": "mTii7MynrbLn" 107 | }, 108 | "outputs": [], 109 | "source": [ 110 | "!SEED=42 python -m pytest -v --color=yes vampire" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "Download preprocessed AG News corpus, ready to run with VAMPIRE:" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": 0, 123 | "metadata": { 124 | "colab": {}, 125 | "colab_type": "code", 126 | "id": "zSVpbMX1rQl7" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "!sh scripts/download_ag.sh\n", 131 | "!curl -Lo ag.tar https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/vampire_preprocessed_example.tar\n", 132 | "!tar -xvf ag.tar -C examples/\n", 133 | "!rm ag.tar" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "Run VAMPIRE:" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 0, 146 | "metadata": { 147 | "colab": {}, 148 | "colab_type": "code", 149 | "id": "d_oLBgaYtKsL" 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "!DATA_DIR=\"$(pwd)/examples/ag\" VOCAB_SIZE=30000 LAZY=1 python -m scripts.train \\\n", 154 | " --config training_config/vampire.jsonnet \\\n", 155 | " --serialization-dir model_logs/vampire \\\n", 156 | " --environment VAMPIRE \\\n", 157 | " --device 0 -o" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "After VAMPIRE has trained, we can run a downstream classifier on the AG News corpus with just 200 examples. We report an average of 83.9% accuracy in the paper over five seeds under this setting:" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 0, 170 | "metadata": { 171 | "colab": {}, 172 | "colab_type": "code", 173 | "id": "D0T6rYRLtTU_" 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "!DATA_DIR=\"$(pwd)/examples/ag\" VAMPIRE_DIR=\"$(pwd)/model_logs/vampire\" VAMPIRE_DIM=81 THROTTLE=200 EVALUATE_ON_TEST=0 python -m scripts.train \\\n", 178 | " --config training_config/classifier.jsonnet \\\n", 179 | " --serialization-dir model_logs/clf \\\n", 180 | " --environment CLASSIFIER \\\n", 181 | " --device 0" 182 | ] 183 | } 184 | ], 185 | "metadata": { 186 | "accelerator": "GPU", 187 | "colab": { 188 | "include_colab_link": true, 189 | "name": "VAMPIRE_AGNews.ipynb", 190 | "provenance": [], 191 | "version": "0.3.2" 192 | }, 193 | "kernelspec": { 194 | "display_name": "Python 3", 195 | "language": "python", 196 | "name": "python3" 197 | } 198 | }, 199 | "nbformat": 4, 200 | "nbformat_minor": 1 201 | } 202 | -------------------------------------------------------------------------------- /vampire/modules/vae/logistic_normal.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List, Tuple 2 | import os 3 | import torch 4 | from allennlp.modules import FeedForward 5 | from overrides import overrides 6 | 7 | from vampire.modules.vae.vae import VAE 8 | 9 | 10 | @VAE.register("logistic_normal") 11 | class LogisticNormal(VAE): 12 | """ 13 | A Variational Autoencoder with a Logistic Normal prior 14 | """ 15 | def __init__(self, 16 | vocab, 17 | encoder: FeedForward, 18 | mean_projection: FeedForward, 19 | log_variance_projection: FeedForward, 20 | decoder: FeedForward, 21 | kld_clamp: Optional[float] = None, 22 | z_dropout: float = 0.2) -> None: 23 | super(LogisticNormal, self).__init__(vocab) 24 | self.encoder = encoder 25 | self.mean_projection = mean_projection 26 | self.log_variance_projection = log_variance_projection 27 | self._kld_clamp = kld_clamp 28 | self._decoder = torch.nn.Linear(decoder.get_input_dim(), decoder.get_output_dim(), 29 | bias=False) 30 | self._z_dropout = torch.nn.Dropout(z_dropout) 31 | 32 | self.latent_dim = mean_projection.get_output_dim() 33 | 34 | @overrides 35 | def forward(self, input_repr: torch.FloatTensor): # pylint: disable = W0221 36 | """ 37 | Given the input representation, produces the reconstruction from theta 38 | as well as the negative KL-divergence, theta itself, and the parameters 39 | of the distribution. 40 | """ 41 | activations: List[Tuple[str, torch.FloatTensor]] = [] 42 | intermediate_input = input_repr 43 | for layer_index, layer in enumerate(self.encoder._linear_layers): # pylint: disable=protected-access 44 | intermediate_input = layer(intermediate_input) 45 | activations.append((f"encoder_layer_{layer_index}", intermediate_input)) 46 | output = self.generate_latent_code(intermediate_input) 47 | theta = output["theta"] 48 | activations.append(('theta', theta)) 49 | reconstruction = self._decoder(theta) 50 | output["reconstruction"] = reconstruction 51 | output['activations'] = activations 52 | 53 | return output 54 | 55 | @overrides 56 | def estimate_params(self, input_repr: torch.FloatTensor): 57 | """ 58 | Estimate the parameters for the logistic normal. 59 | """ 60 | mean = self.mean_projection(input_repr) # pylint: disable=C0103 61 | log_var = self.log_variance_projection(input_repr) 62 | sigma = torch.sqrt(torch.exp(log_var)).clamp(max=10) # log_var is actually log (variance^2). 63 | return { 64 | "mean": mean, 65 | "variance": sigma, 66 | "log_variance": log_var 67 | } 68 | 69 | @overrides 70 | def compute_negative_kld(self, params: Dict): 71 | """ 72 | Compute the closed-form solution for negative KL-divergence for Gaussians. 73 | """ 74 | mu, sigma = params["mean"], params["variance"] # pylint: disable=C0103 75 | negative_kl_divergence = 1 + torch.log(sigma ** 2) - mu ** 2 - sigma ** 2 76 | if self._kld_clamp: 77 | negative_kl_divergence = torch.clamp(negative_kl_divergence, 78 | min=-1 * self._kld_clamp, 79 | max=self._kld_clamp) 80 | negative_kl_divergence = 0.5 * negative_kl_divergence.sum(dim=-1) # Shape: (batch, ) 81 | return negative_kl_divergence 82 | 83 | @overrides 84 | def generate_latent_code(self, input_repr: torch.Tensor): 85 | """ 86 | Given an input vector, produces the latent encoding z, followed by the 87 | mean and log variance of the variational distribution produced. 88 | 89 | z is the result of the reparameterization trick. 90 | (https://arxiv.org/abs/1312.6114) 91 | """ 92 | params = self.estimate_params(input_repr) 93 | negative_kl_divergence = self.compute_negative_kld(params) 94 | mu, sigma = params["mean"], params["variance"] # pylint: disable=C0103 95 | 96 | # Generate random noise and sample theta. 97 | # Shape: (batch, latent_dim) 98 | batch_size = params["mean"].size(0) 99 | 100 | # Enable reparameterization for training only. 101 | if self.training: 102 | seed = os.environ['SEED'] 103 | torch.manual_seed(seed) 104 | # Seed all GPUs with the same seed if available. 105 | if torch.cuda.is_available(): 106 | torch.cuda.manual_seed_all(seed) 107 | epsilon = torch.randn(batch_size, self.latent_dim).to(device=mu.device) 108 | z = mu + sigma * epsilon # pylint: disable=C0103 109 | else: 110 | z = mu # pylint: disable=C0103 111 | 112 | # Apply dropout to theta. 113 | theta = self._z_dropout(z) 114 | 115 | # Normalize theta. 116 | theta = torch.softmax(theta, dim=-1) 117 | 118 | return { 119 | "theta": theta, 120 | "params": params, 121 | "negative_kl_divergence": negative_kl_divergence 122 | } 123 | 124 | @overrides 125 | def encode(self, input_vector: torch.Tensor): 126 | return self.encoder(input_vector) 127 | 128 | @overrides 129 | def get_beta(self): 130 | return self._decoder._parameters['weight'].data.transpose(0, 1) # pylint: disable=W0212 131 | -------------------------------------------------------------------------------- /vampire/modules/token_embedders/vampire_token_embedder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | from allennlp.common import Params 5 | from allennlp.data import Vocabulary 6 | from allennlp.modules.time_distributed import TimeDistributed 7 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 8 | 9 | from vampire.modules.pretrained_vae import PretrainedVAE 10 | 11 | 12 | @TokenEmbedder.register("vampire_token_embedder") 13 | class VampireTokenEmbedder(TokenEmbedder): 14 | """ 15 | Compute VAMPIRE representations for use with downstream models. 16 | 17 | Parameters 18 | ---------- 19 | model_archive : ``str``, required. 20 | A path to the pretrained VAMPIRE model archive 21 | device : ``int``, required. 22 | The device you'd like to load the VAE on. 23 | background_frequency : ``str``, required. 24 | Path to the precomputed background frequency file used with this VAMPIRE 25 | scalar_mix : ``List[int]``, optional, (default=None) 26 | If not ``None``, use these scalar mix parameters to weight the representations 27 | produced by different layers. These mixing weights are not updated during 28 | training. 29 | dropout : ``float``, optional. 30 | The dropout value to be applied to the VAMPIRE representations. 31 | requires_grad : ``bool``, optional 32 | If True, compute gradient of VAMPIRE parameters for fine tuning. 33 | projection_dim : ``int``, optional 34 | If given, we will project the VAMPIRE embedding down to this dimension. We recommend that you 35 | try using VAE with a lot of dropout and no projection first, but we have found a few cases 36 | where projection helps (particularly where there is very limited training data). 37 | expand_dim : `bool``, optional 38 | If True, expand the dimensions of the output to a 3-dimensional matrix that can be concatenated with 39 | word vectors. 40 | """ 41 | def __init__(self, 42 | model_archive: str, 43 | device: int, 44 | background_frequency: str, 45 | scalar_mix: List[int] = None, 46 | dropout: float = None, 47 | requires_grad: bool = False, 48 | projection_dim: int = None, 49 | expand_dim: bool = False) -> None: 50 | super(VampireTokenEmbedder, self).__init__() 51 | 52 | self._vae = PretrainedVAE(model_archive, 53 | device, 54 | background_frequency, 55 | requires_grad, 56 | scalar_mix, 57 | dropout) 58 | self._expand_dim = expand_dim 59 | self._layers = None 60 | if projection_dim: 61 | self._projection = torch.nn.Linear(self._vae.get_output_dim(), projection_dim) 62 | self.output_dim = projection_dim 63 | else: 64 | self._projection = None 65 | self.output_dim = self._vae.get_output_dim() 66 | 67 | def get_output_dim(self) -> int: 68 | return self.output_dim 69 | 70 | def forward(self, # pylint: disable=arguments-differ 71 | inputs: torch.Tensor) -> torch.Tensor: 72 | """ 73 | Parameters 74 | ---------- 75 | inputs: ``torch.Tensor`` 76 | Shape ``(batch_size, timesteps)`` of character ids representing the current batch. 77 | Returns 78 | ------- 79 | The VAMPIRE representations for the input sequence, shape 80 | ``(batch_size, timesteps, embedding_dim)`` or ``(batch_size, timesteps)`` 81 | depending on whether expand_dim is set to True. 82 | """ 83 | vae_output = self._vae(inputs) 84 | embedded = vae_output['vae_representation'] 85 | self._layers = vae_output['layers'] 86 | if self._expand_dim: 87 | embedded = (embedded.unsqueeze(0) 88 | .expand(inputs.shape[1], inputs.shape[0], -1) 89 | .permute(1, 0, 2).contiguous()) 90 | if self._projection: 91 | projection = self._projection 92 | for _ in range(embedded.dim() - 2): 93 | projection = TimeDistributed(projection) 94 | embedded = projection(embedded) 95 | return embedded 96 | 97 | # Custom vocab_to_cache logic requires a from_params implementation. 98 | @classmethod 99 | def from_params(cls, 100 | vocab: Vocabulary, # pylint: disable=unused-argument 101 | params: Params) -> 'VampireTokenEmbedder': # type: ignore 102 | # pylint: disable=arguments-differ 103 | params.add_file_to_archive('model_archive') 104 | model_archive = params.pop('model_archive') 105 | device = params.pop_int('device') 106 | background_frequency = params.pop('background_frequency') 107 | requires_grad = params.pop('requires_grad', False) 108 | scalar_mix = params.pop("scalar_mix", None) 109 | dropout = params.pop_float("dropout", None) 110 | expand_dim = params.pop_float("expand_dim", False) 111 | projection_dim = params.pop_int("projection_dim", None) 112 | params.assert_empty(cls.__name__) 113 | return cls(expand_dim=expand_dim, 114 | scalar_mix=scalar_mix, 115 | background_frequency=background_frequency, 116 | device=device, 117 | model_archive=model_archive, 118 | dropout=dropout, 119 | requires_grad=requires_grad, 120 | projection_dim=projection_dim) 121 | -------------------------------------------------------------------------------- /vampire/tests/modules/token_embedders/vampire_token_embedder_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | import torch 3 | import numpy as np 4 | from allennlp.common import Params 5 | from allennlp.common.testing import ModelTestCase 6 | from allennlp.data.dataset import Batch 7 | from vampire.common.testing.test_case import VAETestCase 8 | from vampire.modules.token_embedders import VampireTokenEmbedder 9 | 10 | 11 | class TestVampireTokenEmbedder(ModelTestCase): 12 | def setUp(self): 13 | super().setUp() 14 | 15 | def test_forward_works_with_encoder_output_and_projection(self): 16 | params = Params({ 17 | 'model_archive': VAETestCase.FIXTURES_ROOT / 'vae' / 'model.tar.gz', 18 | 'background_frequency': VAETestCase.FIXTURES_ROOT / 'imdb' / 'vampire.bgfreq', 19 | 'device': -1, 20 | 'projection_dim': 20 21 | }) 22 | word1 = [0] * 50 23 | word2 = [0] * 50 24 | word1[0] = 6 25 | word1[1] = 5 26 | word1[2] = 4 27 | word1[3] = 3 28 | word2[0] = 3 29 | word2[1] = 2 30 | word2[2] = 1 31 | word2[3] = 0 32 | embedding_layer = VampireTokenEmbedder.from_params(vocab=None, params=params) 33 | assert embedding_layer.get_output_dim() == 20 34 | input_tensor = torch.LongTensor([word1, word2]) 35 | embedded = embedding_layer(input_tensor).data.numpy() 36 | assert embedded.shape == (2, 20) 37 | 38 | def test_forward_encoder_output_with_expansion_works(self): 39 | params = Params({ 40 | 'model_archive': VAETestCase.FIXTURES_ROOT / 'vae' / 'model.tar.gz', 41 | 'background_frequency': VAETestCase.FIXTURES_ROOT / 'imdb' / 'vampire.bgfreq', 42 | 'device': -1, 43 | "expand_dim": True, 44 | "dropout": 0.0 45 | }) 46 | word1 = [0] * 50 47 | word2 = [0] * 50 48 | word1[0] = 6 49 | word1[1] = 5 50 | word1[2] = 4 51 | word1[3] = 3 52 | word2[0] = 3 53 | word2[1] = 2 54 | word2[2] = 1 55 | word2[3] = 0 56 | embedding_layer = VampireTokenEmbedder.from_params(vocab=None, params=params) 57 | input_tensor = torch.LongTensor([word1, word2]) 58 | expected_vectors = embedding_layer._vae(input_tensor)['vae_representation'].detach().data.numpy() 59 | embedded = embedding_layer(input_tensor).detach().data.numpy() 60 | for row in range(input_tensor.shape[0]): 61 | for col in range(input_tensor.shape[1]): 62 | np.testing.assert_allclose(embedded[row, col, :], expected_vectors[row, :]) 63 | 64 | def test_projection_works_with_encoder_weight_representations(self): 65 | params = Params({ 66 | 'model_archive': VAETestCase.FIXTURES_ROOT / 'vae' / 'model.tar.gz', 67 | 'background_frequency': VAETestCase.FIXTURES_ROOT / 'imdb' / 'vampire.bgfreq', 68 | 'device': -1, 69 | 'projection_dim': 20, 70 | 'expand_dim': True 71 | }) 72 | word1 = [0] * 50 73 | word2 = [0] * 50 74 | word1[0] = 6 75 | word1[1] = 5 76 | word1[2] = 4 77 | word1[3] = 3 78 | word2[0] = 3 79 | word2[1] = 2 80 | word2[2] = 1 81 | word2[3] = 0 82 | embedding_layer = VampireTokenEmbedder.from_params(vocab=None, params=params) 83 | assert embedding_layer.get_output_dim() == 20 84 | input_tensor = torch.LongTensor([word1, word2]) 85 | embedded = embedding_layer(input_tensor).data.numpy() 86 | assert embedded.shape == (2, 50, 20) 87 | 88 | def test_forward_works_with_encoder_weight_and_projection(self): 89 | params = Params({ 90 | 'model_archive': VAETestCase.FIXTURES_ROOT / 'vae' / 'model.tar.gz', 91 | 'background_frequency': VAETestCase.FIXTURES_ROOT / 'imdb' / 'vampire.bgfreq', 92 | 'device': -1, 93 | 'projection_dim': 20, 94 | 'expand_dim': True 95 | }) 96 | word1 = [0] * 50 97 | word2 = [0] * 50 98 | word1[0] = 6 99 | word1[1] = 5 100 | word1[2] = 4 101 | word1[3] = 3 102 | word2[0] = 3 103 | word2[1] = 2 104 | word2[2] = 1 105 | word2[3] = 0 106 | embedding_layer = VampireTokenEmbedder.from_params(vocab=None, params=params) 107 | assert embedding_layer.get_output_dim() == 20 108 | input_tensor = torch.LongTensor([word1, word2]) 109 | embedded = embedding_layer(input_tensor).data.numpy() 110 | assert embedded.shape == (2, 50, 20) 111 | 112 | def test_forward_works_with_encoder_output_expand_and_projection(self): 113 | params = Params({ 114 | 'model_archive': VAETestCase.FIXTURES_ROOT / 'vae' / 'model.tar.gz', 115 | 'background_frequency': VAETestCase.FIXTURES_ROOT / 'imdb' / 'vampire.bgfreq', 116 | 'device': -1, 117 | 'projection_dim': 20, 118 | 'expand_dim': True 119 | }) 120 | word1 = [0] * 50 121 | word2 = [0] * 50 122 | word1[0] = 6 123 | word1[1] = 5 124 | word1[2] = 4 125 | word1[3] = 3 126 | word2[0] = 3 127 | word2[1] = 2 128 | word2[2] = 1 129 | word2[3] = 0 130 | embedding_layer = VampireTokenEmbedder.from_params(vocab=None, params=params) 131 | assert embedding_layer.get_output_dim() == 20 132 | input_tensor = torch.LongTensor([word1, word2]) 133 | embedded = embedding_layer(input_tensor).data.numpy() 134 | assert embedded.shape == (2, 50, 20) -------------------------------------------------------------------------------- /vampire/common/stopwords/mallet_stopwords.txt: -------------------------------------------------------------------------------- 1 | a 2 | able 3 | about 4 | above 5 | according 6 | accordingly 7 | across 8 | actually 9 | after 10 | afterwards 11 | again 12 | against 13 | all 14 | allow 15 | allows 16 | almost 17 | alone 18 | along 19 | already 20 | also 21 | although 22 | always 23 | am 24 | among 25 | amongst 26 | an 27 | and 28 | another 29 | any 30 | anybody 31 | anyhow 32 | anyone 33 | anything 34 | anyway 35 | anyways 36 | anywhere 37 | apart 38 | appear 39 | appreciate 40 | appropriate 41 | are 42 | around 43 | as 44 | aside 45 | ask 46 | asking 47 | associated 48 | at 49 | available 50 | away 51 | awfully 52 | b 53 | be 54 | became 55 | because 56 | become 57 | becomes 58 | becoming 59 | been 60 | before 61 | beforehand 62 | behind 63 | being 64 | believe 65 | below 66 | beside 67 | besides 68 | best 69 | better 70 | between 71 | beyond 72 | both 73 | brief 74 | but 75 | by 76 | c 77 | came 78 | can 79 | cannot 80 | cant 81 | cause 82 | causes 83 | certain 84 | certainly 85 | changes 86 | clearly 87 | co 88 | com 89 | come 90 | comes 91 | concerning 92 | consequently 93 | consider 94 | considering 95 | contain 96 | containing 97 | contains 98 | corresponding 99 | could 100 | course 101 | currently 102 | d 103 | definitely 104 | described 105 | despite 106 | did 107 | different 108 | do 109 | does 110 | doing 111 | done 112 | down 113 | downwards 114 | during 115 | e 116 | each 117 | edu 118 | eg 119 | eight 120 | either 121 | else 122 | elsewhere 123 | enough 124 | entirely 125 | especially 126 | et 127 | etc 128 | even 129 | ever 130 | every 131 | everybody 132 | everyone 133 | everything 134 | everywhere 135 | ex 136 | exactly 137 | example 138 | except 139 | f 140 | far 141 | few 142 | fifth 143 | first 144 | five 145 | followed 146 | following 147 | follows 148 | for 149 | former 150 | formerly 151 | forth 152 | four 153 | from 154 | further 155 | furthermore 156 | g 157 | get 158 | gets 159 | getting 160 | given 161 | gives 162 | go 163 | goes 164 | going 165 | gone 166 | got 167 | gotten 168 | greetings 169 | h 170 | had 171 | happens 172 | hardly 173 | has 174 | have 175 | having 176 | he 177 | hello 178 | help 179 | hence 180 | her 181 | here 182 | hereafter 183 | hereby 184 | herein 185 | hereupon 186 | hers 187 | herself 188 | hi 189 | him 190 | himself 191 | his 192 | hither 193 | hopefully 194 | how 195 | howbeit 196 | however 197 | i 198 | ie 199 | if 200 | ignored 201 | immediate 202 | in 203 | inasmuch 204 | inc 205 | indeed 206 | indicate 207 | indicated 208 | indicates 209 | inner 210 | insofar 211 | instead 212 | into 213 | inward 214 | is 215 | it 216 | its 217 | itself 218 | j 219 | just 220 | k 221 | keep 222 | keeps 223 | kept 224 | know 225 | knows 226 | known 227 | l 228 | last 229 | lately 230 | later 231 | latter 232 | latterly 233 | least 234 | less 235 | lest 236 | let 237 | like 238 | liked 239 | likely 240 | little 241 | look 242 | looking 243 | looks 244 | ltd 245 | m 246 | mainly 247 | many 248 | may 249 | maybe 250 | me 251 | mean 252 | meanwhile 253 | merely 254 | might 255 | more 256 | moreover 257 | most 258 | mostly 259 | much 260 | must 261 | my 262 | myself 263 | n 264 | name 265 | namely 266 | nd 267 | near 268 | nearly 269 | necessary 270 | need 271 | needs 272 | neither 273 | never 274 | nevertheless 275 | new 276 | next 277 | nine 278 | no 279 | nobody 280 | non 281 | none 282 | noone 283 | nor 284 | normally 285 | not 286 | nothing 287 | novel 288 | now 289 | nowhere 290 | o 291 | obviously 292 | of 293 | off 294 | often 295 | oh 296 | ok 297 | okay 298 | old 299 | on 300 | once 301 | one 302 | ones 303 | only 304 | onto 305 | or 306 | other 307 | others 308 | otherwise 309 | ought 310 | our 311 | ours 312 | ourselves 313 | out 314 | outside 315 | over 316 | overall 317 | own 318 | p 319 | particular 320 | particularly 321 | per 322 | perhaps 323 | placed 324 | please 325 | plus 326 | possible 327 | presumably 328 | probably 329 | provides 330 | q 331 | que 332 | quite 333 | qv 334 | r 335 | rather 336 | rd 337 | re 338 | really 339 | reasonably 340 | regarding 341 | regardless 342 | regards 343 | relatively 344 | respectively 345 | right 346 | s 347 | said 348 | same 349 | saw 350 | say 351 | saying 352 | says 353 | second 354 | secondly 355 | see 356 | seeing 357 | seem 358 | seemed 359 | seeming 360 | seems 361 | seen 362 | self 363 | selves 364 | sensible 365 | sent 366 | serious 367 | seriously 368 | seven 369 | several 370 | shall 371 | she 372 | should 373 | since 374 | six 375 | so 376 | some 377 | somebody 378 | somehow 379 | someone 380 | something 381 | sometime 382 | sometimes 383 | somewhat 384 | somewhere 385 | soon 386 | sorry 387 | specified 388 | specify 389 | specifying 390 | still 391 | sub 392 | such 393 | sup 394 | sure 395 | t 396 | take 397 | taken 398 | tell 399 | tends 400 | th 401 | than 402 | thank 403 | thanks 404 | thanx 405 | that 406 | thats 407 | the 408 | their 409 | theirs 410 | them 411 | themselves 412 | then 413 | thence 414 | there 415 | thereafter 416 | thereby 417 | therefore 418 | therein 419 | theres 420 | thereupon 421 | these 422 | they 423 | think 424 | third 425 | this 426 | thorough 427 | thoroughly 428 | those 429 | though 430 | three 431 | through 432 | throughout 433 | thru 434 | thus 435 | to 436 | together 437 | too 438 | took 439 | toward 440 | towards 441 | tried 442 | tries 443 | truly 444 | try 445 | trying 446 | twice 447 | two 448 | u 449 | un 450 | under 451 | unfortunately 452 | unless 453 | unlikely 454 | until 455 | unto 456 | up 457 | upon 458 | us 459 | use 460 | used 461 | useful 462 | uses 463 | using 464 | usually 465 | uucp 466 | v 467 | value 468 | various 469 | very 470 | via 471 | viz 472 | vs 473 | w 474 | want 475 | wants 476 | was 477 | way 478 | we 479 | welcome 480 | well 481 | went 482 | were 483 | what 484 | whatever 485 | when 486 | whence 487 | whenever 488 | where 489 | whereafter 490 | whereas 491 | whereby 492 | wherein 493 | whereupon 494 | wherever 495 | whether 496 | which 497 | while 498 | whither 499 | who 500 | whoever 501 | whole 502 | whom 503 | whose 504 | why 505 | will 506 | willing 507 | wish 508 | with 509 | within 510 | without 511 | wonder 512 | would 513 | would 514 | x 515 | y 516 | yes 517 | yet 518 | you 519 | your 520 | yours 521 | yourself 522 | yourselves 523 | z 524 | zero -------------------------------------------------------------------------------- /vampire/tests/fixtures/stopwords/mallet_stopwords.txt: -------------------------------------------------------------------------------- 1 | a 2 | able 3 | about 4 | above 5 | according 6 | accordingly 7 | across 8 | actually 9 | after 10 | afterwards 11 | again 12 | against 13 | all 14 | allow 15 | allows 16 | almost 17 | alone 18 | along 19 | already 20 | also 21 | although 22 | always 23 | am 24 | among 25 | amongst 26 | an 27 | and 28 | another 29 | any 30 | anybody 31 | anyhow 32 | anyone 33 | anything 34 | anyway 35 | anyways 36 | anywhere 37 | apart 38 | appear 39 | appreciate 40 | appropriate 41 | are 42 | around 43 | as 44 | aside 45 | ask 46 | asking 47 | associated 48 | at 49 | available 50 | away 51 | awfully 52 | b 53 | be 54 | became 55 | because 56 | become 57 | becomes 58 | becoming 59 | been 60 | before 61 | beforehand 62 | behind 63 | being 64 | believe 65 | below 66 | beside 67 | besides 68 | best 69 | better 70 | between 71 | beyond 72 | both 73 | brief 74 | but 75 | by 76 | c 77 | came 78 | can 79 | cannot 80 | cant 81 | cause 82 | causes 83 | certain 84 | certainly 85 | changes 86 | clearly 87 | co 88 | com 89 | come 90 | comes 91 | concerning 92 | consequently 93 | consider 94 | considering 95 | contain 96 | containing 97 | contains 98 | corresponding 99 | could 100 | course 101 | currently 102 | d 103 | definitely 104 | described 105 | despite 106 | did 107 | different 108 | do 109 | does 110 | doing 111 | done 112 | down 113 | downwards 114 | during 115 | e 116 | each 117 | edu 118 | eg 119 | eight 120 | either 121 | else 122 | elsewhere 123 | enough 124 | entirely 125 | especially 126 | et 127 | etc 128 | even 129 | ever 130 | every 131 | everybody 132 | everyone 133 | everything 134 | everywhere 135 | ex 136 | exactly 137 | example 138 | except 139 | f 140 | far 141 | few 142 | fifth 143 | first 144 | five 145 | followed 146 | following 147 | follows 148 | for 149 | former 150 | formerly 151 | forth 152 | four 153 | from 154 | further 155 | furthermore 156 | g 157 | get 158 | gets 159 | getting 160 | given 161 | gives 162 | go 163 | goes 164 | going 165 | gone 166 | got 167 | gotten 168 | greetings 169 | h 170 | had 171 | happens 172 | hardly 173 | has 174 | have 175 | having 176 | he 177 | hello 178 | help 179 | hence 180 | her 181 | here 182 | hereafter 183 | hereby 184 | herein 185 | hereupon 186 | hers 187 | herself 188 | hi 189 | him 190 | himself 191 | his 192 | hither 193 | hopefully 194 | how 195 | howbeit 196 | however 197 | i 198 | ie 199 | if 200 | ignored 201 | immediate 202 | in 203 | inasmuch 204 | inc 205 | indeed 206 | indicate 207 | indicated 208 | indicates 209 | inner 210 | insofar 211 | instead 212 | into 213 | inward 214 | is 215 | it 216 | its 217 | itself 218 | j 219 | just 220 | k 221 | keep 222 | keeps 223 | kept 224 | know 225 | knows 226 | known 227 | l 228 | last 229 | lately 230 | later 231 | latter 232 | latterly 233 | least 234 | less 235 | lest 236 | let 237 | like 238 | liked 239 | likely 240 | little 241 | look 242 | looking 243 | looks 244 | ltd 245 | m 246 | mainly 247 | many 248 | may 249 | maybe 250 | me 251 | mean 252 | meanwhile 253 | merely 254 | might 255 | more 256 | moreover 257 | most 258 | mostly 259 | much 260 | must 261 | my 262 | myself 263 | n 264 | name 265 | namely 266 | nd 267 | near 268 | nearly 269 | necessary 270 | need 271 | needs 272 | neither 273 | never 274 | nevertheless 275 | new 276 | next 277 | nine 278 | no 279 | nobody 280 | non 281 | none 282 | noone 283 | nor 284 | normally 285 | not 286 | nothing 287 | novel 288 | now 289 | nowhere 290 | o 291 | obviously 292 | of 293 | off 294 | often 295 | oh 296 | ok 297 | okay 298 | old 299 | on 300 | once 301 | one 302 | ones 303 | only 304 | onto 305 | or 306 | other 307 | others 308 | otherwise 309 | ought 310 | our 311 | ours 312 | ourselves 313 | out 314 | outside 315 | over 316 | overall 317 | own 318 | p 319 | particular 320 | particularly 321 | per 322 | perhaps 323 | placed 324 | please 325 | plus 326 | possible 327 | presumably 328 | probably 329 | provides 330 | q 331 | que 332 | quite 333 | qv 334 | r 335 | rather 336 | rd 337 | re 338 | really 339 | reasonably 340 | regarding 341 | regardless 342 | regards 343 | relatively 344 | respectively 345 | right 346 | s 347 | said 348 | same 349 | saw 350 | say 351 | saying 352 | says 353 | second 354 | secondly 355 | see 356 | seeing 357 | seem 358 | seemed 359 | seeming 360 | seems 361 | seen 362 | self 363 | selves 364 | sensible 365 | sent 366 | serious 367 | seriously 368 | seven 369 | several 370 | shall 371 | she 372 | should 373 | since 374 | six 375 | so 376 | some 377 | somebody 378 | somehow 379 | someone 380 | something 381 | sometime 382 | sometimes 383 | somewhat 384 | somewhere 385 | soon 386 | sorry 387 | specified 388 | specify 389 | specifying 390 | still 391 | sub 392 | such 393 | sup 394 | sure 395 | t 396 | take 397 | taken 398 | tell 399 | tends 400 | th 401 | than 402 | thank 403 | thanks 404 | thanx 405 | that 406 | thats 407 | the 408 | their 409 | theirs 410 | them 411 | themselves 412 | then 413 | thence 414 | there 415 | thereafter 416 | thereby 417 | therefore 418 | therein 419 | theres 420 | thereupon 421 | these 422 | they 423 | think 424 | third 425 | this 426 | thorough 427 | thoroughly 428 | those 429 | though 430 | three 431 | through 432 | throughout 433 | thru 434 | thus 435 | to 436 | together 437 | too 438 | took 439 | toward 440 | towards 441 | tried 442 | tries 443 | truly 444 | try 445 | trying 446 | twice 447 | two 448 | u 449 | un 450 | under 451 | unfortunately 452 | unless 453 | unlikely 454 | until 455 | unto 456 | up 457 | upon 458 | us 459 | use 460 | used 461 | useful 462 | uses 463 | using 464 | usually 465 | uucp 466 | v 467 | value 468 | various 469 | very 470 | via 471 | viz 472 | vs 473 | w 474 | want 475 | wants 476 | was 477 | way 478 | we 479 | welcome 480 | well 481 | went 482 | were 483 | what 484 | whatever 485 | when 486 | whence 487 | whenever 488 | where 489 | whereafter 490 | whereas 491 | whereby 492 | wherein 493 | whereupon 494 | wherever 495 | whether 496 | which 497 | while 498 | whither 499 | who 500 | whoever 501 | whole 502 | whom 503 | whose 504 | why 505 | will 506 | willing 507 | wish 508 | with 509 | within 510 | without 511 | wonder 512 | would 513 | would 514 | x 515 | y 516 | yes 517 | yet 518 | you 519 | your 520 | yours 521 | yourself 522 | yourselves 523 | z 524 | zero -------------------------------------------------------------------------------- /search_spaces/classifier_ag_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "CUDA_DEVICE": 0, 4 | "EVALUATE_ON_TEST": 0, 5 | "NUM_EPOCHS": 50, 6 | "SEED": { 7 | "sampling strategy": "integer", 8 | "bounds": [0, 100000] 9 | }, 10 | "TRAIN_PATH": "s3://suching-dev/final-datasets/ag-news/train_pretokenized.jsonl", 11 | "DEV_PATH": "s3://suching-dev/final-datasets/ag-news/dev_pretokenized.jsonl", 12 | "TEST_PATH": "s3://suching-dev/final-datasets/ag-news/test_pretokenized.jsonl", 13 | "THROTTLE": 200, 14 | "USE_SPACY_TOKENIZER": 0, 15 | "FREEZE_EMBEDDINGS": ["VAMPIRE"], 16 | "EMBEDDINGS": ["RANDOM", "VAMPIRE"], 17 | "VAMPIRE_DIRECTORY": { 18 | "sampling strategy": "choice", 19 | "choices": ["logs/vampire_ag_search/run_6_2019-06-01_16-50-51_lw8tglh 77", "logs/vampire_ag_search/run_48_2019-06-01_18-10-064v2134xc 127", "logs/vampire_ag_search/run_52_2019-06-01_18-15-34p4_ylqex 105", "logs/vampire_ag_search/run_42_2019-06-01_17-55-223lebucd9 114", "logs/vampire_ag_search/run_12_2019-06-01_17-05-124e6o13du 93", "logs/vampire_ag_search/run_50_2019-06-01_18-15-10b8ozax_v 113", "logs/vampire_ag_search/run_38_2019-06-01_17-51-066u0sco8w 97", "logs/vampire_ag_search/run_31_2019-06-01_17-38-59dcgmcvg_ 99", "logs/vampire_ag_search/run_36_2019-06-01_17-45-175d3yuh9y 88", "logs/vampire_ag_search/run_27_2019-06-01_17-33-32oizpves4 96", "logs/vampire_ag_search/run_14_2019-06-01_17-13-240j6p2bpb 69", "logs/vampire_ag_search/run_45_2019-06-01_18-05-120d2sbx4f 95", "logs/vampire_ag_search/run_2_2019-06-01_16-50-503rc003fm 88", "logs/vampire_ag_search/run_54_2019-06-01_18-16-07b6jfwvlx 113", "logs/vampire_ag_search/run_0_2019-06-01_16-50-50c6qnyx9m 121", "logs/vampire_ag_search/run_15_2019-06-01_17-14-26_7i_g9ex 82", "logs/vampire_ag_search/run_19_2019-06-01_17-22-014kxltauj 104", "logs/vampire_ag_search/run_39_2019-06-01_17-52-23_nocxfga 105", "logs/vampire_ag_search/run_79_2019-06-01_19-00-1926urotg7 74", "logs/vampire_ag_search/run_13_2019-06-01_17-09-2943_hr960 68", "logs/vampire_ag_search/run_64_2019-06-01_18-34-58p11abjbk 107", "logs/vampire_ag_search/run_35_2019-06-01_17-45-071k298op4 97", "logs/vampire_ag_search/run_33_2019-06-01_17-43-295jofjqwm 80", "logs/vampire_ag_search/run_47_2019-06-01_18-09-312nca6jun 96", "logs/vampire_ag_search/run_3_2019-06-01_16-50-5067zjl3dy 104", "logs/vampire_ag_search/run_5_2019-06-01_16-50-50n33wveam 120", "logs/vampire_ag_search/run_17_2019-06-01_17-20-58b3x9ma1z 84", "logs/vampire_ag_search/run_77_2019-06-01_18-57-4126y9wekx 71", "logs/vampire_ag_search/run_73_2019-06-01_18-46-55xinkp182 69", "logs/vampire_ag_search/run_59_2019-06-01_18-23-30y8jwi0im 93", "logs/vampire_ag_search/run_46_2019-06-01_18-09-24rxb3m4nz 100", "logs/vampire_ag_search/run_11_2019-06-01_17-03-303g8rtce8 111", "logs/vampire_ag_search/run_7_2019-06-01_16-50-51tgfowrvv 76", "logs/vampire_ag_search/run_37_2019-06-01_17-45-27ffpz7o2j 90", "logs/vampire_ag_search/run_62_2019-06-01_18-28-38du0d4wdj 102", "logs/vampire_ag_search/run_30_2019-06-01_17-37-23j0l69pyw 80", "logs/vampire_ag_search/run_58_2019-06-01_18-22-20xadnv7vg 81", "logs/vampire_ag_search/run_20_2019-06-01_17-26-530uexw33x 68", "logs/vampire_ag_search/run_16_2019-06-01_17-15-53a0vk1h_p 96", "logs/vampire_ag_search/run_70_2019-06-01_18-43-53cnhyfkes 115", "logs/vampire_ag_search/run_66_2019-06-01_18-38-07g2wxxe_2 116", "logs/vampire_ag_search/run_60_2019-06-01_18-27-45yse23gm4 106", "logs/vampire_ag_search/run_43_2019-06-01_18-00-31f887nssb 115", "logs/vampire_ag_search/run_4_2019-06-01_16-50-50tv6jwd7c 121", "logs/vampire_ag_search/run_22_2019-06-01_17-27-49bestv0jj 81", "logs/vampire_ag_search/run_49_2019-06-01_18-11-52s_hdznyh 87", "logs/vampire_ag_search/run_78_2019-06-01_19-00-12dmq3pqw7 104", "logs/vampire_ag_search/run_53_2019-06-01_18-16-06dur56w5r 85", "logs/vampire_ag_search/run_29_2019-06-01_17-35-02r0kldtpb 75", "logs/vampire_ag_search/run_24_2019-06-01_17-29-33oo5p2t_6 84", "logs/vampire_ag_search/run_9_2019-06-01_16-59-36q4s15tv2 67", "logs/vampire_ag_search/run_10_2019-06-01_17-00-18dq5q68kx 68", "logs/vampire_ag_search/run_32_2019-06-01_17-42-21nf4umswh 70", "logs/vampire_ag_search/run_80_2019-06-01_19-01-27vs9pobl2 77", "logs/vampire_ag_search/run_72_2019-06-01_18-46-47n2nm9thz 82", "logs/vampire_ag_search/run_65_2019-06-01_18-36-05e569ytyx 107", "logs/vampire_ag_search/run_26_2019-06-01_17-31-00o3oltc3_ 121", "logs/vampire_ag_search/run_41_2019-06-01_17-54-16qln09rqr 107", "logs/vampire_ag_search/run_63_2019-06-01_18-29-23psp44odn 115", "logs/vampire_ag_search/run_75_2019-06-01_18-56-34i415eul1 117", "logs/vampire_ag_search/run_8_2019-06-01_16-56-356dh2vm9v 115", "logs/vampire_ag_search/run_68_2019-06-01_18-42-07_g8yiqbi 70", "logs/vampire_ag_search/run_18_2019-06-01_17-21-58kf1061om 80", "logs/vampire_ag_search/run_25_2019-06-01_17-30-436z6wc276 94", "logs/vampire_ag_search/run_55_2019-06-01_18-19-16evoydya9 70"] 20 | }, 21 | "ENCODER": { 22 | "sampling strategy": "choice", 23 | "choices": ["AVERAGE"] 24 | }, 25 | "EMBEDDING_DROPOUT": 0.5, 26 | "LEARNING_RATE": 0.004, 27 | "DROPOUT": 0.5, 28 | "BATCH_SIZE": 32, 29 | "NUM_ENCODER_LAYERS": { 30 | "sampling strategy": "choice", 31 | "choices": [1, 2, 3] 32 | }, 33 | "NUM_OUTPUT_LAYERS": { 34 | "sampling strategy": "choice", 35 | "choices": [1, 2, 3] 36 | }, 37 | "MAX_FILTER_SIZE": { 38 | "sampling strategy": "integer", 39 | "bounds": [3, 6] 40 | }, 41 | "NUM_FILTERS": { 42 | "sampling strategy": "integer", 43 | "bounds": [64, 512] 44 | }, 45 | "HIDDEN_SIZE": { 46 | "sampling strategy": "integer", 47 | "bounds": [64, 512] 48 | }, 49 | "AGGREGATIONS": { 50 | "sampling strategy": "subset", 51 | "choices": ["maxpool", "meanpool", "attention", "final_state"] 52 | }, 53 | "MAX_CHARACTER_FILTER_SIZE": { 54 | "sampling strategy": "integer", 55 | "bounds": [3, 6] 56 | }, 57 | "NUM_CHARACTER_FILTERS": { 58 | "sampling strategy": "integer", 59 | "bounds": [16, 64] 60 | }, 61 | "CHARACTER_HIDDEN_SIZE": { 62 | "sampling strategy": "integer", 63 | "bounds": [16, 128] 64 | }, 65 | "CHARACTER_EMBEDDING_DIM": { 66 | "sampling strategy": "integer", 67 | "bounds": [16, 128] 68 | }, 69 | "CHARACTER_ENCODER": { 70 | "sampling strategy": "choice", 71 | "choices": ["LSTM", "CNN", "AVERAGE"] 72 | }, 73 | "NUM_CHARACTER_ENCODER_LAYERS": { 74 | "sampling strategy": "choice", 75 | "choices": [1, 2] 76 | } 77 | } -------------------------------------------------------------------------------- /scripts/figures/regplot.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | sns.set(font_scale=1.3, style='white') 7 | 8 | 9 | # if __name__ == "__main__": 10 | # fontsize=26 11 | # fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 12 | # df = pd.read_json("hyperparameter_search_results/vampire_npmi_classifier_search.jsonl", lines=True) 13 | # sns.regplot(df['best_validation_npmi'], df['best_validation_accuracy']) 14 | # ax.set_xlabel("Best validation NPMI", ) 15 | # ax.set_ylabel("Best validation accuracy", ) 16 | # for tick in ax.xaxis.get_major_ticks(): 17 | # tick.label.set_fontsize(fontsize) 18 | # for label in ax.xaxis.get_ticklabels()[::2]: 19 | # label.set_visible(False) 20 | 21 | # for tick in ax.yaxis.get_major_ticks(): 22 | # tick.label.set_fontsize(fontsize) 23 | # ax.set_ylim([0.60, 0.85]) 24 | # plt.savefig("results/regplot_figure.pdf", dpi=300) 25 | 26 | 27 | # if __name__ == "__main__": 28 | # fontsize=26 29 | # fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 30 | # df = pd.read_json("hyperparameter_search_results/vampire_npmi_classifier_search.jsonl", lines=True) 31 | # sns.regplot(df['best_validation_npmi'], df['best_validation_nll']) 32 | # ax.set_xlabel("Best validation NPMI", ) 33 | # ax.set_ylabel("Best validation NLL", ) 34 | # for tick in ax.xaxis.get_major_ticks(): 35 | # tick.label.set_fontsize(fontsize) 36 | # for tick in ax.yaxis.get_major_ticks(): 37 | # tick.label.set_fontsize(fontsize) 38 | # for label in ax.xaxis.get_ticklabels()[::2]: 39 | # label.set_visible(False) 40 | # # ax.set_ylim([0.60, 0.85]) 41 | # plt.savefig("results/regplot_figure_1.pdf", dpi=300) 42 | 43 | 44 | # if __name__ == '__main__': 45 | # fontsize=23 46 | # fig, ax = plt.subplots(1, 1, figsize=(10, 10)) 47 | # df = pd.read_json("hyperparameter_search_results/clf_search.jsonl", lines=True) 48 | # sns.boxplot(df['model.encoder.architecture.type'],df['best_validation_accuracy']) 49 | # ax.set_xticklabels(["CNN", "LSTM", "Averaging"], fontsize=26) 50 | # ax.set_xlabel('Classifier Encoder', fontsize=26) 51 | # ax.set_ylabel("Validation Accuracy", fontsize=26) 52 | # for tick in ax.xaxis.get_major_ticks(): 53 | # tick.label.set_fontsize(fontsize) 54 | # for tick in ax.yaxis.get_major_ticks(): 55 | # tick.label.set_fontsize(fontsize) 56 | # plt.savefig("results/clf_accuracy_figure.pdf", dpi=300) 57 | 58 | if __name__ == '__main__': 59 | import matplotlib.gridspec as gridspec 60 | fig, ax = plt.subplots(2, 2) 61 | 62 | ax1 = ax[0,0] 63 | ax2 = ax[0,1] 64 | ax3 = ax[1,0] 65 | ax4 = ax[1,1] 66 | 67 | df = pd.read_json("hyperparameter_search_results/vampire_npmi_classifier_search.jsonl", lines=True) 68 | sns.regplot(df['best_validation_npmi'], df['best_validation_nll'], ax=ax1, color='black') 69 | ax1.set_xlabel("NPMI") 70 | ax1.set_ylabel("NLL") 71 | # for tick in ax3.xaxis.get_major_ticks(): 72 | # tick.label.set_fontsize(fontsize) 73 | # for tick in ax3.yaxis.get_major_ticks(): 74 | # tick.label.set_fontsize(fontsize) 75 | ax1.xaxis.set_ticks([0.06, 0.14]) 76 | ax1.set_ylim([820, 900]) 77 | ax1.yaxis.set_ticks([840, 860, 880]) 78 | 79 | 80 | ax1.text(-0.1, 1.15, "A", transform=ax1.transAxes, 81 | fontsize=16, fontweight='bold', va='top', ha='right') 82 | 83 | df = pd.read_json("hyperparameter_search_results/vampire_npmi_classifier_search.jsonl", lines=True) 84 | sns.regplot(df['best_validation_npmi'], df['best_validation_accuracy'], ax=ax2) 85 | ax2.set_xlabel("NPMI") 86 | ax2.set_ylabel("Accuracy") 87 | # for tick in ax2.yaxis.get_major_ticks(): 88 | # tick.label.set_fontsize(fontsize) 89 | ax2.set_ylim([0.60, 0.85]) 90 | ax2.xaxis.set_ticks([0.06, 0.14]) 91 | ax2.yaxis.set_ticks([0.7, 0.8]) 92 | 93 | # for tick in ax2.xaxis.get_major_ticks(): 94 | # tick.label.set_fontsize(fontsize) 95 | # for label in ax2.xaxis.get_ticklabels()[::2]: 96 | # label.set_visible(False) 97 | 98 | ax2.text(-0.1, 1.15, "B", transform=ax2.transAxes, 99 | fontsize=16, fontweight='bold', va='top', ha='right') 100 | 101 | 102 | df = pd.read_json("hyperparameter_search_results/vampire_nll_classifier_search.jsonl", lines=True) 103 | 104 | sns.regplot(df['best_validation_nll'], df['best_validation_accuracy'], ax=ax3) 105 | ax3.set_xlabel("NLL") 106 | ax3.set_ylabel("Accuracy") 107 | # for tick in ax4.xaxis.get_major_ticks(): 108 | # tick.label.set_fontsize(fontsize) 109 | # for tick in ax4.yaxis.get_major_ticks(): 110 | # tick.label.set_fontsize(fontsize) 111 | # for label in ax4.xaxis.get_ticklabels()[::2]: 112 | # label.set_visible(False) 113 | ax3.text(-0.1, 1.15, "C", transform=ax3.transAxes, 114 | fontsize=16, fontweight='bold', va='top', ha='right') 115 | 116 | df = pd.read_json("hyperparameter_search_results/vampire_nll_classifier_search.jsonl", lines=True) 117 | df1 = pd.read_json("hyperparameter_search_results/vampire_npmi_classifier_search.jsonl", lines=True) 118 | master = pd.concat([df, df1], 0) 119 | master['trainer.validation_metric_y'] = master['trainer.validation_metric_y'].fillna('-nll') 120 | sns.boxplot(master['trainer.validation_metric_y'], master['best_validation_accuracy'], ax=ax4, order = ["+npmi", "-nll"]) 121 | ax4.set_xticklabels(["NPMI", "NLL"]) 122 | ax4.set_xlabel('Criterion') 123 | ax4.set_ylabel("Accuracy") 124 | # for tick in ax1.xaxis.get_major_ticks(): 125 | # tick.label.set_fontsize(fontsize) 126 | # for tick in ax1.yaxis.get_major_ticks(): 127 | # tick.label.set_fontsize(fontsize) 128 | ax4.set_ylim([0.4, 0.9]) 129 | ax4.yaxis.set_ticks([0.6, 0.8]) 130 | ax4.text(-0.1, 1.15, "D", transform=ax4.transAxes, 131 | fontsize=16, fontweight='bold', va='top', ha='right') 132 | 133 | plt.tight_layout() 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | plt.savefig("figure_4.pdf", dpi=300) 144 | 145 | -------------------------------------------------------------------------------- /search_spaces/classifier_hatespeech_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "CUDA_DEVICE": 0, 4 | "EVALUATE_ON_TEST": 0, 5 | "NUM_EPOCHS": 50, 6 | "SEED": { 7 | "sampling strategy": "integer", 8 | "bounds": [0, 100000] 9 | }, 10 | "TRAIN_PATH": "s3://suching-dev/final-datasets/hatespeech/train_pretokenized.jsonl", 11 | "DEV_PATH": "s3://suching-dev/final-datasets/hatespeech/dev_pretokenized.jsonl", 12 | "TEST_PATH": "s3://suching-dev/final-datasets/hatespeech/test_pretokenized.jsonl", 13 | "THROTTLE": 200, 14 | "USE_SPACY_TOKENIZER": 0, 15 | "FREEZE_EMBEDDINGS": ["VAMPIRE"], 16 | "EMBEDDINGS": ["RANDOM", "VAMPIRE"], 17 | "VAMPIRE_DIRECTORY": { 18 | "sampling strategy": "choice", 19 | "choices": ["logs/vampire_hatespeech_search/run_5_2019-06-01_12-52-55jkcnokqe 103", "logs/vampire_hatespeech_search/run_31_2019-06-01_13-20-51p25jt72k 68", "logs/vampire_hatespeech_search/run_48_2019-06-01_13-40-04d305225w 114", "logs/vampire_hatespeech_search/run_45_2019-06-01_13-38-59yp36qlcy 93", "logs/vampire_hatespeech_search/run_2_2019-06-01_12-52-555o40j0dj 96", "logs/vampire_hatespeech_search/run_52_2019-06-01_13-44-35jrlza5ih 64", "logs/vampire_hatespeech_search/run_71_2019-06-01_14-08-201oyojz0x 85", "logs/vampire_hatespeech_search/run_4_2019-06-01_12-52-559huf7jy2 102", "logs/vampire_hatespeech_search/run_43_2019-06-01_13-34-58s3xwrfrt 72", "logs/vampire_hatespeech_search/run_30_2019-06-01_13-20-24332rqxmc 127", "logs/vampire_hatespeech_search/run_68_2019-06-01_14-02-07fhmv9c5f 71", "logs/vampire_hatespeech_search/run_41_2019-06-01_13-31-156ghrhgx8 100", "logs/vampire_hatespeech_search/run_38_2019-06-01_13-27-07kufxcguz 64", "logs/vampire_hatespeech_search/run_53_2019-06-01_13-45-147ohgxvvy 69", "logs/vampire_hatespeech_search/run_9_2019-06-01_12-56-39sh3eng8c 109", "logs/vampire_hatespeech_search/run_29_2019-06-01_13-19-01n4lqbht6 119", "logs/vampire_hatespeech_search/run_25_2019-06-01_13-15-51mwkqxh9p 105", "logs/vampire_hatespeech_search/run_20_2019-06-01_13-10-092i60t7ch 65", "logs/vampire_hatespeech_search/run_49_2019-06-01_13-41-00o6lhyx74 111", "logs/vampire_hatespeech_search/run_19_2019-06-01_13-09-432vi_3diq 103", "logs/vampire_hatespeech_search/run_11_2019-06-01_12-59-52c1pvwu5g 77", "logs/vampire_hatespeech_search/run_21_2019-06-01_13-10-251vll9dy7 96", "logs/vampire_hatespeech_search/run_27_2019-06-01_13-17-285wn25ed7 103", "logs/vampire_hatespeech_search/run_1_2019-06-01_12-52-55cvlqk5eu 95", "logs/vampire_hatespeech_search/run_61_2019-06-01_13-55-45az7n14vr 94", "logs/vampire_hatespeech_search/run_17_2019-06-01_13-07-40grw9lbu8 83", "logs/vampire_hatespeech_search/run_69_2019-06-01_14-02-17sxh0ug_a 104", "logs/vampire_hatespeech_search/run_40_2019-06-01_13-30-39muh5_tqp 118", "logs/vampire_hatespeech_search/run_37_2019-06-01_13-26-55wqtt5fze 126", "logs/vampire_hatespeech_search/run_34_2019-06-01_13-24-45h29d346c 85", "logs/vampire_hatespeech_search/run_47_2019-06-01_13-39-46r_yx2umn 93", "logs/vampire_hatespeech_search/run_16_2019-06-01_13-07-0966awqte7 80", "logs/vampire_hatespeech_search/run_58_2019-06-01_13-49-57gn6lepbo 82", "logs/vampire_hatespeech_search/run_13_2019-06-01_13-03-41kve3d07k 119", "logs/vampire_hatespeech_search/run_70_2019-06-01_14-02-22crpfb25y 65", "logs/vampire_hatespeech_search/run_28_2019-06-01_13-17-299fraho0r 126", "logs/vampire_hatespeech_search/run_64_2019-06-01_13-58-002cijgzl1 108", "logs/vampire_hatespeech_search/run_3_2019-06-01_12-52-55a7qhurt5 78", "logs/vampire_hatespeech_search/run_57_2019-06-01_13-49-3285xxohh4 102", "logs/vampire_hatespeech_search/run_56_2019-06-01_13-49-04bhk5irkg 124", "logs/vampire_hatespeech_search/run_14_2019-06-01_13-06-13l8banf74 111", "logs/vampire_hatespeech_search/run_62_2019-06-01_13-56-22nnnyeosy 66", "logs/vampire_hatespeech_search/run_32_2019-06-01_13-22-3274lyair9 125", "logs/vampire_hatespeech_search/run_12_2019-06-01_13-01-297pty9u10 66", "logs/vampire_hatespeech_search/run_8_2019-06-01_12-55-47ovtzvwfd 115", "logs/vampire_hatespeech_search/run_6_2019-06-01_12-52-551hvf9c4g 99", "logs/vampire_hatespeech_search/run_66_2019-06-01_14-00-482vmq1zja 66", "logs/vampire_hatespeech_search/run_22_2019-06-01_13-11-40bggefvrr 126", "logs/vampire_hatespeech_search/run_50_2019-06-01_13-41-17mgcw0t2d 124", "logs/vampire_hatespeech_search/run_59_2019-06-01_13-51-415wfzpj6s 125", "logs/vampire_hatespeech_search/run_74_2019-06-01_14-12-20hd25quyr 110", "logs/vampire_hatespeech_search/run_67_2019-06-01_14-01-298xzs1faw 122", "logs/vampire_hatespeech_search/run_65_2019-06-01_13-58-42ows5baxh 120", "logs/vampire_hatespeech_search/run_24_2019-06-01_13-13-302oojuwms 66", "logs/vampire_hatespeech_search/run_35_2019-06-01_13-25-375v3ye12t 71", "logs/vampire_hatespeech_search/run_72_2019-06-01_14-08-26h9kwmtvm 66", "logs/vampire_hatespeech_search/run_60_2019-06-01_13-54-11yd9n8vtb 64", "logs/vampire_hatespeech_search/run_7_2019-06-01_12-52-55nl4uc17v 115", "logs/vampire_hatespeech_search/run_15_2019-06-01_13-06-573aedhcca 94", "logs/vampire_hatespeech_search/run_42_2019-06-01_13-32-37zbx6_0g5 97"] 20 | }, 21 | "ENCODER": { 22 | "sampling strategy": "choice", 23 | "choices": ["AVERAGE"] 24 | }, 25 | "EMBEDDING_DROPOUT": 0.5, 26 | "LEARNING_RATE": 0.004, 27 | "DROPOUT": 0.5, 28 | "BATCH_SIZE": 32, 29 | "NUM_ENCODER_LAYERS": { 30 | "sampling strategy": "choice", 31 | "choices": [1, 2, 3] 32 | }, 33 | "NUM_OUTPUT_LAYERS": { 34 | "sampling strategy": "choice", 35 | "choices": [1, 2, 3] 36 | }, 37 | "MAX_FILTER_SIZE": { 38 | "sampling strategy": "integer", 39 | "bounds": [3, 6] 40 | }, 41 | "NUM_FILTERS": { 42 | "sampling strategy": "integer", 43 | "bounds": [64, 512] 44 | }, 45 | "HIDDEN_SIZE": { 46 | "sampling strategy": "integer", 47 | "bounds": [64, 512] 48 | }, 49 | "AGGREGATIONS": { 50 | "sampling strategy": "subset", 51 | "choices": ["maxpool", "meanpool", "attention", "final_state"] 52 | }, 53 | "MAX_CHARACTER_FILTER_SIZE": { 54 | "sampling strategy": "integer", 55 | "bounds": [3, 6] 56 | }, 57 | "NUM_CHARACTER_FILTERS": { 58 | "sampling strategy": "integer", 59 | "bounds": [16, 64] 60 | }, 61 | "CHARACTER_HIDDEN_SIZE": { 62 | "sampling strategy": "integer", 63 | "bounds": [16, 128] 64 | }, 65 | "CHARACTER_EMBEDDING_DIM": { 66 | "sampling strategy": "integer", 67 | "bounds": [16, 128] 68 | }, 69 | "CHARACTER_ENCODER": { 70 | "sampling strategy": "choice", 71 | "choices": ["LSTM", "CNN", "AVERAGE"] 72 | }, 73 | "NUM_CHARACTER_ENCODER_LAYERS": { 74 | "sampling strategy": "choice", 75 | "choices": [1, 2] 76 | } 77 | } -------------------------------------------------------------------------------- /vampire/common/util.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import json 3 | import os 4 | import pickle 5 | from typing import Any, Dict, List 6 | 7 | import numpy as np 8 | import torch 9 | from allennlp.data import Vocabulary 10 | from scipy import sparse 11 | 12 | 13 | def compute_background_log_frequency(vocab: Vocabulary, vocab_namespace: str, precomputed_bg_file=None): 14 | """ 15 | Load in the word counts from the JSON file and compute the 16 | background log term frequency w.r.t this vocabulary. 17 | """ 18 | # precomputed_word_counts = json.load(open(precomputed_word_counts, "r")) 19 | log_term_frequency = torch.FloatTensor(vocab.get_vocab_size(vocab_namespace)) 20 | if precomputed_bg_file is not None: 21 | with open(precomputed_bg_file, "r") as file_: 22 | precomputed_bg = json.load(file_) 23 | else: 24 | precomputed_bg = vocab._retained_counter.get(vocab_namespace) # pylint: disable=protected-access 25 | if precomputed_bg is None: 26 | return log_term_frequency 27 | for i in range(vocab.get_vocab_size(vocab_namespace)): 28 | token = vocab.get_token_from_index(i, vocab_namespace) 29 | if token in ("@@UNKNOWN@@", "@@PADDING@@", '@@START@@', '@@END@@') or token not in precomputed_bg: 30 | log_term_frequency[i] = 1e-12 31 | elif token in precomputed_bg: 32 | if precomputed_bg[token] == 0: 33 | log_term_frequency[i] = 1e-12 34 | else: 35 | log_term_frequency[i] = precomputed_bg[token] 36 | log_term_frequency = torch.log(log_term_frequency) 37 | return log_term_frequency 38 | 39 | 40 | def log_standard_categorical(logits: torch.Tensor): 41 | """ 42 | Calculates the cross entropy between a (one-hot) categorical vector 43 | and a standard (uniform) categorical distribution. 44 | :param p: one-hot categorical distribution 45 | :return: H(p, u) 46 | 47 | Originally from https://github.com/wohlert/semi-supervised-pytorch. 48 | """ 49 | # Uniform prior over y 50 | prior = torch.softmax(torch.ones_like(logits), dim=1) 51 | prior.requires_grad = False 52 | 53 | cross_entropy = -torch.sum(logits * torch.log(prior + 1e-8), dim=1) 54 | 55 | return cross_entropy 56 | 57 | 58 | def separate_labeled_unlabeled_instances(text: torch.LongTensor, 59 | classifier_text: torch.Tensor, 60 | label: torch.LongTensor, 61 | metadata: List[Dict[str, Any]]): 62 | """ 63 | Given a batch of examples, separate them into labeled and unlablled instances. 64 | """ 65 | labeled_instances = {} 66 | unlabeled_instances = {} 67 | is_labeled = [int(md['is_labeled']) for md in metadata] 68 | 69 | is_labeled = np.array(is_labeled) 70 | # labeled is zero everywhere an example is unlabeled and 1 otherwise. 71 | labeled_indices = (is_labeled != 0).nonzero() # type: ignore 72 | labeled_instances["tokens"] = text[labeled_indices] 73 | labeled_instances["classifier_tokens"] = classifier_text[labeled_indices] 74 | labeled_instances["label"] = label[labeled_indices] 75 | 76 | unlabeled_indices = (is_labeled == 0).nonzero() # type: ignore 77 | unlabeled_instances["tokens"] = text[unlabeled_indices] 78 | unlabeled_instances["classifier_tokens"] = classifier_text[unlabeled_indices] 79 | 80 | return labeled_instances, unlabeled_instances 81 | 82 | 83 | def schedule(batch_num, anneal_type="sigmoid"): 84 | """ 85 | weight annealing scheduler 86 | """ 87 | if anneal_type == "linear": 88 | return min(1, batch_num / 2500) 89 | elif anneal_type == "sigmoid": 90 | return float(1/(1+np.exp(-0.0025*(batch_num-2500)))) 91 | elif anneal_type == "constant": 92 | return 1.0 93 | elif anneal_type == "reverse_sigmoid": 94 | return float(1/(1+np.exp(0.0025*(batch_num-2500)))) 95 | else: 96 | return 0.01 97 | 98 | 99 | def makedirs(directory): 100 | if not os.path.exists(directory): 101 | os.makedirs(directory) 102 | 103 | 104 | def write_to_json(data, output_filename, indent=2, sort_keys=True): 105 | with codecs.open(output_filename, 'w', encoding='utf-8') as output_file: 106 | json.dump(data, output_file, indent=indent, sort_keys=sort_keys) 107 | 108 | 109 | def read_json(input_filename): 110 | with codecs.open(input_filename, 'r', encoding='utf-8') as input_file: 111 | data = json.load(input_file, encoding='utf-8') 112 | return data 113 | 114 | 115 | def read_jsonlist(input_filename): 116 | data = [] 117 | with codecs.open(input_filename, 'r', encoding='utf-8') as input_file: 118 | for line in input_file: 119 | data.append(json.loads(line, encoding='utf-8')) 120 | return data 121 | 122 | 123 | def write_jsonlist(list_of_json_objects, output_filename, sort_keys=True): 124 | with codecs.open(output_filename, 'w', encoding='utf-8') as output_file: 125 | for obj in list_of_json_objects: 126 | output_file.write(json.dumps(obj, sort_keys=sort_keys) + '\n') 127 | 128 | 129 | def pickle_data(data, output_filename): 130 | with open(output_filename, 'wb') as outfile: 131 | pickle.dump(data, outfile, pickle.HIGHEST_PROTOCOL) 132 | 133 | 134 | def unpickle_data(input_filename): 135 | with open(input_filename, 'rb') as infile: 136 | data = pickle.load(infile) 137 | return data 138 | 139 | 140 | def read_text(input_filename): 141 | with codecs.open(input_filename, 'r', encoding='utf-8') as input_file: 142 | lines = [x.strip() for x in input_file.readlines()] 143 | return lines 144 | 145 | 146 | def write_list_to_text(lines, output_filename, add_newlines=True, add_final_newline=False): 147 | if add_newlines: 148 | lines = '\n'.join(lines) 149 | if add_final_newline: 150 | lines += '\n' 151 | else: 152 | lines = ''.join(lines) 153 | if add_final_newline: 154 | lines[-1] += '\n' 155 | 156 | with codecs.open(output_filename, 'w', encoding='utf-8') as output_file: 157 | output_file.writelines(lines) 158 | 159 | 160 | def save_sparse(sparse_matrix, output_filename): 161 | assert sparse.issparse(sparse_matrix) 162 | if sparse.isspmatrix_coo(sparse_matrix): 163 | coo = sparse_matrix 164 | else: 165 | coo = sparse_matrix.tocoo() 166 | row = coo.row 167 | col = coo.col 168 | data = coo.data 169 | shape = coo.shape 170 | np.savez(output_filename, row=row, col=col, data=data, shape=shape) 171 | 172 | 173 | def load_sparse(input_filename): 174 | npy = np.load(input_filename) 175 | coo_matrix = sparse.coo_matrix((npy['data'], (npy['row'], npy['col'])), shape=npy['shape']) 176 | return coo_matrix.tocsc() 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAMPIRE 2 | 3 | VAriational Methods for Pretraining In Resource-limited Environments 4 | 5 | Read paper [here](https://arxiv.org/abs/1906.02242). 6 | 7 | ## Citation 8 | 9 | ``` 10 | @inproceedings{vampire, 11 | author = {Suchin Gururangan and Tam Dang and Dallas Card and Noah A. Smith}, 12 | title = {Variational Pretraining for Semi-supervised Text Classification}, 13 | year = {2019}, 14 | booktitle = {Proceedings of ACL}, 15 | } 16 | ``` 17 | 18 | 19 | ## Installation 20 | 21 | Install necessary dependencies via `requirements.txt`, which will include the latest unreleased install of `allennlp` (from the `master` branch). 22 | 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | Install the spacy english model with: 28 | 29 | ``` 30 | python -m spacy download en 31 | ``` 32 | 33 | Verify your installation by running: 34 | 35 | ``` 36 | SEED=42 pytest -v --color=yes vampire 37 | ``` 38 | 39 | All tests should pass. 40 | 41 | 42 | ## Install from Docker 43 | 44 | Alternatively, you can install the repository with Docker. 45 | 46 | First, build the container: 47 | 48 | ``` 49 | docker build -f Dockerfile --tag vampire/vampire:latest . 50 | ``` 51 | 52 | Then, run the container: 53 | 54 | ``` 55 | docker run -it vampire/vampire:latest 56 | ``` 57 | 58 | This will open a shell in a docker container that has all the dependencies installed. 59 | 60 | ## Download Data 61 | 62 | Download your dataset of interest, and make sure it is made up of json files, where each line of each file corresponds to a separate instance. Each line must contain a `text` field, and optionally a `label` field. 63 | 64 | In this tutorial we use the AG News dataset hosted on AllenNLP. Download it using the following script: 65 | 66 | ``` 67 | sh scripts/download_ag.sh 68 | ``` 69 | 70 | This will make an `examples/ag` directory with train, dev, test files from the AG News corpus. 71 | 72 | ## Preprocess data 73 | 74 | To make pretraining fast, we precompute fixed bag-of-words representations of the data. 75 | 76 | ``` 77 | python -m scripts.preprocess_data \ 78 | --train-path examples/ag/train.jsonl \ 79 | --dev-path examples/ag/dev.jsonl \ 80 | --tokenize \ 81 | --tokenizer-type spacy \ 82 | --vocab-size 30000 \ 83 | --serialization-dir examples/ag 84 | ``` 85 | 86 | This script will tokenize your data, and save the resulting output into the specified `serialization-dir`. 87 | 88 | Alternatively, under `https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/preprocessed.tar", we have a tar file containing a pre-processed AG news data (with vocab size set to 30K). 89 | 90 | Run 91 | 92 | ``` 93 | curl -Lo examples/ag/ag.tar https://s3-us-west-2.amazonaws.com/allennlp/datasets/ag-news/vampire_preprocessed_example.tar 94 | tar -xvf examples/ag/ag.tar -C examples/ 95 | ``` 96 | 97 | to access its contents. 98 | 99 | In `examples/ag` (after running the `preprocess_data` module or unpacking `ag.tar`), you should see: 100 | 101 | * `train.npz` - pre-computed bag of word representations of the training data 102 | * `dev.npz` - pre-computed bag of word representations of the dev data 103 | * `vampire.bgfreq` - background word frequencies 104 | * `vocabulary/` - AllenNLP vocabulary directory 105 | 106 | This script also creates a reference corpus to calcuate NPMI (normalized pointwise mutual information), a measure of topical coherence that we use for early stopping. By default, we use the validation data as our reference corpus. You can supply a `--reference-corpus-path` to the preprocessing script to use your own reference corpus. 107 | 108 | In `examples/ag/reference`, you should see: 109 | 110 | * `ref.npz` - pre-computed bag of word representations of the reference corpus (the dev data) 111 | * `ref.vocab.json` - the reference corpus vocabulary 112 | 113 | ## Pretrain VAMPIRE 114 | 115 | Set your data directory and vocabulary size as environment variables: 116 | 117 | ``` 118 | export DATA_DIR="$(pwd)/examples/ag" 119 | export VOCAB_SIZE=30000 120 | ``` 121 | 122 | If you're training on a dataset that's to large to fit into RAM, run VAMPIRE in lazy mode by additionally exporting: 123 | 124 | ``` 125 | export LAZY=1 126 | ``` 127 | 128 | Then train VAMPIRE: 129 | 130 | ``` 131 | python -m scripts.train \ 132 | --config training_config/vampire.jsonnet \ 133 | --serialization-dir model_logs/vampire \ 134 | --environment VAMPIRE \ 135 | --device -1 136 | ``` 137 | 138 | This model can be run on a CPU (`--device -1`). To run on a GPU instead, run with `--device 0` (or any other available CUDA device number). 139 | 140 | This command will output training logs at `model_logs/vampire`. 141 | 142 | For convenience, we include the `--override` flag to remove the previous experiment at the same serialization directory. 143 | 144 | 145 | ## Inspect topics learned 146 | 147 | During training, we output the learned topics after each epoch in the serialization directory, under `model_logs/vampire`. 148 | 149 | After your model is finished training, check out the `best_epoch` field in `model_logs/vampire/metrics.json`, which corresponds to the training epoch at which NPMI is highest. 150 | 151 | Then open up the corresponding epoch's file in `model_logs/vampire/topics/`. 152 | 153 | ## Use VAMPIRE with a downstream classifier 154 | 155 | Using VAMPIRE with a downstream classifier is essentially the same as using regular ELMo. See [this documentation](https://github.com/allenai/allennlp/blob/master/docs/tutorials/how_to/elmo.md#using-elmo-with-existing-allennlp-models) for details on how to do that. 156 | 157 | This library has some convenience functions for including VAMPIRE with a downstream classifier. 158 | 159 | First, set some environment variables: 160 | * `VAMPIRE_DIR`: path to newly trained VAMPIRE 161 | * `VAMPIRE_DIM`: dimensionality of the newly trained VAMPIRE (the token embedder needs it explicitly) 162 | * `THROTTLE`: the sample size of the data we want to train on. 163 | * `EVALUATE_ON_TEST`: whether or not you would like to evaluate on test 164 | 165 | 166 | ``` 167 | export VAMPIRE_DIR="$(pwd)/model_logs/vampire" 168 | export VAMPIRE_DIM=81 169 | export THROTTLE=200 170 | export EVALUATE_ON_TEST=0 171 | ``` 172 | 173 | Then, you can run the classifier: 174 | 175 | ``` 176 | python -m scripts.train \ 177 | --config training_config/classifier.jsonnet \ 178 | --serialization-dir model_logs/clf \ 179 | --environment CLASSIFIER \ 180 | --device -1 181 | ``` 182 | 183 | 184 | As with VAMPIRE, this model can be run on a CPU (`--device -1`). To run on a GPU instead, run with `--device 0` (or any other available CUDA device number) 185 | 186 | This command will output training logs at `model_logs/clf`. 187 | 188 | The dataset sample (specified by `THROTTLE`) is governed by the global seed supplied to the trainer; the same seed will result in the same subsampling of training data. You can set an explicit seed by passing the additional flag `--seed` to the `train` module. 189 | 190 | With 200 examples, we report a test accuracy of `83.9 +- 0.9` over 5 random seeds on the AG dataset. Note that your results may vary beyond these bounds under the low-resource setting. 191 | 192 | ## Troubleshooting 193 | 194 | If you're running into issues during training (e.g. NaN losses), checkout the [troubleshooting](TROUBLESHOOTING.md) file. 195 | -------------------------------------------------------------------------------- /search_spaces/classifier_yahoo_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "CUDA_DEVICE": 0, 4 | "EVALUATE_ON_TEST": 0, 5 | "NUM_EPOCHS": 50, 6 | "SEED": { 7 | "sampling strategy": "integer", 8 | "bounds": [0, 100000] 9 | }, 10 | "TRAIN_PATH": "s3://suching-dev/final-datasets/yahoo/train_pretokenized.jsonl", 11 | "DEV_PATH": "s3://suching-dev/final-datasets/yahoo/dev_pretokenized.jsonl", 12 | "TEST_PATH": "s3://suching-dev/final-datasets/yahoo/test_pretokenized.jsonl", 13 | "THROTTLE": 200, 14 | "USE_SPACY_TOKENIZER": 0, 15 | "FREEZE_EMBEDDINGS": ["VAMPIRE"], 16 | "EMBEDDINGS": ["RANDOM", "VAMPIRE"], 17 | "VAMPIRE_DIRECTORY": { 18 | "sampling strategy": "choice", 19 | "choices": ["logs/vampire_yahoo_search/run_23_2019-06-01_19-28-28cvgkhcjc 108", "logs/vampire_yahoo_search/run_1_2019-06-01_19-20-122fsftfoo 109", "logs/vampire_yahoo_search/run_5_2019-06-02_00-35-39vudoel85 69", "logs/vampire_yahoo_search/run_54_2019-06-02_02-39-23z5f_swhc 116", "logs/vampire_yahoo_search/run_18_2019-06-02_01-10-45j4yfinee 83", "logs/vampire_yahoo_search/run_80_2019-06-02_03-43-468ubj4_6i 117", "logs/vampire_yahoo_search/run_83_2019-06-02_03-46-47zxk30onk 91", "logs/vampire_yahoo_search/run_69_2019-06-02_03-00-39holyeeco 89", "logs/vampire_yahoo_search/run_71_2019-06-01_20-29-40usc_nu7v 119", "logs/vampire_yahoo_search/run_97_2019-06-02_04-15-33djqyb39c 112", "logs/vampire_yahoo_search/run_74_2019-06-02_03-33-17qku8l0ny 95", "logs/vampire_yahoo_search/run_98_2019-06-01_20-59-27j6qdv6hr 118", "logs/vampire_yahoo_search/run_45_2019-06-01_20-02-21tvd4b3jt 75", "logs/vampire_yahoo_search/run_93_2019-06-02_04-07-37h0wfy15g 80", "logs/vampire_yahoo_search/run_91_2019-06-02_04-04-395c6emiot 91", "logs/vampire_yahoo_search/run_67_2019-06-02_02-58-28nfu7umgv 112", "logs/vampire_yahoo_search/run_27_2019-06-02_01-25-544xg9vkmh 110", "logs/vampire_yahoo_search/run_90_2019-06-01_20-51-12kf5kw1s_ 71", "logs/vampire_yahoo_search/run_21_2019-06-01_19-26-21emw42my_ 118", "logs/vampire_yahoo_search/run_3_2019-06-02_00-20-27c9cmmb0f 90", "logs/vampire_yahoo_search/run_6_2019-06-02_00-43-22o234xb99 73", "logs/vampire_yahoo_search/run_87_2019-06-02_03-55-03drzvx5h3 100", "logs/vampire_yahoo_search/run_99_2019-06-02_04-19-5427816b8s 101", "logs/vampire_yahoo_search/run_75_2019-06-01_20-36-48o4_0c0vf 125", "logs/vampire_yahoo_search/run_61_2019-06-02_02-48-19wixtbxqp 110", "logs/vampire_yahoo_search/run_19_2019-06-01_19-25-54qa16l8zm 86", "logs/vampire_yahoo_search/run_92_2019-06-02_04-05-32mu568gw0 127", "logs/vampire_yahoo_search/run_39_2019-06-01_19-46-23rl5qws51 111", "logs/vampire_yahoo_search/run_97_2019-06-01_20-58-47exekymow 118", "logs/vampire_yahoo_search/run_99_2019-06-01_21-00-40sjbb2aub 127", "logs/vampire_yahoo_search/run_76_2019-06-02_03-36-15tsk7gqcy 123", "logs/vampire_yahoo_search/run_20_2019-06-01_19-26-01nyb15mv5 86", "logs/vampire_yahoo_search/run_43_2019-06-02_02-09-45kg5lipi9 121", "logs/vampire_yahoo_search/run_0_2019-06-02_00-20-27c9jzzd3j 104", "logs/vampire_yahoo_search/run_32_2019-06-02_01-33-217bh87q5e 105", "logs/vampire_yahoo_search/run_82_2019-06-02_03-44-51jjxtkpxg 121", "logs/vampire_yahoo_search/run_9_2019-06-02_00-56-20l72zsw8t 120", "logs/vampire_yahoo_search/run_63_2019-06-02_02-52-08m1mvhc0q 113", "logs/vampire_yahoo_search/run_55_2019-06-01_20-10-10h_fr7eiy 120", "logs/vampire_yahoo_search/run_56_2019-06-01_20-13-068mq5eomm 73", "logs/vampire_yahoo_search/run_43_2019-06-01_19-56-38p5sw7873 102", "logs/vampire_yahoo_search/run_36_2019-06-02_01-56-500er73fee 112", "logs/vampire_yahoo_search/run_13_2019-06-02_01-02-03m6pvrkov 68", "logs/vampire_yahoo_search/run_46_2019-06-02_02-18-34surc115u 79", "logs/vampire_yahoo_search/run_85_2019-06-01_20-47-13eiqxaow1 79", "logs/vampire_yahoo_search/run_38_2019-06-02_01-59-310ry4vvjy 104", "logs/vampire_yahoo_search/run_28_2019-06-01_19-30-56kn33p_kk 88", "logs/vampire_yahoo_search/run_19_2019-06-02_01-12-11ty0h4x87 125", "logs/vampire_yahoo_search/run_68_2019-06-01_20-26-41hc6mwfpt 125", "logs/vampire_yahoo_search/run_68_2019-06-02_02-58-29vtehbtm9 109", "logs/vampire_yahoo_search/run_6_2019-06-01_19-20-13krww0roi 125", "logs/vampire_yahoo_search/run_38_2019-06-01_19-42-485g_d5hsu 99", "logs/vampire_yahoo_search/run_7_2019-06-02_00-44-09c9rpz43x 72", "logs/vampire_yahoo_search/run_79_2019-06-01_20-40-243rsqn7fn 103", "logs/vampire_yahoo_search/run_73_2019-06-01_20-31-17x64cnsi4 101", "logs/vampire_yahoo_search/run_72_2019-06-02_03-22-35uf8vwv1d 103", "logs/vampire_yahoo_search/run_70_2019-06-02_03-13-02x6o9cw16 121", "logs/vampire_yahoo_search/run_26_2019-06-01_19-29-19jyvqos4g 86", "logs/vampire_yahoo_search/run_51_2019-06-01_20-05-26b39_ikwo 126", "logs/vampire_yahoo_search/run_59_2019-06-01_20-15-563cj1izi_ 116", "logs/vampire_yahoo_search/run_21_2019-06-02_01-15-20m2wtq40d 105", "logs/vampire_yahoo_search/run_1_2019-06-02_00-20-27iad_iit7 67", "logs/vampire_yahoo_search/run_42_2019-06-01_19-56-10hv_ld2kd 77", "logs/vampire_yahoo_search/run_37_2019-06-01_19-41-58safmvujx 109", "logs/vampire_yahoo_search/run_83_2019-06-01_20-46-29_velu7vb 119", "logs/vampire_yahoo_search/run_40_2019-06-02_02-02-22b3oyob42 84", "logs/vampire_yahoo_search/run_85_2019-06-02_03-49-20ljijjy_h 64", "logs/vampire_yahoo_search/run_78_2019-06-01_20-40-14nnp13yuz 69", "logs/vampire_yahoo_search/run_30_2019-06-02_01-31-35i305wt_b 127", "logs/vampire_yahoo_search/run_2_2019-06-02_00-20-275wit9a6x 77", "logs/vampire_yahoo_search/run_81_2019-06-01_20-43-38yn128p50 97", "logs/vampire_yahoo_search/run_41_2019-06-01_19-55-130uoduv1u 90", "logs/vampire_yahoo_search/run_34_2019-06-02_01-36-05we4advtz 117", "logs/vampire_yahoo_search/run_22_2019-06-01_19-26-52vx51wogj 96", "logs/vampire_yahoo_search/run_96_2019-06-02_04-14-53xvzlidy2 93", "logs/vampire_yahoo_search/run_53_2019-06-02_02-38-11gpb6aujg 102", "logs/vampire_yahoo_search/run_8_2019-06-02_00-48-12l2cidtqj 127"] 20 | }, 21 | "ENCODER": { 22 | "sampling strategy": "choice", 23 | "choices": ["AVERAGE"] 24 | }, 25 | "EMBEDDING_DROPOUT": 0.5, 26 | "LEARNING_RATE": 0.004, 27 | "DROPOUT": 0.5, 28 | "BATCH_SIZE": 32, 29 | "NUM_ENCODER_LAYERS": { 30 | "sampling strategy": "choice", 31 | "choices": [1, 2, 3] 32 | }, 33 | "NUM_OUTPUT_LAYERS": { 34 | "sampling strategy": "choice", 35 | "choices": [1, 2, 3] 36 | }, 37 | "MAX_FILTER_SIZE": { 38 | "sampling strategy": "integer", 39 | "bounds": [3, 6] 40 | }, 41 | "NUM_FILTERS": { 42 | "sampling strategy": "integer", 43 | "bounds": [64, 512] 44 | }, 45 | "HIDDEN_SIZE": { 46 | "sampling strategy": "integer", 47 | "bounds": [64, 512] 48 | }, 49 | "AGGREGATIONS": { 50 | "sampling strategy": "subset", 51 | "choices": ["maxpool", "meanpool", "attention", "final_state"] 52 | }, 53 | "MAX_CHARACTER_FILTER_SIZE": { 54 | "sampling strategy": "integer", 55 | "bounds": [3, 6] 56 | }, 57 | "NUM_CHARACTER_FILTERS": { 58 | "sampling strategy": "integer", 59 | "bounds": [16, 64] 60 | }, 61 | "CHARACTER_HIDDEN_SIZE": { 62 | "sampling strategy": "integer", 63 | "bounds": [16, 128] 64 | }, 65 | "CHARACTER_EMBEDDING_DIM": { 66 | "sampling strategy": "integer", 67 | "bounds": [16, 128] 68 | }, 69 | "CHARACTER_ENCODER": { 70 | "sampling strategy": "choice", 71 | "choices": ["LSTM", "CNN", "AVERAGE"] 72 | }, 73 | "NUM_CHARACTER_ENCODER_LAYERS": { 74 | "sampling strategy": "choice", 75 | "choices": [1, 2] 76 | } 77 | } -------------------------------------------------------------------------------- /scripts/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from typing import List 5 | 6 | import nltk 7 | import numpy as np 8 | import pandas as pd 9 | import spacy 10 | from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter 11 | from scipy import sparse 12 | from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer 13 | from spacy.tokenizer import Tokenizer 14 | from tqdm import tqdm 15 | 16 | from vampire.common.util import read_text, save_sparse, write_to_json 17 | 18 | 19 | def load_data(data_path: str, tokenize: bool = False, tokenizer_type: str = "just_spaces") -> List[str]: 20 | if tokenizer_type == "just_spaces": 21 | tokenizer = SpacyWordSplitter() 22 | elif tokenizer_type == "spacy": 23 | nlp = spacy.load('en') 24 | tokenizer = Tokenizer(nlp.vocab) 25 | tokenized_examples = [] 26 | with tqdm(open(data_path, "r"), desc=f"loading {data_path}") as f: 27 | for line in f: 28 | if data_path.endswith(".jsonl") or data_path.endswith(".json"): 29 | example = json.loads(line) 30 | else: 31 | example = {"text": line} 32 | if tokenize: 33 | if tokenizer_type == 'just_spaces': 34 | tokens = list(map(str, tokenizer.split_words(example['text']))) 35 | elif tokenizer_type == 'spacy': 36 | tokens = list(map(str, tokenizer(example['text']))) 37 | text = ' '.join(tokens) 38 | else: 39 | text = example['text'] 40 | tokenized_examples.append(text) 41 | return tokenized_examples 42 | 43 | def main(): 44 | parser = argparse.ArgumentParser(formatter_class = argparse.ArgumentDefaultsHelpFormatter) 45 | parser.add_argument("--train-path", type=str, required=True, 46 | help="Path to the train jsonl file.") 47 | parser.add_argument("--dev-path", type=str, required=True, 48 | help="Path to the dev jsonl file.") 49 | parser.add_argument("--serialization-dir", "-s", type=str, required=True, 50 | help="Path to store the preprocessed output.") 51 | parser.add_argument("--tfidf", action='store_true', 52 | help="use TFIDF as input") 53 | parser.add_argument("--vocab-size", type=int, required=False, default=10000, 54 | help="Path to store the preprocessed corpus vocabulary (output file name).") 55 | parser.add_argument("--tokenize", action='store_true', 56 | help="Path to store the preprocessed corpus vocabulary (output file name).") 57 | parser.add_argument("--tokenizer-type", type=str, default="just_spaces", 58 | help="Path to store the preprocessed corpus vocabulary (output file name).") 59 | parser.add_argument("--reference-corpus-path", type=str, required=False, 60 | help="Path to store the preprocessed corpus vocabulary (output file name).") 61 | parser.add_argument("--tokenize-reference", action='store_true', 62 | help="Path to store the preprocessed corpus vocabulary (output file name).") 63 | parser.add_argument("--reference-tokenizer-type", type=str, default="just_spaces", 64 | help="Path to store the preprocessed corpus vocabulary (output file name).") 65 | args = parser.parse_args() 66 | 67 | if not os.path.isdir(args.serialization_dir): 68 | os.mkdir(args.serialization_dir) 69 | 70 | vocabulary_dir = os.path.join(args.serialization_dir, "vocabulary") 71 | 72 | if not os.path.isdir(vocabulary_dir): 73 | os.mkdir(vocabulary_dir) 74 | 75 | tokenized_train_examples = load_data(args.train_path, args.tokenize, args.tokenizer_type) 76 | tokenized_dev_examples = load_data(args.dev_path, args.tokenize, args.tokenizer_type) 77 | 78 | print("fitting count vectorizer...") 79 | if args.tfidf: 80 | count_vectorizer = TfidfVectorizer(stop_words='english', max_features=args.vocab_size, token_pattern=r'\b[^\d\W]{3,30}\b') 81 | else: 82 | count_vectorizer = CountVectorizer(stop_words='english', max_features=args.vocab_size, token_pattern=r'\b[^\d\W]{3,30}\b') 83 | 84 | text = tokenized_train_examples + tokenized_dev_examples 85 | 86 | count_vectorizer.fit(tqdm(text)) 87 | 88 | vectorized_train_examples = count_vectorizer.transform(tqdm(tokenized_train_examples)) 89 | vectorized_dev_examples = count_vectorizer.transform(tqdm(tokenized_dev_examples)) 90 | 91 | if args.tfidf: 92 | reference_vectorizer = TfidfVectorizer(stop_words='english', token_pattern=r'\b[^\d\W]{3,30}\b') 93 | else: 94 | reference_vectorizer = CountVectorizer(stop_words='english', token_pattern=r'\b[^\d\W]{3,30}\b') 95 | if not args.reference_corpus_path: 96 | print("fitting reference corpus using development data...") 97 | reference_matrix = reference_vectorizer.fit_transform(tqdm(tokenized_dev_examples)) 98 | else: 99 | print(f"loading reference corpus at {args.reference_corpus_path}...") 100 | reference_examples = load_data(args.reference_corpus_path, args.tokenize_reference, args.reference_tokenizer_type) 101 | print("fitting reference corpus...") 102 | reference_matrix = reference_vectorizer.fit_transform(tqdm(reference_examples)) 103 | 104 | reference_vocabulary = reference_vectorizer.get_feature_names() 105 | 106 | # add @@unknown@@ token vector 107 | vectorized_train_examples = sparse.hstack((np.array([0] * len(tokenized_train_examples))[:,None], vectorized_train_examples)) 108 | vectorized_dev_examples = sparse.hstack((np.array([0] * len(tokenized_dev_examples))[:,None], vectorized_dev_examples)) 109 | master = sparse.vstack([vectorized_train_examples, vectorized_dev_examples]) 110 | 111 | # generate background frequency 112 | print("generating background frequency...") 113 | bgfreq = dict(zip(count_vectorizer.get_feature_names(), (np.array(master.sum(0)) / args.vocab_size).squeeze())) 114 | 115 | print("saving data...") 116 | save_sparse(vectorized_train_examples, os.path.join(args.serialization_dir, "train.npz")) 117 | save_sparse(vectorized_dev_examples, os.path.join(args.serialization_dir, "dev.npz")) 118 | if not os.path.isdir(os.path.join(args.serialization_dir, "reference")): 119 | os.mkdir(os.path.join(args.serialization_dir, "reference")) 120 | save_sparse(reference_matrix, os.path.join(args.serialization_dir, "reference", "ref.npz")) 121 | write_to_json(reference_vocabulary, os.path.join(args.serialization_dir, "reference", "ref.vocab.json")) 122 | write_to_json(bgfreq, os.path.join(args.serialization_dir, "vampire.bgfreq")) 123 | 124 | write_list_to_file(['@@UNKNOWN@@'] + count_vectorizer.get_feature_names(), os.path.join(vocabulary_dir, "vampire.txt")) 125 | write_list_to_file(['*tags', '*labels', 'vampire'], os.path.join(vocabulary_dir, "non_padded_namespaces.txt")) 126 | 127 | def write_list_to_file(ls, save_path): 128 | """ 129 | Write each json object in 'jsons' as its own line in the file designated by 'save_path'. 130 | """ 131 | # Open in appendation mode given that this function may be called multiple 132 | # times on the same file (positive and negative sentiment are in separate 133 | # directories). 134 | out_file = open(save_path, "w+") 135 | for example in ls: 136 | out_file.write(example) 137 | out_file.write('\n') 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /vampire/data/dataset_readers/semisupervised_text_classification_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from io import TextIOWrapper 4 | from typing import Dict 5 | import numpy as np 6 | from overrides import overrides 7 | from allennlp.common.checks import ConfigurationError 8 | from allennlp.common.file_utils import cached_path 9 | from allennlp.data.dataset_readers import TextClassificationJsonReader 10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 11 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 12 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer 13 | from allennlp.data.tokenizers.sentence_splitter import SpacySentenceSplitter 14 | from allennlp.data.instance import Instance 15 | from allennlp.data.fields import LabelField, TextField, Field 16 | 17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 18 | 19 | 20 | @DatasetReader.register("semisupervised_text_classification_json") 21 | class SemiSupervisedTextClassificationJsonReader(TextClassificationJsonReader): 22 | """ 23 | Reads tokens and (optionally) their labels from a from text classification dataset. 24 | 25 | This dataset reader inherits from TextClassificationJSONReader, but differs from its parent 26 | in that it is primed for semisupervised learning. This dataset reader allows for: 27 | 1) Ignoring labels in the training data (e.g. for unsupervised pretraining) 28 | 2) Reading additional unlabeled data from another file 29 | 3) Throttling the training data to a random subsample (according to the numpy seed), 30 | for analysis of the effect of semisupervised models on different amounts of labeled 31 | data 32 | 33 | Expects a "tokens" field and a "label" field in JSON format. 34 | 35 | The output of ``read`` is a list of ``Instances`` with the fields: 36 | tokens: ``TextField`` and 37 | label: ``LabelField``, if not ignoring labels. 38 | 39 | Parameters 40 | ---------- 41 | token_indexers : ``Dict[str, TokenIndexer]``, optional 42 | optional (default=``{"tokens": SingleIdTokenIndexer()}``) 43 | We use this to define the input representation for the text. 44 | See :class:`TokenIndexer`. 45 | tokenizer : ``Tokenizer``, optional (default = ``{"tokens": WordTokenizer()}``) 46 | Tokenizer to split the input text into words or other kinds of tokens. 47 | sequence_length: ``int``, optional (default = ``None``) 48 | If specified, will truncate tokens to specified maximum length. 49 | ignore_labels: ``bool``, optional (default = ``False``) 50 | If specified, will ignore labels when reading data. 51 | sample: ``int``, optional (default = ``None``) 52 | If specified, will sample data to a specified length. 53 | **Note**: 54 | 1) This operation will *not* apply to any additional unlabeled data 55 | (specified in `additional_unlabeled_data_path`). 56 | 2) To produce a consistent subsample of data, use a consistent seed in your 57 | training config. 58 | skip_label_indexing: ``bool``, optional (default = ``False``) 59 | Whether or not to skip label indexing. You might want to skip label indexing if your 60 | labels are numbers, so the dataset reader doesn't re-number them starting from 0. 61 | lazy : ``bool``, optional, (default = ``False``) 62 | Whether or not instances can be read lazily. 63 | """ 64 | def __init__(self, 65 | token_indexers: Dict[str, TokenIndexer] = None, 66 | tokenizer: Tokenizer = None, 67 | max_sequence_length: int = None, 68 | ignore_labels: bool = False, 69 | sample: int = None, 70 | skip_label_indexing: bool = False, 71 | lazy: bool = False) -> None: 72 | super().__init__(lazy=lazy, 73 | token_indexers=token_indexers, 74 | tokenizer=tokenizer, 75 | max_sequence_length=max_sequence_length, 76 | skip_label_indexing=skip_label_indexing) 77 | self._tokenizer = tokenizer or WordTokenizer() 78 | self._sample = sample 79 | self._max_sequence_length = max_sequence_length 80 | self._ignore_labels = ignore_labels 81 | self._skip_label_indexing = skip_label_indexing 82 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 83 | if self._segment_sentences: 84 | self._sentence_segmenter = SpacySentenceSplitter() 85 | 86 | @staticmethod 87 | def _reservoir_sampling(file_: TextIOWrapper, sample: int): 88 | """ 89 | A function for reading random lines from file without loading the 90 | entire file into memory. 91 | 92 | For more information, see here: https://en.wikipedia.org/wiki/Reservoir_sampling 93 | 94 | To create a k-length sample of a file, without knowing the length of the file in advance, 95 | we first create a reservoir array containing the first k elements of the file. Then, we further 96 | iterate through the file, replacing elements in the reservoir with decreasing probability. 97 | 98 | By induction, one can prove that if there are n items in the file, each item is sampled with probability 99 | k / n. 100 | 101 | Parameters 102 | ---------- 103 | file : `_io.TextIOWrapper` - file path 104 | sample_size : `int` - size of random sample you want 105 | 106 | Returns 107 | ------- 108 | result : `List[str]` - sample lines of file 109 | """ 110 | # instantiate file iterator 111 | file_iterator = iter(file_) 112 | 113 | try: 114 | # fill the reservoir array 115 | result = [next(file_iterator) for _ in range(sample)] 116 | except StopIteration: 117 | raise ConfigurationError(f"sample size {sample} larger than number of lines in file.") 118 | 119 | # replace elements in reservoir array with decreasing probability 120 | for index, item in enumerate(file_iterator, start=sample): 121 | sample_index = np.random.randint(0, index) 122 | if sample_index < sample: 123 | result[sample_index] = item 124 | 125 | for line in result: 126 | yield line 127 | 128 | @overrides 129 | def _read(self, file_path): 130 | with open(cached_path(file_path), "r") as data_file: 131 | if self._sample is not None: 132 | data_file = self._reservoir_sampling(data_file, self._sample) 133 | for line in data_file: 134 | items = json.loads(line) 135 | text = items["text"] 136 | if self._ignore_labels: 137 | instance = self.text_to_instance(text=text, label=None) 138 | else: 139 | label = str(items.get('label')) 140 | instance = self.text_to_instance(text=text, label=label) 141 | if instance is not None and instance.fields['tokens'].tokens: 142 | yield instance 143 | 144 | @overrides 145 | def text_to_instance(self, text: str, label: str = None) -> Instance: # type: ignore 146 | """ 147 | Parameters 148 | ---------- 149 | text : ``str``, required. 150 | The text to classify 151 | label ``str``, optional, (default = None). 152 | The label for this text. 153 | 154 | Returns 155 | ------- 156 | An ``Instance`` containing the following fields: 157 | tokens : ``TextField`` 158 | The tokens in the sentence or phrase. 159 | label : ``LabelField`` 160 | The label label of the sentence or phrase. 161 | """ 162 | # pylint: disable=arguments-differ 163 | fields: Dict[str, Field] = {} 164 | tokens = self._tokenizer.tokenize(text) 165 | if self._max_sequence_length is not None: 166 | tokens = self._truncate(tokens) 167 | fields['tokens'] = TextField(tokens, self._token_indexers) 168 | if label is not None: 169 | fields['label'] = LabelField(label, 170 | skip_indexing=self._skip_label_indexing) 171 | return Instance(fields) 172 | -------------------------------------------------------------------------------- /search_spaces/classifier_imdb_search.json: -------------------------------------------------------------------------------- 1 | { 2 | "LAZY_DATASET_READER": 0, 3 | "CUDA_DEVICE": 0, 4 | "EVALUATE_ON_TEST": 0, 5 | "NUM_EPOCHS": 50, 6 | "SEED": { 7 | "sampling strategy": "integer", 8 | "bounds": [0, 100000] 9 | }, 10 | "TRAIN_PATH": "s3://suching-dev/final-datasets/imdb/train_pretokenized.jsonl", 11 | "DEV_PATH": "s3://suching-dev/final-datasets/imdb/dev_pretokenized.jsonl", 12 | "TEST_PATH": "s3://suching-dev/final-datasets/imdb/test_pretokenized.jsonl", 13 | "THROTTLE": 200, 14 | "USE_SPACY_TOKENIZER": 0, 15 | "FREEZE_EMBEDDINGS": ["VAMPIRE"], 16 | "EMBEDDINGS": ["RANDOM", "VAMPIRE"], 17 | "VAMPIRE_DIRECTORY": { 18 | "sampling strategy": "choice", 19 | "choices": ["logs/vampire_imdb_search/run_91_2019-06-02_04-48-06ijlgz_6q 126", "logs/vampire_imdb_search/run_33_2019-06-02_01-55-382xq9buid 83", "logs/vampire_imdb_search/run_94_2019-06-02_04-59-51i7do9_8i 74", "logs/vampire_imdb_search/run_6_2019-06-01_22-21-47jeh238zo 69", "logs/vampire_imdb_search/run_64_2019-06-02_03-20-020sfendog 73", "logs/vampire_imdb_search/run_7_2019-06-02_00-29-09f_xgcb0z 78", "logs/vampire_imdb_search/run_74_2019-06-02_03-43-462aqur5s5 114", "logs/vampire_imdb_search/run_34_2019-06-02_01-58-59k63vs9ev 98", "logs/vampire_imdb_search/run_66_2019-06-02_03-22-01t0_x5m7f 111", "logs/vampire_imdb_search/run_11_2019-06-01_22-31-51e2o4kc45 95", "logs/vampire_imdb_search/run_36_2019-06-02_02-03-192tk3_le6 77", "logs/vampire_imdb_search/run_19_2019-06-01_22-43-35j4we9hpz 116", "logs/vampire_imdb_search/run_50_2019-06-02_02-36-44ra49nh4n 125", "logs/vampire_imdb_search/run_26_2019-06-02_01-25-25hkr43rg0 103", "logs/vampire_imdb_search/run_15_2019-06-02_00-58-13dtlp5msu 111", "logs/vampire_imdb_search/run_47_2019-06-02_02-32-045iyqr8bk 119", "logs/vampire_imdb_search/run_35_2019-06-02_02-01-20u5gv_4t7 107", "logs/vampire_imdb_search/run_19_2019-06-02_01-10-31izwywphe 65", "logs/vampire_imdb_search/run_22_2019-06-02_01-15-51m34jzqsr 87", "logs/vampire_imdb_search/run_39_2019-06-02_02-10-0156v62lzr 102", "logs/vampire_imdb_search/run_72_2019-06-02_03-39-04wr8dse28 77", "logs/vampire_imdb_search/run_52_2019-06-02_02-43-32j7skazfp 126", "logs/vampire_imdb_search/run_27_2019-06-01_22-57-07iqgqcmxk 97", "logs/vampire_imdb_search/run_84_2019-06-02_04-24-41kzedjqj6 87", "logs/vampire_imdb_search/run_2_2019-06-01_22-21-474_f5zu5f 108", "logs/vampire_imdb_search/run_79_2019-06-02_04-10-523rbrut4k 121", "logs/vampire_imdb_search/run_34_2019-06-01_23-05-22_7sjvsn7 107", "logs/vampire_imdb_search/run_41_2019-06-02_02-16-321vlnz8jm 124", "logs/vampire_imdb_search/run_85_2019-06-02_04-34-31x8pd9otp 80", "logs/vampire_imdb_search/run_18_2019-06-01_22-43-20v6hsn89z 118", "logs/vampire_imdb_search/run_16_2019-06-02_01-01-090bb76rdi 120", "logs/vampire_imdb_search/run_32_2019-06-02_01-49-09db2pq7dm 93", "logs/vampire_imdb_search/run_29_2019-06-02_01-38-292_1lpbw5 88", "logs/vampire_imdb_search/run_14_2019-06-02_00-49-22jc6l1ryj 97", "logs/vampire_imdb_search/run_18_2019-06-02_01-07-37p38ruh08 120", "logs/vampire_imdb_search/run_23_2019-06-02_01-19-11iv4xgwy6 80", "logs/vampire_imdb_search/run_16_2019-06-01_22-39-00kj7ljpl6 85", "logs/vampire_imdb_search/run_46_2019-06-02_02-26-44nnqrynz1 111", "logs/vampire_imdb_search/run_89_2019-06-02_04-42-35pp5ny8yh 101", "logs/vampire_imdb_search/run_2_2019-06-02_00-20-21ijlztkn8 112", "logs/vampire_imdb_search/run_44_2019-06-02_02-26-20wpao6vh2 91", "logs/vampire_imdb_search/run_23_2019-06-01_22-48-23he8fn9x4 74", "logs/vampire_imdb_search/run_97_2019-06-02_05-03-47_mx0_rzo 71", "logs/vampire_imdb_search/run_17_2019-06-01_22-42-26y0sa9gjk 67", "logs/vampire_imdb_search/run_4_2019-06-02_00-23-09k5hs1evb 122", "logs/vampire_imdb_search/run_9_2019-06-01_22-29-36gb3x0mtk 83", "logs/vampire_imdb_search/run_77_2019-06-02_04-02-456o3u0mk0 127", "logs/vampire_imdb_search/run_9_2019-06-02_00-39-27jsxkj4v5 79", "logs/vampire_imdb_search/run_21_2019-06-01_22-46-155uqu20sx 76", "logs/vampire_imdb_search/run_22_2019-06-01_22-47-128b1zagxx 67", "logs/vampire_imdb_search/run_8_2019-06-01_22-28-579pa36ihl 97", "logs/vampire_imdb_search/run_69_2019-06-02_03-27-380npzw923 67", "logs/vampire_imdb_search/run_80_2019-06-02_04-15-529czzt3o2 70", "logs/vampire_imdb_search/run_38_2019-06-02_02-05-39uvtg6xj0 81", "logs/vampire_imdb_search/run_71_2019-06-02_03-37-09obcz4sal 68", "logs/vampire_imdb_search/run_36_2019-06-01_23-05-55a6kh6uy9 79", "logs/vampire_imdb_search/run_37_2019-06-01_23-07-12nwtpiexq 89", "logs/vampire_imdb_search/run_40_2019-06-02_02-15-470qx4j_7h 116", "logs/vampire_imdb_search/run_87_2019-06-02_04-38-38a0ipi0q3 101", "logs/vampire_imdb_search/run_11_2019-06-02_00-45-21g0hjciv0 104", "logs/vampire_imdb_search/run_20_2019-06-01_22-44-1288czx4zf 96", "logs/vampire_imdb_search/run_31_2019-06-02_01-46-23c1mc6pn5 110", "logs/vampire_imdb_search/run_90_2019-06-02_04-47-43qcer2e9q 72", "logs/vampire_imdb_search/run_93_2019-06-02_04-59-1006g5fux1 103", "logs/vampire_imdb_search/run_98_2019-06-02_05-08-07i0kjg87q 123", "logs/vampire_imdb_search/run_78_2019-06-02_04-06-30jz2wz2k8 120", "logs/vampire_imdb_search/run_30_2019-06-02_01-39-522eyxca9a 70", "logs/vampire_imdb_search/run_8_2019-06-02_00-34-43588y5hmy 105", "logs/vampire_imdb_search/run_73_2019-06-02_03-41-05apko38oh 83", "logs/vampire_imdb_search/run_12_2019-06-02_00-48-22y1zwa19v 97", "logs/vampire_imdb_search/run_10_2019-06-02_00-44-590r7yu8tl 77", "logs/vampire_imdb_search/run_45_2019-06-02_02-26-26youl0d6h 111", "logs/vampire_imdb_search/run_54_2019-06-02_02-50-01b2o67ns2 120", "logs/vampire_imdb_search/run_5_2019-06-01_22-21-473o2cbkab 72", "logs/vampire_imdb_search/run_62_2019-06-02_03-09-00r_n4z60e 121", "logs/vampire_imdb_search/run_83_2019-06-02_04-21-48g35nquif 101", "logs/vampire_imdb_search/run_7_2019-06-01_22-21-47rcadyucz 64", "logs/vampire_imdb_search/run_13_2019-06-02_00-48-58u0imz93c 100", "logs/vampire_imdb_search/run_21_2019-06-02_01-14-3147hts7c6 91", "logs/vampire_imdb_search/run_65_2019-06-02_03-20-039n80pw_j 98", "logs/vampire_imdb_search/run_0_2019-06-01_22-21-46cuyfqol1 112", "logs/vampire_imdb_search/run_51_2019-06-02_02-36-4995anlxqd 105", "logs/vampire_imdb_search/run_15_2019-06-01_22-37-10i08epwat 89", "logs/vampire_imdb_search/run_59_2019-06-02_02-55-05hi5nmi25 88", "logs/vampire_imdb_search/run_82_2019-06-02_04-18-26fagp1f1w 94", "logs/vampire_imdb_search/run_56_2019-06-02_02-51-4797nxqw2p 122", "logs/vampire_imdb_search/run_61_2019-06-02_03-05-26qot_fg0r 124", "logs/vampire_imdb_search/run_32_2019-06-01_23-02-477rgyswxb 117", "logs/vampire_imdb_search/run_1_2019-06-01_22-21-47ougul4dn 84", "logs/vampire_imdb_search/run_99_2019-06-02_05-09-278jpez8lv 89", "logs/vampire_imdb_search/run_75_2019-06-02_03-48-33i5cy8utm 102", "logs/vampire_imdb_search/run_58_2019-06-02_02-53-517_5hfdwt 92", "logs/vampire_imdb_search/run_4_2019-06-01_22-21-47q4y3asxs 88", "logs/vampire_imdb_search/run_67_2019-06-02_03-25-392kkou37j 90", "logs/vampire_imdb_search/run_88_2019-06-02_04-38-54cyk9dp6t 112", "logs/vampire_imdb_search/run_24_2019-06-02_01-19-31qsj_xdk9 99", "logs/vampire_imdb_search/run_3_2019-06-01_22-21-47h15993ke 104", "logs/vampire_imdb_search/run_12_2019-06-01_22-32-2593wcbaq7 77", "logs/vampire_imdb_search/run_27_2019-06-02_01-29-17ykxgpvd3 79", "logs/vampire_imdb_search/run_43_2019-06-02_02-22-28azwhvty8 71", "logs/vampire_imdb_search/run_76_2019-06-02_04-00-40fgkl94w8 97", "logs/vampire_imdb_search/run_31_2019-06-01_22-58-109fj2ttef 95", "logs/vampire_imdb_search/run_1_2019-06-02_00-20-2050gugxkj 93", "logs/vampire_imdb_search/run_5_2019-06-02_00-27-06jr2x9rht 105", "logs/vampire_imdb_search/run_60_2019-06-02_03-04-32211ly3ak 87", "logs/vampire_imdb_search/run_25_2019-06-02_01-23-46yuqr3up4 64", "logs/vampire_imdb_search/run_81_2019-06-02_04-17-00wp97lizc 94"] 20 | }, 21 | "ENCODER": { 22 | "sampling strategy": "choice", 23 | "choices": ["AVERAGE"] 24 | }, 25 | "EMBEDDING_DROPOUT": 0.5, 26 | "LEARNING_RATE": 0.004, 27 | "DROPOUT": 0.5, 28 | "BATCH_SIZE": 32, 29 | "NUM_ENCODER_LAYERS": { 30 | "sampling strategy": "choice", 31 | "choices": [1, 2, 3] 32 | }, 33 | "NUM_OUTPUT_LAYERS": { 34 | "sampling strategy": "choice", 35 | "choices": [1, 2, 3] 36 | }, 37 | "MAX_FILTER_SIZE": { 38 | "sampling strategy": "integer", 39 | "bounds": [3, 6] 40 | }, 41 | "NUM_FILTERS": { 42 | "sampling strategy": "integer", 43 | "bounds": [64, 512] 44 | }, 45 | "HIDDEN_SIZE": { 46 | "sampling strategy": "integer", 47 | "bounds": [64, 512] 48 | }, 49 | "AGGREGATIONS": { 50 | "sampling strategy": "subset", 51 | "choices": ["maxpool", "meanpool", "attention", "final_state"] 52 | }, 53 | "MAX_CHARACTER_FILTER_SIZE": { 54 | "sampling strategy": "integer", 55 | "bounds": [3, 6] 56 | }, 57 | "NUM_CHARACTER_FILTERS": { 58 | "sampling strategy": "integer", 59 | "bounds": [16, 64] 60 | }, 61 | "CHARACTER_HIDDEN_SIZE": { 62 | "sampling strategy": "integer", 63 | "bounds": [16, 128] 64 | }, 65 | "CHARACTER_EMBEDDING_DIM": { 66 | "sampling strategy": "integer", 67 | "bounds": [16, 128] 68 | }, 69 | "CHARACTER_ENCODER": { 70 | "sampling strategy": "choice", 71 | "choices": ["LSTM", "CNN", "AVERAGE"] 72 | }, 73 | "NUM_CHARACTER_ENCODER_LAYERS": { 74 | "sampling strategy": "choice", 75 | "choices": [1, 2] 76 | } 77 | } -------------------------------------------------------------------------------- /vampire/tests/data/dataset_readers/semisupervised_text_classification_json_test.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=no-self-use,invalid-name 2 | import pytest 3 | from allennlp.common.checks import ConfigurationError 4 | from allennlp.common.params import Params 5 | from vampire.common.testing import VAETestCase 6 | from allennlp.common.util import ensure_list, prepare_environment 7 | 8 | from vampire.data.dataset_readers import SemiSupervisedTextClassificationJsonReader 9 | 10 | 11 | class TestSemiSupervisedTextClassificationJsonReader(VAETestCase): 12 | 13 | @pytest.mark.parametrize("lazy", (True, False)) 14 | def test_read_from_file(self): 15 | reader = SemiSupervisedTextClassificationJsonReader() 16 | ag_path = self.FIXTURES_ROOT / "imdb" / "train.jsonl" 17 | instances = reader.read(ag_path) 18 | instances = ensure_list(instances) 19 | 20 | instance1 = {"tokens": ['...', 'And', 'I', 'never', 'thought', 'a', 'movie', 'deserved', 'to', 21 | 'be', 'awarded', 'a', '1', '!', 'But', 'this', 'one', 'is', 'honestly', 22 | 'the', 'worst', 'movie', 23 | 'I', "'ve", 'ever', 'watched', '.', 'My', 'wife', 'picked', 'it', 'up', 24 | 'because', 'of', 'the', 'cast', ',', 25 | 'but', 'the', 'storyline', 'right', 'since', 'the', 'DVD', 'box', 26 | 'seemed', 'quite', 27 | 'predictable', '.', 'It', 'is', 'not', 'a', 'mystery', ',', 'nor', 28 | 'a', 'juvenile', 29 | '-', 'catching', 'film', '.', 'It', 'does', 'not', 'include', 'any', 30 | 'sensuality', 31 | ',', 'if', 'that', "'s", 'what', 'the', 'title', 'could', 'remotely', 32 | 'have', 'suggest', 33 | 'any', 'of', 'you', '.', 'This', 'is', 'just', 'a', 'total', 'no', '-', 34 | 'no', '.', 35 | 'Do', "n't", 'waste', 'your', 'time', 'or', 'money', 'unless', 'you', 36 | 'feel', 'like', 37 | 'watching', 'a', 'bunch', 'of', 'youngsters', 'in', 'a', 'as', 38 | '-', 'grown', '-', 'up', 39 | 'kind', 'of', 'Gothic', 'setting', ',', 'where', 'a', 'killer', 40 | 'is', 'going', 'after', 41 | 'them', '.', 'Nothing', 'new', ',', 'nothing', 'interesting', ',', 42 | 'nothing', 'worth', 43 | 'watching', '.', 'Max', 'Makowski', 'makes', 'the', 'worst', 'of', 44 | 'Nick', 'Stahl', '.'], 45 | "label": "neg"} 46 | instance2 = {"tokens": ['The', 'fight', 'scenes', 'were', 'great', '.', 'Loved', 'the', 47 | 'old', 'and', 'newer', 48 | 'cylons', 'and', 'how', 'they', 'painted', 'the', 'ones', 'on', 49 | 'their', 'side', '.', 'It', 50 | 'was', 'the', 'ending', 'that', 'I', 'hated', '.', 'I', 'was', 51 | 'disappointed', 'that', 'it', 52 | 'was', 'earth', 'but', '150k', 'years', 'back', '.', 'But', 'to', 53 | 'travel', 'all', 'that', 54 | 'way', 'just', 'to', 'start', 'over', '?', 'Are', 'you', 'kidding', 55 | 'me', '?', '38k', 'people', 56 | 'that', 'fought', 'for', 'their', 'very', 'existence', 'and', 'once', 57 | 'they', 'get', 'to', 58 | 'paradise', ',', 'they', 'abandon', 'technology', '?', 'No', 59 | 'way', '.', 'Sure', 'they', 60 | 'were', 'eating', 'paper', 'and', 'rationing', 'food', ',', 61 | 'but', 'that', 'is', 'over', 62 | '.', 'They', 'can', 'live', 'like', 'humans', 'again', '.', 63 | 'They', 'only', 'have', 'one', 64 | 'good', 'doctor', '.', 'What', 'are', 'they', 'going', 'to', 'do', 65 | 'when', 'someone', 'has', 66 | 'a', 'tooth', 'ache', 'never', 'mind', 'giving', 'birth', '...', 67 | 'yea', 'right', '.', 68 | 'No', 'one', 'would', 'have', 'made', 'that', 'choice', '.'], 69 | "label": "pos"} 70 | instance3 = {"tokens": ['The', 'only', 'way', 'this', 'is', 'a', 'family', 'drama', 'is', 71 | 'if', 'parents', 'explain', 72 | 'everything', 'wrong', 'with', 'its', 'message.SPOILER', ':', 'they', 'feed', 74 | 'a', 'deer', 'for', 'a', 'year', 'and', 'then', 'kill', 'it', 75 | 'for', 'eating', 'their', 'food', 76 | 'after', 'killing', 'its', 'mother', 'and', 'at', 'first', 77 | 'pontificating', 'about', 'taking', 78 | 'responsibility', 'for', 'their', 'actions', '.', 'They', 'blame', 79 | 'bears', 'and', 'deer', 80 | 'for', '"', 'misbehaving', '"', 'by', 'eating', 'while', 'they', 81 | 'take', 'no', 'responsibility', 82 | 'to', 'use', 'adequate', 'locks', 'and', 'fences', 'or', 'even', 83 | 'learn', 'to', 'shoot', 84 | 'instead', 'of', 'twice', 'maiming', 'animals', 'and', 'letting', 85 | 'them', 'linger', '.'], 86 | "label": "neg"} 87 | 88 | assert len(instances) == 3 89 | fields = instances[0].fields 90 | assert [t.text for t in fields["tokens"].tokens] == instance1["tokens"] 91 | assert fields["label"].label == instance1["label"] 92 | fields = instances[1].fields 93 | assert [t.text for t in fields["tokens"].tokens] == instance2["tokens"] 94 | assert fields["label"].label == instance2["label"] 95 | fields = instances[2].fields 96 | assert [t.text for t in fields["tokens"].tokens] == instance3["tokens"] 97 | assert fields["label"].label == instance3["label"] 98 | 99 | def test_read_from_file_and_truncates_properly(self): 100 | 101 | reader = SemiSupervisedTextClassificationJsonReader(max_sequence_length=5) 102 | ag_path = self.FIXTURES_ROOT / "imdb" / "train.jsonl" 103 | instances = reader.read(ag_path) 104 | instances = ensure_list(instances) 105 | 106 | instance1 = {"tokens": ['...', 'And', 'I', 'never', 'thought'], 107 | "label": "neg"} 108 | instance2 = {"tokens": ['The', 'fight', 'scenes', 'were', 'great'], 109 | "label": "pos"} 110 | instance3 = {"tokens": ['The', 'only', 'way', 'this', 'is'], 111 | "label": "neg"} 112 | 113 | assert len(instances) == 3 114 | fields = instances[0].fields 115 | assert [t.text for t in fields["tokens"].tokens] == instance1["tokens"] 116 | assert fields["label"].label == instance1["label"] 117 | fields = instances[1].fields 118 | assert [t.text for t in fields["tokens"].tokens] == instance2["tokens"] 119 | assert fields["label"].label == instance2["label"] 120 | fields = instances[2].fields 121 | assert [t.text for t in fields["tokens"].tokens] == instance3["tokens"] 122 | assert fields["label"].label == instance3["label"] 123 | 124 | def test_samples_properly(self): 125 | reader = SemiSupervisedTextClassificationJsonReader(sample=1, max_sequence_length=5) 126 | ag_path = self.FIXTURES_ROOT / "imdb" / "train.jsonl" 127 | params = Params({"random_seed": 5, "numpy_seed": 5, "pytorch_seed": 5}) 128 | prepare_environment(params) 129 | instances = reader.read(ag_path) 130 | instances = ensure_list(instances) 131 | instance = {"tokens": ['The', 'fight', 'scenes', 'were', 'great'], 132 | "label": "pos"} 133 | assert len(instances) == 1 134 | fields = instances[0].fields 135 | assert [t.text for t in fields["tokens"].tokens] == instance["tokens"] 136 | assert fields["label"].label == instance["label"] 137 | 138 | def test_sampling_fails_when_sample_size_larger_than_file_size(self): 139 | reader = SemiSupervisedTextClassificationJsonReader(sample=10, max_sequence_length=5) 140 | ag_path = self.FIXTURES_ROOT / "imdb" / "train.jsonl" 141 | params = Params({"random_seed": 5, "numpy_seed": 5, "pytorch_seed": 5}) 142 | prepare_environment(params) 143 | self.assertRaises(ConfigurationError, reader.read, ag_path) 144 | 145 | def test_samples_according_to_seed_properly(self): 146 | 147 | reader1 = SemiSupervisedTextClassificationJsonReader(sample=2, max_sequence_length=5) 148 | reader2 = SemiSupervisedTextClassificationJsonReader(sample=2, max_sequence_length=5) 149 | reader3 = SemiSupervisedTextClassificationJsonReader(sample=2, max_sequence_length=5) 150 | 151 | imdb_path = self.FIXTURES_ROOT / "imdb" / "train.jsonl" 152 | params = Params({"random_seed": 5, "numpy_seed": 5, "pytorch_seed": 5}) 153 | prepare_environment(params) 154 | instances1 = reader1.read(imdb_path) 155 | params = Params({"random_seed": 2, "numpy_seed": 2, "pytorch_seed": 2}) 156 | prepare_environment(params) 157 | instances2 = reader2.read(imdb_path) 158 | params = Params({"random_seed": 5, "numpy_seed": 5, "pytorch_seed": 5}) 159 | prepare_environment(params) 160 | instances3 = reader3.read(imdb_path) 161 | fields1 = [i.fields for i in instances1] 162 | fields2 = [i.fields for i in instances2] 163 | fields3 = [i.fields for i in instances3] 164 | tokens1 = [f['tokens'].tokens for f in fields1] 165 | tokens2 = [f['tokens'].tokens for f in fields2] 166 | tokens3 = [f['tokens'].tokens for f in fields3] 167 | text1 = [[t.text for t in doc] for doc in tokens1] 168 | text2 = [[t.text for t in doc] for doc in tokens2] 169 | text3 = [[t.text for t in doc] for doc in tokens3] 170 | assert text1 != text2 171 | assert text1 == text3 172 | 173 | def test_ignores_label_properly(self): 174 | 175 | imdb_labeled_path = self.FIXTURES_ROOT / "imdb" / "train.jsonl" 176 | reader = SemiSupervisedTextClassificationJsonReader(ignore_labels=True) 177 | instances = reader.read(imdb_labeled_path) 178 | instances = ensure_list(instances) 179 | fields = [i.fields for i in instances] 180 | labels = [f.get('label') for f in fields] 181 | assert labels == [None] * 3 182 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------