├── .gitignore ├── README.md ├── allennlp_internal_functions.py ├── attn_removal_check.py ├── attn_tests_lib ├── __init__.py ├── classifier_from_attn_and_input_vects.py ├── conv_seq2seq_encoder.py ├── deprecated │ ├── __init__.py │ ├── han_paper_attn_function.py │ └── talkative_intra_sentence_attention.py ├── extended_bucket_iterator.py ├── extended_bucket_iterator_for_reuse.py ├── flat_attention_network.py ├── intermediate_batch_iterator.py ├── pass_through_encoder.py ├── simple_han_attn_layer.py └── talkative_simple_han_attn_layer.py ├── before_model_training ├── calculate_sentence_stats.py ├── calculate_word_stats.py ├── get_vocab_and_initialize_word2vec_embeddings_from_dataset.py ├── get_vocab_and_initialize_word2vec_embeddings_from_dataset.sh ├── make_datasets.py └── subset_data.py ├── configs ├── amazon_flan_no_encoders.jsonnet ├── amazon_flan_with_convs.jsonnet ├── amazon_flan_with_rnns-actuallyflanconv.jsonnet ├── amazon_flan_with_rnns.jsonnet ├── amazon_han_from_paper-1.jsonnet ├── amazon_han_from_paper-2.jsonnet ├── amazon_han_from_paper-3.jsonnet ├── amazon_han_from_paper-4.jsonnet ├── amazon_han_from_paper-5.jsonnet ├── amazon_han_from_paper-6.jsonnet ├── amazon_han_from_paper-actuallyhanconv.jsonnet ├── amazon_han_from_paper.jsonnet ├── amazon_han_no_encoders.jsonnet ├── amazon_han_with_convs.jsonnet ├── imdb_flan_no_encoders.jsonnet ├── imdb_flan_with_convs.jsonnet ├── imdb_flan_with_rnns.jsonnet ├── imdb_han_from_paper.jsonnet ├── imdb_han_no_encoders.jsonnet ├── imdb_han_with_convs.jsonnet ├── old_configs │ ├── imdb_han_from_paper_ORIGINAL.jsonnet │ └── yelp_han_from_paper_ATTNPERF.jsonnet ├── whateverDatasetYouHaveInMind_sample_han.jsonnet ├── yahoo10cat_flan_no_encoders.jsonnet ├── yahoo10cat_flan_with_convs.jsonnet ├── yahoo10cat_flan_with_rnns.jsonnet ├── yahoo10cat_han_from_paper.jsonnet ├── yahoo10cat_han_no_encoders.jsonnet ├── yahoo10cat_han_with_convs.jsonnet ├── yelp_flan_no_encoders.jsonnet ├── yelp_flan_with_convs.jsonnet ├── yelp_flan_with_rnns-actuallyflanconv.jsonnet ├── yelp_flan_with_rnns.jsonnet ├── yelp_han_from_paper-1.jsonnet ├── yelp_han_from_paper-2.jsonnet ├── yelp_han_from_paper-3.jsonnet ├── yelp_han_from_paper-4.jsonnet ├── yelp_han_from_paper-5.jsonnet ├── yelp_han_from_paper-5evensmallerstep.jsonnet ├── yelp_han_from_paper-5smallerstep.jsonnet ├── yelp_han_from_paper-6.jsonnet ├── yelp_han_from_paper.jsonnet ├── yelp_han_no_encoders.jsonnet └── yelp_han_with_convs.jsonnet ├── data └── README.md ├── debugging.py ├── default_directories.py ├── figure_making ├── figure_maker.py ├── figure_maker_single_dataset.py └── table_maker.py ├── misc_scripts ├── get_attnperf_overlap.py ├── hyperparam_search.py ├── make_attnlabel1_dist_hist.py ├── make_rand_subset_data_file.py ├── make_subset_test_file.py ├── move_files_out_pre_test.py └── prob_predictor.py ├── plain_model_test.py ├── process_test_outputs.py ├── test_model.py ├── textcat ├── __init__.py ├── hierarchical_attention_network.py ├── sentence_splitter.py ├── sentence_tokenizer.py ├── textcat_reader.py └── textcat_reader_attnlabel.py └── train_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | *~ 3 | __pycache__/ 4 | .DS_Store 5 | */.DS_Store 6 | *.DS_Store 7 | imgs/ 8 | data/ 9 | attn-test-output/ 10 | models/ 11 | generated_tex_files/ 12 | vocabs/ 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Attention Tests for Text Classification 2 | 3 | Required to run code: 4 | 5 | - python3 (I used python 3.6.3) 6 | - allennlp and all its dependencies (works with both 0.7.1 and 0.8.3) 7 | - numpy 8 | - scipy 9 | - matplotlib 10 | - seaborn 11 | - pytorch 12 | - tqdm 13 | - statsmodels 14 | - pandas 15 | - gensim 16 | 17 | This repo contains code to train and test the attention mechanisms in various text classification models for different 18 | interpretability-related criteria. 19 | 20 | When you download the repo, open 21 | `default_directories.py` and set the directories in there to the 22 | locations you'd like to use. With the exception of `base_data_dir` and 23 | `dir_with_config_files`, all other directories will be created by 24 | the scripts that use them when necessary; those two are the only ones 25 | that need to be set up separately. 26 | 27 | Suppose that you wanted to train and test a new model from scratch. 28 | Here's the pipeline you'd need to follow: 29 | 30 | ## Preparing the dataset 31 | 32 | If this is the first time you've used this dataset in this 33 | repository, there are a few preprocessing steps that need to happen 34 | first. (If you're training a model on a dataset you've already 35 | set up for use in this repository, you can skip this section.) 36 | 37 | First, your data needs to be in a format that the code is 38 | designed to work with. The code expects each file of data to 39 | have one document's information per line, with the different 40 | fields tab-separated. (Therefore, all tabs or newlines in the 41 | original documents need to be replaced as part of the preprocessing.) 42 | The only two fields that the code counts on (though there can 43 | be more) are called `category` and `tokens`, and can appear in 44 | any order (there can also be other fields in the data files; they 45 | will just be ignored). `category`, which is the label for the 46 | document, can be an arbitrary string; it will be automatically 47 | mapped to an int later. `tokens` is just the text of the document 48 | without newlines or tabs. 49 | 50 | To get your data into that format, you can either process it yourself, 51 | or reuse some of the code that I wrote in `make_datasets.py` in 52 | `before_model_training/`. 53 | (I haven't gone back and looked at that code in a long time, 54 | though, so use at your own risk.) 55 | 56 | Then split up your data into train, dev, and test as three 57 | separate files, each with the same format described above. 58 | (I think I used `subset_data.py` in `before_model_training/` 59 | for this, but again, it's been 60 | a while; you might be better off just doing this on your own.) 61 | 62 | Once your data is in the correct format, you need to make 63 | pretrained embeddings and a corresponding allennlp vocabulary 64 | for your dataset. 65 | In `before_model_training/`, open 66 | `get_vocab_and_initialize_word2vec_embeddings_from_dataset.py` 67 | and edit the parameters in lines 23-29 to the desired 68 | values. Then, run 69 | `get_vocab_and_initialize_word2vec_embeddings_from_dataset.sh` 70 | (the bash script, not the python script). We run the bash 71 | script because it splits up the python script into three 72 | separate stages, which, if run as part of the same call to the 73 | python script, crash for memory-related reasons on larger 74 | datasets. 75 | 76 | ## Training a model 77 | 78 | Say you wanted to train a HAN with an RNN encoder on the IMDB 79 | dataset. Then, provided you were on a machine with gpus and 80 | wanted to use gpu 0, you could run the command 81 | 82 | ``` 83 | python3 train_model.py --model hanrnn --dataset-name imdb --gpu 0 84 | ``` 85 | 86 | Options for different models are listed in line 759 of `train_model.py`. 87 | (If you want to run off a gpu, supply -1 as the value for `--gpu`.) The `--model` and `--dataset-name` parameters are used to fetch 88 | the correct config file from the configs directory specified in 89 | `default_directories.py`. Let `{file_ending}` be the file ending 90 | that `--model` (in this case, `hanrnn`) maps to in `corresponding_config_files` (line 30 of `train_model.py`). 91 | Then the config filename that `train_model.py` will look for will be 92 | ``` 93 | {dir_with_config_files}/{--dataset-name}_{file ending} 94 | ``` 95 | 96 | That's the config file that you need to set up; see example config 97 | files in `configs/` for reference. 98 | 99 | More optional parameters are listed in lines 755-781 of `train_model.py`. 100 | A particularly useful one is `--optional-model-tag`; allennlp 101 | doesn't allow a model folder to be overwritten, so if you had 102 | already trained a HANrnn for the IMDB dataset and wanted to train 103 | another one, you could add an optional model tag that `train_model.py` 104 | would append to the new model's specific directory, thus 105 | differentiating it from the old model's directory. 106 | 107 | ## Extracting information about attention from that model 108 | 109 | Once you've got a trained model that you'd like to analyze, 110 | the next step is to run a lot of tests on it and write the 111 | results to a bunch of .csv files in a test-result directory created for this model. This is handled automatically 112 | by `test_model.py`. For the example that we described above 113 | with an IMDB HANrnn, the generated model directory would be 114 | `{base_serialized_models_dir}/imdb-hanrnn`. Assuming you wanted 115 | to get information about how this model's attention works 116 | on all the instances in the data file `{base_data_dir}/imdb_test.tsv`, 117 | you could run the command 118 | 119 | ``` 120 | python3 test_model.py --model-folder-name imdb-hanrnn --test-data-file imdb_test.tsv --gpu 0 --optional-folder-tag testdata 121 | ``` 122 | 123 | (`test_model.py` reads in the same parent directory names as `train_model.py` does, 124 | so those aren't provided in the command.) This will write a bunch 125 | of .csv files to a directory (that this script creates) called 126 | `{base_output_dir}/imdb-hanrnn-testdata/`. 127 | Once again, there are optional 128 | commands listed in lines 1481-1506 of `test_model.py`. 129 | `--optional-folder-tag` is what allows you to have separate test 130 | result directories for the same model on, say, both its test and dev 131 | sets. (If we'd left it off, our results directory would have been `{base_output_dir}/imdb-hanrnn/`.) 132 | 133 | ## Exploring the test results 134 | 135 | After all the .csv files produced in the previous step are created, 136 | you can either work with them yourself, or use scripts provided 137 | here to analyze them. 138 | 139 | If you want a summary of some test statistics in text form, 140 | `process_test_outputs.py` is the file you want. To run it for our 141 | IMDB HANrnn, we would run 142 | 143 | ``` 144 | python3 process_test_outputs.py --model-folder-name imdb-hanrnn-testdata 145 | ``` 146 | 147 | Optional arguments are listed in lines 986-997 of `process_test_outputs.py`. 148 | 149 | If we instead want to generate figures, you can use `figure_making/figure_maker.py` to 150 | generate figures. This script is MUCH more rough around the edges-- 151 | there's a lot of stuff in it that you'd probably need to modify for 152 | use on your generated test results (hard-coded model tags, the expectation 153 | that one of each model has been generated, etc.). But it's here in case it's 154 | helpful. 155 | 156 | The same goes for `figure_making/table_maker.py`, which generates a bunch 157 | of LaTeX tables looking at differences in single-weight 158 | decision flips; it's probably not going to be useful outside of 159 | the specific setting I used it in. 160 | -------------------------------------------------------------------------------- /allennlp_internal_functions.py: -------------------------------------------------------------------------------- 1 | from allennlp.data.dataset_readers import DatasetReader 2 | from allennlp.data import Instance 3 | from allennlp.common.params import Params 4 | from allennlp.common.tee_logger import TeeLogger 5 | from typing import Dict, Any, Iterable, List 6 | import json 7 | import logging 8 | import torch 9 | import sys 10 | 11 | 12 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 13 | 14 | 15 | def dump_metrics(file_path: str, metrics: Dict[str, Any], log: bool = False) -> None: 16 | metrics_json = json.dumps(metrics, indent=2) 17 | with open(file_path, "w") as metrics_file: 18 | metrics_file.write(metrics_json) 19 | if log: 20 | logger.info("Metrics: %s", metrics_json) 21 | 22 | 23 | def datasets_from_params(params: Params) -> Dict[str, Iterable[Instance]]: 24 | """ 25 | Load all the datasets specified by the config. 26 | """ 27 | dataset_reader = DatasetReader.from_params(params.pop('dataset_reader')) 28 | validation_dataset_reader_params = params.pop("validation_dataset_reader", None) 29 | 30 | validation_and_test_dataset_reader: DatasetReader = dataset_reader 31 | if validation_dataset_reader_params is not None: 32 | logger.info("Using a separate dataset reader to load validation and test data.") 33 | validation_and_test_dataset_reader = DatasetReader.from_params(validation_dataset_reader_params) 34 | 35 | train_data_path = params.pop('train_data_path') 36 | logger.info("Reading training data from %s", train_data_path) 37 | train_data = dataset_reader.read(train_data_path) 38 | 39 | datasets: Dict[str, Iterable[Instance]] = {"train": train_data} 40 | 41 | validation_data_path = params.pop('validation_data_path', None) 42 | if validation_data_path is not None: 43 | logger.info("Reading validation data from %s", validation_data_path) 44 | validation_data = validation_and_test_dataset_reader.read(validation_data_path) 45 | datasets["validation"] = validation_data 46 | 47 | test_data_path = params.pop("test_data_path", None) 48 | if test_data_path is not None: 49 | logger.info("Reading test data from %s", test_data_path) 50 | test_data = validation_and_test_dataset_reader.read(test_data_path) 51 | datasets["test"] = test_data 52 | 53 | return datasets 54 | 55 | 56 | def get_frozen_and_tunable_parameter_names(model: torch.nn.Module) -> List: 57 | frozen_parameter_names = [] 58 | tunable_parameter_names = [] 59 | for name, parameter in model.named_parameters(): 60 | if not parameter.requires_grad: 61 | frozen_parameter_names.append(name) 62 | else: 63 | tunable_parameter_names.append(name) 64 | return [frozen_parameter_names, tunable_parameter_names] 65 | 66 | 67 | def cleanup_global_logging(stdout_handler: logging.FileHandler) -> None: 68 | """ 69 | This function closes any open file handles and logs set up by `prepare_global_logging`. 70 | Parameters 71 | ---------- 72 | stdout_handler : ``logging.FileHandler``, required. 73 | The file handler returned from `prepare_global_logging`, attached to the global logger. 74 | """ 75 | stdout_handler.close() 76 | logging.getLogger().removeHandler(stdout_handler) 77 | 78 | if isinstance(sys.stdout, TeeLogger): 79 | sys.stdout = sys.stdout.cleanup() 80 | if isinstance(sys.stderr, TeeLogger): 81 | sys.stderr = sys.stderr.cleanup() 82 | -------------------------------------------------------------------------------- /attn_tests_lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_seq2seq_encoder import ConvSeq2SeqEncoder 2 | from .classifier_from_attn_and_input_vects import \ 3 | ClassifierFromAttnAndInputVects, GradientReportingClassifierFromAttnAndInputVects 4 | from .extended_bucket_iterator import ExtendedBucketIterator 5 | from .extended_bucket_iterator_for_reuse import ExtendedBucketIteratorForReuse 6 | from .intermediate_batch_iterator import IntermediateBatchIterator 7 | from .intermediate_batch_iterator import AttentionIterator 8 | from .intermediate_batch_iterator import GradientsIterator 9 | from .intermediate_batch_iterator import load_attn_dists 10 | from .intermediate_batch_iterator import load_log_unnormalized_attn_dists 11 | from .simple_han_attn_layer import SimpleHanAttention 12 | from .talkative_simple_han_attn_layer import TalkativeSimpleHanAttention 13 | from .flat_attention_network import FlatAttentionNetwork 14 | from .pass_through_encoder import PassThroughSeq2SeqEncoder 15 | -------------------------------------------------------------------------------- /attn_tests_lib/classifier_from_attn_and_input_vects.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from allennlp.nn import util 4 | import pickle 5 | from allennlp.modules import FeedForward 6 | from overrides import overrides 7 | 8 | 9 | class ClassifierFromAttnAndInputVects(torch.nn.Module): 10 | def __init__(self, 11 | classification_module: FeedForward) -> None: 12 | super(ClassifierFromAttnAndInputVects, self).__init__() 13 | self._classification_module = classification_module 14 | 15 | def forward(self, 16 | input_vects: torch.Tensor, 17 | intra_sentence_attention: torch.Tensor)-> Dict[str, torch.Tensor]: 18 | # Shape: (batch_size, sequence_length, projection_dim) 19 | batch_size = input_vects.size(0) 20 | sequence_length = input_vects.size(1) 21 | output_token_representation = input_vects 22 | attn_weights = intra_sentence_attention 23 | 24 | attn_weights = attn_weights.unsqueeze(2).expand(batch_size, sequence_length, 25 | output_token_representation.size(2)) 26 | # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads]) 27 | correct = (((attn_weights[:, :, 0].sum(dim=1) > .98) & (attn_weights[:, :, 0].sum(dim=1) < 1.02)) | 28 | (attn_weights[:, :, 0].sum(dim=1) == 0)) 29 | assert torch.sum(correct.float()) == batch_size, \ 30 | (str(attn_weights[(((attn_weights[:, :, 0].sum(dim=1) <= .98) | 31 | (attn_weights[:, :, 0].sum(dim=1) >= 1.02)) & 32 | (attn_weights[:, :, 0].sum(dim=1) != 0))]) + "\n" + 33 | str(torch.sum(attn_weights, dim=1)[(((attn_weights[:, :, 0].sum(dim=1) <= .98) | 34 | (attn_weights[:, :, 0].sum(dim=1) >= 1.02)) & 35 | (attn_weights[:, :, 0].sum(dim=1) != 0))])) 36 | combined_tensors = output_token_representation * attn_weights 37 | 38 | document_repr = torch.sum(combined_tensors, 1) 39 | 40 | label_logits = self._classification_module(document_repr.view(batch_size, -1)) 41 | label_probs = torch.nn.functional.softmax(label_logits, dim=-1) 42 | 43 | output_dict = {"label_logits": label_logits, "label_probs": label_probs} 44 | 45 | return output_dict 46 | 47 | 48 | def backward_gradient_reporting_template(grad_input, filename): 49 | """ 50 | Wrap this in a lambda function providing the filename when registering it as a backwards hook 51 | :param input: 52 | :param filename: 53 | :return: 54 | """ 55 | tensors_to_cat = [grad_input[j].view(1, -1) for j in range(len(grad_input))] 56 | with open(filename, 'wb') as f: 57 | pickle.dump(torch.cat(tensors_to_cat, dim=0).cpu(), f) 58 | 59 | 60 | class GradientReportingClassifierFromAttnAndInputVects(torch.nn.Module): 61 | def __init__(self, 62 | classification_module: torch.nn.Module, 63 | temp_filename: str = None) -> None: 64 | super(GradientReportingClassifierFromAttnAndInputVects, self).__init__() 65 | self._classification_module = classification_module 66 | assert self._classification_module.__class__.__name__ == 'FeedForward', \ 67 | ("GradientReportingClassifierFromAttnAndInputVects currently assumes a feedforward output classifier " + 68 | "for dropout-zeroing purposes, but given output classifier type was " + 69 | self._classification_module.__class__.__name__) 70 | dropout_module_list = self._classification_module._dropout 71 | # set p in all dropouts to 0 72 | for i in range(len(dropout_module_list)): 73 | modified_dropout = dropout_module_list.__getitem__(i) 74 | modified_dropout.p = 0.0 75 | dropout_module_list.__setitem__(i, modified_dropout) 76 | self._temp_filename = temp_filename 77 | 78 | def set_temp_filename(self, fname): 79 | self._temp_filename = fname 80 | 81 | def forward(self, 82 | input_vects: torch.Tensor, 83 | intra_sentence_attention: torch.Tensor)-> Dict[str, torch.Tensor]: 84 | # Shape: (batch_size, sequence_length, projection_dim) 85 | batch_size = input_vects.size(0) 86 | sequence_length = input_vects.size(1) 87 | output_token_representation = input_vects 88 | attn_weights = intra_sentence_attention 89 | 90 | attn_weights.register_hook(lambda grad: backward_gradient_reporting_template(grad, self._temp_filename)) 91 | 92 | attn_weights = attn_weights.unsqueeze(2).expand(batch_size, sequence_length, 93 | output_token_representation.size(2)) 94 | # Shape: (batch_size, sequence_length, [num_heads,] projection_dim [/ num_heads]) 95 | 96 | correct = (((attn_weights[:, :, 0].sum(dim=1) > .98) & (attn_weights[:, :, 0].sum(dim=1) < 1.02)) | 97 | (attn_weights[:, :, 0].sum(dim=1) == 0)) 98 | assert torch.sum(correct.float()) == batch_size, \ 99 | (str(attn_weights[(((attn_weights[:, :, 0].sum(dim=1) <= .98) | 100 | (attn_weights[:, :, 0].sum(dim=1) >= 1.02)) & 101 | (attn_weights[:, :, 0].sum(dim=1) != 0))]) + "\n" + 102 | str(torch.sum(attn_weights, dim=1)[(((attn_weights[:, :, 0].sum(dim=1) <= .98) | 103 | (attn_weights[:, :, 0].sum(dim=1) >= 1.02)) & 104 | (attn_weights[:, :, 0].sum(dim=1) != 0))])) 105 | 106 | combined_tensors = output_token_representation * attn_weights 107 | 108 | document_repr = torch.sum(combined_tensors, 1) 109 | 110 | label_logits = self._classification_module(document_repr.view(batch_size, -1)) 111 | label_probs = torch.nn.functional.softmax(label_logits, dim=-1) 112 | 113 | output_dict = {"label_logits": label_logits, "label_probs": label_probs} 114 | 115 | return output_dict 116 | -------------------------------------------------------------------------------- /attn_tests_lib/deprecated/__init__.py: -------------------------------------------------------------------------------- 1 | from .han_paper_attn_function import HanPaperSimilarityFunction 2 | from .talkative_intra_sentence_attention import TalkativeIntraSentenceAttentionEncoder -------------------------------------------------------------------------------- /attn_tests_lib/deprecated/han_paper_attn_function.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | import torch 3 | 4 | from allennlp.modules.similarity_functions.similarity_function import SimilarityFunction 5 | 6 | 7 | #@SimilarityFunction.register("han_paper") 8 | class HanPaperSimilarityFunction(SimilarityFunction): 9 | def __init__(self, 10 | input_dim: int, 11 | context_vect_dim: int) -> None: 12 | super(HanPaperSimilarityFunction, self).__init__() 13 | self._mlp = torch.nn.Linear(input_dim, context_vect_dim, bias=True) 14 | self._context_dot_product = torch.nn.Linear(context_vect_dim, 1, bias=False) 15 | 16 | @overrides 17 | def forward(self, tensor_1: torch.Tensor, tensor_2: torch.Tensor) -> torch.Tensor: 18 | # un-expand tensor_1, which is the only one we'll use 19 | tensor_1 = tensor_1[:, :, 0, :].view(tensor_1.size(0), tensor_1.size(1), tensor_1.size(3)).contiguous() 20 | 21 | # new shape: batch_size x seq_len x embedding_dim 22 | batch_size = tensor_1.size(0) 23 | tensor_1 = tensor_1.view(batch_size * tensor_1.size(1), tensor_1.size(2)) 24 | tensor_1 = torch.tanh(self._mlp(tensor_1)) 25 | tensor_1 = self._context_dot_product(tensor_1) 26 | tensor_1 = tensor_1.view(batch_size, -1) # batch_size x seq_len 27 | tensor_1 = tensor_1.unsqueeze(1).expand(batch_size, tensor_1.size(1), tensor_1.size(1)) 28 | return tensor_1 29 | -------------------------------------------------------------------------------- /attn_tests_lib/extended_bucket_iterator.py: -------------------------------------------------------------------------------- 1 | import random 2 | from overrides import overrides 3 | from typing import List, Tuple, Iterable, cast, Dict 4 | from allennlp.data.instance import Instance 5 | from allennlp.data.dataset import Batch 6 | from allennlp.common.util import lazy_groups_of, add_noise_to_dict_values 7 | from allennlp.data.iterators.bucket_iterator import BucketIterator 8 | from allennlp.data.iterators.data_iterator import DataIterator 9 | from allennlp.data.vocabulary import Vocabulary 10 | import math 11 | 12 | 13 | def sort_by_padding_modified(instances: List[Instance], 14 | sorting_keys: List[Tuple[str, str]], # pylint: disable=invalid-sequence-index 15 | vocab: Vocabulary, 16 | padding_noise: float = 0.0) -> List[Instance]: 17 | """ 18 | Sorts the instances by their padding lengths, using the keys in 19 | ``sorting_keys`` (in the order in which they are provided). ``sorting_keys`` is a list of 20 | ``(field_name, padding_key)`` tuples. 21 | """ 22 | instances_with_lengths = [] 23 | for instance in instances: 24 | # Make sure instance is indexed before calling .get_padding 25 | instance.index_fields(vocab) 26 | padding_lengths = instance.get_padding_lengths() 27 | padding_lengths["sentences"] = {"num_sentences": len(instance.fields['tokens'].field_list)} 28 | padding_lengths = cast(Dict[str, Dict[str, float]], padding_lengths) 29 | if padding_noise > 0.0: 30 | noisy_lengths = {} 31 | for field_name, field_lengths in padding_lengths.items(): 32 | noisy_lengths[field_name] = add_noise_to_dict_values(field_lengths, padding_noise) 33 | padding_lengths = noisy_lengths 34 | instance_with_lengths = ([padding_lengths[field_name][padding_key] 35 | for (field_name, padding_key) in sorting_keys], 36 | instance) 37 | instances_with_lengths.append(instance_with_lengths) 38 | instances_with_lengths.sort(key=lambda x: x[0]) 39 | return [instance_with_lengths[-1] for instance_with_lengths in instances_with_lengths] 40 | 41 | 42 | @DataIterator.register("extended_bucket") 43 | class ExtendedBucketIterator(BucketIterator): 44 | def __init__(self, 45 | sorting_keys: List[Tuple[str, str]], 46 | padding_noise: float = 0.1, 47 | biggest_batch_first: bool = False, 48 | batch_size: int = 32, 49 | instances_per_epoch: int = None, 50 | max_instances_in_memory: int = None, 51 | cache_instances: bool = False, 52 | track_epoch: bool = False, 53 | maximum_samples_per_batch: Tuple[str, int] = None) -> None: 54 | super().__init__(sorting_keys, padding_noise=padding_noise, biggest_batch_first=biggest_batch_first, 55 | batch_size=batch_size, instances_per_epoch=instances_per_epoch, 56 | max_instances_in_memory=max_instances_in_memory, cache_instances=cache_instances, 57 | track_epoch=track_epoch, maximum_samples_per_batch=maximum_samples_per_batch) 58 | # look out for [sentences, num_sentences] 59 | self._change_create_batches = False 60 | for key in sorting_keys: 61 | if key[0] == "sentences": 62 | self._change_create_batches = True 63 | 64 | @overrides 65 | def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: 66 | if not self._change_create_batches: 67 | for ret_val in iter(super()._create_batches(instances, shuffle)): 68 | yield ret_val 69 | else: 70 | for ret_val in iter(self._modified_create_batches(instances, shuffle)): 71 | yield ret_val 72 | 73 | def old_ensure_batch_is_sufficiently_small(self, batch_instances: Iterable[Instance]) -> List[List[Instance]]: 74 | """ 75 | If self._maximum_samples_per_batch is specified, then split the batch into smaller 76 | sub-batches if it exceeds the maximum size. 77 | """ 78 | if self._maximum_samples_per_batch is None: 79 | return [list(batch_instances)] 80 | 81 | # check if we need to break into smaller chunks 82 | key, limit = self._maximum_samples_per_batch 83 | padding_length = -1 84 | list_batch_instances = list(batch_instances) 85 | for instance in list_batch_instances: 86 | if self.vocab is not None: 87 | # we index here to ensure that shape information is available, 88 | # as in some cases (with self._maximum_samples_per_batch) 89 | # we need access to shaping information before batches are constructed) 90 | instance.index_fields(self.vocab) 91 | field_lengths = instance.get_padding_lengths() 92 | for _, lengths in field_lengths.items(): 93 | try: 94 | padding_length = max(padding_length, 95 | lengths[key]) 96 | except KeyError: 97 | pass 98 | 99 | if padding_length * len(list_batch_instances) > limit: 100 | # need to shrink 101 | num_samples = padding_length * len(list_batch_instances) 102 | num_shrunk_batches = math.ceil(num_samples / float(limit)) 103 | shrunk_batch_size = math.ceil(len(list_batch_instances) / num_shrunk_batches) 104 | shrunk_batches = [] 105 | start = 0 106 | while start < len(list_batch_instances): 107 | end = start + shrunk_batch_size 108 | shrunk_batches.append(list_batch_instances[start:end]) 109 | start = end 110 | return shrunk_batches 111 | else: 112 | return [list_batch_instances] 113 | 114 | def _modified_create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]: 115 | for instance_list in self._memory_sized_lists(instances): 116 | 117 | instance_list = sort_by_padding_modified(instance_list, 118 | self._sorting_keys, 119 | self.vocab, 120 | self._padding_noise) 121 | 122 | batches = [] 123 | for batch_instances in lazy_groups_of(iter(instance_list), self._batch_size): 124 | for possibly_smaller_batches in self.old_ensure_batch_is_sufficiently_small(batch_instances): 125 | batches.append(Batch(possibly_smaller_batches)) 126 | 127 | move_to_front = self._biggest_batch_first and len(batches) > 1 128 | if move_to_front: 129 | # We'll actually pop the last _two_ batches, because the last one might not be full. 130 | last_batch = batches.pop() 131 | penultimate_batch = batches.pop() 132 | if shuffle: 133 | # NOTE: if shuffle is false, the data will still be in a different order 134 | # because of the bucket sorting. 135 | random.shuffle(batches) 136 | if move_to_front: 137 | batches.insert(0, penultimate_batch) 138 | batches.insert(0, last_batch) 139 | 140 | yield from batches 141 | -------------------------------------------------------------------------------- /attn_tests_lib/pass_through_encoder.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | import torch 3 | from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder 4 | 5 | @Seq2SeqEncoder.register("pass_through_encoder") 6 | class PassThroughSeq2SeqEncoder(Seq2SeqEncoder): 7 | def __init__(self, 8 | input_size: int, 9 | hidden_size: int) -> None: 10 | super(PassThroughSeq2SeqEncoder, self).__init__() 11 | self.input_size = input_size 12 | self.hidden_size = hidden_size 13 | 14 | @overrides 15 | def forward(self, # pylint: disable=arguments-differ 16 | inputs: torch.Tensor, # not packed 17 | mask: torch.Tensor, 18 | hidden_state: torch.Tensor = None) -> torch.Tensor: 19 | # assume batch is first 20 | return inputs 21 | 22 | @overrides 23 | def is_bidirectional(self) -> bool: 24 | return False 25 | 26 | @overrides 27 | def get_input_dim(self) -> int: 28 | return self.input_size 29 | 30 | @overrides 31 | def get_output_dim(self) -> int: 32 | return self.hidden_size -------------------------------------------------------------------------------- /attn_tests_lib/simple_han_attn_layer.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder 3 | from allennlp.modules.seq2seq_encoders.intra_sentence_attention import IntraSentenceAttentionEncoder 4 | from allennlp.nn import util 5 | import torch 6 | import os 7 | import numpy as np 8 | 9 | 10 | @Seq2SeqEncoder.register("simple_han_attention") 11 | class SimpleHanAttention(Seq2SeqEncoder): 12 | def __init__(self, 13 | input_dim : int = None, 14 | context_vector_dim: int = None) -> None: 15 | super(SimpleHanAttention, self).__init__() 16 | self._mlp = torch.nn.Linear(input_dim, context_vector_dim, bias=True) 17 | self._context_dot_product = torch.nn.Linear(context_vector_dim, 1, bias=False) 18 | self.vec_dim = self._mlp.weight.size(1) 19 | 20 | @overrides 21 | def get_input_dim(self) -> int: 22 | return self.vec_dim 23 | 24 | @overrides 25 | def get_output_dim(self) -> int: 26 | return self.vec_dim 27 | 28 | @overrides 29 | def is_bidirectional(self): 30 | return False 31 | 32 | @overrides 33 | def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ 34 | assert mask is not None 35 | batch_size, sequence_length, embedding_dim = tokens.size() 36 | 37 | attn_weights = tokens.view(batch_size * sequence_length, embedding_dim) 38 | attn_weights = torch.tanh(self._mlp(attn_weights)) 39 | attn_weights = self._context_dot_product(attn_weights) 40 | attn_weights = attn_weights.view(batch_size, -1) # batch_size x seq_len 41 | attn_weights = util.masked_softmax(attn_weights, mask) 42 | attn_weights = attn_weights.unsqueeze(2).expand(batch_size, sequence_length, embedding_dim) 43 | 44 | return tokens * attn_weights 45 | -------------------------------------------------------------------------------- /attn_tests_lib/talkative_simple_han_attn_layer.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder 3 | from allennlp.modules.seq2seq_encoders.intra_sentence_attention import IntraSentenceAttentionEncoder 4 | from allennlp.nn import util 5 | import torch 6 | import os 7 | import numpy as np 8 | from attn_tests_lib import SimpleHanAttention 9 | 10 | 11 | def binary_search_for_num_non_padded_tokens_in_instance(full_array, row_ind): 12 | assert full_array.dim() == 2 13 | open_for_checking_start = 0 14 | open_for_checking_endplus1 = full_array.size(1) 15 | look_at = (open_for_checking_start + open_for_checking_endplus1) // 2 16 | first_zero_ind_identified = None 17 | while first_zero_ind_identified is None: 18 | if full_array[row_ind, look_at] != 0: 19 | open_for_checking_start = look_at + 1 20 | else: 21 | if full_array[row_ind, look_at - 1] != 0: 22 | first_zero_ind_identified = look_at 23 | else: 24 | open_for_checking_endplus1 = look_at 25 | if open_for_checking_start == open_for_checking_endplus1: 26 | assert open_for_checking_endplus1 == full_array.size(1) 27 | first_zero_ind_identified = full_array.size(1) 28 | else: 29 | look_at = (open_for_checking_start + open_for_checking_endplus1) // 2 30 | return first_zero_ind_identified 31 | 32 | 33 | @Seq2SeqEncoder.register("talkative_simple_han_attention") 34 | class TalkativeSimpleHanAttention(Seq2SeqEncoder): 35 | def __init__(self, 36 | attn_params: SimpleHanAttention, 37 | attn_weight_filename, 38 | corr_vector_dir, 39 | total_num_test_instances) -> None: 40 | super(TalkativeSimpleHanAttention, self).__init__() 41 | self._mlp = attn_params._mlp 42 | self._context_dot_product = attn_params._context_dot_product 43 | self.vec_dim = self._mlp.weight.size(1) 44 | self._total_num_test_instances = total_num_test_instances 45 | self._attn_weight_filename = attn_weight_filename 46 | self._input_vector_dir_name = corr_vector_dir 47 | if not self._input_vector_dir_name.endswith('/'): 48 | self._input_vector_dir_name += '/' 49 | self._next_available_counter_ind_file = self._input_vector_dir_name + "next_available_counter.txt" 50 | 51 | @overrides 52 | def get_input_dim(self) -> int: 53 | return self.vec_dim 54 | 55 | @overrides 56 | def get_output_dim(self) -> int: 57 | return self.vec_dim 58 | 59 | @overrides 60 | def is_bidirectional(self): 61 | return False 62 | 63 | @overrides 64 | def forward(self, tokens: torch.Tensor, mask: torch.Tensor): # pylint: disable=arguments-differ 65 | assert mask is not None 66 | batch_size, sequence_length, embedding_dim = tokens.size() 67 | 68 | attn_weights = tokens.view(batch_size * sequence_length, embedding_dim) 69 | attn_weights = torch.tanh(self._mlp(attn_weights)) 70 | attn_weights = self._context_dot_product(attn_weights) 71 | attn_weights = attn_weights.view(batch_size, -1) # batch_size x seq_len 72 | 73 | self.report_unnormalized_log_attn_weights(mask * attn_weights, tokens, mask) 74 | 75 | attn_weights = util.masked_softmax(attn_weights, mask) 76 | attn_weights = attn_weights.unsqueeze(2).expand(batch_size, sequence_length, embedding_dim) 77 | 78 | return tokens * attn_weights 79 | 80 | def report_unnormalized_log_attn_weights(self, attn_weights, input_vects, mask): 81 | assert attn_weights.dim() == 2, \ 82 | ("Size of attn weights (" + str(attn_weights.size()) + ") indicates multiheaded attention, but " + 83 | "TalkativeSimpleHanAttention currently assumes single-head attention.") 84 | if not os.path.isfile(self._attn_weight_filename): 85 | next_available_ind = 1 86 | if not os.path.isdir(self._input_vector_dir_name): 87 | os.makedirs(self._input_vector_dir_name) 88 | else: 89 | if len(os.listdir(self._input_vector_dir_name)) != 0: 90 | print("ERROR: couldn't find file " + str(self._attn_weight_filename) + ", but " + 91 | self._input_vector_dir_name + " exists and is nonempty.") 92 | exit(1) 93 | else: 94 | assert os.path.isfile(self._next_available_counter_ind_file) 95 | with open(self._next_available_counter_ind_file, 'r') as f: 96 | next_available_ind = int(f.readline()) 97 | assert next_available_ind <= self._total_num_test_instances, \ 98 | "Looks like you're overwriting previously saved results." 99 | input_vects_filename = (self._input_vector_dir_name + str(next_available_ind) + '-' + 100 | str(next_available_ind + attn_weights.size(0))) 101 | np.save(input_vects_filename, input_vects.data.cpu().numpy()) 102 | with open(self._attn_weight_filename, 'a') as f: 103 | for i in range(attn_weights.size(0)): 104 | f.write(str(next_available_ind) + ": ") 105 | num_nonpadding_pieces_in_row = \ 106 | binary_search_for_num_non_padded_tokens_in_instance(mask, i) 107 | assert num_nonpadding_pieces_in_row > 0, str(mask) 108 | weights_to_write = [float(attn_weights[i, j]) 109 | for j in range(num_nonpadding_pieces_in_row)] 110 | min_val = min(weights_to_write) 111 | max_val = max(weights_to_write) 112 | val_to_subtract = ((max_val - min_val) / 2) + min_val 113 | f.write(str(" ".join([str(w - val_to_subtract) for w in weights_to_write]))) 114 | f.write('\n') 115 | next_available_ind += 1 116 | with open(self._next_available_counter_ind_file, 'w') as f: 117 | f.write(str(next_available_ind)) 118 | -------------------------------------------------------------------------------- /before_model_training/calculate_sentence_stats.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tqdm import tqdm 4 | from allennlp.data.dataset_readers.text_classification.textcat import TextCatReader 5 | from allennlp.data.tokenizers import WordTokenizer 6 | from allennlp.data.tokenizers.word_filter import PassThroughWordFilter 7 | import numpy as np 8 | 9 | filenames_to_use = [] 10 | for i in range(1, len(sys.argv)): 11 | assert os.path.isfile(sys.argv[i]) 12 | filenames_to_use.append(sys.argv[i]) 13 | 14 | write_to_filename = filenames_to_use[0] 15 | write_to_filename = write_to_filename[:write_to_filename.rfind('_')] 16 | if '/' in write_to_filename: 17 | write_to_filename = write_to_filename[write_to_filename.rfind('/') + 1:] 18 | write_to_filename = write_to_filename + "_sentstats.txt" 19 | print("Will write results to " + write_to_filename) 20 | 21 | def get_nth_field_in_line(line, ind): 22 | counter = 0 23 | while counter < ind: 24 | line = line[line.index('\t') + 1:] 25 | counter += 1 26 | if line.rfind('\t') == -1: 27 | # this is the last field, so just remove the trailing newline 28 | return line[:-1] 29 | else: 30 | # return the remaining line up to the next tab 31 | return line[:line.index('\t')] 32 | 33 | def get_info_about_data_len_distribution(filepaths): 34 | numsents_maxnumtokens = [] 35 | for filepath in filepaths: 36 | first_line = True 37 | with open(filepath, 'r') as f: 38 | for line in f: 39 | if first_line: 40 | # find which fields are num_sentences and max_num_tokens_in_sentence 41 | temp_line = line 42 | num_sents_field_ind = 0 43 | while not (temp_line.startswith('num_sentences\t') or temp_line.startswith('num_sentences\n')): 44 | temp_line = temp_line[temp_line.index('\t') + 1:] 45 | num_sents_field_ind += 1 46 | temp_line = line 47 | max_num_tokens_field_ind = 0 48 | while not (temp_line.startswith('max_num_tokens_in_sentence\t') or 49 | temp_line.startswith('max_num_tokens_in_sentence\n')): 50 | temp_line = temp_line[temp_line.index('\t') + 1:] 51 | max_num_tokens_field_ind += 1 52 | first_line = False 53 | else: 54 | if line.strip() == '': 55 | continue 56 | num_sents = int(get_nth_field_in_line(line, num_sents_field_ind)) 57 | max_num_tokens = int(get_nth_field_in_line(line, max_num_tokens_field_ind)) 58 | numsents_maxnumtokens.append((num_sents, max_num_tokens)) 59 | 60 | num_sentences = [tup[0] for tup in numsents_maxnumtokens] 61 | return num_sentences 62 | 63 | 64 | all_lengths = get_info_about_data_len_distribution(filenames_to_use) 65 | 66 | arr_of_lengths = np.array(all_lengths) 67 | m = np.mean(arr_of_lengths) 68 | sd = np.std(arr_of_lengths) 69 | with open(write_to_filename, 'w') as f: 70 | f.write('Mean: ' + str(m) + '\n') 71 | f.write('SD: ' + str(sd) + '\n') 72 | 73 | print("Done calculating sentence stats.") 74 | -------------------------------------------------------------------------------- /before_model_training/calculate_word_stats.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from tqdm import tqdm 4 | from allennlp.data.dataset_readers.text_classification.textcat import TextCatReader 5 | from allennlp.data.tokenizers import WordTokenizer 6 | from allennlp.data.tokenizers.word_filter import PassThroughWordFilter 7 | import numpy as np 8 | 9 | filenames_to_use = [] 10 | for i in range(1, len(sys.argv)): 11 | assert os.path.isfile(sys.argv[i]) 12 | filenames_to_use.append(sys.argv[i]) 13 | 14 | write_to_filename = filenames_to_use[0] 15 | write_to_filename = write_to_filename[:write_to_filename.rfind('_')] 16 | if '/' in write_to_filename: 17 | write_to_filename = write_to_filename[write_to_filename.rfind('/') + 1:] 18 | write_to_filename = write_to_filename + "_wordstats.txt" 19 | print("Will write results to " + write_to_filename) 20 | 21 | class InstanceLenGenerator: 22 | def __init__(self, allennlp_formatted_reader, filepaths): 23 | self.allennlp_formatted_reader = allennlp_formatted_reader 24 | self.filepaths = filepaths 25 | 26 | def __iter__(self): 27 | for filepath in self.filepaths: 28 | for instance in tqdm(self.allennlp_formatted_reader._read(file_path=filepath)): 29 | instance_as_text_field = instance.fields['tokens'] 30 | yield len(instance_as_text_field.tokens) 31 | 32 | 33 | allennlp_reader = TextCatReader(word_tokenizer=WordTokenizer(word_filter=PassThroughWordFilter()), 34 | segment_sentences=False) 35 | len_generator = InstanceLenGenerator(allennlp_reader, filenames_to_use) 36 | all_lengths = [] 37 | for length in iter(len_generator): 38 | all_lengths.append(length) 39 | 40 | arr_of_lengths = np.array(all_lengths) 41 | m = np.mean(arr_of_lengths) 42 | sd = np.std(arr_of_lengths) 43 | with open(write_to_filename, 'w') as f: 44 | f.write('Mean: ' + str(m) + '\n') 45 | f.write('SD: ' + str(sd) + '\n') 46 | 47 | print("Done calculating word stats.") 48 | -------------------------------------------------------------------------------- /before_model_training/get_vocab_and_initialize_word2vec_embeddings_from_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python3 get_vocab_and_initialize_word2vec_embeddings_from_dataset.py --run-part 1 4 | python3 get_vocab_and_initialize_word2vec_embeddings_from_dataset.py --run-part 2 5 | python3 get_vocab_and_initialize_word2vec_embeddings_from_dataset.py --run-part 3 6 | -------------------------------------------------------------------------------- /configs/amazon_flan_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.6, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "pass_through_encoder", 37 | "input_size": 200, 38 | "hidden_size": 200, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 200, 43 | "context_vector_dim": 200 44 | }, 45 | "output_logit": { 46 | "input_dim": 200, 47 | "num_layers": 1, 48 | "hidden_dims": 5, 49 | "dropout": 0.4, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "max_instances_in_memory": 1000000, 65 | "batch_size": 64, 66 | "maximum_samples_per_batch": ["num_tokens", 135000], // confirmed that this affects batch size 67 | "biggest_batch_first": false 68 | }, 69 | "trainer": { 70 | "optimizer": { 71 | "type": "adam", 72 | "lr": 0.0002 73 | }, 74 | "validation_metric": "+accuracy", 75 | "num_serialized_models_to_keep": 2, 76 | "num_epochs": 60, 77 | //"grad_norm": 10.0, 78 | "grad_clipping": 50.0, 79 | "patience": 10, 80 | "cuda_device": 0, 81 | "learning_rate_scheduler": { 82 | "type": "reduce_on_plateau", 83 | "factor": 0.5, 84 | "mode": "max", 85 | "patience": 0 86 | }, 87 | "shuffle": true 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /configs/amazon_flan_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.6, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "convolutional_rnn_substitute", 37 | "input_size": 200, 38 | "hidden_size": 100, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "context_vector_dim": 100, 43 | "input_dim": 100 44 | }, 45 | "output_logit": { 46 | "input_dim": 100, 47 | "num_layers": 1, 48 | "hidden_dims": 5, 49 | "dropout": 0.4, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "max_instances_in_memory": 1000000, 64 | "sorting_keys": [["tokens", "num_tokens"]], 65 | "batch_size": 64, 66 | "maximum_samples_per_batch": ["num_tokens", 135000], // confirmed that this affects batch size 67 | "biggest_batch_first": false 68 | }, 69 | "trainer": { 70 | "optimizer": { 71 | "type": "adam", 72 | "lr": 0.0002 73 | }, 74 | "validation_metric": "+accuracy", 75 | "num_serialized_models_to_keep": 2, 76 | "num_epochs": 60, 77 | "grad_norm": 10.0, 78 | "patience": 5, 79 | "cuda_device": 0, 80 | "learning_rate_scheduler": { 81 | "type": "reduce_on_plateau", 82 | "factor": 0.5, 83 | "mode": "max", 84 | "patience": 0 85 | }, 86 | "shuffle": true 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /configs/amazon_flan_with_rnns-actuallyflanconv.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.6, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "convolutional_rnn_substitute", 37 | "input_size": 200, 38 | "hidden_size": 100, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "context_vector_dim": 100, 43 | "input_dim": 100 44 | }, 45 | "output_logit": { 46 | "input_dim": 100, 47 | "num_layers": 1, 48 | "hidden_dims": 5, 49 | "dropout": 0.4, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "max_instances_in_memory": 1000000, 64 | "sorting_keys": [["tokens", "num_tokens"]], 65 | "batch_size": 64, 66 | "maximum_samples_per_batch": ["num_tokens", 135000], // confirmed that this affects batch size 67 | "biggest_batch_first": false 68 | }, 69 | "trainer": { 70 | "optimizer": { 71 | "type": "adam", 72 | "lr": 0.0002 73 | }, 74 | "validation_metric": "+accuracy", 75 | "num_serialized_models_to_keep": 2, 76 | "num_epochs": 60, 77 | "grad_norm": 10.0, 78 | "patience": 5, 79 | "cuda_device": 0, 80 | "learning_rate_scheduler": { 81 | "type": "reduce_on_plateau", 82 | "factor": 0.5, 83 | "mode": "max", 84 | "patience": 0 85 | }, 86 | "shuffle": true 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /configs/amazon_flan_with_rnns.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.6, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "gru", 37 | "num_layers": 1, 38 | "bidirectional": true, 39 | "input_size": 200, 40 | "hidden_size": 50, 41 | }, 42 | "word_attention": { 43 | "type": "simple_han_attention", 44 | "input_dim": 100, 45 | "context_vector_dim": 100 46 | }, 47 | "output_logit": { 48 | "input_dim": 100, 49 | "num_layers": 1, 50 | "hidden_dims": 5, 51 | "dropout": 0.4, 52 | "activations": "linear" 53 | }, 54 | "initializer": [ 55 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 56 | [".*linear_layers.*bias", {"type": "zero"}], 57 | [".*weight_ih.*", {"type": "xavier_uniform"}], 58 | [".*weight_hh.*", {"type": "orthogonal"}], 59 | [".*bias_ih.*", {"type": "zero"}], 60 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 61 | ] 62 | }, 63 | "iterator": { 64 | "type": "extended_bucket", 65 | "sorting_keys": [["tokens", "num_tokens"]], 66 | "max_instances_in_memory": 1000000, 67 | "batch_size": 64, 68 | "maximum_samples_per_batch": ["num_tokens", 135000], // confirmed that this affects batch size 69 | "biggest_batch_first": false 70 | }, 71 | "trainer": { 72 | "optimizer": { 73 | "type": "adam", 74 | "lr": 0.0002 75 | }, 76 | "validation_metric": "+accuracy", 77 | "num_serialized_models_to_keep": 2, 78 | "num_epochs": 60, 79 | //"grad_norm": 10.0, 80 | "grad_clipping": 50.0, 81 | "patience": 5, 82 | "cuda_device": 0, 83 | "learning_rate_scheduler": { 84 | "type": "reduce_on_plateau", 85 | "factor": 0.5, 86 | "mode": "max", 87 | "patience": 0 88 | }, 89 | "shuffle": true 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-1.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.4, 25 | "pre_document_encoder_dropout": 0.4, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.4, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.001 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-2.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.6, 25 | "pre_document_encoder_dropout": 0.3, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.3, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.001 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-3.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.5, 25 | "pre_document_encoder_dropout": 0.1, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.5, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0001 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-4.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.6, 25 | "pre_document_encoder_dropout": 0.2, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.4, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0002 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-5.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.5, 25 | "pre_document_encoder_dropout": 0.1, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.5, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.00005 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-6.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.5, 25 | "pre_document_encoder_dropout": 0.1, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.5, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.00001 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper-actuallyhanconv.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.2, 25 | "pre_sentence_encoder_dropout": 0.6, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "convolutional_rnn_substitute", 38 | "input_size": 200, 39 | "hidden_size": 100, 40 | }, 41 | "document_encoder": { 42 | "type": "convolutional_rnn_substitute", 43 | "input_size": 100, 44 | "hidden_size": 100, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "context_vector_dim": 100, 49 | "input_dim": 100 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "context_vector_dim": 100, 54 | "input_dim": 100 55 | }, 56 | "output_logit": { 57 | "input_dim": 100, 58 | "num_layers": 1, 59 | "hidden_dims": 5, 60 | "dropout": 0.4, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "max_instances_in_memory": 1000000, 75 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 76 | "batch_size": 64, 77 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 78 | "biggest_batch_first": false 79 | }, 80 | "trainer": { 81 | "optimizer": { 82 | "type": "adam", 83 | "lr": 0.0002 84 | }, 85 | "validation_metric": "+accuracy", 86 | "num_serialized_models_to_keep": 2, 87 | "num_epochs": 60, 88 | "grad_norm": 10.0, 89 | "patience": 5, 90 | "cuda_device": 1, 91 | "learning_rate_scheduler": { 92 | "type": "reduce_on_plateau", 93 | "factor": 0.5, 94 | "mode": "max", 95 | "patience": 0 96 | }, 97 | "shuffle": true 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /configs/amazon_han_from_paper.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.6, 25 | "pre_document_encoder_dropout": 0.2, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 5, 64 | "dropout": 0.4, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "max_instances_in_memory": 1000000, 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0006 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 1, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/amazon_han_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.6, 25 | "pre_document_encoder_dropout": 0.2, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "pass_through_encoder", 38 | "input_size": 200, 39 | "hidden_size": 200, 40 | }, 41 | "document_encoder": { 42 | "type": "pass_through_encoder", 43 | "input_size": 200, 44 | "hidden_size": 200, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "input_dim": 200, 49 | "context_vector_dim": 200 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 200, 54 | "context_vector_dim": 200 55 | }, 56 | "output_logit": { 57 | "input_dim": 200, 58 | "num_layers": 1, 59 | "hidden_dims": 5, 60 | "dropout": 0.4, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 75 | "max_instances_in_memory": 1000000, 76 | "batch_size": 64, 77 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 78 | "biggest_batch_first": false 79 | }, 80 | "trainer": { 81 | "optimizer": { 82 | "type": "adam", 83 | "lr": 0.0002 84 | }, 85 | "validation_metric": "+accuracy", 86 | "num_serialized_models_to_keep": 2, 87 | "num_epochs": 60, 88 | //"grad_norm": 10.0, 89 | "grad_clipping": 50.0, 90 | "patience": 10, 91 | "cuda_device": 1, 92 | "learning_rate_scheduler": { 93 | "type": "reduce_on_plateau", 94 | "factor": 0.5, 95 | "mode": "max", 96 | "patience": 0 97 | }, 98 | "shuffle": true 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /configs/amazon_han_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 9, 3 | "numpy_seed": 289, 4 | "pytorch_seed": 218, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/amazon-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/amazon_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/amazon_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.2, 25 | "pre_sentence_encoder_dropout": 0.6, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/amazon_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "convolutional_rnn_substitute", 38 | "input_size": 200, 39 | "hidden_size": 100, 40 | }, 41 | "document_encoder": { 42 | "type": "convolutional_rnn_substitute", 43 | "input_size": 100, 44 | "hidden_size": 100, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "context_vector_dim": 100, 49 | "input_dim": 100 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "context_vector_dim": 100, 54 | "input_dim": 100 55 | }, 56 | "output_logit": { 57 | "input_dim": 100, 58 | "num_layers": 1, 59 | "hidden_dims": 5, 60 | "dropout": 0.4, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "max_instances_in_memory": 1000000, 75 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 76 | "batch_size": 64, 77 | "maximum_samples_per_batch": ["list_num_tokens", 9000], // confirmed that this affects batch size 78 | "biggest_batch_first": false 79 | }, 80 | "trainer": { 81 | "optimizer": { 82 | "type": "adam", 83 | "lr": 0.0002 84 | }, 85 | "validation_metric": "+accuracy", 86 | "num_serialized_models_to_keep": 2, 87 | "num_epochs": 60, 88 | "grad_norm": 10.0, 89 | "patience": 5, 90 | "cuda_device": 1, 91 | "learning_rate_scheduler": { 92 | "type": "reduce_on_plateau", 93 | "factor": 0.5, 94 | "mode": "max", 95 | "patience": 0 96 | }, 97 | "shuffle": true 98 | } 99 | } 100 | -------------------------------------------------------------------------------- /configs/imdb_flan_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.44446746096594764, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "pass_through_encoder", 37 | "input_size": 200, 38 | "hidden_size": 200, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 200, 43 | "context_vector_dim": 200 44 | }, 45 | "output_logit": { 46 | "input_dim": 200, 47 | "num_layers": 1, 48 | "hidden_dims": 10, 49 | "dropout": 0.3457355626352195, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 20000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0004 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 60, 76 | "grad_norm": 10.0, 77 | "patience": 10, 78 | "cuda_device": 2, 79 | "learning_rate_scheduler": { 80 | "type": "reduce_on_plateau", 81 | "factor": 0.5, 82 | "mode": "max", 83 | "patience": 0 84 | }, 85 | "shuffle": true 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /configs/imdb_flan_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.44446746096594764, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "convolutional_rnn_substitute", 37 | "input_size": 200, 38 | "hidden_size": 100, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 100, 43 | "context_vector_dim": 100 44 | }, 45 | "output_logit": { 46 | "input_dim": 100, 47 | "num_layers": 1, 48 | "hidden_dims": 10, 49 | "dropout": 0.3457355626352195, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 20000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0004 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 15, 76 | "grad_norm": 10.0, 77 | "patience": 5, 78 | "cuda_device": 2, 79 | "learning_rate_scheduler": { 80 | "type": "reduce_on_plateau", 81 | "factor": 0.5, 82 | "mode": "max", 83 | "patience": 0 84 | }, 85 | "shuffle": true 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /configs/imdb_flan_with_rnns.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.44446746096594764, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "gru", 37 | "num_layers": 1, 38 | "bidirectional": true, 39 | "input_size": 200, 40 | "hidden_size": 50, 41 | }, 42 | "word_attention": { 43 | "type": "simple_han_attention", 44 | "input_dim": 100, 45 | "context_vector_dim": 100 46 | }, 47 | "output_logit": { 48 | "input_dim": 100, 49 | "num_layers": 1, 50 | "hidden_dims": 10, 51 | "dropout": 0.3457355626352195, 52 | "activations": "linear" 53 | }, 54 | "initializer": [ 55 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 56 | [".*linear_layers.*bias", {"type": "zero"}], 57 | [".*weight_ih.*", {"type": "xavier_uniform"}], 58 | [".*weight_hh.*", {"type": "orthogonal"}], 59 | [".*bias_ih.*", {"type": "zero"}], 60 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 61 | ] 62 | }, 63 | "iterator": { 64 | "type": "extended_bucket", 65 | "sorting_keys": [["tokens", "num_tokens"]], 66 | "batch_size": 64, 67 | "maximum_samples_per_batch": ["num_tokens", 20000], // confirmed that this affects batch size 68 | "biggest_batch_first": false 69 | }, 70 | "trainer": { 71 | "optimizer": { 72 | "type": "adam", 73 | "lr": 0.0004 74 | }, 75 | "validation_metric": "+accuracy", 76 | "num_serialized_models_to_keep": 2, 77 | "num_epochs": 15, 78 | "grad_norm": 10.0, 79 | "patience": 5, 80 | "cuda_device": 2, 81 | "learning_rate_scheduler": { 82 | "type": "reduce_on_plateau", 83 | "factor": 0.5, 84 | "mode": "max", 85 | "patience": 0 86 | }, 87 | "shuffle": true 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /configs/imdb_han_from_paper.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat-attnlabel", 7 | "model_folder_name": "hanrnn-postattnfix", 8 | "segment_sentences": true, 9 | "word_tokenizer": "word", 10 | "token_indexers": { 11 | "tokens": { 12 | "type": "single_id", 13 | "lowercase_tokens": true 14 | } 15 | } 16 | }, 17 | "datasets_for_vocab_creation": [], 18 | "vocabulary": { 19 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 20 | }, 21 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 22 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 23 | "model": { 24 | "type": "han", 25 | "calculate_f1": true, 26 | "pre_sentence_encoder_dropout": 0.44446746096594764, 27 | "pre_document_encoder_dropout": 0.22016423400055152, 28 | "text_field_embedder": { 29 | "token_embedders": { 30 | "tokens": { 31 | "type": "embedding", 32 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 33 | "embedding_dim": 200, 34 | "trainable": true 35 | } 36 | } 37 | }, 38 | "sentence_encoder": { 39 | "type": "gru", 40 | "num_layers": 1, 41 | "bidirectional": true, 42 | "input_size": 200, 43 | "hidden_size": 50, 44 | }, 45 | "document_encoder": { 46 | "type": "gru", 47 | "num_layers": 1, 48 | "bidirectional": true, 49 | "input_size": 100, 50 | "hidden_size": 50, 51 | }, 52 | "word_attention": { 53 | "type": "simple_han_attention", 54 | "input_dim": 100, 55 | "context_vector_dim": 100 56 | }, 57 | "sentence_attention": { 58 | "type": "simple_han_attention", 59 | "input_dim": 100, 60 | "context_vector_dim": 100 61 | }, 62 | "output_logit": { 63 | "input_dim": 100, 64 | "num_layers": 1, 65 | "hidden_dims": 2, 66 | "dropout": 0.2457355626352195, 67 | "activations": "linear" 68 | }, 69 | "initializer": [ 70 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 71 | [".*linear_layers.*bias", {"type": "zero"}], 72 | [".*weight_ih.*", {"type": "xavier_uniform"}], 73 | [".*weight_hh.*", {"type": "orthogonal"}], 74 | [".*bias_ih.*", {"type": "zero"}], 75 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 76 | ] 77 | }, 78 | "iterator": { 79 | "type": "extended_bucket", 80 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 81 | "batch_size": 64, 82 | "maximum_samples_per_batch": ["list_num_tokens", 2000], // confirmed that this affects batch size 83 | "biggest_batch_first": false 84 | }, 85 | "trainer": { 86 | "optimizer": { 87 | "type": "adam", 88 | "lr": 0.0004 89 | }, 90 | "validation_metric": "+accuracy", 91 | "num_serialized_models_to_keep": 2, 92 | "num_epochs": 15, 93 | "grad_norm": 10.0, 94 | "patience": 5, 95 | "cuda_device": 2, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/imdb_han_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.44446746096594764, 25 | "pre_document_encoder_dropout": 0.22016423400055152, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "pass_through_encoder", 38 | "input_size": 200, 39 | "hidden_size": 200, 40 | }, 41 | "document_encoder": { 42 | "type": "pass_through_encoder", 43 | "input_size": 200, 44 | "hidden_size": 200, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "input_dim": 200, 49 | "context_vector_dim": 200 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 200, 54 | "context_vector_dim": 200 55 | }, 56 | "output_logit": { 57 | "input_dim": 200, 58 | "num_layers": 1, 59 | "hidden_dims": 10, 60 | "dropout": 0.2457355626352195, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 75 | "batch_size": 64, 76 | "maximum_samples_per_batch": ["list_num_tokens", 2000], // confirmed that this affects batch size 77 | "biggest_batch_first": false 78 | }, 79 | "trainer": { 80 | "optimizer": { 81 | "type": "adam", 82 | "lr": 0.0004 83 | }, 84 | "validation_metric": "+accuracy", 85 | "num_serialized_models_to_keep": 2, 86 | "num_epochs": 60, 87 | "grad_norm": 10.0, 88 | "patience": 10, 89 | "cuda_device": 2, 90 | "learning_rate_scheduler": { 91 | "type": "reduce_on_plateau", 92 | "factor": 0.5, 93 | "mode": "max", 94 | "patience": 0 95 | }, 96 | "shuffle": true 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /configs/imdb_han_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.44446746096594764, 25 | "pre_document_encoder_dropout": 0.22016423400055152, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "convolutional_rnn_substitute", 38 | "input_size": 200, 39 | "hidden_size": 100, 40 | }, 41 | "document_encoder": { 42 | "type": "convolutional_rnn_substitute", 43 | "input_size": 100, 44 | "hidden_size": 100, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "input_dim": 100, 49 | "context_vector_dim": 100 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "output_logit": { 57 | "input_dim": 100, 58 | "num_layers": 1, 59 | "hidden_dims": 10, 60 | "dropout": 0.2457355626352195, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 75 | "batch_size": 64, 76 | "maximum_samples_per_batch": ["list_num_tokens", 2000], // confirmed that this affects batch size 77 | "biggest_batch_first": false 78 | }, 79 | "trainer": { 80 | "optimizer": { 81 | "type": "adam", 82 | "lr": 0.0004 83 | }, 84 | "validation_metric": "+accuracy", 85 | "num_serialized_models_to_keep": 2, 86 | "num_epochs": 15, 87 | "grad_norm": 10.0, 88 | "patience": 5, 89 | "cuda_device": 2, 90 | "learning_rate_scheduler": { 91 | "type": "reduce_on_plateau", 92 | "factor": 0.5, 93 | "mode": "max", 94 | "patience": 0 95 | }, 96 | "shuffle": true 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /configs/old_configs/imdb_han_from_paper_ORIGINAL.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/imdb-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.44446746096594764, 25 | "pre_document_encoder_dropout": 0.22016423400055152, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/imdb_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 10, 64 | "dropout": 0.2457355626352195, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "batch_size": 64, 80 | "maximum_samples_per_batch": ["list_num_tokens", 2000], // confirmed that this affects batch size 81 | "biggest_batch_first": false 82 | }, 83 | "trainer": { 84 | "optimizer": { 85 | "type": "adam", 86 | "lr": 0.0004 87 | }, 88 | "validation_metric": "+accuracy", 89 | "num_serialized_models_to_keep": 2, 90 | "num_epochs": 15, 91 | "grad_norm": 10.0, 92 | "patience": 5, 93 | "cuda_device": 2, 94 | "learning_rate_scheduler": { 95 | "type": "reduce_on_plateau", 96 | "factor": 0.5, 97 | "mode": "max", 98 | "patience": 0 99 | }, 100 | "shuffle": true 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /configs/old_configs/yelp_han_from_paper_ATTNPERF.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat-attnlabel", 7 | "model_folder_name": "hanrnn-postattnfix-2", 8 | "segment_sentences": true, 9 | "word_tokenizer": "word", 10 | "token_indexers": { 11 | "tokens": { 12 | "type": "single_id", 13 | "lowercase_tokens": true 14 | } 15 | } 16 | }, 17 | "datasets_for_vocab_creation": [], 18 | "vocabulary": { 19 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 20 | }, 21 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 22 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 23 | "model": { 24 | "type": "han", 25 | "calculate_f1": true, 26 | "loss_class_weights": [1.0, 2.226], 27 | "pre_document_encoder_dropout": 0.4, 28 | "pre_sentence_encoder_dropout": 0.4, 29 | "text_field_embedder": { 30 | "token_embedders": { 31 | "tokens": { 32 | "type": "embedding", 33 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 34 | "embedding_dim": 200, 35 | "trainable": true, 36 | "max_norm": 1.0 37 | } 38 | } 39 | }, 40 | "sentence_encoder": { 41 | "type": "gru", 42 | "num_layers": 1, 43 | "bidirectional": true, 44 | "input_size": 200, 45 | "hidden_size": 50, 46 | }, 47 | "document_encoder": { 48 | "type": "gru", 49 | "num_layers": 1, 50 | "bidirectional": true, 51 | "input_size": 100, 52 | "hidden_size": 50, 53 | }, 54 | "word_attention": { 55 | "type": "simple_han_attention", 56 | "input_dim": 100, 57 | "context_vector_dim": 100 58 | }, 59 | "sentence_attention": { 60 | "type": "simple_han_attention", 61 | "input_dim": 100, 62 | "context_vector_dim": 100 63 | }, 64 | "output_logit": { 65 | "input_dim": 100, 66 | "num_layers": 1, 67 | "hidden_dims": 2, 68 | "dropout": 0.4, 69 | "activations": "linear" 70 | }, 71 | "initializer": [ 72 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 73 | [".*linear_layers.*bias", {"type": "zero"}], 74 | [".*weight_ih.*", {"type": "xavier_uniform"}], 75 | [".*weight_hh.*", {"type": "orthogonal"}], 76 | [".*bias_ih.*", {"type": "zero"}], 77 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 78 | ] 79 | }, 80 | "iterator": { 81 | "type": "extended_bucket", 82 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 83 | "batch_size": 64, 84 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 85 | "biggest_batch_first": false 86 | }, 87 | "trainer": { 88 | "optimizer": { 89 | "type": "sgd", 90 | "lr": 0.001, 91 | "momentum": 0.9 92 | }, 93 | "validation_metric": "+accuracy", 94 | "num_serialized_models_to_keep": 2, 95 | "num_epochs": 15, 96 | //"grad_norm": 10.0, 97 | "grad_clipping": 50.0, 98 | "patience": 5, 99 | "cuda_device": 0, 100 | "learning_rate_scheduler": { 101 | "type": "reduce_on_plateau", 102 | "factor": 0.5, 103 | "mode": "max", 104 | "patience": 0 105 | }, 106 | "shuffle": true 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /configs/whateverDatasetYouHaveInMind_sample_han.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "textcat", 4 | "segment_sentences": true, 5 | "word_tokenizer": "word", 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | } 11 | } 12 | }, 13 | "datasets_for_vocab_creation": ["train"], // this assumes you want to create the vocabulary from the training set-- feel free to ask me if you want to do something else 14 | "train_data_path": "/homes/gws/sofias6/data/imdb_train.tsv", // replace this with your training data file, in the format listed at the top of textcat.TextCatReader (in textcat_reader.py) 15 | "validation_data_path": "/homes/gws/sofias6/data/imdb_dev.tsv", // replace this with your dev data file, in the format listed at the top of textcat.TextCatReader (in textcat_reader.py) 16 | "model": { 17 | "type": "han", 18 | "pre_sentence_encoder_dropout": 0.4, 19 | "pre_document_encoder_dropout": 0.4, 20 | "text_field_embedder": { 21 | "token_embedders": { 22 | "tokens": { 23 | "type": "embedding", 24 | "pretrained_file": "https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.6B.50d.txt.gz", 25 | "embedding_dim": 50, 26 | "trainable": true 27 | } 28 | } 29 | }, 30 | "sentence_encoder": { 31 | "type": "gru", 32 | "num_layers": 1, 33 | "bidirectional": true, 34 | "input_size": 50, 35 | "hidden_size": 50, 36 | }, 37 | "document_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 100, 42 | "hidden_size": 50, 43 | }, 44 | "word_attention": { 45 | "type": "intra_sentence_attention", 46 | "input_dim": 100, 47 | "combination": "2", 48 | "similarity_function": { 49 | "type": "han_paper", 50 | "input_dim": 100, 51 | "context_vect_dim": 100 52 | }, 53 | }, 54 | "sentence_attention": { 55 | "type": "intra_sentence_attention", 56 | "input_dim": 100, 57 | "combination": "2", 58 | "similarity_function": { 59 | "type": "han_paper", 60 | "input_dim": 100, 61 | "context_vect_dim": 100 62 | }, 63 | }, 64 | "output_logit": { 65 | "input_dim": 100, 66 | "num_layers": 1, 67 | "hidden_dims": 10, 68 | "dropout": 0.4, 69 | "activations": "linear" 70 | }, 71 | "initializer": [ 72 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 73 | [".*linear_layers.*bias", {"type": "zero"}], 74 | [".*weight_ih.*", {"type": "xavier_uniform"}], 75 | [".*weight_hh.*", {"type": "orthogonal"}], 76 | [".*bias_ih.*", {"type": "zero"}], 77 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 78 | ] 79 | }, 80 | "iterator": { 81 | "type": "extended_bucket", 82 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 83 | "batch_size": 64, 84 | "maximum_samples_per_batch": ["list_num_tokens", 2000], // confirmed that this affects batch size 85 | "biggest_batch_first": false 86 | }, 87 | "trainer": { 88 | "optimizer": { 89 | "type": "adam", 90 | "lr": 0.0004 91 | }, 92 | "validation_metric": "+accuracy", 93 | "num_serialized_models_to_keep": 2, 94 | "num_epochs": 15, 95 | "grad_norm": 10.0, 96 | "patience": 5, 97 | "cuda_device": -1, // swap this to a gpu id on the machine (such as 0 or 1) if you want to use a gpu 98 | "learning_rate_scheduler": { 99 | "type": "reduce_on_plateau", 100 | "factor": 0.5, 101 | "mode": "max", 102 | "patience": 0 103 | }, 104 | "shuffle": true 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /configs/yahoo10cat_flan_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yahoo10cat-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yahoo10cat_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yahoo10cat_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.44446746096594764, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yahoo10cat_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "pass_through_encoder", 37 | "input_size": 200, 38 | "hidden_size": 200, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 200, 43 | "context_vector_dim": 200 44 | }, 45 | "output_logit": { 46 | "input_dim": 200, 47 | "num_layers": 1, 48 | "hidden_dims": 10, 49 | "dropout": 0.4457355626352195, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 15000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0004 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 60, 76 | "grad_norm": 10.0, 77 | "patience": 10, 78 | "cuda_device": 2, 79 | "learning_rate_scheduler": { 80 | "type": "reduce_on_plateau", 81 | "factor": 0.5, 82 | "mode": "max", 83 | "patience": 0 84 | }, 85 | "shuffle": true 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /configs/yahoo10cat_flan_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yahoo10cat-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yahoo10cat_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yahoo10cat_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.44446746096594764, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yahoo10cat_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "convolutional_rnn_substitute", 37 | "input_size": 200, 38 | "hidden_size": 100, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 100, 43 | "context_vector_dim": 100 44 | }, 45 | "output_logit": { 46 | "input_dim": 100, 47 | "num_layers": 1, 48 | "hidden_dims": 10, 49 | "dropout": 0.4457355626352195, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 15000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0004 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 15, 76 | "grad_norm": 10.0, 77 | "patience": 5, 78 | "cuda_device": 2, 79 | "learning_rate_scheduler": { 80 | "type": "reduce_on_plateau", 81 | "factor": 0.5, 82 | "mode": "max", 83 | "patience": 0 84 | }, 85 | "shuffle": true 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /configs/yahoo10cat_flan_with_rnns.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 370, 3 | "numpy_seed": 944, 4 | "pytorch_seed": 972, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yahoo10cat-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yahoo10cat_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yahoo10cat_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.44446746096594764, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yahoo10cat_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "gru", 37 | "num_layers": 1, 38 | "bidirectional": true, 39 | "input_size": 200, 40 | "hidden_size": 50, 41 | }, 42 | "word_attention": { 43 | "type": "simple_han_attention", 44 | "input_dim": 100, 45 | "context_vector_dim": 100 46 | }, 47 | "output_logit": { 48 | "input_dim": 100, 49 | "num_layers": 1, 50 | "hidden_dims": 10, 51 | "dropout": 0.4457355626352195, 52 | "activations": "linear" 53 | }, 54 | "initializer": [ 55 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 56 | [".*linear_layers.*bias", {"type": "zero"}], 57 | [".*weight_ih.*", {"type": "xavier_uniform"}], 58 | [".*weight_hh.*", {"type": "orthogonal"}], 59 | [".*bias_ih.*", {"type": "zero"}], 60 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 61 | ] 62 | }, 63 | "iterator": { 64 | "type": "extended_bucket", 65 | "sorting_keys": [["tokens", "num_tokens"]], 66 | "batch_size": 64, 67 | "maximum_samples_per_batch": ["num_tokens", 15000], // confirmed that this affects batch size 68 | "biggest_batch_first": false 69 | }, 70 | "trainer": { 71 | "optimizer": { 72 | "type": "adam", 73 | "lr": 0.0004 74 | }, 75 | "validation_metric": "+accuracy", 76 | "num_serialized_models_to_keep": 2, 77 | "num_epochs": 15, 78 | "grad_norm": 10.0, 79 | "patience": 5, 80 | "cuda_device": 2, 81 | "learning_rate_scheduler": { 82 | "type": "reduce_on_plateau", 83 | "factor": 0.5, 84 | "mode": "max", 85 | "patience": 0 86 | }, 87 | "shuffle": true 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /configs/yahoo10cat_han_from_paper.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 122, 3 | "numpy_seed": 462, 4 | "pytorch_seed": 458, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yahoo10cat-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yahoo10cat_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yahoo10cat_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.44446746096594764, 25 | "pre_document_encoder_dropout": 0.22016423400055152, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yahoo10cat_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "gru", 38 | "num_layers": 1, 39 | "bidirectional": true, 40 | "input_size": 200, 41 | "hidden_size": 50, 42 | }, 43 | "document_encoder": { 44 | "type": "gru", 45 | "num_layers": 1, 46 | "bidirectional": true, 47 | "input_size": 100, 48 | "hidden_size": 50, 49 | }, 50 | "word_attention": { 51 | "type": "simple_han_attention", 52 | "input_dim": 100, 53 | "context_vector_dim": 100 54 | }, 55 | "sentence_attention": { 56 | "type": "simple_han_attention", 57 | "input_dim": 100, 58 | "context_vector_dim": 100 59 | }, 60 | "output_logit": { 61 | "input_dim": 100, 62 | "num_layers": 1, 63 | "hidden_dims": 10, 64 | "dropout": 0.3748675587736766, 65 | "activations": "linear" 66 | }, 67 | "initializer": [ 68 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 69 | [".*linear_layers.*bias", {"type": "zero"}], 70 | [".*weight_ih.*", {"type": "xavier_uniform"}], 71 | [".*weight_hh.*", {"type": "orthogonal"}], 72 | [".*bias_ih.*", {"type": "zero"}], 73 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 74 | ] 75 | }, 76 | "iterator": { 77 | "type": "extended_bucket", 78 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 79 | "batch_size": 64, 80 | "maximum_samples_per_batch": ["list_num_tokens", 1000], // confirmed that this affects batch size 81 | "biggest_batch_first": false 82 | }, 83 | "trainer": { 84 | "optimizer": { 85 | "type": "adam", 86 | "lr": 0.0004 87 | }, 88 | "validation_metric": "+accuracy", 89 | "num_serialized_models_to_keep": 2, 90 | "num_epochs": 15, 91 | "grad_norm": 10.0, 92 | "patience": 5, 93 | "cuda_device": 0, 94 | "learning_rate_scheduler": { 95 | "type": "reduce_on_plateau", 96 | "factor": 0.5, 97 | "mode": "max", 98 | "patience": 0 99 | }, 100 | "shuffle": true 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /configs/yahoo10cat_han_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 122, 3 | "numpy_seed": 462, 4 | "pytorch_seed": 458, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yahoo10cat-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yahoo10cat_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yahoo10cat_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.44446746096594764, 25 | "pre_document_encoder_dropout": 0.22016423400055152, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yahoo10cat_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "pass_through_encoder", 38 | "input_size": 200, 39 | "hidden_size": 200, 40 | }, 41 | "document_encoder": { 42 | "type": "pass_through_encoder", 43 | "input_size": 200, 44 | "hidden_size": 200, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "input_dim": 200, 49 | "context_vector_dim": 200 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 200, 54 | "context_vector_dim": 200 55 | }, 56 | "output_logit": { 57 | "input_dim": 200, 58 | "num_layers": 1, 59 | "hidden_dims": 10, 60 | "dropout": 0.3748675587736766, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 75 | "batch_size": 64, 76 | "maximum_samples_per_batch": ["list_num_tokens", 1000], // confirmed that this affects batch size 77 | "biggest_batch_first": false 78 | }, 79 | "trainer": { 80 | "optimizer": { 81 | "type": "adam", 82 | "lr": 0.0004 83 | }, 84 | "validation_metric": "+accuracy", 85 | "num_serialized_models_to_keep": 2, 86 | "num_epochs": 60, 87 | "grad_norm": 10.0, 88 | "patience": 10, 89 | "cuda_device": 0, 90 | "learning_rate_scheduler": { 91 | "type": "reduce_on_plateau", 92 | "factor": 0.5, 93 | "mode": "max", 94 | "patience": 0 95 | }, 96 | "shuffle": true 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /configs/yahoo10cat_han_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 122, 3 | "numpy_seed": 462, 4 | "pytorch_seed": 458, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yahoo10cat-lowercase-vocab" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yahoo10cat_train.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yahoo10cat_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_sentence_encoder_dropout": 0.44446746096594764, 25 | "pre_document_encoder_dropout": 0.22016423400055152, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yahoo10cat_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "convolutional_rnn_substitute", 38 | "input_size": 200, 39 | "hidden_size": 100, 40 | }, 41 | "document_encoder": { 42 | "type": "convolutional_rnn_substitute", 43 | "input_size": 100, 44 | "hidden_size": 100, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "input_dim": 100, 49 | "context_vector_dim": 100 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "output_logit": { 57 | "input_dim": 100, 58 | "num_layers": 1, 59 | "hidden_dims": 10, 60 | "dropout": 0.3748675587736766, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 75 | "batch_size": 64, 76 | "maximum_samples_per_batch": ["list_num_tokens", 1000], // confirmed that this affects batch size 77 | "biggest_batch_first": false 78 | }, 79 | "trainer": { 80 | "optimizer": { 81 | "type": "adam", 82 | "lr": 0.0004 83 | }, 84 | "validation_metric": "+accuracy", 85 | "num_serialized_models_to_keep": 2, 86 | "num_epochs": 15, 87 | "grad_norm": 10.0, 88 | "patience": 5, 89 | "cuda_device": 0, 90 | "learning_rate_scheduler": { 91 | "type": "reduce_on_plateau", 92 | "factor": 0.5, 93 | "mode": "max", 94 | "patience": 0 95 | }, 96 | "shuffle": true 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /configs/yelp_flan_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.7, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "pass_through_encoder", 37 | "input_size": 200, 38 | "hidden_size": 200, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 200, 43 | "context_vector_dim": 200 44 | }, 45 | "output_logit": { 46 | "input_dim": 200, 47 | "num_layers": 1, 48 | "hidden_dims": 5, 49 | "dropout": 0.7, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 90000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0001 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 60, 76 | //"grad_norm": 10.0, 77 | "grad_clipping": 50.0, 78 | "patience": 10, 79 | "cuda_device": 0, 80 | "learning_rate_scheduler": { 81 | "type": "reduce_on_plateau", 82 | "factor": 0.5, 83 | "mode": "max", 84 | "patience": 0 85 | }, 86 | "shuffle": true 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /configs/yelp_flan_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.7, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "convolutional_rnn_substitute", 37 | "input_size": 200, 38 | "hidden_size": 100, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 100, 43 | "context_vector_dim": 100 44 | }, 45 | "output_logit": { 46 | "input_dim": 100, 47 | "num_layers": 1, 48 | "hidden_dims": 5, 49 | "dropout": 0.7, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 90000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0001 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 60, 76 | "grad_norm": 10.0, 77 | "patience": 5, 78 | "cuda_device": 1, 79 | "learning_rate_scheduler": { 80 | "type": "reduce_on_plateau", 81 | "factor": 0.5, 82 | "mode": "max", 83 | "patience": 0 84 | }, 85 | "shuffle": true 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /configs/yelp_flan_with_rnns-actuallyflanconv.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.7, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "convolutional_rnn_substitute", 37 | "input_size": 200, 38 | "hidden_size": 100, 39 | }, 40 | "word_attention": { 41 | "type": "simple_han_attention", 42 | "input_dim": 100, 43 | "context_vector_dim": 100 44 | }, 45 | "output_logit": { 46 | "input_dim": 100, 47 | "num_layers": 1, 48 | "hidden_dims": 5, 49 | "dropout": 0.7, 50 | "activations": "linear" 51 | }, 52 | "initializer": [ 53 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 54 | [".*linear_layers.*bias", {"type": "zero"}], 55 | [".*weight_ih.*", {"type": "xavier_uniform"}], 56 | [".*weight_hh.*", {"type": "orthogonal"}], 57 | [".*bias_ih.*", {"type": "zero"}], 58 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 59 | ] 60 | }, 61 | "iterator": { 62 | "type": "extended_bucket", 63 | "sorting_keys": [["tokens", "num_tokens"]], 64 | "batch_size": 64, 65 | "maximum_samples_per_batch": ["num_tokens", 90000], // confirmed that this affects batch size 66 | "biggest_batch_first": false 67 | }, 68 | "trainer": { 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 0.0001 72 | }, 73 | "validation_metric": "+accuracy", 74 | "num_serialized_models_to_keep": 2, 75 | "num_epochs": 60, 76 | "grad_norm": 10.0, 77 | "patience": 5, 78 | "cuda_device": 1, 79 | "learning_rate_scheduler": { 80 | "type": "reduce_on_plateau", 81 | "factor": 0.5, 82 | "mode": "max", 83 | "patience": 0 84 | }, 85 | "shuffle": true 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /configs/yelp_flan_with_rnns.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": false, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "flan", 24 | "pre_document_encoder_dropout": 0.7, 25 | "text_field_embedder": { 26 | "token_embedders": { 27 | "tokens": { 28 | "type": "embedding", 29 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 30 | "embedding_dim": 200, 31 | "trainable": true 32 | } 33 | } 34 | }, 35 | "document_encoder": { 36 | "type": "gru", 37 | "num_layers": 1, 38 | "bidirectional": true, 39 | "input_size": 200, 40 | "hidden_size": 50, 41 | }, 42 | "word_attention": { 43 | "type": "simple_han_attention", 44 | "input_dim": 100, 45 | "context_vector_dim": 100 46 | }, 47 | "output_logit": { 48 | "input_dim": 100, 49 | "num_layers": 1, 50 | "hidden_dims": 5, 51 | "dropout": 0.7, 52 | "activations": "linear" 53 | }, 54 | "initializer": [ 55 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 56 | [".*linear_layers.*bias", {"type": "zero"}], 57 | [".*weight_ih.*", {"type": "xavier_uniform"}], 58 | [".*weight_hh.*", {"type": "orthogonal"}], 59 | [".*bias_ih.*", {"type": "zero"}], 60 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 61 | ] 62 | }, 63 | "iterator": { 64 | "type": "extended_bucket", 65 | "sorting_keys": [["tokens", "num_tokens"]], 66 | "batch_size": 64, 67 | "maximum_samples_per_batch": ["num_tokens", 90000], // confirmed that this affects batch size 68 | "biggest_batch_first": false 69 | }, 70 | "trainer": { 71 | "optimizer": { 72 | "type": "adam", 73 | "lr": 0.0001 74 | }, 75 | "validation_metric": "+accuracy", 76 | "num_serialized_models_to_keep": 2, 77 | "num_epochs": 60, 78 | //"grad_norm": 10.0, 79 | "grad_clipping": 50.0, 80 | "patience": 5, 81 | "cuda_device": 0, 82 | "learning_rate_scheduler": { 83 | "type": "reduce_on_plateau", 84 | "factor": 0.5, 85 | "mode": "max", 86 | "patience": 0 87 | }, 88 | "shuffle": true 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-1.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.2, 25 | "pre_sentence_encoder_dropout": 0.6, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.2, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0004 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-2.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.4, 25 | "pre_sentence_encoder_dropout": 0.6, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.3, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "sgd", 87 | "lr": 0.001, 88 | "momentum": 0.9 89 | }, 90 | "validation_metric": "+accuracy", 91 | "num_serialized_models_to_keep": 2, 92 | "num_epochs": 60, 93 | //"grad_norm": 10.0, 94 | "grad_clipping": 50.0, 95 | "patience": 5, 96 | "cuda_device": 0, 97 | "learning_rate_scheduler": { 98 | "type": "reduce_on_plateau", 99 | "factor": 0.5, 100 | "mode": "max", 101 | "patience": 0 102 | }, 103 | "shuffle": true 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-3.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.2, 25 | "pre_sentence_encoder_dropout": 0.5, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.5, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.001 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-4.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.3, 25 | "pre_sentence_encoder_dropout": 0.3, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.7, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0008 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-5.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.1, 25 | "pre_sentence_encoder_dropout": 0.7, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.7, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0004 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-5evensmallerstep.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.1, 25 | "pre_sentence_encoder_dropout": 0.7, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.7, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.000025 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-5smallerstep.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.1, 25 | "pre_sentence_encoder_dropout": 0.7, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.7, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0001 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper-6.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.0, 25 | "pre_sentence_encoder_dropout": 0.8, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.8, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "adam", 87 | "lr": 0.0004 88 | }, 89 | "validation_metric": "+accuracy", 90 | "num_serialized_models_to_keep": 2, 91 | "num_epochs": 60, 92 | //"grad_norm": 10.0, 93 | "grad_clipping": 50.0, 94 | "patience": 5, 95 | "cuda_device": 0, 96 | "learning_rate_scheduler": { 97 | "type": "reduce_on_plateau", 98 | "factor": 0.5, 99 | "mode": "max", 100 | "patience": 0 101 | }, 102 | "shuffle": true 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /configs/yelp_han_from_paper.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.4, 25 | "pre_sentence_encoder_dropout": 0.4, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "gru", 39 | "num_layers": 1, 40 | "bidirectional": true, 41 | "input_size": 200, 42 | "hidden_size": 50, 43 | }, 44 | "document_encoder": { 45 | "type": "gru", 46 | "num_layers": 1, 47 | "bidirectional": true, 48 | "input_size": 100, 49 | "hidden_size": 50, 50 | }, 51 | "word_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "sentence_attention": { 57 | "type": "simple_han_attention", 58 | "input_dim": 100, 59 | "context_vector_dim": 100 60 | }, 61 | "output_logit": { 62 | "input_dim": 100, 63 | "num_layers": 1, 64 | "hidden_dims": 5, 65 | "dropout": 0.3, 66 | "activations": "linear" 67 | }, 68 | "initializer": [ 69 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 70 | [".*linear_layers.*bias", {"type": "zero"}], 71 | [".*weight_ih.*", {"type": "xavier_uniform"}], 72 | [".*weight_hh.*", {"type": "orthogonal"}], 73 | [".*bias_ih.*", {"type": "zero"}], 74 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 75 | ] 76 | }, 77 | "iterator": { 78 | "type": "extended_bucket", 79 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 80 | "batch_size": 64, 81 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 82 | "biggest_batch_first": false 83 | }, 84 | "trainer": { 85 | "optimizer": { 86 | "type": "sgd", 87 | "lr": 0.0004, 88 | "momentum": 0.9 89 | }, 90 | "validation_metric": "+accuracy", 91 | "num_serialized_models_to_keep": 2, 92 | "num_epochs": 60, 93 | //"grad_norm": 10.0, 94 | "grad_clipping": 50.0, 95 | "patience": 5, 96 | "cuda_device": 0, 97 | "learning_rate_scheduler": { 98 | "type": "reduce_on_plateau", 99 | "factor": 0.5, 100 | "mode": "max", 101 | "patience": 0 102 | }, 103 | "shuffle": true 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /configs/yelp_han_no_encoders.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.1, 25 | "pre_sentence_encoder_dropout": 0.7, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true, 33 | "max_norm": 1.0 34 | } 35 | } 36 | }, 37 | "sentence_encoder": { 38 | "type": "pass_through_encoder", 39 | "input_size": 200, 40 | "hidden_size": 200, 41 | }, 42 | "document_encoder": { 43 | "type": "pass_through_encoder", 44 | "input_size": 200, 45 | "hidden_size": 200, 46 | }, 47 | "word_attention": { 48 | "type": "simple_han_attention", 49 | "input_dim": 200, 50 | "context_vector_dim": 200 51 | }, 52 | "sentence_attention": { 53 | "type": "simple_han_attention", 54 | "input_dim": 200, 55 | "context_vector_dim": 200 56 | }, 57 | "output_logit": { 58 | "input_dim": 200, 59 | "num_layers": 1, 60 | "hidden_dims": 5, 61 | "dropout": 0.7, 62 | "activations": "linear" 63 | }, 64 | "initializer": [ 65 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 66 | [".*linear_layers.*bias", {"type": "zero"}], 67 | [".*weight_ih.*", {"type": "xavier_uniform"}], 68 | [".*weight_hh.*", {"type": "orthogonal"}], 69 | [".*bias_ih.*", {"type": "zero"}], 70 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 71 | ] 72 | }, 73 | "iterator": { 74 | "type": "extended_bucket", 75 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 76 | "batch_size": 64, 77 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 78 | "biggest_batch_first": false 79 | }, 80 | "trainer": { 81 | "optimizer": { 82 | "type": "adam", 83 | "lr": 0.0001 84 | }, 85 | "validation_metric": "+accuracy", 86 | "num_serialized_models_to_keep": 2, 87 | "num_epochs": 60, 88 | //"grad_norm": 10.0, 89 | "grad_clipping": 50.0, 90 | "patience": 10, 91 | "cuda_device": 0, 92 | "learning_rate_scheduler": { 93 | "type": "reduce_on_plateau", 94 | "factor": 0.5, 95 | "mode": "max", 96 | "patience": 0 97 | }, 98 | "shuffle": true 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /configs/yelp_han_with_convs.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "random_seed": 217, 3 | "numpy_seed": 735, 4 | "pytorch_seed": 781, 5 | "dataset_reader": { 6 | "type": "textcat", 7 | "segment_sentences": true, 8 | "word_tokenizer": "word", 9 | "token_indexers": { 10 | "tokens": { 11 | "type": "single_id", 12 | "lowercase_tokens": true 13 | } 14 | } 15 | }, 16 | "datasets_for_vocab_creation": [], 17 | "vocabulary": { 18 | "directory_path": "/homes/gws/sofias6/vocabs/yelp-lowercase-vocab-30its" 19 | }, 20 | "train_data_path": "/homes/gws/sofias6/data/yelp_train_remoutliers.tsv", 21 | "validation_data_path": "/homes/gws/sofias6/data/yelp_dev.tsv", 22 | "model": { 23 | "type": "han", 24 | "pre_document_encoder_dropout": 0.1, 25 | "pre_sentence_encoder_dropout": 0.7, 26 | "text_field_embedder": { 27 | "token_embedders": { 28 | "tokens": { 29 | "type": "embedding", 30 | "pretrained_file": "/homes/gws/sofias6/data/yelp_train_lowercase_embeddings.h5", 31 | "embedding_dim": 200, 32 | "trainable": true 33 | } 34 | } 35 | }, 36 | "sentence_encoder": { 37 | "type": "convolutional_rnn_substitute", 38 | "input_size": 200, 39 | "hidden_size": 100, 40 | }, 41 | "document_encoder": { 42 | "type": "convolutional_rnn_substitute", 43 | "input_size": 100, 44 | "hidden_size": 100, 45 | }, 46 | "word_attention": { 47 | "type": "simple_han_attention", 48 | "input_dim": 100, 49 | "context_vector_dim": 100 50 | }, 51 | "sentence_attention": { 52 | "type": "simple_han_attention", 53 | "input_dim": 100, 54 | "context_vector_dim": 100 55 | }, 56 | "output_logit": { 57 | "input_dim": 100, 58 | "num_layers": 1, 59 | "hidden_dims": 5, 60 | "dropout": 0.7, 61 | "activations": "linear" 62 | }, 63 | "initializer": [ 64 | [".*linear_layers.*weight", {"type": "xavier_uniform"}], 65 | [".*linear_layers.*bias", {"type": "zero"}], 66 | [".*weight_ih.*", {"type": "xavier_uniform"}], 67 | [".*weight_hh.*", {"type": "orthogonal"}], 68 | [".*bias_ih.*", {"type": "zero"}], 69 | [".*bias_hh.*", {"type": "lstm_hidden_bias"}] 70 | ] 71 | }, 72 | "iterator": { 73 | "type": "extended_bucket", 74 | "sorting_keys": [["sentences", "num_sentences"], ["tokens", "list_num_tokens"]], 75 | "batch_size": 64, 76 | "maximum_samples_per_batch": ["list_num_tokens", 6000], // confirmed that this affects batch size 77 | "biggest_batch_first": false 78 | }, 79 | "trainer": { 80 | "optimizer": { 81 | "type": "adam", 82 | "lr": 0.0001 83 | }, 84 | "validation_metric": "+accuracy", 85 | "num_serialized_models_to_keep": 2, 86 | "num_epochs": 60, 87 | "grad_norm": 10.0, 88 | "patience": 5, 89 | "cuda_device": 0, 90 | "learning_rate_scheduler": { 91 | "type": "reduce_on_plateau", 92 | "factor": 0.5, 93 | "mode": "max", 94 | "patience": 0 95 | }, 96 | "shuffle": true 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Before running `train_model.py`, there must exist a data directory 2 | containing properly formatted data files. 3 | 4 | (There also need to exist pretrained word embeddings and an 5 | allennlp vocabulary; the paths to those must be specified in 6 | the config files.) 7 | 8 | If you want to use a different data directory location, just change 9 | the location listed in `default_directories.py`. -------------------------------------------------------------------------------- /default_directories.py: -------------------------------------------------------------------------------- 1 | base_output_dir = 'attn-test-output/' 2 | base_serialized_models_dir = 'models/' 3 | base_data_dir = 'data/' 4 | dir_with_config_files = 'configs/' 5 | images_dir = 'imgs/' 6 | tex_files_dir = 'generated_tex_files/' 7 | vocabs_dir = 'vocabs/' 8 | 9 | all_dirs = list(globals().keys()) 10 | for dirname in all_dirs: 11 | if not dirname.startswith('__') and not globals()[dirname].endswith('/'): 12 | globals()[dirname] = globals()[dirname] + '/' 13 | -------------------------------------------------------------------------------- /misc_scripts/get_attnperf_overlap.py: -------------------------------------------------------------------------------- 1 | # all label filenames should have same length. Each line should consist of either a 1 or a 0. 2 | label_filenames = [] 3 | 4 | list_of_lists = [] 5 | with open(label_filenames[0], 'r') as f: 6 | for line in f: 7 | line = line.strip() 8 | if line != '': 9 | list_of_lists.append([int(line)]) 10 | 11 | for i in range(1, len(label_filenames)): 12 | with open(label_filenames[i], 'r') as f: 13 | counter = 0 14 | for line in f: 15 | line = line.strip() 16 | if line != '': 17 | list_of_lists[counter].append(int(line)) 18 | counter += 1 19 | 20 | results_tally = [0] * (len(list_of_lists[0]) + 1) 21 | 22 | for label_list in list_of_lists: 23 | s = sum(label_list) 24 | results_tally[s] += 1 25 | 26 | total_num_instances = sum(results_tally) 27 | 28 | print() 29 | for i in range(len(results_tally)): 30 | print("Num instances with a totaled label of " + str(i) + " across all models: " + 31 | str(results_tally[i]) + " (" + str(100 * results_tally[i] / total_num_instances) + "%)") 32 | print() 33 | -------------------------------------------------------------------------------- /misc_scripts/make_attnlabel1_dist_hist.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | from math import fabs 6 | matplotlib.use('Agg') 7 | 8 | from tqdm import tqdm 9 | from math import ceil 10 | 11 | filepath = '/homes/gws/sofias6/models/imdb-hanrnn-attnperf/imdb_test_reallabel_guessedlabel.csv' 12 | output_filename = '/homes/gws/sofias6/attn-tests/imgs/data_descriptions/bad_attn_performance_dist/imdb-hanrnn-test-guessedperformance' 13 | 14 | if filepath.endswith('.csv'): 15 | is_csv = True 16 | else: 17 | is_csv = False 18 | 19 | list_of_lens = [] 20 | counter = 0 21 | with open(filepath, 'r') as f: 22 | if is_csv: 23 | f.readline() 24 | for line in tqdm(f): 25 | line = line.strip() 26 | if is_csv: 27 | line = line[line.index(',') + 1:] 28 | if line == '1': 29 | list_of_lens.append(counter) 30 | counter += 1 31 | 32 | print("Collected " + str(len(list_of_lens)) + " indices corresponding to 1s") 33 | 34 | fig = plt.figure() 35 | 36 | plt.title(output_filename[output_filename.rfind('/') + 1:] + ' dist of 1 labels') 37 | 38 | bin_width = ceil(counter / 50) 39 | plt.hist(list_of_lens, bins=[i * bin_width for i in range(ceil(counter / bin_width) + 1)]) 40 | plt.savefig(output_filename) 41 | 42 | plt.close(fig) 43 | 44 | print("Successfully made figure.") 45 | print("Saved figure to " + output_filename) 46 | -------------------------------------------------------------------------------- /misc_scripts/make_rand_subset_data_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from random import shuffle 3 | from random import random 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser( 8 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument("--filename", type=str, required=True, 10 | help="Local filename of data") 11 | parser.add_argument("--frac-to-take", type=float, required=True, 12 | help="How much of the data to store in the new filename") 13 | parser.add_argument("--new-filename", type=str, required=True, 14 | help="New filename for the data") 15 | 16 | parser.add_argument("--data-dir", required=False, 17 | default="/homes/gws/sofias6/data/", 18 | help="Base data dir") 19 | args = parser.parse_args() 20 | 21 | instances = [] 22 | with open(args.data_dir + args.filename, 'r') as old_f: 23 | first_line = old_f.readline() 24 | for line in old_f: 25 | if line.strip() != '': 26 | instances.append(line) 27 | shuffle(instances) 28 | 29 | took = 0 30 | with open(args.data_dir + args.new_filename, 'w') as f: 31 | f.write(first_line) 32 | for instance in instances: 33 | decider = random() 34 | if decider < args.frac_to_take: 35 | f.write(instance) 36 | took += 1 37 | 38 | print("Wrote " + str(took) + " / " + str(len(instances)) + " instances to file " + 39 | str(args.data_dir + args.new_filename)) 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /misc_scripts/make_subset_test_file.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | test_data_file = sys.argv[1] 4 | take_first_n = int(sys.argv[2]) 5 | 6 | 7 | list_to_take = [] 8 | counter = 0 9 | with open(test_data_file, 'r') as f: 10 | for line in f: 11 | if counter == 0: 12 | first_line = line 13 | elif counter <= take_first_n: 14 | list_to_take.append(line) 15 | else: 16 | break 17 | counter += 1 18 | 19 | 20 | with open(test_data_file[:test_data_file.rfind('.')] + '_first' + str(take_first_n) + '.tsv', 'w') as f: 21 | f.write(first_line) 22 | for line in list_to_take: 23 | f.write(line) 24 | -------------------------------------------------------------------------------- /misc_scripts/move_files_out_pre_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | from glob import glob 5 | 6 | dir = sys.argv[1] 7 | if not dir.endswith('/'): 8 | dir += '/' 9 | new_dir = dir + 'old_test_results/' 10 | os.makedirs(new_dir) 11 | 12 | fnames_to_move = [ 13 | 'dec_flip_stats.csv', 14 | 'first_vs_second.csv', 15 | 'grad_based_stats.csv', 16 | 'rand_nontop_decflipjs.csv', 17 | 'rand_sample_stats.csv' 18 | ] 19 | 20 | for fname in fnames_to_move: 21 | shutil.move(dir + fname, new_dir + fname) 22 | 23 | paths_starting_with__sentence = list(glob(dir + '_sentence_attention*')) 24 | if len(paths_starting_with__sentence) > 0: 25 | inner_dir = dir + '_sentence_attention_corresponding_vects/' 26 | new_inner_dir = new_dir + '_sentence_attention_corresponding_vects/' 27 | else: 28 | inner_dir = dir + '_word_attention_corresponding_vects/' 29 | new_inner_dir = new_dir + '_word_attention_corresponding_vects/' 30 | 31 | os.makedirs(new_inner_dir) 32 | 33 | inner_fnames_to_move = [ 34 | 'gradients/', 35 | 'next_available_counter.txt' 36 | ] 37 | 38 | for fname in inner_fnames_to_move: 39 | shutil.move(inner_dir + fname, new_inner_dir + fname) 40 | 41 | print("Done moving files into " + new_dir) 42 | -------------------------------------------------------------------------------- /misc_scripts/prob_predictor.py: -------------------------------------------------------------------------------- 1 | probs_of_uninterp = [0.0871252488287455, 0.17563790642028432, 0.653690708382764, 0.790742565298695] # yahoo 2 | total = 49733 # yahoo 3 | probs_of_uninterp = [0.45549544208107906, 0.5309419699103238, 0.6451493366931002, 0.5466538204995183] # imdb 4 | total = 13493 # imdb 5 | probs_of_uninterp = [0.24615796937231565, 0.35901867922351965, 0.7142445193814755, 0.5339150376909142] # amazon 6 | total = 589532 # amazon 7 | probs_of_uninterp = [0.31526378871669786, 0.47784763546901615, 0.7055743633955044, 0.6329667556826545] # yelp 8 | total = 47557 # yelp 9 | 10 | 11 | def get_predicted_probabilities(p1, p2, p3, p4): 12 | prob_all_4 = p1 * p2 * p3 * p4 13 | prob_exactly_3 = ((1 - p1) * p2 * p3 * p4) + (p1 * (1 - p2) * p3 * p4) + (p1 * p2 * (1 - p3) * p4) + \ 14 | (p1 * p2 * p3 * (1 - p4)) 15 | list_of_probs = [p1, p2, p3, p4] 16 | prob_exactly_2 = 0 17 | for i in range(4): 18 | for j in range(4): 19 | if j <= i: 20 | continue 21 | other_inds = {0:0, 1:1, 2:2, 3:3} 22 | del other_inds[i] 23 | del other_inds[j] 24 | other_inds = list(other_inds.keys()) 25 | prob_exactly_2 += (list_of_probs[i] * list_of_probs[j] * (1 - list_of_probs[other_inds[0]]) * 26 | (1 - list_of_probs[other_inds[1]])) 27 | prob_exactly_1 = (p1 * (1 - p2) * (1 - p3) * (1 - p4)) + ((1 - p1) * p2 * (1 - p3) * (1 - p4)) + \ 28 | ((1 - p1) * (1 - p2) * p3 * (1 - p4)) + ((1 - p1) * (1 - p2) * (1 - p3) * p4) 29 | prob_exactly_0 = ((1 - p1) * (1 - p2) * (1 - p3) * (1 - p4)) 30 | return prob_all_4, prob_exactly_3, prob_exactly_2, prob_exactly_1, prob_exactly_0 31 | 32 | 33 | prob_all_4, prob_exactly_3, prob_exactly_2, prob_exactly_1, prob_exactly_0 =\ 34 | get_predicted_probabilities(probs_of_uninterp[0], probs_of_uninterp[1], probs_of_uninterp[2], probs_of_uninterp[3]) 35 | 36 | print("\tProbs sum to " + str(prob_all_4 + prob_exactly_0 + prob_exactly_1 + prob_exactly_2 + prob_exactly_3)) 37 | print() 38 | 39 | print("Prob all 4: " + str(prob_all_4)) 40 | print("Prob exactly 3: " + str(prob_exactly_3)) 41 | print("Prob exactly 2: " + str(prob_exactly_2)) 42 | print("Prob exactly 1: " + str(prob_exactly_1)) 43 | print("Prob never uninterpretable: " + str(prob_exactly_0)) 44 | 45 | print() 46 | 47 | print("Num expected 4: " + str(prob_all_4 * total)) 48 | print("Num expected 3: " + str(prob_exactly_3 * total)) 49 | print("Num expected 2: " + str(prob_exactly_2 * total)) 50 | print("Num expected 1: " + str(prob_exactly_1 * total)) 51 | print("Num expected 0: " + str(prob_exactly_0 * total)) 52 | -------------------------------------------------------------------------------- /textcat/__init__.py: -------------------------------------------------------------------------------- 1 | from .textcat_reader import TextCatReader 2 | from .hierarchical_attention_network import HierarchicalAttentionNetwork 3 | from .sentence_tokenizer import SentenceTokenizer 4 | -------------------------------------------------------------------------------- /textcat/sentence_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from overrides import overrides 3 | from allennlp.common import Registrable 4 | from allennlp.common.util import get_spacy_model 5 | 6 | 7 | class SentenceSplitter(Registrable): 8 | """ 9 | A ``SentenceSplitter`` splits strings into sentences. 10 | """ 11 | default_implementation = 'spacy' 12 | 13 | def split_sentences(self, text: str) -> List[str]: 14 | """ 15 | Splits ``texts`` into a list of :class:`Token` objects. 16 | """ 17 | raise NotImplementedError 18 | 19 | def batch_split_sentences(self, texts: List[str]) -> List[List[str]]: 20 | """ 21 | This method lets you take advantage of spacy's batch processing. 22 | Default implementation is to just iterate over the texts and call ``split_sentences``. 23 | """ 24 | return [self.split_sentences(text) for text in texts] 25 | 26 | 27 | @SentenceSplitter.register('spacy') 28 | class SpacySentenceSplitter(SentenceSplitter): 29 | """ 30 | A ``SentenceSplitter`` that uses spaCy's built-in sentence boundary detection. 31 | Spacy's default sentence splitter uses a dependency parse to detect sentence boundaries, so 32 | it is slow, but accurate. 33 | Another option is to use rule-based sentence boundary detection. It's fast and has a small memory footprint, 34 | since it uses punctuation to detect sentence boundaries. This can be activated with the `rule_based` flag. 35 | By default, ``SpacySentenceSplitter`` calls the default spacy boundary detector. 36 | """ 37 | def __init__(self, 38 | language: str = 'en_core_web_sm', 39 | rule_based: bool = False) -> None: 40 | # we need spacy's dependency parser if we're not using rule-based sentence boundary detection. 41 | self.spacy = get_spacy_model(language, parse=not rule_based, ner=False, pos_tags=False) 42 | if rule_based: 43 | # we use `sbd`, a built-in spacy module for rule-based sentence boundary detection. 44 | if not self.spacy.has_pipe('sbd'): 45 | sbd = self.spacy.create_pipe('sbd') 46 | self.spacy.add_pipe(sbd) 47 | 48 | @overrides 49 | def split_sentences(self, text: str) -> List[str]: 50 | return [sent.string.strip() for sent in self.spacy(text).sents] 51 | 52 | @overrides 53 | def batch_split_sentences(self, texts: List[str]) -> List[List[str]]: 54 | return [[sentence.string.strip() for sentence in doc.sents] for doc in self.spacy.pipe(texts)] 55 | -------------------------------------------------------------------------------- /textcat/sentence_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from overrides import overrides 4 | 5 | from allennlp.data.tokenizers.token import Token 6 | from allennlp.data.tokenizers.tokenizer import Tokenizer 7 | from .sentence_splitter import SpacySentenceSplitter 8 | 9 | 10 | @Tokenizer.register("sentence") 11 | class SentenceTokenizer(Tokenizer): 12 | """ 13 | A ``WordTokenizer`` handles the splitting of strings into words as well as any desired 14 | post-processing (e.g., stemming, filtering, etc.). Note that we leave one particular piece of 15 | post-processing for later: the decision of whether or not to lowercase the token. This is for 16 | two reasons: (1) if you want to make two different casing decisions for whatever reason, you 17 | won't have to run the tokenizer twice, and more importantly (2) if you want to lowercase words 18 | for your word embedding, but retain capitalization in a character-level representation, we need 19 | to retain the capitalization here. 20 | 21 | Parameters 22 | ---------- 23 | word_splitter : ``WordSplitter``, optional 24 | The :class:`WordSplitter` to use for splitting text strings into word tokens. The default 25 | is to use the ``SpacyWordSplitter`` with default parameters. 26 | word_filter : ``WordFilter``, optional 27 | The :class:`WordFilter` to use for, e.g., removing stopwords. Default is to do no 28 | filtering. 29 | word_stemmer : ``WordStemmer``, optional 30 | The :class:`WordStemmer` to use. Default is no stemming. 31 | start_tokens : ``List[str]``, optional 32 | If given, these tokens will be added to the beginning of every string we tokenize. 33 | end_tokens : ``List[str]``, optional 34 | If given, these tokens will be added to the end of every string we tokenize. 35 | """ 36 | def __init__(self) -> None: 37 | self._sentence_splitter = SpacySentenceSplitter() 38 | 39 | @overrides 40 | def tokenize(self, text: str) -> List[Token]: 41 | """ 42 | Does whatever processing is required to convert a string of text into a sequence of tokens. 43 | 44 | At a minimum, this uses a ``WordSplitter`` to split words into text. It may also do 45 | stemming or stopword removal, depending on the parameters given to the constructor. 46 | """ 47 | sents = self._sentence_splitter.split_sentences(text) 48 | return sents 49 | 50 | @overrides 51 | def batch_tokenize(self, texts: List[str]) -> List[List[Token]]: 52 | batched_sents = self._sentence_splitter.batch_split_sentences(texts) 53 | return batched_sents 54 | -------------------------------------------------------------------------------- /textcat/textcat_reader.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import logging 3 | import numpy as np 4 | import re 5 | from overrides import overrides 6 | 7 | from allennlp.common.file_utils import cached_path 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.data.fields import LabelField, TextField, Field, ListField 10 | from allennlp.data.instance import Instance 11 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 12 | from allennlp.data.tokenizers import Token 13 | from allennlp.common.checks import ConfigurationError 14 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer 15 | from .sentence_tokenizer import SentenceTokenizer 16 | from allennlp.data.tokenizers.word_filter import StopwordFilter, PassThroughWordFilter 17 | 18 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 19 | 20 | 21 | @DatasetReader.register("textcat") 22 | class TextCatReader(DatasetReader): 23 | """ 24 | Reads tokens and their topic labels. 25 | 26 | Assumes that data in file_path provided to _read is tab-separated, containing (at least) the two 27 | fields 'tokens' and 'category', in no particular order, with each document/label on one line. 28 | (So this means that documents must not contain either newlines or tabs.) 29 | 30 | Example: 31 | 32 | category tokens 33 | sample_label_1 This is a document. It contains a couple of sentences. 34 | sample_label_1 This is another document. It also contains two sentences. 35 | sample_label_2 This document has a different label. 36 | 37 | and so on. 38 | 39 | The output of ``read`` is a list of ``Instance`` s with the fields: 40 | tokens: ``TextField`` and 41 | label: ``LabelField`` 42 | 43 | Parameters 44 | ---------- 45 | token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) 46 | We use this to define the input representation for the text. See :class:`TokenIndexer`. 47 | lazy : ``bool``, optional, (default = ``False``) 48 | Whether or not instances can be read lazily. 49 | """ 50 | def __init__(self, 51 | token_indexers: Dict[str, TokenIndexer] = None, 52 | word_tokenizer: Tokenizer = None, 53 | segment_sentences: bool = False, 54 | lazy: bool = False, 55 | column_titles_to_index: List[str] = ("tokens", )) -> None: 56 | super().__init__(lazy=lazy) 57 | self._word_tokenizer = word_tokenizer or WordTokenizer(word_filter=PassThroughWordFilter()) 58 | self._segment_sentences = segment_sentences 59 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 60 | self._column_titles_to_index = column_titles_to_index 61 | assert len(self._column_titles_to_index) > 0 62 | if self._segment_sentences: 63 | self._sentence_segmenter = SentenceTokenizer() 64 | 65 | 66 | @overrides 67 | def _read(self, file_path): 68 | with open(cached_path(file_path), "r") as data_file: 69 | logger.info("Reading instances from lines in file at: %s", file_path) 70 | columns = data_file.readline().strip('\n').split('\t') 71 | token_col_inds = [columns.index(self._column_titles_to_index[field_ind]) 72 | for field_ind in range(len(self._column_titles_to_index))] 73 | for line in data_file.readlines(): 74 | if not line: 75 | continue 76 | items = line.strip("\n").split("\t") 77 | tokens = '' 78 | for col_ind in token_col_inds: 79 | tokens += items[col_ind] + ' ' 80 | tokens = tokens[:-1] 81 | tokens = items[columns.index("tokens")] 82 | if len(tokens.strip()) == 0: 83 | continue 84 | category = items[columns.index("category")] 85 | instance = self.text_to_instance(tokens=tokens, 86 | category=category) 87 | if instance is not None: 88 | yield instance 89 | 90 | 91 | @overrides 92 | def text_to_instance(self, tokens: List[str], category: str = None) -> Instance: # type: ignore 93 | """ 94 | We take `pre-tokenized` input here, because we don't have a tokenizer in this class. 95 | 96 | Parameters 97 | ---------- 98 | tokens : ``List[str]``, required. 99 | The tokens in a given sentence. 100 | category ``str``, optional, (default = None). 101 | The category for this sentence. 102 | 103 | Returns 104 | ------- 105 | An ``Instance`` containing the following fields: 106 | tokens : ``TextField`` 107 | The tokens in the sentence or phrase. 108 | label : ``LabelField`` 109 | The category label of the sentence or phrase. 110 | """ 111 | # pylint: disable=arguments-differ 112 | fields: Dict[str, Field] = {} 113 | text_fields = [] 114 | if self._segment_sentences: 115 | sentence_tokens = self._sentence_segmenter.tokenize(tokens) 116 | original_len_sentence_tokens = len(sentence_tokens) 117 | corresponding_sentence_ind = 0 118 | for i in range(original_len_sentence_tokens): 119 | sentence = sentence_tokens[corresponding_sentence_ind] 120 | word_tokens = self._word_tokenizer.tokenize(sentence) 121 | if len(word_tokens) == 0: 122 | del sentence_tokens[corresponding_sentence_ind] 123 | continue 124 | text_fields.append(TextField(word_tokens, self._token_indexers)) 125 | corresponding_sentence_ind += 1 126 | if len(text_fields) == 0: 127 | return None 128 | fields['tokens'] = ListField(text_fields) 129 | else: 130 | fields['tokens'] = TextField(self._word_tokenizer.tokenize(tokens), 131 | self._token_indexers) 132 | fields['label'] = LabelField(category) 133 | return Instance(fields) 134 | -------------------------------------------------------------------------------- /textcat/textcat_reader_attnlabel.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import logging 3 | import numpy as np 4 | import re 5 | from overrides import overrides 6 | 7 | from allennlp.common.file_utils import cached_path 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.data.fields import LabelField, TextField, Field, ListField 10 | from allennlp.data.instance import Instance 11 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 12 | from allennlp.data.tokenizers import Token 13 | from allennlp.common.checks import ConfigurationError 14 | from allennlp.data.tokenizers import Tokenizer, WordTokenizer 15 | from .sentence_tokenizer import SentenceTokenizer 16 | from allennlp.data.tokenizers.word_filter import StopwordFilter, PassThroughWordFilter 17 | 18 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 19 | 20 | 21 | @DatasetReader.register("textcat-attnlabel") 22 | class TextCatAttnReader(DatasetReader): 23 | """ 24 | Reads tokens and their topic labels. 25 | 26 | Assumes that data in file_path provided to _read is tab-separated, containing (at least) the two 27 | fields 'tokens' and 'category', in no particular order, with each document/label on one line. 28 | (So this means that documents must not contain either newlines or tabs.) 29 | 30 | Example: 31 | 32 | category tokens 33 | sample_label_1 This is a document. It contains a couple of sentences. 34 | sample_label_1 This is another document. It also contains two sentences. 35 | sample_label_2 This document has a different label. 36 | 37 | and so on. 38 | 39 | The output of ``read`` is a list of ``Instance`` s with the fields: 40 | tokens: ``TextField`` and 41 | label: ``LabelField`` 42 | 43 | Parameters 44 | ---------- 45 | token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) 46 | We use this to define the input representation for the text. See :class:`TokenIndexer`. 47 | lazy : ``bool``, optional, (default = ``False``) 48 | Whether or not instances can be read lazily. 49 | """ 50 | 51 | def __init__(self, 52 | model_folder_name: str, 53 | token_indexers: Dict[str, TokenIndexer] = None, 54 | word_tokenizer: Tokenizer = None, 55 | segment_sentences: bool = False, 56 | lazy: bool = False, 57 | column_titles_to_index: List[str] = ("tokens",)) -> None: 58 | super().__init__(lazy=lazy) 59 | if model_folder_name.endswith('/'): 60 | model_folder_name = model_folder_name[:-1] 61 | if '/' in model_folder_name: 62 | model_folder_name = model_folder_name[model_folder_name.rfind('/') + 1:] 63 | self.model_folder_name = model_folder_name 64 | self._word_tokenizer = word_tokenizer or WordTokenizer(word_filter=PassThroughWordFilter()) 65 | self._segment_sentences = segment_sentences 66 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 67 | self._column_titles_to_index = column_titles_to_index 68 | assert len(self._column_titles_to_index) > 0 69 | if self._segment_sentences: 70 | self._sentence_segmenter = SentenceTokenizer() 71 | 72 | @overrides 73 | def _read(self, file_path): 74 | label_file_path = file_path[:file_path.rfind('.')] + "_attnperformancelabels_" + self.model_folder_name + ".txt" 75 | label_file = open(label_file_path, 'r') 76 | with open(cached_path(file_path), "r") as data_file: 77 | logger.info("Reading instances from lines in file at: %s", file_path) 78 | columns = data_file.readline().strip('\n').split('\t') 79 | token_col_inds = [columns.index(self._column_titles_to_index[field_ind]) 80 | for field_ind in range(len(self._column_titles_to_index))] 81 | for line in data_file.readlines(): 82 | if not line: 83 | continue 84 | items = line.strip("\n").split("\t") 85 | tokens = '' 86 | for col_ind in token_col_inds: 87 | tokens += items[col_ind] + ' ' 88 | tokens = tokens[:-1] 89 | tokens = items[columns.index("tokens")] 90 | if len(tokens.strip()) == 0: 91 | continue 92 | instance = self.text_to_instance(tokens=tokens) 93 | if instance is not None: 94 | str_category = label_file.readline().strip() 95 | assert str_category != '' 96 | instance.fields['label'] = LabelField(str_category) 97 | yield instance 98 | 99 | try: 100 | next_label = label_file.readline() 101 | if isinstance(next_label, str) and len(next_label.strip()) >= 1: 102 | assert not next_label[0].isdigit(), \ 103 | "We had too many labels corresponding to the given data file " + file_path 104 | except: 105 | pass 106 | label_file.close() 107 | 108 | @overrides 109 | def text_to_instance(self, tokens: List[str]) -> Instance: # type: ignore 110 | """ 111 | We take `pre-tokenized` input here, because we don't have a tokenizer in this class. 112 | 113 | Parameters 114 | ---------- 115 | tokens : ``List[str]``, required. 116 | The tokens in a given sentence. 117 | category ``str``, optional, (default = None). 118 | The category for this sentence. 119 | 120 | Returns 121 | ------- 122 | An ``Instance`` containing the following fields: 123 | tokens : ``TextField`` 124 | The tokens in the sentence or phrase. 125 | label : ``LabelField`` 126 | The category label of the sentence or phrase. 127 | """ 128 | # pylint: disable=arguments-differ 129 | fields: Dict[str, Field] = {} 130 | text_fields = [] 131 | if self._segment_sentences: 132 | sentence_tokens = self._sentence_segmenter.tokenize(tokens) 133 | original_len_sentence_tokens = len(sentence_tokens) 134 | corresponding_sentence_ind = 0 135 | for i in range(original_len_sentence_tokens): 136 | sentence = sentence_tokens[corresponding_sentence_ind] 137 | word_tokens = self._word_tokenizer.tokenize(sentence) 138 | if len(word_tokens) == 0: 139 | del sentence_tokens[corresponding_sentence_ind] 140 | continue 141 | text_fields.append(TextField(word_tokens, self._token_indexers)) 142 | corresponding_sentence_ind += 1 143 | if len(text_fields) == 0: 144 | return None 145 | fields['tokens'] = ListField(text_fields) 146 | else: 147 | list_of_word_tokens = self._word_tokenizer.tokenize(tokens) 148 | if len(list_of_word_tokens) == 0: 149 | return None 150 | fields['tokens'] = TextField(list_of_word_tokens, 151 | self._token_indexers) 152 | return Instance(fields) 153 | --------------------------------------------------------------------------------