├── .gitignore ├── all_cat.txt ├── allennlp_glue_patch ├── __init__.py ├── bag_of_word_corrected.py ├── basic_regressor.py ├── dataset_readers │ ├── BinarySentiment.py │ ├── SimilarityRegression.py │ └── __init__.py ├── notes.md ├── pretrained_transformer_tokenizer_corrected.py ├── roberta_pooler.py ├── stsb_regressor.py └── stsb_text_field_embedder.py ├── config ├── experiment │ ├── bert-base-pool[imdb]-40.jsonnet │ ├── bert-base-pool[sst-2]-20.jsonnet │ ├── bert-base-pool[sst-2]-23.jsonnet │ ├── bert-base-pool[sst-2]-24.jsonnet │ ├── bert-base-pool[sst-2]-26.jsonnet │ ├── bert-base-pool[sst-2]-28.jsonnet │ ├── bert-base-pool[sst-2]-29.jsonnet │ ├── bert-large-pool[imdb]-39.jsonnet │ ├── bert-large-pool[sst-2]-21.jsonnet │ ├── bert-large-pool[sst-2]-25.jsonnet │ ├── bert-large-pool[sst-2]-27.jsonnet │ ├── bert-large-pool[sst-2]-30.jsonnet │ ├── bow-sum[imdb]-31.jsonnet │ ├── bow-sum[sst-2]-1.jsonnet │ ├── glove-cnn[imdb]-37.jsonnet │ ├── glove-cnn[sst-2]-17.jsonnet │ ├── glove-cnn[sst-2]-18.jsonnet │ ├── glove-cnn[sst-2]-19.jsonnet │ ├── glove-lstm[imdb]-36.jsonnet │ ├── glove-lstm[sst-2]-15.jsonnet │ ├── glove-lstm[sst-2]-16.jsonnet │ ├── glove-sum[imdb]-35.jsonnet │ ├── glove-sum[sst-2]-10.jsonnet │ ├── glove-sum[sst-2]-11.jsonnet │ ├── glove-sum[sst-2]-12.jsonnet │ ├── glove-sum[sst-2]-13.jsonnet │ ├── glove-sum[sst-2]-14.jsonnet │ ├── roberta-large-pool[imdb]-38.jsonnet │ ├── roberta-large-pool[sst-2]-22.jsonnet │ ├── word2vec-cnn[imdb]-34.jsonnet │ ├── word2vec-cnn[sst-2]-7.jsonnet │ ├── word2vec-cnn[sst-2]-8.jsonnet │ ├── word2vec-cnn[sst-2]-9.jsonnet │ ├── word2vec-lstm[imdb]-33.jsonnet │ ├── word2vec-lstm[sst-2]-6.jsonnet │ ├── word2vec-sum[imdb]-32.jsonnet │ ├── word2vec-sum[sst-2]-2.jsonnet │ ├── word2vec-sum[sst-2]-3.jsonnet │ ├── word2vec-sum[sst-2]-4.jsonnet │ └── word2vec-sum[sst-2]-5.jsonnet ├── experiment2 │ ├── bert-base-pool[msrpar]-57.jsonnet │ ├── bow-sum[msrpar]-41.jsonnet │ ├── glove-cnn[msrpar]-54.jsonnet │ ├── glove-cnn[msrpar]-55.jsonnet │ ├── glove-cnn[msrpar]-56.jsonnet │ ├── glove-lstm[msrpar]-49.jsonnet │ ├── glove-lstm[msrpar]-50.jsonnet │ ├── glove-lstm[msrpar]-51.jsonnet │ ├── glove-lstm[msrpar]-52.jsonnet │ ├── glove-lstm[msrpar]-53.jsonnet │ ├── glove-sum[msrpar]-48.jsonnet │ ├── word2vec-cnn[msrpar]-47.jsonnet │ ├── word2vec-lstm[msrpar]-43.jsonnet │ ├── word2vec-lstm[msrpar]-44.jsonnet │ ├── word2vec-lstm[msrpar]-45.jsonnet │ ├── word2vec-lstm[msrpar]-46.jsonnet │ └── word2vec-sum[msrpar]-42.jsonnet └── test │ ├── sst-cnn-bow.json │ ├── sst-cnn.json │ ├── sst-lstm.json │ ├── sst-roberta.jsonnet │ ├── sst-sum.json │ ├── stsball-bert.jsonnet │ ├── stsball-bow.json │ ├── stsball-glove-lstm.json │ └── stsball-w2v.json ├── readme.md ├── result.txt ├── results.txt ├── scripts ├── 20191006.py ├── generate_config.py ├── run_glue_tasks.sh ├── run_record.sh ├── run_squad.sh └── temp_run.sh ├── src ├── amazon_preprocess │ ├── fetch_result.py │ ├── preprocess.py │ ├── sample.py │ └── sample2.py ├── data_loader.py ├── deprecated_run.py ├── mnli_preprocess │ └── splitMNLI.py ├── record2squad │ ├── .ipynb_checkpoints │ │ └── ReCoRD2SQuAD-checkpoint.ipynb │ └── ReCoRD2SQuAD.ipynb ├── run.py ├── run_glue.py ├── run_record(deprecated).py ├── run_squad.py ├── sts_preprocess │ ├── resplit.py │ └── resplit2.py ├── utils_glue.py ├── utils_record(deprecated).py ├── utils_squad.py └── utils_squad_evaluate.py ├── statistic.txt └── sts.png /.gitignore: -------------------------------------------------------------------------------- 1 | log.txt 2 | model/* 3 | runs/* 4 | data/* 5 | results/* 6 | amazon_split/* 7 | .vscode 8 | focus.json 9 | focus.pickle 10 | *.pyc -------------------------------------------------------------------------------- /allennlp_glue_patch/__init__.py: -------------------------------------------------------------------------------- 1 | import allennlp_glue_patch.dataset_readers -------------------------------------------------------------------------------- /allennlp_glue_patch/bag_of_word_corrected.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 3 | from allennlp.data import Vocabulary 4 | from allennlp.common import Params 5 | from allennlp.nn.util import get_text_field_mask 6 | from allennlp.common.checks import ConfigurationError 7 | 8 | 9 | @TokenEmbedder.register("bag_of_word_counts_corrected") 10 | class CorrectedBagOfWordCountsTokenEmbedder(TokenEmbedder): 11 | """ 12 | Represents a sequence of tokens as a bag of (discrete) word ids, as it was done 13 | in the pre-neural days. 14 | 15 | Each sequence gets a vector of length vocabulary size, where the i'th entry in the vector 16 | corresponds to number of times the i'th token in the vocabulary appears in the sequence. 17 | 18 | By default, we ignore padding tokens. 19 | 20 | Parameters 21 | ---------- 22 | vocab: ``Vocabulary`` 23 | vocab_namespace: ``str`` 24 | namespace of vocabulary to embed 25 | projection_dim : ``int``, optional (default = ``None``) 26 | if specified, will project the resulting bag of words representation 27 | to specified dimension. 28 | ignore_oov : ``bool``, optional (default = ``False``) 29 | If true, we ignore the OOV token. 30 | """ 31 | def __init__(self, 32 | vocab: Vocabulary, 33 | vocab_namespace: str, 34 | projection_dim: int = None, 35 | ignore_oov: bool = False) -> None: 36 | super().__init__() 37 | self.vocab = vocab 38 | self.vocab_size = vocab.get_vocab_size(vocab_namespace) 39 | if projection_dim: 40 | self._projection = torch.nn.Linear(self.vocab_size, projection_dim) 41 | else: 42 | self._projection = None 43 | self._ignore_oov = ignore_oov 44 | oov_token = vocab._oov_token # pylint: disable=protected-access 45 | self._oov_idx = vocab.get_token_to_index_vocabulary(vocab_namespace).get(oov_token) 46 | if self._oov_idx is None: 47 | raise ConfigurationError("OOV token does not exist in vocabulary namespace {}".format(vocab_namespace)) 48 | self.output_dim = projection_dim or self.vocab_size 49 | 50 | def get_output_dim(self): 51 | return self.output_dim 52 | 53 | def forward(self, # pylint: disable=arguments-differ 54 | inputs: torch.Tensor) -> torch.Tensor: 55 | """ 56 | Parameters 57 | ---------- 58 | inputs: ``torch.Tensor`` 59 | Shape ``(batch_size, timesteps, sequence_length)`` of word ids 60 | representing the current batch. 61 | 62 | Returns 63 | ------- 64 | The bag-of-words representations for the input sequence, shape 65 | ``(batch_size, vocab_size)`` 66 | """ 67 | bag_of_words_vectors = [] 68 | 69 | mask = get_text_field_mask({'tokens': inputs}) 70 | if self._ignore_oov: 71 | # also mask out positions corresponding to oov 72 | mask *= (inputs != self._oov_idx).long() 73 | for document, doc_mask in zip(inputs, mask): 74 | document = torch.masked_select(document, doc_mask.to(dtype=torch.bool)) 75 | vec = torch.bincount(document, minlength=self.vocab_size).float() 76 | vec = vec.view(1, -1) 77 | bag_of_words_vectors.append(vec) 78 | bag_of_words_output = torch.cat(bag_of_words_vectors, 0) 79 | 80 | if self._projection: 81 | projection = self._projection 82 | bag_of_words_output = projection(bag_of_words_output) 83 | 84 | ori_shape = bag_of_words_output.shape 85 | return bag_of_words_output.reshape((ori_shape[0], 1, ori_shape[1])) 86 | 87 | @classmethod 88 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'BagOfWordCountsTokenEmbedder': # type: ignore 89 | # pylint: disable=arguments-differ 90 | """ 91 | we look for a ``vocab_namespace`` key in the parameter dictionary 92 | to know which vocabulary to use. 93 | """ 94 | 95 | vocab_namespace = params.pop("vocab_namespace", "tokens") 96 | projection_dim = params.pop_int("projection_dim", None) 97 | ignore_oov = params.pop_bool("ignore_oov", False) 98 | params.assert_empty(cls.__name__) 99 | return cls(vocab=vocab, 100 | vocab_namespace=vocab_namespace, 101 | ignore_oov=ignore_oov, 102 | projection_dim=projection_dim) 103 | -------------------------------------------------------------------------------- /allennlp_glue_patch/basic_regressor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from overrides import overrides 4 | import torch 5 | 6 | from allennlp.data import Vocabulary 7 | from allennlp.models.model import Model 8 | from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder 9 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 10 | from allennlp.nn.util import get_text_field_mask 11 | from allennlp.training.metrics import PearsonCorrelation 12 | 13 | # TODO: CHECK THE COMPATIBILITY WITH BERT 14 | @Model.register("basic_regressor") 15 | class BasicRegressor(Model): 16 | """ 17 | This ``Model`` implements a basic text regressor. After embedding the text into 18 | a text field, we will optionally encode the embeddings with a ``Seq2SeqEncoder``. The 19 | resulting sequence is pooled using a ``Seq2VecEncoder`` and then passed to 20 | a linear regression layer, which projects into a single value. If a 21 | ``Seq2SeqEncoder`` is not provided, we will pass the embedded text directly to the 22 | ``Seq2VecEncoder``. 23 | 24 | Parameters 25 | ---------- 26 | vocab : ``Vocabulary`` 27 | text_field_embedder : ``TextFieldEmbedder`` 28 | Used to embed the input text into a ``TextField`` 29 | seq2seq_encoder : ``Seq2SeqEncoder``, optional (default=``None``) 30 | Optional Seq2Seq encoder layer for the input text. 31 | seq2vec_encoder : ``Seq2VecEncoder`` 32 | Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder 33 | will pool its output. Otherwise, this encoder will operate directly on the output 34 | of the `text_field_embedder`. 35 | dropout : ``float``, optional (default = ``None``) 36 | Dropout percentage to use. 37 | scale : ``float``, optional (default = 1) 38 | Scale regression result is between 0 ~ scale 39 | label_namespace: ``str``, optional (default = "labels") 40 | Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. 41 | initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) 42 | If provided, will be used to initialize the model parameters. 43 | regularizer : ``RegularizerApplicator``, optional (default=``None``) 44 | If provided, will be used to calculate the regularization penalty during training. 45 | """ 46 | def __init__(self, 47 | vocab: Vocabulary, 48 | text_field_embedder: TextFieldEmbedder, 49 | seq2vec_encoder: Seq2VecEncoder, 50 | seq2seq_encoder: Seq2SeqEncoder = None, 51 | dropout: float = None, 52 | scale: float = 1, 53 | label_namespace: str = "labels", 54 | initializer: InitializerApplicator = InitializerApplicator(), 55 | regularizer: Optional[RegularizerApplicator] = None) -> None: 56 | 57 | super().__init__(vocab, regularizer) 58 | self._text_field_embedder = text_field_embedder 59 | 60 | if seq2seq_encoder: 61 | self._seq2seq_encoder = seq2seq_encoder 62 | else: 63 | self._seq2seq_encoder = None 64 | 65 | self._seq2vec_encoder = seq2vec_encoder 66 | self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() 67 | 68 | if dropout: 69 | self._dropout = torch.nn.Dropout(dropout) 70 | else: 71 | self._dropout = None 72 | 73 | self._label_namespace = label_namespace 74 | 75 | self._num_labels = 1 # because we're running a regression task 76 | self._scale = scale 77 | 78 | self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) 79 | self._metric = PearsonCorrelation() 80 | self._loss = torch.nn.MSELoss() 81 | initializer(self) 82 | 83 | def forward(self, # type: ignore 84 | tokens: Dict[str, torch.LongTensor], 85 | label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: 86 | # pylint: disable=arguments-differ 87 | """ 88 | Parameters 89 | ---------- 90 | tokens : Dict[str, torch.LongTensor] 91 | From a ``TextField`` 92 | label : torch.IntTensor, optional (default = None) 93 | From a ``LabelField`` 94 | 95 | Returns 96 | ------- 97 | An output dictionary consisting of: 98 | 99 | logits : torch.FloatTensor 100 | A tensor of shape ``(batch_size, 1)`` representing 101 | unnormalized log probabilities of the label. 102 | loss : torch.FloatTensor, optional 103 | A scalar loss to be optimised. 104 | """ 105 | embedded_text = self._text_field_embedder(tokens) 106 | mask = get_text_field_mask(tokens).float() 107 | 108 | if self._seq2seq_encoder: 109 | embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) 110 | 111 | embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) 112 | 113 | if self._dropout: 114 | embedded_text = self._dropout(embedded_text) 115 | 116 | logits = self._classification_layer(embedded_text) 117 | output_dict = {"logits": logits} 118 | 119 | if label is not None: # convert the label into a float number and update the metric 120 | label_to_str = lambda l: self.vocab.get_index_to_token_vocabulary(self._label_namespace).get(l) 121 | label_tensor = torch.tensor([float(label_to_str(int(label[i]))) for i in range(label.shape[0])], device=logits.device) 122 | loss = self._loss(logits.view(-1), label_tensor) 123 | output_dict["loss"] = loss 124 | self._metric(logits, label_tensor) 125 | 126 | return output_dict 127 | 128 | @overrides 129 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 130 | """ 131 | Does a simple argmax over the probabilities, converts index to string label, and 132 | add ``"label"`` key to the dictionary with the result. 133 | """ 134 | # update this part to generate a float number result as similarity score 135 | predictions = output_dict["logits"] 136 | if predictions.dim() == 2: 137 | predictions_list = [predictions[i] for i in range(predictions.shape[0])] 138 | else: 139 | predictions_list = [predictions] 140 | classes = [] 141 | for prediction in predictions_list: 142 | label_idx = "{:.1f}".format(prediction.long()) 143 | label_str = (self.vocab.get_index_to_token_vocabulary(self._label_namespace) 144 | .get(label_idx, str(label_idx))) 145 | classes.append(label_str) 146 | output_dict["label"] = classes 147 | return output_dict 148 | 149 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 150 | metrics = {'PearsonCorrelation': self._metric.get_metric(reset)} 151 | return metrics 152 | 153 | -------------------------------------------------------------------------------- /allennlp_glue_patch/dataset_readers/BinarySentiment.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Iterable 2 | import logging 3 | import csv 4 | import sys 5 | 6 | from overrides import overrides 7 | 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 10 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.fields import TextField, LabelField 13 | from allennlp.common.file_utils import cached_path 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | @DatasetReader.register("binary_sentiment") 18 | class BinarySentimentDatasetReader(DatasetReader): 19 | 20 | @staticmethod 21 | def _read_tsv(input_file, quotechar=None): 22 | """Reads a tab separated value file.""" 23 | with open(input_file, "r", encoding="utf-8-sig") as f: 24 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 25 | lines = [] 26 | for line in reader: 27 | if sys.version_info[0] == 2: 28 | line = list(unicode(cell, 'utf-8') for cell in line) 29 | lines.append(line) 30 | return lines 31 | 32 | def __init__( 33 | self, 34 | tokenizer: Tokenizer = None, 35 | token_indexers: Dict[str, TokenIndexer] = None, 36 | lazy: bool = False, 37 | ) -> None: 38 | super().__init__(lazy) 39 | self._tokenizer = tokenizer or WordTokenizer() 40 | self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()} 41 | 42 | @overrides 43 | def _read(self, file_path: str) -> Iterable[Instance]: 44 | file_path = cached_path(file_path) 45 | logger.info("Reading instances from lines in file at: %s", file_path) 46 | lines = BinarySentimentDatasetReader._read_tsv(file_path) 47 | for i, line in enumerate(lines): 48 | yield self.text_to_instance(i, line[0], line[1]) 49 | 50 | def text_to_instance(self, index: int, text: str, label: str) -> Instance: 51 | tokens = self._tokenizer.tokenize(text) 52 | return Instance({ 53 | "tokens": TextField(tokens, self._token_indexers), 54 | "label": LabelField(label) 55 | }) -------------------------------------------------------------------------------- /allennlp_glue_patch/dataset_readers/SimilarityRegression.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Iterable 2 | import logging 3 | import csv 4 | import sys 5 | 6 | from overrides import overrides 7 | 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 10 | from allennlp.data.tokenizers import Token, Tokenizer, WordTokenizer 11 | from allennlp.data.instance import Instance 12 | from allennlp.data.fields import TextField, LabelField 13 | from allennlp.common.file_utils import cached_path 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | @DatasetReader.register("similarity_regression") 18 | class SimilarityRegressionDatasetReader(DatasetReader): 19 | 20 | @staticmethod 21 | def _read_tsv(input_file, quotechar=None): 22 | """Reads a tab separated value file.""" 23 | with open(input_file, "r", encoding="utf-8-sig") as f: 24 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 25 | lines = [] 26 | for line in reader: 27 | if sys.version_info[0] == 2: 28 | line = list(unicode(cell, 'utf-8') for cell in line) 29 | lines.append(line) 30 | return lines[1:] # ignore the header 31 | 32 | def __init__( 33 | self, 34 | tokenizer: Tokenizer = None, 35 | token_indexers: Dict[str, TokenIndexer] = None, 36 | lazy: bool = False, 37 | ) -> None: 38 | super().__init__(lazy) 39 | self._tokenizer = tokenizer or WordTokenizer() 40 | #siti = SingleIdTokenIndexer() 41 | self._token_indexers = token_indexers or {"tokens_a": SingleIdTokenIndexer(), "tokens_b": SingleIdTokenIndexer()} 42 | 43 | 44 | # TODO: update the following part 45 | @overrides 46 | def _read(self, file_path: str) -> Iterable[Instance]: 47 | file_path = cached_path(file_path) 48 | logger.info("Reading instances from lines in file at: %s", file_path) 49 | lines = SimilarityRegressionDatasetReader._read_tsv(file_path) 50 | for i, line in enumerate(lines): 51 | yield self.text_to_instance(i, line[7], line[8], line[-1]) 52 | 53 | def text_to_instance(self, index: int, text_a: str, text_b: str, label: str) -> Instance: 54 | if "tokens" in self._token_indexers: 55 | return Instance({ # use [1:] to skip the second symbol 56 | "tokens": TextField(self._tokenizer.tokenize(text_a) + self._tokenizer.tokenize(text_b)[1:], self._token_indexers), 57 | "label": LabelField(label) 58 | }) 59 | else: # splitted 60 | return Instance({ 61 | "tokens_a": TextField(self._tokenizer.tokenize(text_a), self._token_indexers), 62 | "tokens_b": TextField(self._tokenizer.tokenize(text_b), self._token_indexers), 63 | "label": LabelField(label) 64 | }) -------------------------------------------------------------------------------- /allennlp_glue_patch/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from allennlp_glue_patch.dataset_readers.BinarySentiment import BinarySentimentDatasetReader -------------------------------------------------------------------------------- /allennlp_glue_patch/notes.md: -------------------------------------------------------------------------------- 1 | # Notes 2 | 3 | ## Troubleshooting 4 | - "allennlp: command not found" 5 | ``` 6 | alias allennlp='python -m allennlp.run' 7 | ``` 8 | - cannot use roberta tokenizer 9 | - bug from pytorch-transformers 10 | - change tokenization_auto.py, line 88 to: 11 | ``` 12 | if 'bert' in pretrained_model_name_or_path and 'roberta' not in pretrained_model_name_or_path: 13 | ``` 14 | - change modeling_auto.py, line 113 to: 15 | ``` 16 | if 'bert' in pretrained_model_name_or_path and 'roberta' not in pretrained_model_name_or_path: 17 | ``` 18 | - change modeling_auto.py, line 228 to: 19 | ``` 20 | if 'bert' in pretrained_model_name_or_path and 'roberta' not in pretrained_model_name_or_path: 21 | ``` 22 | - Stop Iteration in RNN 23 | - Do not set lstm layer num to 0 24 | 25 | ## Commands 26 | - Train a model 27 | ``` 28 | allennlp train -s 29 | # e.g. allennlp train config/test/sst-lstm.json -s model/test-lstm[sst-2] --include-package allennlp_glue_patch 30 | ``` 31 | - Eval a model 32 | ``` 33 | allennlp evaluate 34 | ``` 35 | - Predict a model 36 | ``` 37 | allennlp predict 38 | ``` 39 | 40 | ## convenient commands 41 | ``` 42 | rm -rf model/test-lstm[sst-2] && allennlp train config/test/sst-lstm.json -s model/test-lstm[sst-2] --include-package allennlp_glue_patch 43 | rm -rf model/test-cnn[sst-2] && allennlp train config/test/sst-cnn.json -s model/test-cnn[sst-2] --include-package allennlp_glue_patch 44 | rm -rf model/test-cnn-bow[sst-2] && allennlp train config/test/sst-cnn-bow.json -s model/test-cnn-bow[sst-2] --include-package allennlp_glue_patch 45 | rm -rf model/test-sum[sst-2] && allennlp train config/test/sst-sum.json -s model/test-sum[sst-2] --include-package allennlp_glue_patch 46 | rm -rf model/test-bert[sst-2] && allennlp train config/test/sst-bert.jsonnet -s model/test-bert[sst-2] --include-package allennlp_glue_patch 47 | rm -rf model/test-roberta[sst-2] && allennlp train config/test/sst-roberta.jsonnet -s model/test-roberta[sst-2] --include-package allennlp_glue_patch 48 | 49 | allennlp train config/test/headlines-glove-lstm.json -s model/test-headlines --include-package allennlp_glue_patch -f 50 | allennlp train config/test/stsball-bert.jsonnet -s model/test-stsball --include-package allennlp_glue_patch -f 51 | ``` -------------------------------------------------------------------------------- /allennlp_glue_patch/pretrained_transformer_tokenizer_corrected.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Tuple 3 | 4 | from overrides import overrides 5 | from pytorch_transformers.tokenization_auto import AutoTokenizer 6 | 7 | from allennlp.data.tokenizers.token import Token 8 | from allennlp.data.tokenizers.tokenizer import Tokenizer 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @Tokenizer.register("pretrained_transformer_corrected") 14 | class CorrectedPretrainedTransformerTokenizer(Tokenizer): 15 | """ 16 | A ``PretrainedTransformerTokenizer`` uses a model from HuggingFace's 17 | ``pytorch_transformers`` library to tokenize some input text. This often means wordpieces 18 | (where ``'AllenNLP is awesome'`` might get split into ``['Allen', '##NL', '##P', 'is', 19 | 'awesome']``), but it could also use byte-pair encoding, or some other tokenization, depending 20 | on the pretrained model that you're using. 21 | 22 | We take a model name as an input parameter, which we will pass to 23 | ``AutoTokenizer.from_pretrained``. 24 | 25 | Parameters 26 | ---------- 27 | model_name : ``str`` 28 | The name of the pretrained wordpiece tokenizer to use. 29 | start_tokens : ``List[str]``, optional 30 | If given, these tokens will be added to the beginning of every string we tokenize. We try 31 | to be a little bit smart about defaults here - e.g., if your model name contains ``bert``, 32 | we by default add ``[CLS]`` at the beginning and ``[SEP]`` at the end. 33 | end_tokens : ``List[str]``, optional 34 | If given, these tokens will be added to the end of every string we tokenize. 35 | """ 36 | def __init__(self, 37 | model_name: str, 38 | do_lowercase: bool, 39 | start_tokens: List[str] = None, 40 | end_tokens: List[str] = None) -> None: 41 | if model_name.endswith("-cased") and do_lowercase: 42 | logger.warning("Your pretrained model appears to be cased, " 43 | "but your tokenizer is lowercasing tokens.") 44 | elif model_name.endswith("-uncased") and not do_lowercase: 45 | logger.warning("Your pretrained model appears to be uncased, " 46 | "but your tokenizer is not lowercasing tokens.") 47 | self._tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=do_lowercase) 48 | default_start_tokens, default_end_tokens = _guess_start_and_end_token_defaults(model_name) 49 | self._start_tokens = start_tokens if start_tokens is not None else default_start_tokens 50 | self._end_tokens = end_tokens if end_tokens is not None else default_end_tokens 51 | 52 | @overrides 53 | def tokenize(self, text: str) -> List[Token]: 54 | # TODO(mattg): track character offsets. Might be too challenging to do it here, given that 55 | # pytorch-transformers is dealing with the whitespace... 56 | length_limit = 256-2 57 | tokenized = self._tokenizer.tokenize(text) 58 | if len(tokenized) >= length_limit: 59 | tokenized = tokenized[:length_limit] 60 | token_strings = self._start_tokens + tokenized + self._end_tokens 61 | return [Token(t) for t in token_strings] 62 | 63 | 64 | def _guess_start_and_end_token_defaults(model_name: str) -> Tuple[List[str], List[str]]: 65 | if 'roberta' in model_name: 66 | return ([''], ['']) 67 | elif "bert" in model_name: 68 | return (['[CLS]'], ['[SEP]']) 69 | else: 70 | return ([], []) 71 | -------------------------------------------------------------------------------- /allennlp_glue_patch/roberta_pooler.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from overrides import overrides 4 | 5 | import torch 6 | import torch.nn 7 | from pytorch_pretrained_bert import BertModel 8 | 9 | from allennlp.modules.seq2vec_encoders.seq2vec_encoder import Seq2VecEncoder 10 | from allennlp.modules.token_embedders.bert_token_embedder import PretrainedBertModel 11 | 12 | 13 | @Seq2VecEncoder.register("roberta_pooler") 14 | class RobertaPooler(Seq2VecEncoder): 15 | """ 16 | Dense with same hidden + activation using tanh, EASY 17 | """ 18 | def __init__(self, 19 | pretrained_model: Union[str, BertModel], 20 | embedding_dim: int = 1024, # default for large 21 | dropout: float = 0.0) -> None: 22 | super().__init__() 23 | 24 | self._dropout = torch.nn.Dropout(p=dropout) 25 | self._dense = torch.nn.Linear(embedding_dim, embedding_dim) 26 | self._activation = torch.nn.Tanh() 27 | self._embedding_dim = embedding_dim 28 | 29 | @overrides 30 | def get_input_dim(self) -> int: 31 | return self._embedding_dim 32 | 33 | @overrides 34 | def get_output_dim(self) -> int: 35 | return self._embedding_dim 36 | 37 | def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None): # pylint: disable=arguments-differ,unused-argument 38 | first_token_tensor = tokens[:, 0] 39 | return self._dropout(self._activation(self._dense(first_token_tensor))) 40 | -------------------------------------------------------------------------------- /allennlp_glue_patch/stsb_regressor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | from overrides import overrides 4 | import torch 5 | 6 | from allennlp.data import Vocabulary 7 | from allennlp.models.model import Model 8 | from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder 9 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 10 | from allennlp.nn.util import get_text_field_mask 11 | from allennlp.training.metrics import PearsonCorrelation 12 | 13 | 14 | @Model.register("stsb_regressor") 15 | class STSBRegressor(Model): 16 | """ 17 | This ``Model`` implements a basic text regressor. After embedding the text into 18 | a text field, we will optionally encode the embeddings with a ``Seq2SeqEncoder``. The 19 | resulting sequence is pooled using a ``Seq2VecEncoder`` and then passed to 20 | a linear regression layer, which projects into a single value. If a 21 | ``Seq2SeqEncoder`` is not provided, we will pass the embedded text directly to the 22 | ``Seq2VecEncoder``. 23 | 24 | Parameters 25 | ---------- 26 | vocab : ``Vocabulary`` 27 | text_field_embedder : ``TextFieldEmbedder`` 28 | Used to embed the input text into a ``TextField`` 29 | seq2seq_encoder : ``Seq2SeqEncoder``, optional (default=``None``) 30 | Optional Seq2Seq encoder layer for the input text. 31 | seq2vec_encoder : ``Seq2VecEncoder`` 32 | Required Seq2Vec encoder layer. If `seq2seq_encoder` is provided, this encoder 33 | will pool its output. Otherwise, this encoder will operate directly on the output 34 | of the `text_field_embedder`. 35 | dropout : ``float``, optional (default = ``None``) 36 | Dropout percentage to use. 37 | scale : ``float``, optional (default = 1) 38 | Scale regression result is between 0 ~ scale 39 | label_namespace: ``str``, optional (default = "labels") 40 | Vocabulary namespace corresponding to labels. By default, we use the "labels" namespace. 41 | initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) 42 | If provided, will be used to initialize the model parameters. 43 | regularizer : ``RegularizerApplicator``, optional (default=``None``) 44 | If provided, will be used to calculate the regularization penalty during training. 45 | """ 46 | def __init__(self, 47 | vocab: Vocabulary, 48 | text_field_embedder: TextFieldEmbedder, 49 | seq2vec_encoder: Seq2VecEncoder, 50 | seq2seq_encoder: Seq2SeqEncoder = None, 51 | dropout: float = None, 52 | scale: float = 1, 53 | label_namespace: str = "labels", 54 | initializer: InitializerApplicator = InitializerApplicator(), 55 | regularizer: Optional[RegularizerApplicator] = None) -> None: 56 | 57 | super().__init__(vocab, regularizer) 58 | self._text_field_embedder = text_field_embedder 59 | 60 | if seq2seq_encoder: 61 | self._seq2seq_encoder = seq2seq_encoder 62 | else: 63 | self._seq2seq_encoder = None 64 | 65 | self._seq2vec_encoder = seq2vec_encoder 66 | 67 | self._classifier_input_dim = self._seq2vec_encoder.get_output_dim() * 2 # run encoder seperately and concat the result 68 | 69 | if dropout: 70 | self._dropout = torch.nn.Dropout(dropout) 71 | self._dropout_a = torch.nn.Dropout(dropout) 72 | self._dropout_b = torch.nn.Dropout(dropout) 73 | else: 74 | self._dropout = None 75 | 76 | self._label_namespace = label_namespace 77 | 78 | self._num_labels = 1 # because we're running a regression task 79 | self._scale = scale 80 | self.__first = True 81 | 82 | self._mlp_dims = [self._classifier_input_dim] * 3 83 | self._mlp_layers = torch.nn.ModuleList() 84 | for i, j in zip(self._mlp_dims, self._mlp_dims[1:]): 85 | self._mlp_layers.append(torch.nn.Linear(i, j)) 86 | self._mlp_layers.append(torch.nn.ReLU()) 87 | if dropout: 88 | self._mlp_layers.append(torch.nn.Dropout(dropout)) 89 | self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) 90 | self._metric = PearsonCorrelation() 91 | self._similarity = torch.nn.CosineSimilarity() 92 | self._loss = torch.nn.MSELoss() 93 | initializer(self) 94 | 95 | def forward(self, # type: ignore 96 | tokens_a: Dict[str, torch.LongTensor], 97 | tokens_b: Dict[str, torch.LongTensor], 98 | label: torch.IntTensor = None) -> Dict[str, torch.Tensor]: 99 | # pylint: disable=arguments-differ 100 | """ 101 | Parameters 102 | ---------- 103 | tokens : Dict[str, torch.LongTensor] 104 | From a ``TextField`` 105 | label : torch.IntTensor, optional (default = None) 106 | From a ``LabelField`` 107 | 108 | Returns 109 | ------- 110 | An output dictionary consisting of: 111 | 112 | logits : torch.FloatTensor 113 | A tensor of shape ``(batch_size, 1)`` representing 114 | unnormalized log probabilities of the label. 115 | loss : torch.FloatTensor, optional 116 | A scalar loss to be optimised. 117 | """ 118 | tokens = {"tokens_a": tokens_a["tokens_a"], "tokens_b": tokens_b["tokens_b"]} 119 | if self.__first: 120 | self.__first = False 121 | print("tokens: \n") 122 | print(tokens) 123 | # I don't know why tokens_a and tokens_b both includes keys named by each other 124 | tokens_a = {"tokens_a": tokens_a["tokens_a"]} 125 | tokens_b = {"tokens_b": tokens_b["tokens_b"]} 126 | embedded_text = self._text_field_embedder(tokens) 127 | embedded_text_a = embedded_text["tokens_a"] # TODO: check the shape for this 128 | mask_a = get_text_field_mask(tokens_a).float() 129 | embedded_text_b = embedded_text["tokens_b"] 130 | mask_b = get_text_field_mask(tokens_b).float() 131 | 132 | if self._seq2seq_encoder: 133 | embedded_text_a = self._seq2seq_encoder(embedded_text_a, mask=mask_a) 134 | embedded_text_b = self._seq2seq_encoder(embedded_text_b, mask=mask_b) 135 | 136 | embedded_text_a = self._seq2vec_encoder(embedded_text_a, mask=mask_a) 137 | embedded_text_b = self._seq2vec_encoder(embedded_text_b, mask=mask_b) 138 | # embedded_text = torch.cat([embedded_text_a, embedded_text_b], dim=-1) 139 | 140 | if self._dropout: 141 | embedded_text_a = self._dropout_a(embedded_text_a) 142 | embedded_text_b = self._dropout_b(embedded_text_b) 143 | ''' 144 | if self._mlp_layers: 145 | for l in self._mlp_layers: 146 | embedded_text = l(embedded_text) 147 | logits = self._classification_layer(embedded_text) 148 | ''' 149 | logits = self._similarity(embedded_text_a, embedded_text_b) * 5 150 | output_dict = {"logits": logits} 151 | 152 | if label is not None: # convert the label into a float number and update the metric 153 | label_to_str = lambda l: self.vocab.get_index_to_token_vocabulary(self._label_namespace).get(l) 154 | label_tensor = torch.tensor([float(label_to_str(int(label[i]))) for i in range(label.shape[0])], device=logits.device, requires_grad=True) # make sure loss.backward have something to update 155 | loss = self._loss(logits.view(-1), label_tensor) 156 | output_dict["loss"] = loss 157 | self._metric(logits, label_tensor) 158 | 159 | return output_dict 160 | 161 | @overrides 162 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 163 | """ 164 | Does a simple argmax over the probabilities, converts index to string label, and 165 | add ``"label"`` key to the dictionary with the result. 166 | """ 167 | # update this part to generate a float number result as similarity score 168 | predictions = output_dict["logits"] 169 | if predictions.dim() == 2: 170 | predictions_list = [predictions[i] for i in range(predictions.shape[0])] 171 | else: 172 | predictions_list = [predictions] 173 | classes = [] 174 | for prediction in predictions_list: 175 | label_idx = "{:.1f}".format(prediction.long()) 176 | label_str = (self.vocab.get_index_to_token_vocabulary(self._label_namespace) 177 | .get(label_idx, str(label_idx))) 178 | classes.append(label_str) 179 | output_dict["label"] = classes 180 | return output_dict 181 | 182 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 183 | metrics = {'PearsonCorrelation': self._metric.get_metric(reset)} 184 | return metrics 185 | -------------------------------------------------------------------------------- /allennlp_glue_patch/stsb_text_field_embedder.py: -------------------------------------------------------------------------------- 1 | # DEPRECATED!!! 2 | 3 | from typing import Dict, List, Union, Any 4 | import inspect 5 | 6 | import torch 7 | from overrides import overrides 8 | 9 | from allennlp.common import Params 10 | from allennlp.common.checks import ConfigurationError 11 | from allennlp.data import Vocabulary 12 | from allennlp.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder 13 | from allennlp.modules.time_distributed import TimeDistributed 14 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 15 | 16 | @TextFieldEmbedder.register("stsb") 17 | class BasicTextFieldEmbedder(TextFieldEmbedder): 18 | """ 19 | This is a ``TextFieldEmbedder`` that wraps a collection of :class:`TokenEmbedder` objects. Each 20 | ``TokenEmbedder`` embeds or encodes the representation output from one 21 | :class:`~allennlp.data.TokenIndexer`. As the data produced by a 22 | :class:`~allennlp.data.fields.TextField` is a dictionary mapping names to these 23 | representations, we take ``TokenEmbedders`` with corresponding names. Each ``TokenEmbedders`` 24 | embeds its input, and the result is concatenated in an arbitrary order. 25 | 26 | Parameters 27 | ---------- 28 | 29 | token_embedders : ``Dict[str, TokenEmbedder]``, required. 30 | A dictionary mapping token embedder names to implementations. 31 | These names should match the corresponding indexer used to generate 32 | the tensor passed to the TokenEmbedder. 33 | embedder_to_indexer_map : ``Dict[str, Union[List[str], Dict[str, str]]]``, optional, (default = None) 34 | Optionally, you can provide a mapping between the names of the TokenEmbedders that 35 | you are using to embed your TextField and an ordered list of indexer names which 36 | are needed for running it, or a mapping between the parameters which the 37 | ``TokenEmbedder.forward`` takes and the indexer names which are viewed as arguments. 38 | In most cases, your TokenEmbedder will only require a single tensor, because it is 39 | designed to run on the output of a single TokenIndexer. For example, the ELMo Token 40 | Embedder can be used in two modes, one of which requires both character ids and word 41 | ids for the same text. Note that the list of token indexer names is `ordered`, 42 | meaning that the tensors produced by the indexers will be passed to the embedders in 43 | the order you specify in this list. You can also use `null` in the configuration to 44 | set some specified parameters to None. 45 | allow_unmatched_keys : ``bool``, optional (default = False) 46 | If True, then don't enforce the keys of the ``text_field_input`` to 47 | match those in ``token_embedders`` (useful if the mapping is specified 48 | via ``embedder_to_indexer_map``). 49 | """ 50 | def __init__(self, 51 | token_embedders: Dict[str, TokenEmbedder], 52 | embedder_to_indexer_map: Dict[str, Union[List[str], Dict[str, str]]] = None, 53 | allow_unmatched_keys: bool = False, 54 | output_key_specified: bool = False) -> None: 55 | super(BasicTextFieldEmbedder, self).__init__() 56 | self._token_embedders = token_embedders 57 | self._embedder_to_indexer_map = embedder_to_indexer_map 58 | for key, embedder in token_embedders.items(): 59 | name = 'token_embedder_%s' % key 60 | self.add_module(name, embedder) 61 | self._allow_unmatched_keys = allow_unmatched_keys 62 | self._output_key_specified = output_key_specified 63 | 64 | @overrides 65 | def get_output_dim(self) -> int: 66 | output_dim = 0 67 | for embedder in self._token_embedders.values(): 68 | output_dim += embedder.get_output_dim() 69 | return output_dim 70 | 71 | def forward(self, text_field_input: Dict[str, torch.Tensor], 72 | num_wrapping_dims: int = 0, 73 | **kwargs) -> torch.Tensor: 74 | embedder_keys = self._token_embedders.keys() 75 | input_keys = text_field_input.keys() 76 | 77 | # Check for unmatched keys 78 | if not self._allow_unmatched_keys: 79 | if embedder_keys < input_keys: 80 | # token embedder keys are a strict subset of text field input keys. 81 | message = (f"Your text field is generating more keys ({list(input_keys)}) " 82 | f"than you have token embedders ({list(embedder_keys)}. " 83 | f"If you are using a token embedder that requires multiple keys " 84 | f"(for example, the OpenAI Transformer embedder or the BERT embedder) " 85 | f"you need to add allow_unmatched_keys = True " 86 | f"(and likely an embedder_to_indexer_map) to your " 87 | f"BasicTextFieldEmbedder configuration. " 88 | f"Otherwise, you should check that there is a 1:1 embedding " 89 | f"between your token indexers and token embedders.") 90 | raise ConfigurationError(message) 91 | 92 | elif self._token_embedders.keys() != text_field_input.keys(): 93 | # some other mismatch 94 | message = "Mismatched token keys: %s and %s" % (str(self._token_embedders.keys()), 95 | str(text_field_input.keys())) 96 | raise ConfigurationError(message) 97 | 98 | if self._output_key_specified: 99 | embedded_representations = {} 100 | else: 101 | embedded_representations = [] 102 | keys = sorted(embedder_keys) 103 | for key in keys: 104 | # Note: need to use getattr here so that the pytorch voodoo 105 | # with submodules works with multiple GPUs. 106 | embedder = getattr(self, 'token_embedder_{}'.format(key)) 107 | forward_params = inspect.signature(embedder.forward).parameters 108 | forward_params_values = {} 109 | for param in forward_params.keys(): 110 | if param in kwargs: 111 | forward_params_values[param] = kwargs[param] 112 | 113 | for _ in range(num_wrapping_dims): 114 | embedder = TimeDistributed(embedder) 115 | # If we pre-specified a mapping explictly, use that. 116 | # make mypy happy 117 | tensors: Union[List[Any], Dict[str, Any]] = None 118 | if self._embedder_to_indexer_map is not None: 119 | indexer_map = self._embedder_to_indexer_map[key] 120 | if isinstance(indexer_map, list): 121 | # If `indexer_key` is None, we map it to `None`. 122 | tensors = [(text_field_input[indexer_key] if indexer_key is not None else None) 123 | for indexer_key in indexer_map] 124 | token_vectors = embedder(*tensors, **forward_params_values) 125 | elif isinstance(indexer_map, dict): 126 | tensors = { 127 | name: text_field_input[argument] 128 | for name, argument in indexer_map.items() 129 | } 130 | token_vectors = embedder(**tensors, **forward_params_values) 131 | else: 132 | raise NotImplementedError 133 | else: 134 | # otherwise, we assume the mapping between indexers and embedders 135 | # is bijective and just use the key directly. 136 | tensors = [text_field_input[key]] 137 | token_vectors = embedder(*tensors, **forward_params_values) 138 | if self._output_key_specified: 139 | embedded_representations[key] = token_vectors 140 | else: 141 | embedded_representations.append(token_vectors) 142 | if self._output_key_specified: 143 | return embedded_representations 144 | else: 145 | return torch.cat(embedded_representations, dim=-1) 146 | 147 | # This is some unusual logic, it needs a custom from_params. 148 | @classmethod 149 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'BasicTextFieldEmbedder': # type: ignore 150 | # pylint: disable=arguments-differ,bad-super-call 151 | 152 | # The original `from_params` for this class was designed in a way that didn't agree 153 | # with the constructor. The constructor wants a 'token_embedders' parameter that is a 154 | # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those 155 | # key-value pairs to be top-level in the params object. 156 | # 157 | # This breaks our 'configuration wizard' and configuration checks. Hence, going forward, 158 | # the params need a 'token_embedders' key so that they line up with what the constructor wants. 159 | # For now, the old behavior is still supported, but produces a DeprecationWarning. 160 | 161 | embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None) 162 | if embedder_to_indexer_map is not None: 163 | embedder_to_indexer_map = embedder_to_indexer_map.as_dict(quiet=True) 164 | allow_unmatched_keys = params.pop_bool("allow_unmatched_keys", False) 165 | output_key_specified = params.pop_bool("output_key_specified", False) 166 | token_embedder_params = params.pop('token_embedders', None) 167 | 168 | if token_embedder_params is not None: 169 | # New way: explicitly specified, so use it. 170 | token_embedders = { 171 | name: TokenEmbedder.from_params(subparams, vocab=vocab) 172 | for name, subparams in token_embedder_params.items() 173 | } 174 | 175 | else: 176 | token_embedders = {} 177 | keys = list(params.keys()) 178 | for key in keys: 179 | embedder_params = params.pop(key) 180 | token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params) 181 | 182 | params.assert_empty(cls.__name__) 183 | return cls(token_embedders, embedder_to_indexer_map, allow_unmatched_keys, output_key_specified) 184 | -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[imdb]-40.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/IMDB/train.tsv", 20 | "validation_data_path": "./data/IMDB/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 8 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 2e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[sst-2]-20.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 4, 42 | "patience": 2, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 1e-4 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[sst-2]-23.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 4, 42 | "patience": 2, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 1e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[sst-2]-24.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 5e-6 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[sst-2]-26.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 2e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[sst-2]-28.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 32 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 4e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-base-pool[sst-2]-29.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 32 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 2e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-large-pool[imdb]-39.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-large-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/IMDB/train.tsv", 20 | "validation_data_path": "./data/IMDB/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 4 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 0, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 5e-6 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-large-pool[sst-2]-21.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-large-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 4, 42 | "patience": 2, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 0, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 1e-4 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-large-pool[sst-2]-25.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-large-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 0, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 5e-6 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-large-pool[sst-2]-27.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-large-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 0, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 2e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bert-large-pool[sst-2]-30.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-large-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 16 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 5.0, 44 | "validation_metric": "+accuracy", 45 | "cuda_device": 0, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 1e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment/bow-sum[imdb]-31.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "bag_of_word_counts_corrected", 16 | "projection_dim": 300 17 | } 18 | } 19 | }, 20 | "seq2vec_encoder": { 21 | "type": "bag_of_embeddings", 22 | "embedding_dim": 300 23 | } 24 | }, 25 | "iterator": { 26 | "type": "basic", 27 | "batch_size": 32 28 | }, 29 | "trainer": { 30 | "num_epochs": 16, 31 | "patience": 4, 32 | "grad_norm": 4.0, 33 | "validation_metric": "+accuracy", 34 | "cuda_device": 0, 35 | "optimizer": { 36 | "type": "adam", 37 | "lr": 1e-3 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /config/experiment/bow-sum[sst-2]-1.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "bag_of_word_counts_corrected", 16 | "projection_dim": 300 17 | } 18 | } 19 | }, 20 | "seq2vec_encoder": { 21 | "type": "bag_of_embeddings", 22 | "embedding_dim": 300 23 | } 24 | }, 25 | "iterator": { 26 | "type": "basic", 27 | "batch_size": 32 28 | }, 29 | "trainer": { 30 | "num_epochs": 16, 31 | "patience": 4, 32 | "grad_norm": 4.0, 33 | "validation_metric": "+accuracy", 34 | "cuda_device": 0, 35 | "optimizer": { 36 | "type": "adam", 37 | "lr": 1e-3 38 | } 39 | } 40 | } -------------------------------------------------------------------------------- /config/experiment/glove-cnn[imdb]-37.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 512, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/glove-cnn[sst-2]-17.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 1024, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/glove-cnn[sst-2]-18.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 512, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/glove-cnn[sst-2]-19.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 512, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 1e-3 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/glove-lstm[imdb]-36.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "lstm", 24 | "input_size": 300, 25 | "hidden_size": 512, 26 | "num_layers": 2, 27 | "batch_first": true 28 | } 29 | }, 30 | "iterator": { 31 | "type": "basic", 32 | "batch_size": 32 33 | }, 34 | "trainer": { 35 | "num_epochs": 16, 36 | "patience": 4, 37 | "grad_norm": 4.0, 38 | "validation_metric": "+accuracy", 39 | "cuda_device": 0, 40 | "optimizer": { 41 | "type": "adam", 42 | "lr": 1e-3 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /config/experiment/glove-lstm[sst-2]-15.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "lstm", 24 | "input_size": 300, 25 | "hidden_size": 512, 26 | "num_layers": 2, 27 | "batch_first": true 28 | } 29 | }, 30 | "iterator": { 31 | "type": "basic", 32 | "batch_size": 32 33 | }, 34 | "trainer": { 35 | "num_epochs": 16, 36 | "patience": 4, 37 | "grad_norm": 4.0, 38 | "validation_metric": "+accuracy", 39 | "cuda_device": 0, 40 | "optimizer": { 41 | "type": "adam", 42 | "lr": 5e-4 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /config/experiment/glove-lstm[sst-2]-16.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "lstm", 24 | "input_size": 300, 25 | "hidden_size": 512, 26 | "num_layers": 2, 27 | "batch_first": true 28 | } 29 | }, 30 | "iterator": { 31 | "type": "basic", 32 | "batch_size": 32 33 | }, 34 | "trainer": { 35 | "num_epochs": 16, 36 | "patience": 4, 37 | "grad_norm": 4.0, 38 | "validation_metric": "+accuracy", 39 | "cuda_device": 0, 40 | "optimizer": { 41 | "type": "adam", 42 | "lr": 1e-3 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /config/experiment/glove-sum[imdb]-35.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 5e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/glove-sum[sst-2]-10.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 5e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/glove-sum[sst-2]-11.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-3 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/glove-sum[sst-2]-12.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/glove-sum[sst-2]-13.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-2 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/glove-sum[sst-2]-14.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 5e-3 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/roberta-large-pool[imdb]-38.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "roberta-large"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/IMDB/train.tsv", 20 | "validation_data_path": "./data/IMDB/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "roberta_pooler" 33 | } 34 | }, 35 | "iterator": { 36 | "type": "basic", 37 | "batch_size": 4 38 | }, 39 | "trainer": { 40 | "num_epochs": 4, 41 | "patience": 2, 42 | "grad_norm": 5.0, 43 | "validation_metric": "+accuracy", 44 | "cuda_device": 1, 45 | "optimizer": { 46 | "type": "adam", 47 | "lr": 1e-5 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /config/experiment/roberta-large-pool[sst-2]-22.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "roberta-large"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "roberta_pooler" 33 | } 34 | }, 35 | "iterator": { 36 | "type": "basic", 37 | "batch_size": 16 38 | }, 39 | "trainer": { 40 | "num_epochs": 4, 41 | "patience": 2, 42 | "grad_norm": 5.0, 43 | "validation_metric": "+accuracy", 44 | "cuda_device": 1, 45 | "optimizer": { 46 | "type": "adam", 47 | "lr": 1e-5 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-cnn[imdb]-34.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 1024, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-cnn[sst-2]-7.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 512, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-cnn[sst-2]-8.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 1024, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-cnn[sst-2]-9.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 2048, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 4.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 0, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 5e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-lstm[imdb]-33.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "lstm", 24 | "input_size": 300, 25 | "hidden_size": 512, 26 | "num_layers": 2, 27 | "batch_first": true 28 | } 29 | }, 30 | "iterator": { 31 | "type": "basic", 32 | "batch_size": 32 33 | }, 34 | "trainer": { 35 | "num_epochs": 16, 36 | "patience": 4, 37 | "grad_norm": 4.0, 38 | "validation_metric": "+accuracy", 39 | "cuda_device": 0, 40 | "optimizer": { 41 | "type": "adam", 42 | "lr": 5e-4 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-lstm[sst-2]-6.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "lstm", 24 | "input_size": 300, 25 | "hidden_size": 512, 26 | "num_layers": 2, 27 | "batch_first": true 28 | } 29 | }, 30 | "iterator": { 31 | "type": "basic", 32 | "batch_size": 32 33 | }, 34 | "trainer": { 35 | "num_epochs": 16, 36 | "patience": 4, 37 | "grad_norm": 4.0, 38 | "validation_metric": "+accuracy", 39 | "cuda_device": 0, 40 | "optimizer": { 41 | "type": "adam", 42 | "lr": 5e-4 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-sum[imdb]-32.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/IMDB/train.tsv", 9 | "validation_data_path": "./data/IMDB/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 5e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-sum[sst-2]-2.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-3 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-sum[sst-2]-3.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-sum[sst-2]-4.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 2e-3 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment/word2vec-sum[sst-2]-5.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 4.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 5e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/experiment2/bert-base-pool[msrpar]-57.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "similarity_regression", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 20 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 21 | "model": { 22 | "type": "basic_regressor", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 8 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 4.0, 44 | "validation_metric": "+PearsonCorrelation", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 2e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/experiment2/bow-sum[msrpar]-41.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "bag_of_word_counts_corrected", 18 | "projection_dim": 300 19 | }, 20 | "tokens_b": { 21 | "type": "bag_of_word_counts_corrected", 22 | "projection_dim": 300 23 | } 24 | } 25 | }, 26 | "seq2vec_encoder": { 27 | "type": "bag_of_embeddings", 28 | "embedding_dim": 300 29 | } 30 | }, 31 | "iterator": { 32 | "type": "basic", 33 | "batch_size": 32 34 | }, 35 | "trainer": { 36 | "num_epochs": 128, 37 | "patience": 32, 38 | "grad_norm": 4.0, 39 | "validation_metric": "+PearsonCorrelation", 40 | "cuda_device": 3, 41 | "optimizer": { 42 | "type": "adam", 43 | "lr": 1e-4 44 | } 45 | } 46 | } -------------------------------------------------------------------------------- /config/experiment2/glove-cnn[msrpar]-54.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "cnn", 32 | "embedding_dim": 300, 33 | "num_filters": 512, 34 | "ngram_filter_sizes": [2, 3, 4, 5] 35 | }, 36 | "dropout": 0.1 37 | }, 38 | "iterator": { 39 | "type": "basic", 40 | "batch_size": 32 41 | }, 42 | "trainer": { 43 | "num_epochs": 128, 44 | "patience": 32, 45 | "grad_norm": 4.0, 46 | "validation_metric": "+PearsonCorrelation", 47 | "cuda_device": 3, 48 | "optimizer": { 49 | "type": "adam", 50 | "lr": 1e-4 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /config/experiment2/glove-cnn[msrpar]-55.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "cnn", 32 | "embedding_dim": 300, 33 | "num_filters": 1024, 34 | "ngram_filter_sizes": [2, 3, 4, 5] 35 | }, 36 | "dropout": 0.1 37 | }, 38 | "iterator": { 39 | "type": "basic", 40 | "batch_size": 32 41 | }, 42 | "trainer": { 43 | "num_epochs": 128, 44 | "patience": 32, 45 | "grad_norm": 4.0, 46 | "validation_metric": "+PearsonCorrelation", 47 | "cuda_device": 3, 48 | "optimizer": { 49 | "type": "adam", 50 | "lr": 1e-4 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /config/experiment2/glove-cnn[msrpar]-56.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "cnn", 32 | "embedding_dim": 300, 33 | "num_filters": 300, 34 | "ngram_filter_sizes": [2, 3, 4, 5] 35 | }, 36 | "dropout": 0.1 37 | }, 38 | "iterator": { 39 | "type": "basic", 40 | "batch_size": 32 41 | }, 42 | "trainer": { 43 | "num_epochs": 128, 44 | "patience": 32, 45 | "grad_norm": 4.0, 46 | "validation_metric": "+PearsonCorrelation", 47 | "cuda_device": 3, 48 | "optimizer": { 49 | "type": "adam", 50 | "lr": 1e-4 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /config/experiment2/glove-lstm[msrpar]-49.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 32 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/glove-lstm[msrpar]-50.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 32 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-3 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/glove-lstm[msrpar]-51.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 8 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/glove-lstm[msrpar]-52.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 512, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 32 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/glove-lstm[msrpar]-53.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 300, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 32 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/glove-sum[msrpar]-48.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "bag_of_embeddings", 32 | "embedding_dim": 300 33 | } 34 | }, 35 | "iterator": { 36 | "type": "basic", 37 | "batch_size": 32 38 | }, 39 | "trainer": { 40 | "num_epochs": 1, 41 | "patience": 32, 42 | "grad_norm": 4.0, 43 | "validation_metric": "+PearsonCorrelation", 44 | "cuda_device": 3, 45 | "optimizer": { 46 | "type": "adam", 47 | "lr": 1e-4 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /config/experiment2/word2vec-cnn[msrpar]-47.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "cnn", 32 | "embedding_dim": 300, 33 | "num_filters": 512, 34 | "ngram_filter_sizes": [2, 3, 4, 5] 35 | }, 36 | "dropout": 0.1 37 | }, 38 | "iterator": { 39 | "type": "basic", 40 | "batch_size": 8 41 | }, 42 | "trainer": { 43 | "num_epochs": 128, 44 | "patience": 32, 45 | "grad_norm": 4.0, 46 | "validation_metric": "+PearsonCorrelation", 47 | "cuda_device": 3, 48 | "optimizer": { 49 | "type": "adam", 50 | "lr": 1e-4 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /config/experiment2/word2vec-lstm[msrpar]-43.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 32 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/word2vec-lstm[msrpar]-44.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 8 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/word2vec-lstm[msrpar]-45.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 8 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-3 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/word2vec-lstm[msrpar]-46.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 8 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-5 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/experiment2/word2vec-sum[msrpar]-42.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-MSRPAR/train.tsv", 9 | "validation_data_path": "./data/STS-B-MSRPAR/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "bag_of_embeddings", 32 | "embedding_dim": 300 33 | } 34 | }, 35 | "iterator": { 36 | "type": "basic", 37 | "batch_size": 32 38 | }, 39 | "trainer": { 40 | "num_epochs": 1, 41 | "patience": 32, 42 | "grad_norm": 4.0, 43 | "validation_metric": "+PearsonCorrelation", 44 | "cuda_device": 3, 45 | "optimizer": { 46 | "type": "adam", 47 | "lr": 1e-4 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /config/test/sst-cnn-bow.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "bag_of_word_counts_corrected", 16 | "projection_dim": 300 17 | } 18 | } 19 | }, 20 | "seq2vec_encoder": { 21 | "type": "cnn", 22 | "embedding_dim": 300, 23 | "num_filters": 512, 24 | "ngram_filter_sizes": [2, 3, 4, 5] 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 16, 33 | "patience": 4, 34 | "grad_norm": 5.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 0, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-4 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/test/sst-cnn.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "cnn", 24 | "embedding_dim": 300, 25 | "num_filters": 512, 26 | "ngram_filter_sizes": [2, 3, 4, 5] 27 | } 28 | }, 29 | "iterator": { 30 | "type": "basic", 31 | "batch_size": 32 32 | }, 33 | "trainer": { 34 | "num_epochs": 16, 35 | "patience": 4, 36 | "grad_norm": 5.0, 37 | "validation_metric": "+accuracy", 38 | "cuda_device": 7, 39 | "optimizer": { 40 | "type": "adam", 41 | "lr": 1e-4 42 | } 43 | } 44 | } -------------------------------------------------------------------------------- /config/test/sst-lstm.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "lstm", 24 | "input_size": 300, 25 | "hidden_size": 512, 26 | "num_layers": 2, 27 | "batch_first": true 28 | } 29 | }, 30 | "iterator": { 31 | "type": "basic", 32 | "batch_size": 32 33 | }, 34 | "trainer": { 35 | "num_epochs": 16, 36 | "patience": 4, 37 | "grad_norm": 5.0, 38 | "validation_metric": "+accuracy", 39 | "cuda_device": 7, 40 | "optimizer": { 41 | "type": "adam", 42 | "lr": 1e-4 43 | } 44 | } 45 | } -------------------------------------------------------------------------------- /config/test/sst-roberta.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "roberta-large"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "binary_sentiment", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/SST-2/train.tsv", 20 | "validation_data_path": "./data/SST-2/dev.tsv", 21 | "model": { 22 | "type": "basic_classifier", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "roberta_pooler" 33 | } 34 | }, 35 | "iterator": { 36 | "type": "basic", 37 | "batch_size": 16 38 | }, 39 | "trainer": { 40 | "num_epochs": 4, 41 | "patience": 2, 42 | "grad_norm": 5.0, 43 | "validation_metric": "+accuracy", 44 | "cuda_device": 1, 45 | "optimizer": { 46 | "type": "adam", 47 | "lr": 1e-5 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /config/test/sst-sum.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "binary_sentiment" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "binary_sentiment" 7 | }, 8 | "train_data_path": "./data/SST-2/train.tsv", 9 | "validation_data_path": "./data/SST-2/dev.tsv", 10 | "model": { 11 | "type": "basic_classifier", 12 | "text_field_embedder": { 13 | "token_embedders": { 14 | "tokens": { 15 | "type": "embedding", 16 | "embedding_dim": 300, 17 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 18 | "trainable": false 19 | } 20 | } 21 | }, 22 | "seq2vec_encoder": { 23 | "type": "bag_of_embeddings", 24 | "embedding_dim": 300 25 | } 26 | }, 27 | "iterator": { 28 | "type": "basic", 29 | "batch_size": 32 30 | }, 31 | "trainer": { 32 | "num_epochs": 8, 33 | "patience": 4, 34 | "grad_norm": 5.0, 35 | "validation_metric": "+accuracy", 36 | "cuda_device": 7, 37 | "optimizer": { 38 | "type": "adam", 39 | "lr": 1e-3 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /config/test/stsball-bert.jsonnet: -------------------------------------------------------------------------------- 1 | local pretrained_transformer_model_name = "bert-base-uncased"; 2 | 3 | { 4 | "dataset_reader": { 5 | "type": "similarity_regression", 6 | "tokenizer": { 7 | "type": "pretrained_transformer_corrected", 8 | "model_name": pretrained_transformer_model_name, 9 | "do_lowercase": true 10 | }, 11 | "token_indexers": { 12 | "tokens": { 13 | "type": "pretrained_transformer", 14 | "model_name": pretrained_transformer_model_name, 15 | "do_lowercase": true 16 | } 17 | } 18 | }, 19 | "train_data_path": "./data/STS-B-ALL/train.tsv", 20 | "validation_data_path": "./data/STS-B-ALL/dev.tsv", 21 | "model": { 22 | "type": "basic_regressor", 23 | "text_field_embedder": { 24 | "token_embedders": { 25 | "tokens": { 26 | "type": "pretrained_transformer", 27 | "model_name": pretrained_transformer_model_name, 28 | } 29 | } 30 | }, 31 | "seq2vec_encoder": { 32 | "type": "bert_pooler", 33 | "pretrained_model": pretrained_transformer_model_name 34 | } 35 | }, 36 | "iterator": { 37 | "type": "basic", 38 | "batch_size": 8 39 | }, 40 | "trainer": { 41 | "num_epochs": 8, 42 | "patience": 4, 43 | "grad_norm": 4.0, 44 | "validation_metric": "+PearsonCorrelation", 45 | "cuda_device": 1, 46 | "optimizer": { 47 | "type": "adam", 48 | "lr": 2e-5 49 | } 50 | } 51 | } -------------------------------------------------------------------------------- /config/test/stsball-bow.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-ALL/train.tsv", 9 | "validation_data_path": "./data/STS-B-ALL/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "bag_of_word_counts_corrected", 18 | "projection_dim": 300 19 | }, 20 | "tokens_b": { 21 | "type": "bag_of_word_counts_corrected", 22 | "projection_dim": 300 23 | } 24 | } 25 | }, 26 | "seq2vec_encoder": { 27 | "type": "bag_of_embeddings", 28 | "embedding_dim": 300 29 | } 30 | }, 31 | "iterator": { 32 | "type": "basic", 33 | "batch_size": 32 34 | }, 35 | "trainer": { 36 | "num_epochs": 128, 37 | "patience": 32, 38 | "grad_norm": 4.0, 39 | "validation_metric": "+PearsonCorrelation", 40 | "cuda_device": 3, 41 | "optimizer": { 42 | "type": "adam", 43 | "lr": 1e-4 44 | } 45 | } 46 | } -------------------------------------------------------------------------------- /config/test/stsball-glove-lstm.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-ALL/train.tsv", 9 | "validation_data_path": "./data/STS-B-ALL/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "lstm", 32 | "input_size": 300, 33 | "hidden_size": 1024, 34 | "num_layers": 1, 35 | "batch_first": true 36 | }, 37 | "dropout": 0.1 38 | }, 39 | "iterator": { 40 | "type": "basic", 41 | "batch_size": 32 42 | }, 43 | "trainer": { 44 | "num_epochs": 128, 45 | "patience": 32, 46 | "grad_norm": 4.0, 47 | "validation_metric": "+PearsonCorrelation", 48 | "cuda_device": 3, 49 | "optimizer": { 50 | "type": "adam", 51 | "lr": 1e-4 52 | } 53 | } 54 | } -------------------------------------------------------------------------------- /config/test/stsball-w2v.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "similarity_regression" 4 | }, 5 | "validation_dataset_reader": { 6 | "type": "similarity_regression" 7 | }, 8 | "train_data_path": "./data/STS-B-ALL/train.tsv", 9 | "validation_data_path": "./data/STS-B-ALL/dev.tsv", 10 | "model": { 11 | "type": "stsb_regressor", 12 | "text_field_embedder": { 13 | "type": "stsb", 14 | "output_key_specified": true, 15 | "token_embedders": { 16 | "tokens_a": { 17 | "type": "embedding", 18 | "embedding_dim": 300, 19 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 20 | "trainable": false 21 | }, 22 | "tokens_b": { 23 | "type": "embedding", 24 | "embedding_dim": 300, 25 | "pretrained_file": "https://allennlp.s3.amazonaws.com/datasets/word2vec/GoogleNews-vectors-negative300.txt.gz", 26 | "trainable": false 27 | } 28 | } 29 | }, 30 | "seq2vec_encoder": { 31 | "type": "bag_of_embeddings", 32 | "embedding_dim": 300 33 | } 34 | }, 35 | "iterator": { 36 | "type": "basic", 37 | "batch_size": 32 38 | }, 39 | "trainer": { 40 | "num_epochs": 128, 41 | "patience": 32, 42 | "grad_norm": 4.0, 43 | "validation_metric": "+PearsonCorrelation", 44 | "cuda_device": 3, 45 | "optimizer": { 46 | "type": "adam", 47 | "lr": 1e-4 48 | } 49 | } 50 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Pretrained Transformers Improve Out-of-Distribution Robustness 2 | 3 | How does pretraining affect out-of-distribution robustness? We create an OOD benchmark and use it to show that pretrained transformers such as BERT have substantially higher OOD accuracy and OOD detection rates compared to traditional NLP models. 4 | 5 | This repository most of the code for the paper [_Pretrained Transformers Improve Out-of-Distribution Robustness_](https://arxiv.org/abs/2004.06100), ACL 2020. 6 | 7 | 8 | 9 | Requires Python 3+ and PyTorch 1.0+. 10 | 11 | 12 | To correctly use RoBERTa model, see `allennlp_glue_patch/notes.py`. 13 | 14 | 15 | ## Citation 16 | 17 | If you find this useful in your research, please consider citing: 18 | 19 | @inproceedings{hendrycks2020pretrained, 20 | Author = {Dan Hendrycks and Xiaoyuan Liu and Eric Wallace and Adam Dziedzic Rishabh Krishnan and Dawn Song}, 21 | Booktitle = {Association for Computational Linguistics}, 22 | Year = {2020}, 23 | Title = {Pretrained Transformers Improve Out-of-Distribution Robustness}} 24 | -------------------------------------------------------------------------------- /result.txt: -------------------------------------------------------------------------------- 1 | ['C', 'WC', 'MC', 'BC', 'S', 'MS', 'MV', 'B'] 2 | C WC MC BC S MS MV B 3 | BERT -3.0/0.7 -4.9/-1.2 -3.0/1.2 -0.9/3.3 -2.0/2.8 -5.0/-4.0 -5.7/-4.7 -6.8/-4.6 4 | C 83.1/74.3 87.9/80.6 85.3/73.2 82.4/71.1 81.7/71.7 63.3/49.3 71.8/62.3 65.6/54.9 5 | WC 87.1/76.6 84.1/78.1 82.1/70.8 81.8/70.4 81.4/71.3 62.8/48.6 72.1/62.3 64.8/52.8 6 | MC 86.0/75.7 83.6/77.7 82.3/71.0 81.8/70.4 81.1/71.3 62.9/49.2 70.9/62.0 64.9/53.9 7 | BC 82.4/73.7 82.8/76.9 80.9/69.9 83.3/71.9 80.3/70.6 62.2/47.8 71.5/62.2 64.5/52.7 8 | S 82.1/73.3 83.1/77.1 81.1/70.1 81.9/70.4 81.8/72.0 62.3/48.8 71.3/61.8 63.7/53.1 9 | MS 74.9/67.2 75.7/70.8 74.0/64.2 73.5/63.2 73.6/65.2 69.9/54.3 75.0/64.9 68.5/57.1 10 | MV 76.1/68.6 77.3/72.6 74.4/65.3 74.9/64.9 73.8/66.3 67.5/52.1 77.3/68.1 69.0/57.3 11 | B 75.6/68.5 75.7/71.6 74.0/65.1 73.4/63.8 72.6/64.8 66.6/51.6 74.6/65.5 73.4/60.3 12 | -------------------------------------------------------------------------------- /results.txt: -------------------------------------------------------------------------------- 1 | ['bert-base-uncased[sst-2]', 'bert-base-uncased[imdb]', 'bert-base-uncased[sts-b-msrvid]', 'bert-base-uncased[sts-b-msrpar]', 'bert-base-uncased[sts-b-headlines]', 'bert-base-uncased[sts-b-images]'] 2 | -------------------------------------------------------------------------------- /scripts/20191006.py: -------------------------------------------------------------------------------- 1 | # util 2 | def fine_tune_model(model_type, task): 3 | return "model" 4 | def test_model_on_task(model, task): 5 | return "result" 6 | 7 | # models 8 | transformer_model_types = ["bert-base-uncased", "bert-large-uncased", "roberta-large"] 9 | other_model_types = ["Glove", "Word2vec", "LSTM"] 10 | model_types = transformer_model_types + other_model_types 11 | # tasks 12 | task_groups = [ 13 | # ("SST-2", "IMDb"), 14 | # ("STS-b-headlines", "STS-b-MSRpar"), 15 | ("STS-b-MSRvid", "STS-b-images") 16 | ] 17 | # run the experiment 18 | for model_type in model_types: 19 | for task_group in task_groups: 20 | tuned_models = {} 21 | for task in task_group: 22 | tuned_models[task] = fine_tune_model(model_type, task) 23 | for task in task_group: 24 | for source_task, model in tuned_models.items(): 25 | result = test_model_on_task(model, task) 26 | current = "{:20}([{:10}]-[{:10}])".format(model_type, source_task, task) 27 | result.save("./results/{}".format(current)) 28 | print("{} -> Finished").format(current) 29 | 30 | -------------------------------------------------------------------------------- /scripts/generate_config.py: -------------------------------------------------------------------------------- 1 | 2 | src = "MSRPAR" 3 | src_file_id = [ 4 | 42, 46, 47, 48, 52, 54, ... 5 | ] 6 | 7 | import os 8 | cnt = 100 # TODO: Change this 9 | 10 | targets = [ 11 | 'headlines', 12 | 'msrvid', 13 | 'images' 14 | ] 15 | 16 | def create_config(old_path, new_path, old_task, new_task, dryrun=True): 17 | if dryrun: 18 | return # dryrun mode 19 | with open(old_path, 'r') as f_old, open(new_path, 'w') as f_new: 20 | for l in f_old.readlines(): 21 | f_new.write(l.replace(old_task, new_task)) 22 | f_new.write('\n') 23 | 24 | 25 | path_to_dir = "config/experiment2" 26 | created = [] 27 | for t in targets: 28 | to_create = [] 29 | for fname in os.listdir(path_to_dir): 30 | # TODO: scan the file dir to pick up id 31 | fid = int(fname.split(".json")[0].split("-")[-1]) 32 | if fid in src_file_id: 33 | to_create.append((fid, fname)) 34 | for fid, fname in sorted(to_create): 35 | path_to_src_file = os.path.join(path_to_dir, fname) 36 | new_fid = cnt 37 | cnt += 1 38 | new_fname = fname.replace(str(fid), new_fid) 39 | created.append((new_fid, new_fname)) 40 | create_config(path_to_src_file, os.path.join(path_to_dir, new_fname), src, t) 41 | 42 | print(created) -------------------------------------------------------------------------------- /scripts/run_glue_tasks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATA_DIR="./data/raw_dataset" 3 | BATCH_ID="AmazonSplit6" 4 | num_train_epochs=8.0 5 | learning_rate=2e-5 6 | 7 | train(){ 8 | 9 | task_name=${1} 10 | 11 | RUN_TAG="${BATCH_ID}-train-${task_name}" 12 | TRAIN_TASK_NAME=${task_name} 13 | EVAL_TASK_NAME=${task_name} 14 | 15 | python ./src/run_glue.py \ 16 | --model_type bert \ 17 | --model_name_or_path bert-base-uncased \ 18 | --task_name ${EVAL_TASK_NAME} \ 19 | --do_train \ 20 | --do_eval \ 21 | --do_lower_case \ 22 | --data_dir ${DATA_DIR}/${EVAL_TASK_NAME} \ 23 | --max_seq_length 64 \ 24 | --per_gpu_eval_batch_size=32 \ 25 | --per_gpu_train_batch_size=32 \ 26 | --learning_rate ${learning_rate} \ 27 | --num_train_epochs ${num_train_epochs} \ 28 | --save_steps 200 \ 29 | --overwrite_output_dir \ 30 | --output_dir ./model/${BATCH_ID}/${RUN_TAG} >> ./log.txt 2>&1 31 | 32 | echo ${RUN_TAG} 33 | } 34 | 35 | test(){ 36 | train_name=${1} 37 | test_name=${2} 38 | model_run_tag=${3} 39 | 40 | RUN_TAG="${BATCH_ID}-[${model_run_tag}]-${test_name}" 41 | TRAIN_TASK_NAME=${train_name} 42 | EVAL_TASK_NAME=${test_name} 43 | 44 | if [ ${model_run_tag} == "bert-base-uncased" ] 45 | then 46 | model_path="bert-base-uncased" 47 | else 48 | model_path="./model/${BATCH_ID}/${model_run_tag}" 49 | fi 50 | 51 | python ./src/run_glue.py \ 52 | --model_type bert \ 53 | --model_name_or_path ${model_path} \ 54 | --task_name ${EVAL_TASK_NAME} \ 55 | --do_train \ 56 | --do_eval \ 57 | --do_lower_case \ 58 | --data_dir ${DATA_DIR}/${EVAL_TASK_NAME} \ 59 | --max_seq_length 64 \ 60 | --per_gpu_eval_batch_size=32 \ 61 | --per_gpu_train_batch_size=32 \ 62 | --learning_rate 0 \ 63 | --num_train_epochs 0.0 \ 64 | --save_steps 200 \ 65 | --overwrite_output_dir \ 66 | --output_dir ./model/${BATCH_ID}/${RUN_TAG} >> ./log.txt 2>&1 67 | 68 | echo ${RUN_TAG} 69 | } 70 | 71 | # # untrained performance 72 | # test "STS-B-MSRPAR" "STS-B-MSRPAR" "bert-base-uncased" 73 | # test "STS-B-HEADLINES" "STS-B-HEADLINES" "bert-base-uncased" 74 | # test "STS-B-MSRVID" "STS-B-MSRVID" "bert-base-uncased" 75 | # test "STS-B-IMAGES" "STS-B-IMAGES" "bert-base-uncased" 76 | 77 | # # trained 78 | # MSRPAR_MODEL=`train "STS-B-MSRPAR"` 79 | # test "STS-B-MSRPAR" "STS-B-HEADLINES" ${MSRPAR_MODEL} 80 | # HEADLINES_MODEL=`train "STS-B-HEADLINES"` 81 | # test "STS-B-HEADLINES" "STS-B-MSRPAR" ${HEADLINES_MODEL} 82 | # MSRVID_MODEL=`train "STS-B-MSRVID"` 83 | # test "STS-B-MSRVID" "STS-B-IMAGES" ${MSRVID_MODEL} 84 | # IMAGES_MODEL=`train "STS-B-IMAGES"` 85 | # test "STS-B-IMAGES" "STS-B-MSRVID" ${IMAGES_MODEL} 86 | 87 | # # MNLI-experiments 88 | # test "MNLI-TELEPHONE" "MNLI-TELEPHONE" "bert-base-uncased" 89 | # test "MNLI-LETTERS" "MNLI-LETTERS" "bert-base-uncased" 90 | # test "MNLI-FACETOFACE" "MNLI-FACETOFACE" "bert-base-uncased" 91 | 92 | # TELEPHONE_MODEL=`train "MNLI-TELEPHONE"` 93 | # test "MNLI-TELEPHONE" "MNLI-LETTERS" ${TELEPHONE_MODEL} 94 | # test "MNLI-TELEPHONE" "MNLI-FACETOFACE" ${TELEPHONE_MODEL} 95 | 96 | # ALL_MODEL=`train "MNLI"` 97 | # test "MNLI" "MNLI-TELEPHONE" ${ALL_MODEL} 98 | # test "MNLI" "MNLI-LETTERS" ${ALL_MODEL} 99 | # test "MNLI" "MNLI-FACETOFACE" ${ALL_MODEL} 100 | 101 | # # Amazon 102 | 103 | # AMAZON_C_MODEL=`train "AMAZON-C-SAMPLE"` 104 | # test "AMAZON-C-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_C_MODEL} 105 | # test "AMAZON-C-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_C_MODEL} 106 | # test "AMAZON-C-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_C_MODEL} 107 | # test "AMAZON-C-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_C_MODEL} 108 | # test "AMAZON-C-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_C_MODEL} 109 | # test "AMAZON-C-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_C_MODEL} 110 | # test "AMAZON-C-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_C_MODEL} 111 | 112 | # AMAZON_WC_MODEL=`train "AMAZON-WC-SAMPLE"` 113 | # test "AMAZON-WC-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_WC_MODEL} 114 | # test "AMAZON-WC-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_WC_MODEL} 115 | # test "AMAZON-WC-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_WC_MODEL} 116 | # test "AMAZON-WC-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_WC_MODEL} 117 | # test "AMAZON-WC-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_WC_MODEL} 118 | # test "AMAZON-WC-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_WC_MODEL} 119 | # test "AMAZON-WC-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_WC_MODEL} 120 | 121 | # AMAZON_MC_MODEL=`train "AMAZON-MC-SAMPLE"` 122 | # test "AMAZON-MC-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_MC_MODEL} 123 | # test "AMAZON-MC-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_MC_MODEL} 124 | # test "AMAZON-MC-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_MC_MODEL} 125 | # test "AMAZON-MC-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_MC_MODEL} 126 | # test "AMAZON-MC-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_MC_MODEL} 127 | # test "AMAZON-MC-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_MC_MODEL} 128 | # test "AMAZON-MC-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_MC_MODEL} 129 | 130 | # AMAZON_BC_MODEL=`train "AMAZON-BC-SAMPLE"` 131 | # test "AMAZON-BC-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_BC_MODEL} 132 | # test "AMAZON-BC-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_BC_MODEL} 133 | # test "AMAZON-BC-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_BC_MODEL} 134 | # test "AMAZON-BC-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_BC_MODEL} 135 | # test "AMAZON-BC-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_BC_MODEL} 136 | # test "AMAZON-BC-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_BC_MODEL} 137 | # test "AMAZON-BC-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_BC_MODEL} 138 | 139 | # AMAZON_S_MODEL=`train "AMAZON-S-SAMPLE"` 140 | # test "AMAZON-S-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_S_MODEL} 141 | # test "AMAZON-S-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_S_MODEL} 142 | # test "AMAZON-S-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_S_MODEL} 143 | # test "AMAZON-S-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_S_MODEL} 144 | # test "AMAZON-S-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_S_MODEL} 145 | # test "AMAZON-S-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_S_MODEL} 146 | # test "AMAZON-S-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_S_MODEL} 147 | 148 | # AMAZON_MS_MODEL=`train "AMAZON-MS-SAMPLE"` 149 | # test "AMAZON-MS-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_MS_MODEL} 150 | # test "AMAZON-MS-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_MS_MODEL} 151 | # test "AMAZON-MS-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_MS_MODEL} 152 | # test "AMAZON-MS-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_MS_MODEL} 153 | # test "AMAZON-MS-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_MS_MODEL} 154 | # test "AMAZON-MS-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_MS_MODEL} 155 | # test "AMAZON-MS-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_MS_MODEL} 156 | 157 | # AMAZON_MV_MODEL=`train "AMAZON-MV-SAMPLE"` 158 | # test "AMAZON-MV-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_MV_MODEL} 159 | # test "AMAZON-MV-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_MV_MODEL} 160 | # test "AMAZON-MV-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_MV_MODEL} 161 | # test "AMAZON-MV-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_MV_MODEL} 162 | # test "AMAZON-MV-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_MV_MODEL} 163 | # test "AMAZON-MV-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_MV_MODEL} 164 | # test "AMAZON-MV-SAMPLE" "AMAZON-B-SAMPLE" ${AMAZON_MV_MODEL} 165 | 166 | # AMAZON_B_MODEL=`train "AMAZON-B-SAMPLE"` 167 | # test "AMAZON-B-SAMPLE" "AMAZON-C-SAMPLE" ${AMAZON_B_MODEL} 168 | # test "AMAZON-B-SAMPLE" "AMAZON-WC-SAMPLE" ${AMAZON_B_MODEL} 169 | # test "AMAZON-B-SAMPLE" "AMAZON-MC-SAMPLE" ${AMAZON_B_MODEL} 170 | # test "AMAZON-B-SAMPLE" "AMAZON-BC-SAMPLE" ${AMAZON_B_MODEL} 171 | # test "AMAZON-B-SAMPLE" "AMAZON-S-SAMPLE" ${AMAZON_B_MODEL} 172 | # test "AMAZON-B-SAMPLE" "AMAZON-MS-SAMPLE" ${AMAZON_B_MODEL} 173 | # test "AMAZON-B-SAMPLE" "AMAZON-MV-SAMPLE" ${AMAZON_B_MODEL} 174 | # # untrained 175 | # test "AMAZON-C-SAMPLE" "AMAZON-C-SAMPLE" "bert-base-uncased" 176 | # test "AMAZON-B-SAMPLE" "AMAZON-B-SAMPLE" "bert-base-uncased" 177 | # test "AMAZON-WC-SAMPLE" "AMAZON-WC-SAMPLE" "bert-base-uncased" 178 | # test "AMAZON-MC-SAMPLE" "AMAZON-MC-SAMPLE" "bert-base-uncased" 179 | # test "AMAZON-BC-SAMPLE" "AMAZON-BC-SAMPLE" "bert-base-uncased" 180 | # test "AMAZON-S-SAMPLE" "AMAZON-S-SAMPLE" "bert-base-uncased" 181 | # test "AMAZON-MS-SAMPLE" "AMAZON-MS-SAMPLE" "bert-base-uncased" 182 | # test "AMAZON-MV-SAMPLE" "AMAZON-MV-SAMPLE" "bert-base-uncased" 183 | 184 | # Amazon Sample 2 (as classification) 185 | 186 | AMAZON_C_MODEL=`train "AMAZON-C-SAMPLE2"` 187 | test "AMAZON-C-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_C_MODEL} 188 | test "AMAZON-C-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_C_MODEL} 189 | test "AMAZON-C-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_C_MODEL} 190 | test "AMAZON-C-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_C_MODEL} 191 | test "AMAZON-C-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_C_MODEL} 192 | test "AMAZON-C-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_C_MODEL} 193 | test "AMAZON-C-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_C_MODEL} 194 | 195 | AMAZON_WC_MODEL=`train "AMAZON-WC-SAMPLE2"` 196 | test "AMAZON-WC-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_WC_MODEL} 197 | test "AMAZON-WC-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_WC_MODEL} 198 | test "AMAZON-WC-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_WC_MODEL} 199 | test "AMAZON-WC-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_WC_MODEL} 200 | test "AMAZON-WC-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_WC_MODEL} 201 | test "AMAZON-WC-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_WC_MODEL} 202 | test "AMAZON-WC-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_WC_MODEL} 203 | 204 | AMAZON_MC_MODEL=`train "AMAZON-MC-SAMPLE2"` 205 | test "AMAZON-MC-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_MC_MODEL} 206 | test "AMAZON-MC-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_MC_MODEL} 207 | test "AMAZON-MC-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_MC_MODEL} 208 | test "AMAZON-MC-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_MC_MODEL} 209 | test "AMAZON-MC-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_MC_MODEL} 210 | test "AMAZON-MC-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_MC_MODEL} 211 | test "AMAZON-MC-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_MC_MODEL} 212 | 213 | AMAZON_BC_MODEL=`train "AMAZON-BC-SAMPLE2"` 214 | test "AMAZON-BC-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_BC_MODEL} 215 | test "AMAZON-BC-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_BC_MODEL} 216 | test "AMAZON-BC-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_BC_MODEL} 217 | test "AMAZON-BC-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_BC_MODEL} 218 | test "AMAZON-BC-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_BC_MODEL} 219 | test "AMAZON-BC-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_BC_MODEL} 220 | test "AMAZON-BC-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_BC_MODEL} 221 | 222 | AMAZON_S_MODEL=`train "AMAZON-S-SAMPLE2"` 223 | test "AMAZON-S-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_S_MODEL} 224 | test "AMAZON-S-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_S_MODEL} 225 | test "AMAZON-S-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_S_MODEL} 226 | test "AMAZON-S-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_S_MODEL} 227 | test "AMAZON-S-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_S_MODEL} 228 | test "AMAZON-S-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_S_MODEL} 229 | test "AMAZON-S-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_S_MODEL} 230 | 231 | AMAZON_MS_MODEL=`train "AMAZON-MS-SAMPLE2"` 232 | test "AMAZON-MS-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_MS_MODEL} 233 | test "AMAZON-MS-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_MS_MODEL} 234 | test "AMAZON-MS-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_MS_MODEL} 235 | test "AMAZON-MS-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_MS_MODEL} 236 | test "AMAZON-MS-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_MS_MODEL} 237 | test "AMAZON-MS-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_MS_MODEL} 238 | test "AMAZON-MS-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_MS_MODEL} 239 | 240 | AMAZON_MV_MODEL=`train "AMAZON-MV-SAMPLE2"` 241 | test "AMAZON-MV-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_MV_MODEL} 242 | test "AMAZON-MV-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_MV_MODEL} 243 | test "AMAZON-MV-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_MV_MODEL} 244 | test "AMAZON-MV-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_MV_MODEL} 245 | test "AMAZON-MV-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_MV_MODEL} 246 | test "AMAZON-MV-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_MV_MODEL} 247 | test "AMAZON-MV-SAMPLE2" "AMAZON-B-SAMPLE2" ${AMAZON_MV_MODEL} 248 | 249 | AMAZON_B_MODEL=`train "AMAZON-B-SAMPLE2"` 250 | test "AMAZON-B-SAMPLE2" "AMAZON-C-SAMPLE2" ${AMAZON_B_MODEL} 251 | test "AMAZON-B-SAMPLE2" "AMAZON-WC-SAMPLE2" ${AMAZON_B_MODEL} 252 | test "AMAZON-B-SAMPLE2" "AMAZON-MC-SAMPLE2" ${AMAZON_B_MODEL} 253 | test "AMAZON-B-SAMPLE2" "AMAZON-BC-SAMPLE2" ${AMAZON_B_MODEL} 254 | test "AMAZON-B-SAMPLE2" "AMAZON-S-SAMPLE2" ${AMAZON_B_MODEL} 255 | test "AMAZON-B-SAMPLE2" "AMAZON-MS-SAMPLE2" ${AMAZON_B_MODEL} 256 | test "AMAZON-B-SAMPLE2" "AMAZON-MV-SAMPLE2" ${AMAZON_B_MODEL} 257 | # untrained 258 | test "AMAZON-C-SAMPLE2" "AMAZON-C-SAMPLE2" "bert-base-uncased" 259 | test "AMAZON-B-SAMPLE2" "AMAZON-B-SAMPLE2" "bert-base-uncased" 260 | test "AMAZON-WC-SAMPLE2" "AMAZON-WC-SAMPLE2" "bert-base-uncased" 261 | test "AMAZON-MC-SAMPLE2" "AMAZON-MC-SAMPLE2" "bert-base-uncased" 262 | test "AMAZON-BC-SAMPLE2" "AMAZON-BC-SAMPLE2" "bert-base-uncased" 263 | test "AMAZON-S-SAMPLE2" "AMAZON-S-SAMPLE2" "bert-base-uncased" 264 | test "AMAZON-MS-SAMPLE2" "AMAZON-MS-SAMPLE2" "bert-base-uncased" 265 | test "AMAZON-MV-SAMPLE2" "AMAZON-MV-SAMPLE2" "bert-base-uncased" 266 | -------------------------------------------------------------------------------- /scripts/run_record.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SQUAD_DIR="./data/record-squad" 3 | BATCH_ID="Record" 4 | num_train_epochs=8.0 5 | learning_rate=4e-5 6 | 7 | RUN_TAG="tryrecord2" 8 | python ./src/run_squad.py \ 9 | --model_type bert \ 10 | --model_name_or_path bert-base-cased \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_lower_case \ 14 | --train_file $SQUAD_DIR/train.json \ 15 | --predict_file $SQUAD_DIR/dev.json \ 16 | --per_gpu_train_batch_size 8 \ 17 | --learning_rate ${learning_rate} \ 18 | --num_train_epochs ${num_train_epochs} \ 19 | --max_seq_length 384 \ 20 | --doc_stride 128 \ 21 | --overwrite_output_dir \ 22 | --output_dir ./model/${BATCH_ID}/${RUN_TAG} # >> ./log.txt 2>&1 23 | -------------------------------------------------------------------------------- /scripts/run_squad.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SQUAD_DIR="./data/SQuAD" 3 | BATCH_ID="SQUAD" 4 | num_train_epochs=2.0 5 | learning_rate=3e-5 6 | 7 | RUN_TAG="trysquad" 8 | python ./src/run_squad.py \ 9 | --model_type bert \ 10 | --model_name_or_path bert-base-cased \ 11 | --do_train \ 12 | --do_eval \ 13 | --do_lower_case \ 14 | --train_file $SQUAD_DIR/train-v1.1.json \ 15 | --predict_file $SQUAD_DIR/dev-v1.1.json \ 16 | --per_gpu_train_batch_size 8 \ 17 | --learning_rate ${learning_rate} \ 18 | --num_train_epochs ${num_train_epochs} \ 19 | --max_seq_length 384 \ 20 | --doc_stride 128 \ 21 | --overwrite_output_dir \ 22 | --output_dir ./model/${BATCH_ID}/${RUN_TAG} # >> ./log.txt 2>&1 23 | -------------------------------------------------------------------------------- /scripts/temp_run.sh: -------------------------------------------------------------------------------- 1 | allennlp evaluate ./model/bow-sum[sst-2]-1/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/bow-sum[sst-2]-1-[IMDB] 2 | allennlp evaluate ./model/word2vec-sum[sst-2]-5/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/word2vec-sum[sst-2]-5-[IMDB] 3 | allennlp evaluate ./model/word2vec-lstm[sst-2]-6/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/word2vec-lstm[sst-2]-6-[IMDB] 4 | allennlp evaluate ./model/word2vec-cnn[sst-2]-8/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/word2vec-cnn[sst-2]-8-[IMDB] 5 | allennlp evaluate ./model/glove-sum[sst-2]-10/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/glove-sum[sst-2]-10-[IMDB] 6 | allennlp evaluate ./model/glove-lstm[sst-2]-16/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/glove-lstm[sst-2]-16-[IMDB] 7 | allennlp evaluate ./model/glove-cnn[sst-2]-18/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/glove-cnn[sst-2]-18-[IMDB] 8 | allennlp evaluate ./model/roberta-large-pool[sst-2]-22/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/roberta-large-pool[sst-2]-22-[IMDB] 9 | allennlp evaluate ./model/bert-large-pool[sst-2]-25/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/bert-large-pool[sst-2]-25-[IMDB] 10 | allennlp evaluate ./model/bert-base-pool[sst-2]-26/model.tar.gz ./data/IMDB/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/bert-base-pool[sst-2]-26-[IMDB] 11 | allennlp evaluate ./model/bow-sum[imdb]-31/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/bow-sum[imdb]-31-[SST-2] 12 | allennlp evaluate ./model/word2vec-sum[imdb]-32/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/word2vec-sum[imdb]-32-[SST-2] 13 | allennlp evaluate ./model/word2vec-lstm[imdb]-33/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/word2vec-lstm[imdb]-33-[SST-2] 14 | allennlp evaluate ./model/word2vec-cnn[imdb]-34/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/word2vec-cnn[imdb]-34-[SST-2] 15 | allennlp evaluate ./model/glove-sum[imdb]-35/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/glove-sum[imdb]-35-[SST-2] 16 | allennlp evaluate ./model/glove-lstm[imdb]-36/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/glove-lstm[imdb]-36-[SST-2] 17 | allennlp evaluate ./model/glove-cnn[imdb]-37/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/glove-cnn[imdb]-37-[SST-2] 18 | allennlp evaluate ./model/roberta-large-pool[imdb]-38/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/roberta-large-pool[imdb]-38-[SST-2] 19 | allennlp evaluate ./model/bert-large-pool[imdb]-39/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/bert-large-pool[imdb]-39-[SST-2] 20 | allennlp evaluate ./model/bert-base-pool[imdb]-40/model.tar.gz ./data/SST-2/dev.tsv --include-package allennlp_glue_patch --cuda-device 0 --output-file ./results/bert-base-pool[imdb]-40-[SST-2] 21 | -------------------------------------------------------------------------------- /src/amazon_preprocess/fetch_result.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from ast import literal_eval as l_eval 3 | 4 | 5 | target_dir = "model/AmazonSplit5" 6 | if sys.argv: 7 | target_dir = sys.argv[1] 8 | 9 | all_results = {} 10 | 11 | for d in os.listdir(target_dir): 12 | try: 13 | loc = os.path.join(target_dir, d) 14 | d = d.replace("SAMPLE2", "SAMPLE") 15 | #print(loc) 16 | if '-[' in d: 17 | document = (d.split('-[')[1].split(']-')[0].strip(), d.split('-[')[1].split(']-')[1].strip()) 18 | else: 19 | document = (d, "-".join(d.split('-')[1:])[6:]) 20 | if 'bert-' in d: 21 | document = ("BERT", document[1][7:-7]) 22 | else: 23 | document = (document[0][19:], document[1]) 24 | document = (document[0][7:-7], document[1][7:-7]) 25 | eval_result = os.path.join(loc, "eval_results.txt") 26 | r = {} 27 | with open(eval_result) as fin: 28 | for l in fin: 29 | cor = l.split(" = ") 30 | k = cor[0].strip() 31 | v = "%.1f" % (l_eval(cor[1].strip())*100) 32 | r[k] = v 33 | # all_results[document] = "%s/%s" % (r['pearson'], r['spearmanr']) 34 | all_results[document] = "%s" % r['acc'] 35 | except Exception as e: 36 | print(e) 37 | pass 38 | 39 | from pprint import pprint 40 | # labels = list(set([k[1] for k in all_results.keys()])) 41 | labels = ['C','WC','MC','BC','S','MS','MV','B',] 42 | 43 | print(labels) 44 | # print 45 | line = "\t" 46 | for l in labels: 47 | line = line+l+'\t' 48 | print(line) 49 | line = "BERT\t" 50 | for l in labels: 51 | if ("BERT", l) in all_results: 52 | line = line + all_results[("BERT", l)] + '\t' 53 | else: 54 | line = line + '\t' 55 | print(line) 56 | for l in labels: 57 | line = l+'\t' 58 | for l2 in labels: 59 | if (l, l2) in all_results: 60 | line = line + all_results[(l, l2)] + '\t' 61 | else: 62 | line = line + '\t' 63 | print(line) 64 | -------------------------------------------------------------------------------- /src/amazon_preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | from tqdm import tqdm 3 | import csv 4 | 5 | complete_path = "data/raw_dataset/AMAZON/complete.json" 6 | metadata_path = "data/raw_dataset/AMAZON/metadata.json" 7 | 8 | # generate all_cat.txt 9 | ''' 10 | all_cat = set() 11 | with open(metadata_path) as f: 12 | for l in tqdm(f, total=9400000): 13 | d = eval(l) 14 | if 'categories' not in d: 15 | continue 16 | cs = d['categories'] 17 | for c in cs: 18 | c = tuple(c) 19 | if c not in all_cat: 20 | all_cat.add(c) 21 | print(c) 22 | 23 | with open("all_cat.txt", 'w') as fout: 24 | for c in all_cat: 25 | fout.write(str(c)) 26 | fout.write('\n') 27 | ''' 28 | 29 | # sort all_cat 30 | ''' 31 | with open("all_cat.txt") as fin: 32 | all_cat = [] 33 | from ast import literal_eval 34 | for l in fin: 35 | t = literal_eval(l) 36 | all_cat.append(t) 37 | 38 | with open("all_cat.txt", 'w') as fout: 39 | for c in sorted(all_cat): 40 | fout.write(str(c)) 41 | fout.write('\n') 42 | ''' 43 | # generate focus list 44 | focus_list_output_loc = "focus.json" 45 | labels = { 46 | "Clothes": lambda x: "Clothing" in x and "Shoes" not in x, 47 | "Women Clothing": lambda x: "Clothing" in x and "Women" in x and "Shoes" not in x, 48 | "Men Clothing": lambda x: "Clothing" in x and "Men" in x and "Shoes" not in x, 49 | "Baby Clothing": lambda x: "Clothing" in x and "Baby" in x and "Shoes" not in x, 50 | "Shoes": lambda x: "Shoes" in x, 51 | "Music": lambda x: ("CDs & Vinyl" in x or "Digital Music" in x) and "Books" not in x, 52 | "Movies": lambda x: "Movies" in x and "Books" not in x, 53 | "Books": lambda x: 'Books' in x, 54 | } 55 | ''' 56 | with open(metadata_path) as fin, open(focus_list_output_loc, "w") as fout: 57 | for l in tqdm(fin, total=9500000): 58 | d = eval(l) 59 | if 'categories' not in d or 'asin' not in d: 60 | continue 61 | cs = sum(d['categories'], []) 62 | ls = [] 63 | for label, f in labels.items(): 64 | if f(cs): 65 | ls.append(label) 66 | if ls: 67 | fout.write(str({"asin": d['asin'], "label": ls})) 68 | fout.write('\n') 69 | ''' 70 | 71 | import time, pickle 72 | from ast import literal_eval as l_eval 73 | ''' 74 | focus = {} 75 | t1 = time.time() 76 | print("Loading focus list") 77 | with open(focus_list_output_loc) as fin: 78 | for l in tqdm(fin, total=4240092): 79 | d = l_eval(l) 80 | focus[d['asin']] = d['label'] # 200 seconds 81 | with open("focus.pickle", 'wb') as fout: 82 | pickle.dump(focus, fout) 83 | with open("focus.pickle", 'rb') as fin: 84 | x = pickle.load(fin) # 10 seconds 85 | print("Loading focus list finished, used_time: {}".format(time.time()-t1)) 86 | ''' 87 | 88 | ''' 89 | 90 | def parse(path): 91 | with open(path) as fin: 92 | for l in fin: 93 | yield l_eval(l) 94 | 95 | t1 = time.time() 96 | print("Loading focus list") 97 | with open("focus.pickle", 'rb') as fin: 98 | focus = pickle.load(fin) # 10 seconds 99 | print("Loading focus list finished, used_time: {}".format(time.time()-t1)) 100 | 101 | import csv 102 | fouts = {label:open("amazon_split/{}-all.tsv".format(label.replace(' ', '_').lower()), 'w', newline="", encoding="UTF-8") for label, _ in labels.items()} 103 | writers = {label:csv.writer(f, delimiter='\t') for label, f in fouts.items()} 104 | 105 | # tryit = 50 106 | 107 | for d in tqdm(parse(complete_path), total=142900000): 108 | if "asin" not in d or "reviewText" not in d or "overall" not in d: 109 | continue 110 | asin = d['asin'] 111 | if asin not in focus: 112 | continue 113 | reviewText = d['reviewText'] 114 | overall = d['overall'] 115 | for l in focus[asin]: 116 | writers[l].writerow([asin, overall, reviewText]) 117 | # tryit -= 1 118 | # if tryit < 0: 119 | # break 120 | 121 | for d in fouts.values(): 122 | d.close() 123 | 124 | ''' 125 | 126 | # make folder for each task 127 | shorthand = { 128 | 'baby_clothing': "AMAZON-BC", 129 | 'books': "AMAZON-B", 130 | 'clothes': "AMAZON-C", 131 | 'men_clothing': "AMAZON-MC", 132 | 'women_clothing': "AMAZON-WC", 133 | 'movies': "AMAZON-MV", 134 | 'music': "AMAZON-MS", 135 | 'shoes': "AMAZON-S" 136 | } 137 | for k, v in shorthand.items(): 138 | print("begin -> {}".format(k)) 139 | def create_folder(loc): 140 | if not os.path.exists(loc): 141 | os.makedirs(loc) 142 | folder = "amazon_split/{}".format(v) 143 | create_folder(folder) 144 | with open("amazon_split/{}-all.tsv".format(k)) as fin: 145 | f_train = open(folder+'/train.tsv', 'w', newline='', encoding="UTF-8") 146 | f_dev = open(folder+'/dev.tsv', 'w', newline='', encoding="UTF-8") 147 | s_train, n_train = 0, 0 148 | s_dev, n_dev = 0, 0 149 | for i, l in enumerate(fin): 150 | if i % 5 == 0: 151 | f_dev.write(l) 152 | s_dev += l_eval(l.split('\t')[1]) 153 | n_dev += 1 154 | else: 155 | f_train.write(l) 156 | s_train += l_eval(l.split('\t')[1]) 157 | n_train += 1 158 | print(s_train, n_train, s_dev, n_dev) 159 | f_train.close() 160 | f_dev.close() -------------------------------------------------------------------------------- /src/amazon_preprocess/sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dir_to_sample = [ 4 | "AMAZON-B", 5 | "AMAZON-BC", 6 | "AMAZON-C", 7 | "AMAZON-MC", 8 | "AMAZON-MS", 9 | "AMAZON-MV", 10 | "AMAZON-S", 11 | "AMAZON-WC", 12 | ] 13 | 14 | data_path = "data/raw_dataset" 15 | expected_num = int(5e4) 16 | train_num = int(expected_num*0.8) 17 | dev_num = expected_num - train_num 18 | 19 | import random 20 | from pprint import pprint 21 | from collections import Counter 22 | from ast import literal_eval as l_eval 23 | random.seed("AMAZON") 24 | 25 | for d in dir_to_sample: 26 | print("Processing {}".format(d)) 27 | ori_dir = os.path.join(data_path, d) 28 | target_dir = ori_dir+"-SAMPLE" 29 | if not os.path.exists(target_dir): 30 | os.makedirs(target_dir) 31 | dev_loc = os.path.join(ori_dir, "dev.tsv") 32 | train_loc = os.path.join(ori_dir, "train.tsv") 33 | results = {} 34 | statistic = [] 35 | with open(train_loc, 'r') as fin: 36 | for line in fin: 37 | sentence = line.split('\t')[2] 38 | score = l_eval(line.split('\t')[1]) 39 | if sentence in results: 40 | continue 41 | results[sentence] = line 42 | statistic.append(score) 43 | if len(results) >= expected_num: 44 | break 45 | pprint(sorted(dict(Counter(statistic)).items())) 46 | assert len(results) == expected_num 47 | dev_loc = os.path.join(target_dir, "dev.tsv") 48 | train_loc = os.path.join(target_dir, "train.tsv") 49 | output = list(results.values()) 50 | random.shuffle(output) 51 | with open(train_loc, 'w', newline='', encoding="UTF-8") as fout: 52 | for l in output[:train_num]: 53 | fout.write(l) 54 | with open(dev_loc, 'w', newline='', encoding="UTF-8") as fout: 55 | for l in output[train_num:]: 56 | fout.write(l) 57 | -------------------------------------------------------------------------------- /src/amazon_preprocess/sample2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | dir_to_sample = [ 4 | "AMAZON-BC", 5 | "AMAZON-B", 6 | "AMAZON-C", 7 | "AMAZON-MC", 8 | "AMAZON-MS", 9 | "AMAZON-MV", 10 | "AMAZON-S", 11 | "AMAZON-WC", 12 | ] 13 | # sample again 14 | # dir_to_sample = [d + "-SAMPLE" for d in dir_to_sample] 15 | # no, that is wrong 16 | 17 | data_path = "data/raw_dataset" 18 | expected_num = int(3e3*5) 19 | train_num = int(expected_num*0.8) 20 | dev_num = expected_num - train_num 21 | 22 | import random 23 | from pprint import pprint 24 | from collections import Counter 25 | from ast import literal_eval as l_eval 26 | random.seed("AMAZON") 27 | 28 | from tqdm import tqdm 29 | for d in dir_to_sample: 30 | print("Processing {}".format(d)) 31 | ori_dir = os.path.join(data_path, d) 32 | target_dir = ori_dir+"-SAMPLE2" 33 | if not os.path.exists(target_dir): 34 | os.makedirs(target_dir) 35 | dev_loc = os.path.join(ori_dir, "dev.tsv") 36 | train_loc = os.path.join(ori_dir, "train.tsv") 37 | results = {} 38 | statistic = Counter() 39 | for loc in [train_loc, dev_loc]: 40 | with open(loc, 'r') as fin, tqdm(total=expected_num) as pbar: 41 | for line in fin: 42 | sentence = line.split('\t')[2] 43 | score = l_eval(line.split('\t')[1]) 44 | if len(results) >= expected_num: 45 | break 46 | if sentence in results: 47 | continue 48 | if statistic[score] >= expected_num / 5: 49 | continue 50 | results[sentence] = line 51 | statistic[score] += 1 52 | pbar.update(1) 53 | pprint(sorted(dict(statistic).items())) 54 | assert len(results) == expected_num 55 | dev_loc = os.path.join(target_dir, "dev.tsv") 56 | train_loc = os.path.join(target_dir, "train.tsv") 57 | output = list(results.values()) 58 | random.shuffle(output) 59 | with open(train_loc, 'w', newline='', encoding="UTF-8") as fout: 60 | for l in output[:train_num]: 61 | fout.write(l) 62 | with open(dev_loc, 'w', newline='', encoding="UTF-8") as fout: 63 | for l in output[train_num:]: 64 | fout.write(l) 65 | -------------------------------------------------------------------------------- /src/data_loader.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camelop/NLP-Robustness/de2761886cfb4c7d3cf0bf0a9d19d77588bfea36/src/data_loader.py -------------------------------------------------------------------------------- /src/deprecated_run.py: -------------------------------------------------------------------------------- 1 | import argparse, uuid, os 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from pytorch_transformers import (BertConfig, 8 | BertForSequenceClassification, BertTokenizer) 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | def set_seed(args): 13 | random.seed(args.seed) 14 | np.random.seed(args.seed) 15 | torch.manual_seed(args.seed) 16 | if args.n_gpu > 0: 17 | torch.cuda.manual_seed_all(args.seed) 18 | 19 | def main(): 20 | # parse the arguments 21 | parser = argparse.ArgumentParser(description='Process some integers.') 22 | # required parameters 23 | parser.add_argument("func", default='help', type=str, help="train/test/help") 24 | parser.add_argument("--data_dir", default="data", type=str, required=False) 25 | parser.add_argument("--task_name", default=None, type=str, required=False) 26 | parser.add_argument("--tag", default=None, type=str, required=False) 27 | parser.add_argument("--input_dir", default=None, type=str, required=False) 28 | parser.add_argument("--output_dir", default=None, type=str, required=False) 29 | parser.add_argument("--model_name", default="bert-base-uncased", type=str, required=False) 30 | 31 | args = parser.parse_args() 32 | 33 | # do the func 34 | if args.func == "help": 35 | print("train to generate model, test to evaluate model") 36 | else: 37 | # gather parameters 38 | tag = args.tag 39 | if tag == None: 40 | tag = args.tag = str(uuid.uuid1()) 41 | print("params: {}\ntag: {}".format(str(args), tag)) 42 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 43 | n_gpu = args.n_gpu = torch.cuda.device_count() 44 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt = '%m/%d/%Y %H:%M:%S',level = logging.INFO) 45 | logger.warning("device: %s, n_gpu: %s", device, n_gpu) 46 | set_seed(args) 47 | args.task_name = args.task_name.lower() 48 | # TODO task specific settings 49 | num_labels = None 50 | 51 | 52 | 53 | if args.func == "train": 54 | pass # train on the task 55 | # gather parameters 56 | config = BertConfig.from_pretrained() 57 | 58 | output_dir = args.output_dir = args.output_dir if args.output_dir else "model" 59 | if os.path.exists(output_dir) and os.list(output_dir): 60 | raise ValueError("Output dir exists") 61 | config = BertConfig.from_pretrained(args.model_name, num_labels=num_labels, finetuning_task=args.task_name) 62 | tokenizer = BertTokenizer.from_pretrained(args.model_name, do_lower_case="uncased" in args.model_name) 63 | model = BertForSequenceClassification.from_pretrained(args.model_name, from_tf=False, config=config) 64 | 65 | 66 | elif args.func == "test": 67 | pass # test on the task 68 | else: 69 | raise NotImplementedError 70 | 71 | 72 | # do corresponding task 73 | 74 | if __name__ == "__main__": 75 | main() -------------------------------------------------------------------------------- /src/mnli_preprocess/splitMNLI.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | from csv import QUOTE_NONE 4 | import random 5 | 6 | train_loc = "train.tsv" 7 | dev_loc = "dev.tsv" 8 | def read_table(loc): 9 | return pd.read_csv(loc, sep='\t', index_col=0, quoting=QUOTE_NONE) 10 | 11 | train = read_table(train_loc) 12 | dev = read_table(dev_loc) 13 | 14 | # telephone - face2face/letter 15 | header = ['index'] + list(train.columns) 16 | 17 | def write_split(s, tag, genre): 18 | cnt = 0 19 | with open("{}.tsv".format(tag), 'w', encoding="UTF-8") as fout: 20 | writer = csv.writer(fout, delimiter='\t') 21 | writer.writerow(header) 22 | for i in s.itertuples(): 23 | if i[3] == genre: 24 | cnt += 1 25 | writer.writerow(list(i)) 26 | print(genre, cnt) 27 | 28 | write_split(train, "train-telephone", "telephone") 29 | write_split(dev, "dev-telephone", "telephone") 30 | write_split(dev, "dev-letters", "letters") 31 | write_split(dev, "dev-facetoface", "facetoface") -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | from multiprocessing import Pool 3 | 4 | run_glue_loc = "/home/littleround/nlp-robustness/src" 5 | if run_glue_loc not in sys.path: 6 | sys.path.append(run_glue_loc) 7 | 8 | from run_glue import main 9 | 10 | def find_source_task(task): 11 | task = task.lower() 12 | if task in ("sst-2", "imdb"): 13 | return "sst-2" 14 | elif task in ("sts-b-headlines", "sts-b-msrpar", "sts-b-msrvid", "sts-b-images"): 15 | return "sts-b" 16 | 17 | 18 | def run(model_type, task, ori_task=None, gpu="-1"): 19 | gpu = str(gpu) 20 | task = task.lower() 21 | ori_task = ori_task.lower() if ori_task is not None else None 22 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 23 | if model_type in ["Glove", "Word2vec", "LSTM"]: 24 | print("{} not supported now.".format(model_type)) 25 | else: 26 | _model_type = model_type.split('-')[0] 27 | ori_model_name = "{}[{}]".format(model_type, ori_task) if ori_task is not None else None 28 | _model_name_or_path = model_type if ori_task is None else "./model/{}".format(ori_model_name) 29 | _task_name = find_source_task(task) 30 | _data_dir = "./data/{}".format(task.upper()) 31 | _max_seq_length = str(128) 32 | _per_gpu_train_batch_size = str(8) 33 | _learning_rate = "2e-5" if ori_task is None else "0.0" 34 | _num_train_epochs = str(3) 35 | new_model_name = "{}[{}]".format(model_type, task) if ori_task is None else "{}[{}]-[{}]".format(model_type, task, ori_task) 36 | _output_dir = "./model/{}".format(new_model_name) 37 | main([ 38 | "--model_type", _model_type, 39 | "--model_name_or_path", _model_name_or_path, 40 | "--task_name", _task_name, 41 | "--do_train", "--do_eval", "--do_lower_case", 42 | "--data_dir", _data_dir, 43 | "--max_seq_length", _max_seq_length, 44 | "--per_gpu_train_batch_size", _per_gpu_train_batch_size, 45 | "--learning_rate", _learning_rate, 46 | "--num_train_epochs", _num_train_epochs, 47 | "--output_dir", _output_dir, 48 | "--save_steps", "10000", 49 | "--overwrite_output_dir" 50 | ]) 51 | return new_model_name 52 | 53 | def packed_run(x): 54 | return run(*x) 55 | 56 | if __name__ == "__main__": 57 | # test "fine_tune_model" 58 | available_gpus = [0, 3, 4, 5, 6, 7] 59 | with Pool(6) as p, open("results.txt", 'w') as fout: 60 | model = "bert-base-uncased" # "roberta-large" 61 | print(p.map(packed_run, [ 62 | (model, "sst-2", None, 0), 63 | (model, "imdb", None, 3), 64 | (model, "sts-b-msrvid", None, 4), 65 | (model, "sts-b-msrpar", None, 5), 66 | (model, "sts-b-headlines", None, 6), 67 | (model, "sts-b-images", None, 7), 68 | ]), file=fout) -------------------------------------------------------------------------------- /src/sts_preprocess/resplit.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | from csv import QUOTE_NONE 4 | import random 5 | 6 | data_dir = "./" 7 | def read_table(loc): 8 | return pd.read_csv(loc, sep='\t', index_col=0, quoting=QUOTE_NONE) 9 | train = read_table(data_dir+"train.tsv") # actually all 10 | 11 | genres = [("main-news", ), ("main-captions", ), ("main-forum", "main-forums")] 12 | trains_num = dict(zip(genres, [3039, 2100, 660])) 13 | devs_num = dict(zip(genres, [760, 525, 165])) 14 | map_to_genre = {} 15 | for gs in genres: 16 | for g in gs: 17 | map_to_genre[g] = gs 18 | 19 | examples = {g:[] for g in genres} 20 | for i in train.itertuples(): 21 | g = map_to_genre[i[1]] 22 | examples[g].append(list(i)) 23 | 24 | random.seed("TAPE") 25 | new_trains = {g:None for g in genres} 26 | new_devs = {g:None for g in genres} 27 | for k, v in examples.items(): 28 | random.shuffle(v) 29 | new_trains[k] = v[:trains_num[k]] 30 | new_devs[k] = v[trains_num[k]:] 31 | assert len(new_devs[k]) == devs_num[k] 32 | new_trains_all = sum([v for k, v in new_trains.items()], []) 33 | random.shuffle(new_trains_all) 34 | new_devs_all = sum([v for k, v in new_devs.items()], []) 35 | random.shuffle(new_devs_all) 36 | 37 | header = ['index'] + list(train.columns) 38 | with open("train-{}.tsv".format("all"), 'w', newline='', encoding="UTF-8") as f: 39 | writer = csv.writer(f, delimiter='\t') 40 | writer.writerow(header) 41 | for l in new_trains_all: 42 | writer.writerow(l) 43 | with open("dev-{}.tsv".format("all"), 'w', newline='', encoding="UTF-8") as f: 44 | writer = csv.writer(f, delimiter='\t') 45 | writer.writerow(header) 46 | for l in new_devs_all: 47 | writer.writerow(l) 48 | 49 | for g in genres: 50 | with open("train-{}.tsv".format(g[0]), 'w', newline='', encoding="UTF-8") as f: 51 | writer = csv.writer(f, delimiter='\t') 52 | writer.writerow(header) 53 | for l in new_trains[g]: 54 | writer.writerow(l) 55 | with open("dev-{}.tsv".format(g[0]), 'w', newline='', encoding="UTF-8") as f: 56 | writer = csv.writer(f, delimiter='\t') 57 | writer.writerow(header) 58 | for l in new_devs[g]: 59 | writer.writerow(l) -------------------------------------------------------------------------------- /src/sts_preprocess/resplit2.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import csv 3 | from csv import QUOTE_NONE 4 | import random 5 | 6 | data_dir = "./" 7 | def read_table(loc): 8 | return pd.read_csv(loc, sep='\t', index_col=0, quoting=QUOTE_NONE) 9 | train = read_table(data_dir+"train.tsv") # actually all 10 | 11 | files = ['MSRpar', 'headlines', 'MSRvid', 'images'] 12 | trains_num = dict(zip(files, [1000, 1799, 1000, 1000])) 13 | devs_num = dict(zip(files, [250, 450, 250, 250])) 14 | 15 | examples = {f:[] for f in files} 16 | 17 | for i in train.itertuples(): 18 | f = i[2] 19 | if f not in files: 20 | continue 21 | examples[f].append(list(i)) 22 | 23 | random.seed("TAPE") 24 | new_trains = {f:None for f in files} 25 | new_devs = {f:None for f in files} 26 | for k, v in examples.items(): 27 | random.shuffle(v) 28 | new_trains[k] = v[:trains_num[k]] 29 | new_devs[k] = v[trains_num[k]:] 30 | assert len(new_devs[k]) == devs_num[k] 31 | 32 | header = ['index'] + list(train.columns) 33 | 34 | ''' 35 | new_trains_all = sum([v for k, v in new_trains.items()], []) 36 | random.shuffle(new_trains_all) 37 | new_devs_all = sum([v for k, v in new_devs.items()], []) 38 | random.shuffle(new_devs_all) 39 | with open("train-{}.tsv".format("all"), 'w', newline='', encoding="UTF-8") as f: 40 | writer = csv.writer(f, delimiter='\t') 41 | writer.writerow(header) 42 | for l in new_trains_all: 43 | writer.writerow(l) 44 | with open("dev-{}.tsv".format("all"), 'w', newline='', encoding="UTF-8") as f: 45 | writer = csv.writer(f, delimiter='\t') 46 | writer.writerow(header) 47 | for l in new_devs_all: 48 | writer.writerow(l) 49 | ''' 50 | 51 | for f in files: 52 | with open("train-{}.tsv".format(f), 'w', newline='', encoding="UTF-8") as fout: 53 | writer = csv.writer(fout, delimiter='\t') 54 | writer.writerow(header) 55 | for l in new_trains[f]: 56 | writer.writerow(l) 57 | with open("dev-{}.tsv".format(f), 'w', newline='', encoding="UTF-8") as fout: 58 | writer = csv.writer(fout, delimiter='\t') 59 | writer.writerow(header) 60 | for l in new_devs[f]: 61 | writer.writerow(l) -------------------------------------------------------------------------------- /src/utils_squad_evaluate.py: -------------------------------------------------------------------------------- 1 | """ Official evaluation script for SQuAD version 2.0. 2 | Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0 3 | 4 | In addition to basic functionality, we also compute additional statistics and 5 | plot precision-recall curves if an additional na_prob.json file is provided. 6 | This file is expected to map question ID's to the model's predicted probability 7 | that a question is unanswerable. 8 | """ 9 | import argparse 10 | import collections 11 | import json 12 | import numpy as np 13 | import os 14 | import re 15 | import string 16 | import sys 17 | 18 | class EVAL_OPTS(): 19 | def __init__(self, data_file, pred_file, out_file="", 20 | na_prob_file="na_prob.json", na_prob_thresh=1.0, 21 | out_image_dir=None, verbose=False): 22 | self.data_file = data_file 23 | self.pred_file = pred_file 24 | self.out_file = out_file 25 | self.na_prob_file = na_prob_file 26 | self.na_prob_thresh = na_prob_thresh 27 | self.out_image_dir = out_image_dir 28 | self.verbose = verbose 29 | 30 | OPTS = None 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser('Official evaluation script for SQuAD version 2.0.') 34 | parser.add_argument('data_file', metavar='data.json', help='Input data JSON file.') 35 | parser.add_argument('pred_file', metavar='pred.json', help='Model predictions.') 36 | parser.add_argument('--out-file', '-o', metavar='eval.json', 37 | help='Write accuracy metrics to file (default is stdout).') 38 | parser.add_argument('--na-prob-file', '-n', metavar='na_prob.json', 39 | help='Model estimates of probability of no answer.') 40 | parser.add_argument('--na-prob-thresh', '-t', type=float, default=1.0, 41 | help='Predict "" if no-answer probability exceeds this (default = 1.0).') 42 | parser.add_argument('--out-image-dir', '-p', metavar='out_images', default=None, 43 | help='Save precision-recall curves to directory.') 44 | parser.add_argument('--verbose', '-v', action='store_true') 45 | if len(sys.argv) == 1: 46 | parser.print_help() 47 | sys.exit(1) 48 | return parser.parse_args() 49 | 50 | def make_qid_to_has_ans(dataset): 51 | qid_to_has_ans = {} 52 | for article in dataset: 53 | for p in article['paragraphs']: 54 | for qa in p['qas']: 55 | qid_to_has_ans[qa['id']] = bool(qa['answers']) 56 | return qid_to_has_ans 57 | 58 | def normalize_answer(s): 59 | """Lower text and remove punctuation, articles and extra whitespace.""" 60 | def remove_articles(text): 61 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 62 | return re.sub(regex, ' ', text) 63 | def white_space_fix(text): 64 | return ' '.join(text.split()) 65 | def remove_punc(text): 66 | exclude = set(string.punctuation) 67 | return ''.join(ch for ch in text if ch not in exclude) 68 | def lower(text): 69 | return text.lower() 70 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 71 | 72 | def get_tokens(s): 73 | if not s: return [] 74 | return normalize_answer(s).split() 75 | 76 | def compute_exact(a_gold, a_pred): 77 | return int(normalize_answer(a_gold) == normalize_answer(a_pred)) 78 | 79 | def compute_f1(a_gold, a_pred): 80 | gold_toks = get_tokens(a_gold) 81 | pred_toks = get_tokens(a_pred) 82 | common = collections.Counter(gold_toks) & collections.Counter(pred_toks) 83 | num_same = sum(common.values()) 84 | if len(gold_toks) == 0 or len(pred_toks) == 0: 85 | # If either is no-answer, then F1 is 1 if they agree, 0 otherwise 86 | return int(gold_toks == pred_toks) 87 | if num_same == 0: 88 | return 0 89 | precision = 1.0 * num_same / len(pred_toks) 90 | recall = 1.0 * num_same / len(gold_toks) 91 | f1 = (2 * precision * recall) / (precision + recall) 92 | return f1 93 | 94 | def get_raw_scores(dataset, preds): 95 | exact_scores = {} 96 | f1_scores = {} 97 | for article in dataset: 98 | for p in article['paragraphs']: 99 | for qa in p['qas']: 100 | qid = qa['id'] 101 | gold_answers = [a['text'] for a in qa['answers'] 102 | if normalize_answer(a['text'])] 103 | if not gold_answers: 104 | # For unanswerable questions, only correct answer is empty string 105 | gold_answers = [''] 106 | if qid not in preds: 107 | print('Missing prediction for %s' % qid) 108 | continue 109 | a_pred = preds[qid] 110 | # Take max over all gold answers 111 | exact_scores[qid] = max(compute_exact(a, a_pred) for a in gold_answers) 112 | f1_scores[qid] = max(compute_f1(a, a_pred) for a in gold_answers) 113 | return exact_scores, f1_scores 114 | 115 | def apply_no_ans_threshold(scores, na_probs, qid_to_has_ans, na_prob_thresh): 116 | new_scores = {} 117 | for qid, s in scores.items(): 118 | pred_na = na_probs[qid] > na_prob_thresh 119 | if pred_na: 120 | new_scores[qid] = float(not qid_to_has_ans[qid]) 121 | else: 122 | new_scores[qid] = s 123 | return new_scores 124 | 125 | def make_eval_dict(exact_scores, f1_scores, qid_list=None): 126 | if not qid_list: 127 | total = len(exact_scores) 128 | return collections.OrderedDict([ 129 | ('exact', 100.0 * sum(exact_scores.values()) / total), 130 | ('f1', 100.0 * sum(f1_scores.values()) / total), 131 | ('total', total), 132 | ]) 133 | else: 134 | total = len(qid_list) 135 | return collections.OrderedDict([ 136 | ('exact', 100.0 * sum(exact_scores[k] for k in qid_list) / total), 137 | ('f1', 100.0 * sum(f1_scores[k] for k in qid_list) / total), 138 | ('total', total), 139 | ]) 140 | 141 | def merge_eval(main_eval, new_eval, prefix): 142 | for k in new_eval: 143 | main_eval['%s_%s' % (prefix, k)] = new_eval[k] 144 | 145 | def plot_pr_curve(precisions, recalls, out_image, title): 146 | plt.step(recalls, precisions, color='b', alpha=0.2, where='post') 147 | plt.fill_between(recalls, precisions, step='post', alpha=0.2, color='b') 148 | plt.xlabel('Recall') 149 | plt.ylabel('Precision') 150 | plt.xlim([0.0, 1.05]) 151 | plt.ylim([0.0, 1.05]) 152 | plt.title(title) 153 | plt.savefig(out_image) 154 | plt.clf() 155 | 156 | def make_precision_recall_eval(scores, na_probs, num_true_pos, qid_to_has_ans, 157 | out_image=None, title=None): 158 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 159 | true_pos = 0.0 160 | cur_p = 1.0 161 | cur_r = 0.0 162 | precisions = [1.0] 163 | recalls = [0.0] 164 | avg_prec = 0.0 165 | for i, qid in enumerate(qid_list): 166 | if qid_to_has_ans[qid]: 167 | true_pos += scores[qid] 168 | cur_p = true_pos / float(i+1) 169 | cur_r = true_pos / float(num_true_pos) 170 | if i == len(qid_list) - 1 or na_probs[qid] != na_probs[qid_list[i+1]]: 171 | # i.e., if we can put a threshold after this point 172 | avg_prec += cur_p * (cur_r - recalls[-1]) 173 | precisions.append(cur_p) 174 | recalls.append(cur_r) 175 | if out_image: 176 | plot_pr_curve(precisions, recalls, out_image, title) 177 | return {'ap': 100.0 * avg_prec} 178 | 179 | def run_precision_recall_analysis(main_eval, exact_raw, f1_raw, na_probs, 180 | qid_to_has_ans, out_image_dir): 181 | if out_image_dir and not os.path.exists(out_image_dir): 182 | os.makedirs(out_image_dir) 183 | num_true_pos = sum(1 for v in qid_to_has_ans.values() if v) 184 | if num_true_pos == 0: 185 | return 186 | pr_exact = make_precision_recall_eval( 187 | exact_raw, na_probs, num_true_pos, qid_to_has_ans, 188 | out_image=os.path.join(out_image_dir, 'pr_exact.png'), 189 | title='Precision-Recall curve for Exact Match score') 190 | pr_f1 = make_precision_recall_eval( 191 | f1_raw, na_probs, num_true_pos, qid_to_has_ans, 192 | out_image=os.path.join(out_image_dir, 'pr_f1.png'), 193 | title='Precision-Recall curve for F1 score') 194 | oracle_scores = {k: float(v) for k, v in qid_to_has_ans.items()} 195 | pr_oracle = make_precision_recall_eval( 196 | oracle_scores, na_probs, num_true_pos, qid_to_has_ans, 197 | out_image=os.path.join(out_image_dir, 'pr_oracle.png'), 198 | title='Oracle Precision-Recall curve (binary task of HasAns vs. NoAns)') 199 | merge_eval(main_eval, pr_exact, 'pr_exact') 200 | merge_eval(main_eval, pr_f1, 'pr_f1') 201 | merge_eval(main_eval, pr_oracle, 'pr_oracle') 202 | 203 | def histogram_na_prob(na_probs, qid_list, image_dir, name): 204 | if not qid_list: 205 | return 206 | x = [na_probs[k] for k in qid_list] 207 | weights = np.ones_like(x) / float(len(x)) 208 | plt.hist(x, weights=weights, bins=20, range=(0.0, 1.0)) 209 | plt.xlabel('Model probability of no-answer') 210 | plt.ylabel('Proportion of dataset') 211 | plt.title('Histogram of no-answer probability: %s' % name) 212 | plt.savefig(os.path.join(image_dir, 'na_prob_hist_%s.png' % name)) 213 | plt.clf() 214 | 215 | def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): 216 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 217 | cur_score = num_no_ans 218 | best_score = cur_score 219 | best_thresh = 0.0 220 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 221 | for i, qid in enumerate(qid_list): 222 | if qid not in scores: continue 223 | if qid_to_has_ans[qid]: 224 | diff = scores[qid] 225 | else: 226 | if preds[qid]: 227 | diff = -1 228 | else: 229 | diff = 0 230 | cur_score += diff 231 | if cur_score > best_score: 232 | best_score = cur_score 233 | best_thresh = na_probs[qid] 234 | return 100.0 * best_score / len(scores), best_thresh 235 | 236 | def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): 237 | num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) 238 | cur_score = num_no_ans 239 | best_score = cur_score 240 | best_thresh = 0.0 241 | qid_list = sorted(na_probs, key=lambda k: na_probs[k]) 242 | for i, qid in enumerate(qid_list): 243 | if qid not in scores: continue 244 | if qid_to_has_ans[qid]: 245 | diff = scores[qid] 246 | else: 247 | if preds[qid]: 248 | diff = -1 249 | else: 250 | diff = 0 251 | cur_score += diff 252 | if cur_score > best_score: 253 | best_score = cur_score 254 | best_thresh = na_probs[qid] 255 | 256 | has_ans_score, has_ans_cnt = 0, 0 257 | for qid in qid_list: 258 | if not qid_to_has_ans[qid]: continue 259 | has_ans_cnt += 1 260 | 261 | if qid not in scores: continue 262 | has_ans_score += scores[qid] 263 | 264 | return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt 265 | 266 | def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 267 | best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) 268 | best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) 269 | main_eval['best_exact'] = best_exact 270 | main_eval['best_exact_thresh'] = exact_thresh 271 | main_eval['best_f1'] = best_f1 272 | main_eval['best_f1_thresh'] = f1_thresh 273 | 274 | def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): 275 | best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) 276 | best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) 277 | main_eval['best_exact'] = best_exact 278 | main_eval['best_exact_thresh'] = exact_thresh 279 | main_eval['best_f1'] = best_f1 280 | main_eval['best_f1_thresh'] = f1_thresh 281 | main_eval['has_ans_exact'] = has_ans_exact 282 | main_eval['has_ans_f1'] = has_ans_f1 283 | 284 | def main(OPTS): 285 | with open(OPTS.data_file) as f: 286 | dataset_json = json.load(f) 287 | dataset = dataset_json['data'] 288 | with open(OPTS.pred_file) as f: 289 | preds = json.load(f) 290 | if OPTS.na_prob_file: 291 | with open(OPTS.na_prob_file) as f: 292 | na_probs = json.load(f) 293 | else: 294 | na_probs = {k: 0.0 for k in preds} 295 | qid_to_has_ans = make_qid_to_has_ans(dataset) # maps qid to True/False 296 | has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] 297 | no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] 298 | exact_raw, f1_raw = get_raw_scores(dataset, preds) 299 | exact_thresh = apply_no_ans_threshold(exact_raw, na_probs, qid_to_has_ans, 300 | OPTS.na_prob_thresh) 301 | f1_thresh = apply_no_ans_threshold(f1_raw, na_probs, qid_to_has_ans, 302 | OPTS.na_prob_thresh) 303 | out_eval = make_eval_dict(exact_thresh, f1_thresh) 304 | if has_ans_qids: 305 | has_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=has_ans_qids) 306 | merge_eval(out_eval, has_ans_eval, 'HasAns') 307 | if no_ans_qids: 308 | no_ans_eval = make_eval_dict(exact_thresh, f1_thresh, qid_list=no_ans_qids) 309 | merge_eval(out_eval, no_ans_eval, 'NoAns') 310 | if OPTS.na_prob_file: 311 | find_all_best_thresh(out_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans) 312 | if OPTS.na_prob_file and OPTS.out_image_dir: 313 | run_precision_recall_analysis(out_eval, exact_raw, f1_raw, na_probs, 314 | qid_to_has_ans, OPTS.out_image_dir) 315 | histogram_na_prob(na_probs, has_ans_qids, OPTS.out_image_dir, 'hasAns') 316 | histogram_na_prob(na_probs, no_ans_qids, OPTS.out_image_dir, 'noAns') 317 | if OPTS.out_file: 318 | with open(OPTS.out_file, 'w') as f: 319 | json.dump(out_eval, f) 320 | else: 321 | print(json.dumps(out_eval, indent=2)) 322 | return out_eval 323 | 324 | if __name__ == '__main__': 325 | OPTS = parse_args() 326 | if OPTS.out_image_dir: 327 | import matplotlib 328 | matplotlib.use('Agg') 329 | import matplotlib.pyplot as plt 330 | main(OPTS) 331 | -------------------------------------------------------------------------------- /statistic.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camelop/NLP-Robustness/de2761886cfb4c7d3cf0bf0a9d19d77588bfea36/statistic.txt -------------------------------------------------------------------------------- /sts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/camelop/NLP-Robustness/de2761886cfb4c7d3cf0bf0a9d19d77588bfea36/sts.png --------------------------------------------------------------------------------