├── .gitignore ├── LICENSE ├── README.md ├── augment_index_builder.py ├── augment_utils.py ├── combined_data ├── README.md ├── SentiWordNet_3.0.0_20100705.txt ├── hotel_classification │ ├── dev.txt │ └── train.txt ├── hotel_pairing │ ├── dev.txt │ ├── train.txt.combined │ └── unlabeled.txt └── hotel_tagging │ ├── dev.txt │ ├── train.txt.combined │ └── unlabeled.txt ├── configs.json ├── requirements.txt ├── run_pipeline.py ├── snippext ├── __init__.py ├── augment.py ├── baseline.py ├── conlleval.py ├── dataset.py ├── mixda.py ├── mixmatchnl.py ├── model.py └── train_util.py ├── train_baseline.py ├── train_mixda.py └── train_mixmatchnl.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pt 3 | *.pth 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Megagon Labs 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Snippext 2 | Snippext is the extraction pipeline for mining opinions and customer experiences from user-generated content (e.g., online reviews). 3 | 4 | Paper: Zhengjie Miao, Yuliang Li, Xiaolan Wang, Wang-Chiew Tan, "Snippext: Semi-supervised Opinion Mining with Augmented Data", In theWebConf (WWW) 2020 5 | 6 | ## Requirements 7 | 8 | * Python 3.7.5 9 | * PyTorch 1.3 10 | * HuggingFace Transformers 11 | * Spacy with the ``em_core_web_sm`` models 12 | * NLTK (stopwords, wordnet) 13 | * Gensim 14 | * NVIDIA Apex (fp16 training) 15 | 16 | Install required packages 17 | ``` 18 | conda install -c conda-forge nvidia-apex 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | Download pre-trained BERT models and word2vec models (for data augmentation) : 23 | ``` 24 | wget https://snippext.s3.us-east-2.amazonaws.com/finetuned_bert.zip 25 | unzip finetuned_bert.zip 26 | wget https://snippext.s3.us-east-2.amazonaws.com/word2vec.zip 27 | unzip word2vec.zip 28 | ``` 29 | 30 | ## Training with the baseline BERT finetuning 31 | 32 | The baseline method performs BERT finetuning on a specific task: 33 | ``` 34 | CUDA_VISIBLE_DEVICES=0 python train_baseline.py \ 35 | --task restaurant_ae_tagging \ 36 | --logdir results/ \ 37 | --save_model \ 38 | --finetuning \ 39 | --batch_size 32 \ 40 | --lr 5e-5 \ 41 | --n_epochs 20 \ 42 | --bert_path finetuned_bert/rest_model.bin 43 | ``` 44 | 45 | Parameters: 46 | * ``--task``: the name of the task (defined in ``configs.json``) 47 | * ``--logdir``: the logging directory with Tensorboard 48 | * ``--save_model``: whether to save the best model 49 | * ``--batch_size``, ``--lr``, ``--n_epochs``: batch size, learning rate, and the number of epochs 50 | * ``--bert_path`` (Optional): the path of a fine-tuned BERT checkpoint. Use the base uncased model if not specified. 51 | * ``--max_len`` (Optional): maximum sequence length 52 | 53 | *(New)* (also in MixDA and MixMatchNL): 54 | * ``--fp16`` (Optional): whether to train with fp16 acceleration 55 | * ``--lm`` (Optional): other language models, e.g., "distilbert" or "albert" 56 | 57 | ### Task Specification 58 | 59 | The train/dev/test sets of a task (tagging or span classification) are specificed in the file ``configs.json``. 60 | The file ``configs.json`` is a list of entries where each one is of the following format: 61 | ``` 62 | { 63 | "name": "hotel_tagging", 64 | "task_type": "tagging", 65 | "vocab": [ 66 | "B-AS", 67 | "I-AS", 68 | "B-OP", 69 | "I-OP" 70 | ], 71 | "trainset": "combined_data/hotel/train.txt.combined", 72 | "validset": "combined_data/hotel/dev.txt", 73 | "testset": "combined_data/hotel/dev.txt", 74 | "unlabeled": "combined_data/hotel/unlabeled.txt" 75 | }, 76 | ``` 77 | 78 | Fields: 79 | * ``name``: the name of the task. A tagging task should end with a suffix ``_tagging`` 80 | * ``task_type``: either ``tagging`` or ``classification`` 81 | * ``vocab``: the list of class labels. For tagging tasks, all labels start with ``B-`` or ``I-`` indicating the begin/end of a span. For classification task, the list contains all the possible class labels. 82 | * ``trainset``, ``validset``, ``testset``: the paths to the train/dev/test sets 83 | * ``unlabeled`` (Optional): the path to the unlabeled dataset for semi-supervised learning. The file has same format as the train/test sets but the labels are simply ignored. 84 | 85 | ## Training with MixDA (data augmentation) 86 | 87 | 1. Build the augmentation index: 88 | 89 | ``` 90 | python augment_index_builder.py \ 91 | --task restaurant_ae_tagging \ 92 | --w2v_path word2vec/rest_w2v.model \ 93 | --idf_path word2vec/rest_finetune.txt \ 94 | --bert_path finetuned_bert/rest_model.bin \ 95 | --index_output_path augment/rest_index.json 96 | ``` 97 | 98 | Simply replace ``restaurant_ae_tagging`` with ``laptop_ae_tagging``, ``restaurant_asc``, and ``laptop_asc`` to generate the other indices. 99 | Replace ``rest`` with ``laptop`` for the ``laptop_ae_tagging`` and ``laptop_asc`` indices. 100 | 101 | 2. Train with: 102 | ``` 103 | CUDA_VISIBLE_DEVICES=0 python train_mixda.py \ 104 | --task restaurant_ae_tagging \ 105 | --logdir results/ \ 106 | --finetuning \ 107 | --batch_size 32 \ 108 | --lr 5e-5 \ 109 | --n_epochs 5 \ 110 | --bert_path finetuned_bert/rest_model.bin \ 111 | --alpha_aug 0.8 \ 112 | --augment_index augment/rest_index.json \ 113 | --augment_op token_repl_tfidf \ 114 | --run_id 0 115 | ``` 116 | 117 | Parameters: 118 | * ``alpha_aug``: the [mixup](https://arxiv.org/abs/1710.09412) parameter between the original example and the augmented example ({0.2, 0.5, 0.8} are usually good). 119 | * ``augment_index``: the path to the augmentation indices 120 | * ``augment_op``: the name of the DA operators. We currently support the following 9 operators: 121 | 122 | | Operators | Details | 123 | |-----------------|---------------------------------------------------| 124 | |token_del_tfidf | Token deletion by importance (measured by TF-IDF) | 125 | |token_del | Token deletion (uniform) | 126 | |token_repl_tfidf | Token replacement by importance | 127 | |token_repl | Token replacement (uniform) | 128 | |token_swap | Swapping two tokens | 129 | |token_ins | Inserting new tokens | 130 | |span_sim | Replacing a span with similar a one | 131 | |span_freq | Replacing a span by frequency | 132 | |span | Uniform span replacement | 133 | 134 | 135 | ## Training with MixMatchNL (MixDA + Semi-supervised Learning) 136 | 137 | Our implementation of [MixMatch](https://arxiv.org/abs/1905.02249) with MixDA. To train with MixMatchNL: 138 | ``` 139 | CUDA_VISIBLE_DEVICES=0 python train_mixmatchnl.py \ 140 | --task restaurant_ae_tagging \ 141 | --logdir results/ \ 142 | --finetuning \ 143 | --batch_size 32 \ 144 | --lr 5e-5 \ 145 | --n_epochs 5 \ 146 | --bert_path finetuned_bert/rest_model.bin \ 147 | --alpha_aug 0.8 \ 148 | --alpha 0.2 \ 149 | --u_lambda 50.0 \ 150 | --num_aug 2 \ 151 | --augment_index augment/rest_index.json \ 152 | --augment_op token_repl_tfidf \ 153 | --run_id 0 154 | ``` 155 | 156 | Additional parameters: 157 | * ``alpha``: the mixup parameter (between labeled and unlabeled data). We chose ``alpha`` from {0.2, 0.5, 0.8}. 158 | * ``u_lambda``: the weight of unlabeled data loss, typically (chosen from {10.0, 25.0, 50.0}) 159 | * ``num_aug``: the number of augmented examples per unlabeled example (we chose from {2, 4}) 160 | 161 | 162 | ### Hyperparameters and experiment scipts (coming soon) 163 | -------------------------------------------------------------------------------- /augment_index_builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import math 5 | import numpy as np 6 | import argparse 7 | import torch 8 | 9 | from augment_utils import * 10 | from transformers import BertModel 11 | from nltk.corpus import wordnet 12 | from nltk.corpus.reader.sentiwordnet import SentiWordNetCorpusReader 13 | from gensim.models import Word2Vec 14 | 15 | from snippext.dataset import get_tokenizer 16 | 17 | 18 | class IndexBuilder(object): 19 | """Index builder for span-level and token-level data augmentation 20 | 21 | Support both token and span level augmentation operators. 22 | 23 | Attributes: 24 | tokens (list of lists): tokens of training data 25 | labels (list of lists): labels of each token of training data (for AOE task) 26 | sents (list of dicts): training examples for ASC task 27 | w2v: Word2Vec model for similar word replacement 28 | index (dict): a dictionary containing both the token and span level index 29 | all_spans (list of str): a list of all the spans in the index 30 | span_freqs (list of int): the document frequency of each span in all_spans 31 | lm (string): the language model; 'bert' by default 32 | """ 33 | def __init__(self, train_fn, idf_fn, w2v, task, bert_path, lm='bert'): 34 | if 'tagging' in task or 'qa' in task: 35 | self.tokens, self.labels = read_tagging_file(train_fn) 36 | else: 37 | self.sents = read_asc_file(train_fn) 38 | self.tokens = list(map(lambda x: x['token'], self.sents)) 39 | 40 | idf_dict = json.load(open(idf_fn)) 41 | self.w2v = w2v 42 | self.task = task 43 | self.index = {'token': dict(), 'span': dict()} 44 | self.all_spans = list() 45 | self.span_freqs = list() 46 | self.avg_senti = dict() 47 | if self.task == 'classification': 48 | # sentiment sensitive 49 | self.calc_senti_score() 50 | self.tokenizer = get_tokenizer(lm=lm) 51 | self.init_token_index(idf_dict) 52 | self.init_span_index(bert_path=bert_path) 53 | self.index_token_replacement() 54 | 55 | def init_token_index(self, idf_dict): 56 | oov_th = math.log(1e8) 57 | for token in self.tokens: 58 | for w in token: 59 | if w not in self.index['token']: 60 | self.index['token'][w] = dict() 61 | wl = w.lower() 62 | 63 | if wl not in idf_dict: 64 | self.index['token'][w]['idf'] = oov_th 65 | else: 66 | self.index['token'][w]['idf'] = idf_dict[wl] 67 | tokenized_w = self.tokenizer.tokenize(w) 68 | self.index['token'][w]['bert_token'] = tokenized_w 69 | self.index['token'][w]['bert_length'] = len(tokenized_w) 70 | self.index['token'][w]['similar_words'] = None 71 | self.index['token'][w]['similar_words_bert'] = None 72 | self.index['token'][w]['similar_words_length'] = None 73 | 74 | def init_span_index(self, sim_token='cls', sim_topk=100, bert_path=None): 75 | if bert_path is None: 76 | bert_model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True) 77 | else: 78 | model_state_dict = torch.load(bert_path) 79 | bert_model = BertModel.from_pretrained('bert-base-uncased', 80 | state_dict=model_state_dict, 81 | output_hidden_states=True) 82 | bert_model.eval() 83 | 84 | aspect_dict = dict() 85 | opinion_dict = dict() 86 | aspect_token_list = [] 87 | opinion_token_list = [] 88 | aspect_raw_token_list = [] 89 | opinion_raw_token_list = [] 90 | 91 | n = len(self.tokens) 92 | max_len_as = 0 93 | max_len_op = 0 94 | for j in range(n): 95 | if 'classification' in self.task: 96 | # for ASC datasets, the aspect is given in the term field 97 | as_term = self.sents[j]['term'] 98 | 99 | # as_labels 100 | as_labels = [] 101 | tokenized_as = self.tokenizer.tokenize(as_term) 102 | for idx_t, t in enumerate(tokenized_as): 103 | if idx_t == 0: 104 | as_labels.append(1 if idx_t == 0 else 2) 105 | else: 106 | as_labels.append(0) 107 | 108 | len_as = len(tokenized_as) 109 | max_len_as = max(len_as, max_len_as) 110 | as_str = ' '.join(tokenized_as) 111 | if as_term not in aspect_dict: 112 | aspect_dict[as_term] = { 113 | 'document_freq': 1, 114 | 'bert_token': tokenized_as, 115 | 'bert_length': len_as, 116 | 'bert_label': as_labels, 117 | 'similar_spans': [], 118 | 'similar_spans_length': [] 119 | } 120 | aspect_token_list.append(tokenized_as) 121 | aspect_raw_token_list.append(as_term) 122 | else: 123 | aspect_dict[as_term]['document_freq'] += 1 124 | 125 | else: 126 | # for AOE datasets, we have to enumerate tokens to find all aspects and opinions 127 | aspects = [] 128 | opinions = [] 129 | m = len(self.tokens[j]) 130 | k = 0 131 | while k < m: 132 | if 'B-AS' in self.labels[j][k]: 133 | aspects.append([self.tokens[j][k]]) 134 | k += 1 135 | elif 'I-AS' in self.labels[j][k]: 136 | # ignore spans that are incorrectly labeled 137 | if len(aspects) > 0: 138 | aspects[-1].append(self.tokens[j][k]) 139 | k += 1 140 | elif 'B-OP' in self.labels[j][k]: 141 | opinions.append([self.tokens[j][k]]) 142 | k += 1 143 | elif 'I-OP' in self.labels[j][k]: 144 | # ignore spans that are incorrectly labeled 145 | if len(opinions) > 0: 146 | opinions[-1].append(self.tokens[j][k]) 147 | k += 1 148 | else: 149 | k += 1 150 | 151 | for as_term in aspects: 152 | tokenized_as = [] 153 | as_labels = [] 154 | for idx_as, w in enumerate(as_term): 155 | tokenized_w = self.tokenizer.tokenize(w) 156 | for idx_t, t in enumerate(tokenized_w): 157 | if idx_t == 0: 158 | as_labels.append(1 if idx_as == 0 else 2) 159 | else: 160 | as_labels.append(0) 161 | tokenized_as += tokenized_w 162 | len_as = len(tokenized_as) 163 | max_len_as = max(len_as, max_len_as) 164 | as_str = ' '.join(tokenized_as) 165 | as_raw = ' '.join(as_term) 166 | if as_raw not in aspect_dict: 167 | aspect_dict[as_raw] = { 168 | 'document_freq': 1, 169 | 'bert_token': tokenized_as, 170 | 'bert_length': len_as, 171 | 'bert_label': as_labels, 172 | 'similar_spans': [], 173 | 'similar_spans_length': [] 174 | } 175 | aspect_token_list.append(tokenized_as) 176 | aspect_raw_token_list.append(as_raw) 177 | else: 178 | aspect_dict[as_raw]['document_freq'] += 1 179 | 180 | for op_term in opinions: 181 | tokenized_op = [] 182 | op_labels = [] 183 | for idx_op, w in enumerate(op_term): 184 | tokenized_w = self.tokenizer.tokenize(w) 185 | for idx_t, t in enumerate(tokenized_w): 186 | if idx_t == 0: 187 | op_labels.append(3 if idx_op == 0 else 4) 188 | else: 189 | op_labels.append(0) 190 | tokenized_op += tokenized_w 191 | len_op = len(tokenized_op) 192 | max_len_op = max(len_op, max_len_op) 193 | op_str = ' '.join(tokenized_op) 194 | op_raw = ' '.join(op_term) 195 | if op_raw not in opinion_dict: 196 | opinion_dict[op_raw] = { 197 | 'document_freq': 1, 198 | 'bert_token': tokenized_op, 199 | 'bert_length': len_op, 200 | 'bert_label': op_labels, 201 | 'similar_spans': [], 202 | 'similar_spans_length': [] 203 | } 204 | opinion_token_list.append(tokenized_op) 205 | opinion_raw_token_list.append(op_raw) 206 | else: 207 | opinion_dict[op_raw]['document_freq'] += 1 208 | 209 | as_ids = [] 210 | op_ids = [] 211 | # Pad to max length and convert to ids 212 | for as_term in aspect_token_list: 213 | tk_as = ['[CLS]'] + as_term + ['[SEP]'] + ['[PAD]' for k in range(max_len_as - len(as_term))] 214 | as_ids.append(self.tokenizer.convert_tokens_to_ids(tk_as)) 215 | for op_term in opinion_token_list: 216 | tk_op = ['[CLS]'] + op_term + ['[SEP]'] + ['[PAD]' for k in range(max_len_op - len(op_term))] 217 | op_ids.append(self.tokenizer.convert_tokens_to_ids(tk_op)) 218 | 219 | # migrated to transformers 220 | if len(aspect_token_list) > 0: 221 | X_as = torch.LongTensor(as_ids) 222 | as_encoded_layers = bert_model(X_as)[2] 223 | X_as = as_encoded_layers[-2].detach() 224 | 225 | if len(opinion_token_list) > 0: 226 | X_op = torch.LongTensor(op_ids) 227 | op_encoded_layers = bert_model(X_op)[2] 228 | X_op = op_encoded_layers[-2].detach() 229 | 230 | # Compute the dot-product between all pairs of spans 231 | for i in range(len(aspect_token_list)): 232 | if sim_token == 'all': 233 | # using all tokens 234 | q = X_as[i] 235 | z = q * X_as 236 | score = torch.sum(z, dim=(1,2)) / torch.tensor( 237 | np.linalg.norm(q) * np.linalg.norm(X_as, axis=(1,2))) 238 | elif sim_token == 'cls': 239 | # using the CLS token 240 | q = X_as[i][0] 241 | z = q * X_as[:, 0, :] 242 | score = torch.sum(z, dim=(1)) / torch.tensor( 243 | np.linalg.norm(q) * np.linalg.norm(X_as[:, 0, :], axis=(1))) 244 | elif sim_token == 'bas': 245 | # using the first token of the span 246 | q = X_as[i][1] 247 | z = q * X_as[:, 1, :] 248 | score = torch.sum(z, dim=(1)) / torch.tensor( 249 | np.linalg.norm(q) * np.linalg.norm(X_as[:, 1, :], axis=(1))) 250 | 251 | topk_idx = torch.argsort(score, dim=0, descending=True) 252 | for idx in topk_idx: 253 | if idx == i: 254 | continue 255 | if len(aspect_dict[aspect_raw_token_list[i]]['similar_spans']) < sim_topk: 256 | aspect_dict[aspect_raw_token_list[i]]['similar_spans'].append( 257 | [aspect_raw_token_list[idx], score[idx].item()]) 258 | aspect_dict[aspect_raw_token_list[i]]['similar_spans_length'].append( 259 | aspect_dict[aspect_raw_token_list[idx]]['bert_length']) 260 | else: 261 | break 262 | 263 | for i in range(len(opinion_token_list)): 264 | if sim_token == 'all': 265 | # using all tokens 266 | q = X_op[i] 267 | z = q * X_op 268 | score = torch.sum(z, dim=(1,2)) / torch.tensor( 269 | np.linalg.norm(q) * np.linalg.norm(X_op, axis=(1,2))) 270 | elif sim_token == 'cls': 271 | # using the CLS token 272 | q = X_op[i][0] 273 | z = q * X_op[:, 0, :] 274 | score = torch.sum(z, dim=(1)) / torch.tensor( 275 | np.linalg.norm(q) * np.linalg.norm(X_op[:, 0, :], axis=(1))) 276 | elif sim_token == 'bas': 277 | # using the first token of the span 278 | q = X_op[i][1] 279 | z = q * X_op[:, 1, :] 280 | score = torch.sum(z, dim=(1)) / torch.tensor( 281 | np.linalg.norm(q) * np.linalg.norm(X_op[:, 1, :], axis=(1))) 282 | topk_idx = torch.argsort(score, dim=0, descending=True) 283 | for idx in topk_idx: 284 | if idx == i: 285 | continue 286 | if len(opinion_dict[opinion_raw_token_list[i]]['similar_spans']) < sim_topk: 287 | opinion_dict[opinion_raw_token_list[i]]['similar_spans'].append( 288 | [opinion_raw_token_list[idx], score[idx].item()]) 289 | opinion_dict[opinion_raw_token_list[i]]['similar_spans_length'].append( 290 | opinion_dict[opinion_raw_token_list[idx]]['bert_length']) 291 | else: 292 | break 293 | self.index['span'] = {'aspect': aspect_dict, 'opinion': opinion_dict} 294 | 295 | 296 | def index_token_replacement(self): 297 | # pre-compute all token replacement candidates and store them in the index 298 | for token in self.tokens: 299 | for w in token: 300 | if is_stopword(w) or self.index['token'][w]['similar_words'] is not None: 301 | continue 302 | self.index['token'][w]['similar_words'] = [] 303 | # self.index['token'][w]['similar_words_bert'] = [] 304 | self.index['token'][w]['similar_words_length'] = [] 305 | 306 | synonyms = self.find_word_replacement(word_str=w) 307 | similar_words_dict = dict() 308 | if len(synonyms) >= 1: 309 | for s in list(synonyms): 310 | s_arr = s[0].split('_') 311 | if s_arr[0] not in similar_words_dict: 312 | similar_words_dict[s_arr[0]] = True 313 | else: 314 | continue 315 | tokenized_s = self.tokenizer.tokenize(s_arr[0]) 316 | l_s = len(tokenized_s) 317 | self.index['token'][w]['similar_words'].append([s_arr[0], s[1]]) 318 | # self.index['token'][w]['similar_words_bert'].append(tokenized_s) 319 | self.index['token'][w]['similar_words_length'].append(l_s) 320 | 321 | def find_word_replacement(self, word_str, sim_topk=10, is_senti_sensitive=False): 322 | # find sim_topk similar words to word_str 323 | if is_senti_sensitive: 324 | # if sentiment sensitive, compute the senti score of word_str 325 | senti_score = 0 326 | word_str = word_str.lower() 327 | if word_str.lower() in self.avg_senti: 328 | senti_score = self.avg_senti[word_str.lower()]['pos_score'] - avg_senti[word_str.lower()]['neg_score'] 329 | 330 | if self.w2v is None: 331 | # if Word2Vec is not given, using wordnet 332 | syns = wordnet.synsets(word_str) 333 | syn_list = [] 334 | for syn in syns: 335 | for lem in syn.lemmas(): 336 | if lem.name() != word_str: 337 | if is_senti_sensitive: 338 | lem_senti_score = 0 339 | if lem_str in self.avg_senti: 340 | lem_senti_score = self.avg_senti[lem_str]['pos_score'] - self.avg_senti[lem_str]['neg_score'] 341 | ''' maybe we can use a different way to determine whether two words 342 | are of the same sentiment ''' 343 | if sign(lem_senti_score) == sign(senti_score): 344 | syn_list.append(lem_str) 345 | else: 346 | syn_list.append(lem.name()) 347 | if len(syn_list) == 0: 348 | return [] 349 | return list(zip(syn_list, [1.0 for i in range(len(syn_list))])) 350 | else: 351 | if word_str in self.w2v.wv.vocab: 352 | # if word_str appears in Word2Vec vocabulary 353 | similar_list = self.w2v.wv.most_similar(positive=[word_str], topn=sim_topk) 354 | if is_senti_sensitive: 355 | arr = [] 356 | for ws in similar_list: 357 | w_senti_score = 0 358 | w = ws[0] 359 | if w in self.avg_senti: 360 | w_senti_score = self.avg_senti[w]['pos_score'] - self.avg_senti[w]['neg_score'] 361 | if sign(w_senti_score) == sign(senti_score): 362 | arr.append(ws) 363 | return arr 364 | else: 365 | return similar_list 366 | else: 367 | # if word_str does not appear in Word2Vec vocabulary, find a synonym of it using WordNet 368 | # if the synonym appears in Word2Vec vocabulary, use similar words of this synonym 369 | syns = wordnet.synsets(word_str) 370 | syns_dict = dict() 371 | arr = [] 372 | for syn in syns: 373 | flag = False 374 | for lem in syn.lemmas(): 375 | if lem.name() != word_str: 376 | syns_dict[lem.name()] = True 377 | if lem.name() in self.w2v.wv.vocab: 378 | similar_list = self.w2v.wv.most_similar(positive=[lem.name()], topn=sim_topk) 379 | if is_senti_sensitive: 380 | for ws in similar_list: 381 | w_senti_score = 0 382 | w = ws[0] 383 | if w in self.avg_senti: 384 | w_senti_score = self.avg_senti[w]['pos_score'] - self.avg_senti[w]['neg_score'] 385 | if sign(w_senti_score) == sign(senti_score): 386 | arr.append(ws) 387 | else: 388 | arr = similar_list 389 | flag = True 390 | break 391 | if flag: 392 | break 393 | if len(arr) == 0: 394 | res = list(syns_dict.keys()) 395 | return list(zip(res, [1.0 for i in range(len(res))])) 396 | else: 397 | return arr 398 | 399 | def calc_senti_score(self, swn_filename='combined_data/SentiWordNet_3.0.0_20100705.txt'): 400 | # aggregate sentiment score of tokens using SentiWordNet 401 | swn = SentiWordNetCorpusReader('./', [swn_filename]) 402 | for senti_synset in swn.all_senti_synsets(): 403 | w = senti_synset.synset.name().split('.')[0] 404 | if w not in self.avg_senti: 405 | self.avg_senti[w] = { 406 | 'pos_score': 0, 407 | 'neg_score': 0, 408 | 'count': 0 409 | } 410 | self.avg_senti[w]['pos_score'] += senti_synset.pos_score() 411 | self.avg_senti[w]['neg_score'] += senti_synset.neg_score() 412 | self.avg_senti[w]['count'] += 1 413 | 414 | for w in self.avg_senti: 415 | self.avg_senti[w]['pos_score'] /= self.avg_senti[w]['count'] 416 | self.avg_senti[w]['neg_score'] /= self.avg_senti[w]['count'] 417 | 418 | def dump_index(self, index_filename='augment_index.json'): 419 | outfile = open(index_filename, 'w') 420 | json.dump(self.index, outfile) 421 | outfile.close() 422 | 423 | 424 | # def build_idf_dict(text_path): 425 | # from gensim.utils import simple_preprocess 426 | # from collections import Counter 427 | # 428 | # cnt = Counter() 429 | # N = 0 430 | # for line in open(text_path): 431 | # tokens = simple_preprocess(line.lower()) 432 | # tokens = set(tokens) 433 | # if len(tokens) > 0: 434 | # N += 1 435 | # for token in tokens: 436 | # cnt[token] += 1 437 | # 438 | # idf_dict = {} 439 | # for token in cnt: 440 | # idf_dict[token] = math.log(N / cnt[token]) 441 | # return idf_dict 442 | 443 | 444 | 445 | if __name__ == '__main__': 446 | parser = argparse.ArgumentParser() 447 | parser.add_argument("--task", type=str, default="hotel_tagging") 448 | parser.add_argument("--train_path", type=str, default=None) 449 | parser.add_argument("--w2v_path", type=str, default="../rest_w2v.model") 450 | parser.add_argument("--bert_path", type=str, default=None) 451 | parser.add_argument("--lm", type=str, default='bert') 452 | parser.add_argument("--idf_path", type=str, default=None) 453 | parser.add_argument("--index_output_path", type=str, default="augment_index.json") 454 | 455 | hp = parser.parse_args() 456 | configs = json.load(open('configs.json')) 457 | configs = {conf['name'] : conf for conf in configs} 458 | config = configs[hp.task] 459 | if hp.train_path is None: 460 | train_fn = config['trainset'] 461 | else: 462 | train_fn = hp.train_path 463 | w2v = Word2Vec.load(hp.w2v_path) 464 | if hp.idf_path[-5:] != '.json': 465 | idf_fn = hp.idf_path + '.json' 466 | if not os.path.exists(idf_fn): 467 | idf_dict = build_idf_dict(hp.idf_path) 468 | json.dump(idf_dict, open(idf_fn, 'w')) 469 | else: 470 | idf_fn = hp.idf_path 471 | 472 | ib = IndexBuilder(train_fn, idf_fn, w2v, 473 | config['task_type'], 474 | hp.bert_path, 475 | lm=hp.lm) 476 | ib.dump_index(hp.index_output_path) 477 | -------------------------------------------------------------------------------- /augment_utils.py: -------------------------------------------------------------------------------- 1 | from nltk.corpus import stopwords 2 | from sklearn.feature_extraction.text import TfidfVectorizer 3 | # from dataset import tokenizer 4 | from gensim.utils import simple_preprocess 5 | 6 | stopword_set = set(stopwords.words('english')) 7 | 8 | def read_asc_file(fn): 9 | res = [] 10 | for line in open(fn): 11 | if len(line) < 3: 12 | continue 13 | else: 14 | LL = line.strip().split('\t') 15 | # token = twt.tokenize(LL[0]) 16 | # term = twt.tokenize(LL[1]) 17 | tokens = simple_preprocess(LL[0]) 18 | term = LL[1] 19 | polarity = LL[2] 20 | res.append({ 21 | 'raw': line, 22 | 'token': tokens, 23 | 'term': term, 24 | 'polarity': polarity 25 | }) 26 | return res 27 | 28 | def is_stopword(token): 29 | return token in ['[SEP]', '[CLS]'] or token in stopword_set or not token.isalpha() 30 | 31 | def read_tagging_file(fn): 32 | tokens = [[]] 33 | labels = [[]] 34 | for line in open(fn): 35 | if len(line) < 3: 36 | tokens.append([]) 37 | labels.append([]) 38 | else: 39 | LL = line.strip().split(' ') 40 | token = LL[0] 41 | label = LL[-1] 42 | tokens[-1].append(token) 43 | labels[-1].append(label) 44 | return tokens, labels 45 | 46 | def build_idf_dict(fn): 47 | corpus = [''] 48 | cnt = 0 49 | for line in open(fn): 50 | if len(line) < 2: 51 | corpus.append('') 52 | cnt+=1 53 | # if cnt % 10 == 0: 54 | # print(cnt) 55 | else: 56 | corpus[-1] += ' ' + line 57 | 58 | vectorizer = TfidfVectorizer() 59 | vectorizer.fit(corpus) 60 | idf_dict = dict() 61 | for w in vectorizer.vocabulary_: 62 | idf_dict[w] = vectorizer.idf_[vectorizer.vocabulary_[w]] 63 | return idf_dict 64 | 65 | def sign(a): 66 | return (a > 0) - (a < 0) 67 | -------------------------------------------------------------------------------- /combined_data/README.md: -------------------------------------------------------------------------------- 1 | The ABSA datasets are obtained from this [repos](https://github.com/howardhsu/BERT-for-RRC-ABSA). 2 | -------------------------------------------------------------------------------- /combined_data/hotel_classification/dev.txt: -------------------------------------------------------------------------------- 1 | generous management staff 2 | adequate size rooms size 3 | very nice rooms room 4 | cleanly refurbished amenities cleanliness 5 | free high internet wifi 6 | gorgeous windows room 7 | pleasant staff staff 8 | clean closet cleanliness 9 | very comfy bed bed 10 | bad reviews general 11 | special hotel general 12 | clean rooms cleanliness 13 | perfect weather comfort 14 | good walk location 15 | well lighted rooms room 16 | free parking parking 17 | close to everything location 18 | heavenly bed bed 19 | great value price 20 | central location location 21 | worth search general 22 | free tvs facility 23 | close to bart location 24 | near mall location 25 | well appointed room room 26 | near restaurants food 27 | clean room cleanliness 28 | helpful staff staff 29 | shook room room 30 | disgusting bed bed 31 | helpful staff staff 32 | helpful reception staff staff 33 | one block off location location 34 | waiting room room 35 | very comfty bed bed 36 | helpful staff staff 37 | out dated rooms room 38 | lovely view of a net view 39 | closed business center location 40 | very clean towels cleanliness 41 | good experience general 42 | great sightseeing location 43 | great location location 44 | noise decor quietness 45 | felt elevator facility 46 | spaciousbathroom rooms bathroom 47 | nice hotels general 48 | great views view 49 | quick value wrong 50 | correct location location 51 | colorful furnishings style 52 | accomodating staff staff 53 | really unprepossessing cab wrong 54 | large wardrobe size 55 | ok check in checkin 56 | pretty lame organization staff 57 | wonderful accommodations general 58 | amazing views of the city view 59 | friendly service staff 60 | couple blocks from union square location 61 | many times larger rooms size 62 | bursting places general 63 | helpful concierge desk staff 64 | huge much size 65 | clean service staff 66 | large spa pool 67 | free parking parking 68 | excellent meal food 69 | old navy style 70 | nice pool area pool 71 | super clean kitchen cleanliness 72 | well furnishings room 73 | super friendly staff staff 74 | tired furniture room 75 | clean rooms cleanliness 76 | beautiful common areas style 77 | clean union cleanliness 78 | 2 mins walk from pier location 79 | not overrun with tourists general 80 | accommodating staff staff 81 | beautiful place style 82 | brand bloomingdale location 83 | within easy walking distance union square location 84 | contemporary lobby style 85 | most comfortable bed bed 86 | pretty clean room cleanliness 87 | very friendly staff staff 88 | fine bathroom bathroom 89 | located bart station location 90 | nearby scala location 91 | clean room cleanliness 92 | worn chairs cleanliness 93 | close to location location 94 | clean room cleanliness 95 | 2 mins walk away meal location 96 | not safe walking safety 97 | attentive front staff 98 | beautiful lobby style 99 | large closet size 100 | easy trams location 101 | spacious rooms size 102 | contemporary boutique style 103 | good location location 104 | get wireless wifi 105 | great area location 106 | easy to find access to the hotel location 107 | just outside cabs location 108 | absolutely disgusting place general 109 | very abrupt waiter staff 110 | extremely tiny bathroom bathroom 111 | very chic interior design style 112 | lovely restaurant food 113 | double paned windows room 114 | quick room room 115 | lovely lobby style 116 | very helpful staff staff 117 | unavailable room room 118 | incredible deal price 119 | within easy walking distance motel location 120 | plenty of room room 121 | weird hallways style 122 | sense layout style 123 | king bed bed 124 | is a bit washer dryers facility 125 | loved place general 126 | great city general 127 | always laundry room facility 128 | unhelpful staff staff 129 | friendly staff staff 130 | perfect vacation location 131 | lot traffic quietness 132 | great restaurants food 133 | elegant atmosphere vibe 134 | great balcony view 135 | amazing rooms room 136 | great concept style 137 | claustraphobic rooms room 138 | very friendly front desk staff staff 139 | great view of view 140 | quite good breakfast food 141 | great views of view 142 | barely warm restaurant food 143 | not impressed like wrong 144 | nice windows room 145 | great location location 146 | very relaxing setting vibe 147 | sometimes atmosphere vibe 148 | appalling food food 149 | helpful staff staff 150 | delicious grilled vegetables food 151 | clean rooms cleanliness 152 | complimentary numerous price 153 | very clean room cleanliness 154 | great location location 155 | clean service staff 156 | ok food food 157 | great view view 158 | very small closet space size 159 | fantastic staff staff 160 | a charming grounds style 161 | good tours general 162 | very clean bathroom cleanliness 163 | nearby restaurants food 164 | resonably priced cafe food 165 | best service staff 166 | very small bathroom bathroom 167 | slow staff staff 168 | large room size 169 | great location location 170 | nice bar area drink 171 | fantastic room room 172 | extremely popular indoor style 173 | average rooms room 174 | very small room size 175 | eggs warm food 176 | cute hotel style 177 | excellent reviews general 178 | narrow stairs size 179 | helpful staff staff 180 | beautiful location location 181 | prompt service staff 182 | consistently friendly hotel staff staff 183 | close to restaurants food 184 | somewhat past furniture room 185 | moldy bathroom bathroom 186 | clean hotel cleanliness 187 | worst evening service staff 188 | in walking distance everything location 189 | 5 feet from bed bed 190 | also bathtub bathroom 191 | just right location location 192 | aptly decorated bedding bed 193 | lucky people staff 194 | really small room size 195 | thin walls room 196 | similar room room 197 | perfect morning general 198 | well fitted rooms room 199 | distinguished hotel general 200 | fantastic cheese food 201 | dull brick room 202 | complimentary chocolates food 203 | not up to the standard bathroom bathroom 204 | excellent room general 205 | great cold comfort 206 | nice touches general 207 | lovely owners staff 208 | a little small room size 209 | extremely elegant hotel style 210 | like walking location 211 | smallest seating area size 212 | large chest size 213 | pretentious staff staff 214 | fairly new furnishings facility 215 | delicious decaf coffee 216 | quite spacious room size 217 | extremely noisy street quietness 218 | miles from room location 219 | better room room 220 | reasonable clean cleanliness 221 | located place location 222 | received concierge staff 223 | available parking parking 224 | relatively solid staff staff 225 | within walking distance financial district location 226 | less than a block from walgreen location 227 | ideally located hotel location 228 | wonderful shops location 229 | great views view 230 | great hotel general 231 | bad closet facility 232 | free wifi wifi 233 | great location location 234 | clean rooms cleanliness 235 | helpful staff staff 236 | great massaging facility 237 | available coffee table coffee 238 | high speed internet wifi 239 | wonderful hotel general 240 | a little run down hotel style 241 | great hotel general 242 | close to buses location 243 | great place general 244 | 15 minute walk location location 245 | a few blocks from montgomery station location 246 | very near hotel location 247 | warm staff staff 248 | very comfortable bed bed 249 | superb hotel general 250 | not easy hotels general 251 | free parking parking 252 | small room size 253 | knowledgeable staff staff 254 | very charming hotel style 255 | lack cleaning standards cleanliness 256 | safe driver safety 257 | street pool pool 258 | fresh pastries food 259 | so luxurious shower bathroom 260 | small room size 261 | great source general 262 | so helpful directions staff 263 | great location location 264 | awesome rooms room 265 | well located hotel location 266 | attentive staff staff 267 | excellent hotel general 268 | pleasant mix comfort 269 | two blocks off cartwright location 270 | ok room room 271 | not open window room 272 | fun guests staff 273 | close to everything location 274 | very friendly staff staff 275 | no frills hotel style 276 | well equipped room room 277 | clean room cleanliness 278 | plenty of carpets room 279 | small rooms size 280 | better views of the city view 281 | very clean room cleanliness 282 | clean room cleanliness 283 | better place general 284 | very close to budget location 285 | very quiet upscale quietness 286 | old buildings style 287 | everyone was doorman staff 288 | cool weather comfort 289 | not outrgaeous housing general 290 | free buffet breakfast food 291 | very reasonably priced honor bar price 292 | nice staff staff 293 | well appointed bed bed 294 | somewhat past furniture room 295 | quiet part quietness 296 | lovely rather wrong 297 | claustrophobic rooms room 298 | awful street noise quietness 299 | complimentary coffee coffee 300 | favorite hotels general 301 | cool ocean comfort 302 | cheapest available amenities facility 303 | a bit disappointed rooms room 304 | perfect end general 305 | 20 mins walk location location 306 | located pretty centrally halcyon location 307 | spacious room size 308 | sooo comfortable beds bed 309 | good rate price 310 | clean bathroom cleanliness 311 | nice hotel general 312 | wonderful hotel general 313 | lousy service staff 314 | great service staff 315 | small bathroom bathroom 316 | light pants wrong 317 | fantastic cheese food 318 | good condition hallways facility 319 | very odd bathroom bathroom 320 | decorated lobby style 321 | great shower bathroom 322 | nice people staff 323 | very helpful reception staff staff 324 | friendly staff staff 325 | white room room 326 | always full of bustle style 327 | historic hotel style 328 | great dining area food 329 | older facilities facility 330 | beautiful hotel style 331 | very nice hotel general 332 | helpful staff staff 333 | great room room 334 | close to rooms location 335 | real wireless broadband wifi 336 | so big service staff 337 | helpful staff staff 338 | free car transportation price 339 | stand tub bathroom 340 | did not work well air conditioning facility 341 | noisy room quietness 342 | really pretty hotel style 343 | within walking distance local attractions location 344 | fully stocked mini bar drink 345 | huge chinatown location 346 | able map of bus routes location 347 | ok room room 348 | lovely breakfast food 349 | lots of restaurants food 350 | tiny restaurant food 351 | spacious hotel size 352 | big room size 353 | fabulous staff staff 354 | great oasis general 355 | near location location 356 | fantastic location location 357 | free secure parking parking 358 | very comfortable bed bed 359 | good contintal wrong 360 | adjustable shower head bathroom 361 | very helpful staff staff 362 | favorite hotels general 363 | right in location location 364 | smelled room cleanliness 365 | loved staff staff 366 | beautiful bay bridge view 367 | very efficient lay out style 368 | nice walk location 369 | loved hotel general 370 | delightful hotel style 371 | friendly staff staff 372 | well done decor style 373 | helpful staff staff 374 | situated shops location 375 | very little room size 376 | good remodel of general 377 | close to everything location 378 | very good windows room 379 | kind room room 380 | great view view 381 | mediocre front desk service staff 382 | very close to location location 383 | great deal price 384 | great price price 385 | different retro furnishings style 386 | great fries food 387 | walking distance to chinatown location 388 | adequate bathroom bathroom 389 | perfectly located direct bus location 390 | european hotel style 391 | ideally located for trams system junction location 392 | excellent location location 393 | great place general 394 | really nice rooms size 395 | street huge wrong 396 | easy cab location 397 | great food food 398 | quite minimal noise level quietness 399 | amazing berries food 400 | newly renovated hotel style 401 | secure parking parking 402 | big shady forested paths style 403 | comfortable rooms comfort 404 | prepared latte coffee 405 | reasonably comfortable queen bed 406 | superb hotel restaurant food 407 | helpful hotel staff staff 408 | great everything general 409 | terrible admin staff staff 410 | internal outside noise quietness 411 | small room size 412 | friendly atmosphere vibe 413 | free internet access wifi 414 | ready room room 415 | loved location for location 416 | convenient location location 417 | very good drinks drink 418 | extremely comfortable drapes comfort 419 | free breakfast food 420 | within walking distance post office location 421 | marvellous price price 422 | much bigger room size 423 | very nice hotel general 424 | far bed linens bed 425 | wonderful everything general 426 | cozy room vibe 427 | a little more firm pillows bed 428 | nowhere near enough staff staff 429 | only ok bed bed 430 | horrible place general 431 | very grand lobby area style 432 | horribly disorganized garage general 433 | expensive hotel price 434 | great idea way general 435 | well located radisson location 436 | very unorganized storage area room 437 | rock hard beds bed 438 | serenely decor style 439 | friendly lobby staff style 440 | superior view rooms view 441 | average room room 442 | close to shopping location 443 | pretty dark room room 444 | so close to hotel location location 445 | very very expnsive hotel price 446 | very slow smile staff 447 | clean rooms cleanliness 448 | resonable place general 449 | most beautiful homes general 450 | a bit dated rooms room 451 | helpful desk staff staff 452 | free computer facility 453 | lots of coffee coffee 454 | large rooms size 455 | very elegant hotel style 456 | comfortable bed bed 457 | directly across from coffee service staff 458 | perfect location location 459 | very quiet street quietness 460 | free access to computer facility 461 | best marina wrong 462 | neglected room room 463 | cheap sfo price 464 | healthy food food 465 | very good location location 466 | spotlessly clean furnishings cleanliness 467 | quickly staff staff 468 | fresh cereal food 469 | wonderful bed bed 470 | perfect location location 471 | well appointed rooms room 472 | small room size 473 | small sink size 474 | helpful to staff staff 475 | non room room 476 | free internet wifi 477 | cheap carlton price 478 | sucked place general 479 | very easy taxi fare location 480 | safe location safety 481 | always staff staff 482 | plenty of wardrobe space size 483 | below average experience general 484 | small fitness room gym 485 | very comfortable rooms comfort 486 | clean bathroom cleanliness 487 | late checkins checkin 488 | nice touch fire pits style 489 | worn hotel cleanliness 490 | without views of the street view 491 | nearby union square location 492 | small rooms size 493 | ornate bar drink 494 | nice motel general 495 | very nice surprise general 496 | only minutes walk downtown shopping location 497 | most beautiful city general 498 | quite helpful ice facility 499 | free wine hour drink 500 | pretty tough critics general 501 | close enough walk location 502 | small room size 503 | brightly painted view view 504 | good time general 505 | ultra expensive restuarnat price 506 | delicious food food 507 | very noisy room quietness 508 | fine bedroom bed 509 | great location location 510 | were unquestionably stains cleanliness 511 | very close to major location 512 | lovely room style 513 | very basic breakfast food 514 | warmer fridge facility 515 | noisy rooms quietness 516 | very nice staff staff 517 | wonderful hotel general 518 | charming rooms style 519 | spotless the king cleanliness 520 | well located bars location 521 | musty room cleanliness 522 | good time general 523 | old hotel style 524 | helpful everyone staff 525 | spacious room size 526 | poorly remodeled room style 527 | clean mini fridge cleanliness 528 | good options general 529 | broken shower bathroom 530 | very uninspiring rooms room 531 | late , plus wrong 532 | very nice room room 533 | 1.5 blocks away cable location 534 | great rate price 535 | affordable lot price 536 | low prices price 537 | not fabulous room room 538 | fantastic room room 539 | not the most attractive stairs style 540 | excellent service staff 541 | stunningly beautiful dining room style 542 | king room room 543 | excellent blackout shade style 544 | right hotel general 545 | wonderful hotel service staff 546 | delicious breakfast food 547 | close to restaurants food 548 | airy space size 549 | upgraded heaters facility 550 | 6 blocks down diner location 551 | need bistros food 552 | not new furniture room 553 | a bit dark room room 554 | easy freeway access location 555 | understated elegance style 556 | much bigger room size 557 | nice pillows bed 558 | unlimited free printing facility 559 | tough trip general 560 | extremely helpful owners staff 561 | old carpets style 562 | beautiful days style 563 | helpful front desk staff staff 564 | 5 blocks form bart location 565 | comfortable pressure comfort 566 | accommodating rooms staff 567 | everywhere art location 568 | outdoor pool pool 569 | decent enough rooms room 570 | within short walking distance to bart location 571 | good value price 572 | free cable price 573 | great bakery food 574 | great restaurants food 575 | handy base facility 576 | few blocks away shopping location 577 | pristine condition facility facility 578 | very clean room cleanliness 579 | great rate price 580 | lots of restaurants food 581 | narrow parking lot parking 582 | fantastic area location 583 | located wharf location 584 | very modern looking hotel style 585 | too high workspace size 586 | comfortable room comfort 587 | way too many choices general 588 | basic amenities facility 589 | good distance cable location 590 | very comfortable beds bed 591 | comfortable hotel comfort 592 | knowledgable people staff 593 | boutique hotel style 594 | ok rooms room 595 | glowed annoying comfort 596 | great location location 597 | nice motel general 598 | perfect suite general 599 | spacious suite size 600 | no extra room room 601 | excellent almond food 602 | love beds bed 603 | good quality linens bed 604 | weird layout style 605 | perfect room general 606 | free parking parking 607 | perfect hotel general 608 | nice face staff 609 | significantly lower sink bathroom 610 | nice seating room 611 | great location location 612 | fantastic restaurants food 613 | clean bed cleanliness 614 | small tv facility 615 | well appointed rooms room 616 | awesome showers bathroom 617 | very nice hotel general 618 | great price price 619 | way out staff staff 620 | couple minutes elevators location 621 | perfect rooms general 622 | loved shampoos bathroom 623 | within bus lines location 624 | fresh room room 625 | very close pier location 626 | free parking parking 627 | better breakfast food 628 | decent size view of the size 629 | close to place location 630 | very nice bell room 631 | a bit underwhelming front desk staff 632 | quick check in checkin 633 | comfortable bed bed 634 | stale potatoes food 635 | nice hotel amenities facility 636 | clean hotel cleanliness 637 | attentive staff staff 638 | very slow internet access wifi 639 | soft sheets bed 640 | informative audio tour staff 641 | extremely helpful staff staff 642 | most friendly driver view 643 | helpful front desk staff staff 644 | the cable lamps style 645 | quite dated hotel style 646 | strange shutters facility 647 | convenient to sights location 648 | big room size 649 | very good for dining food 650 | lovely castro style 651 | small kitchenette facilities facility 652 | within restaurants food 653 | tremendously underwhelming hotel style 654 | bright hotel style 655 | excellent concierge service staff 656 | undesireable however bar drink 657 | nice accommodations general 658 | very clean room cleanliness 659 | way above average quality linens bed 660 | like mildew bathroom bathroom 661 | 10 mins bus location 662 | a little room room 663 | a bit cheaper web rates price 664 | love windown room 665 | pretty far from attractions location 666 | excellent food food 667 | good stay general 668 | free parking parking 669 | friendly staff staff 670 | great joints wrong 671 | pullout sofa facility 672 | walls floors wrong 673 | very good price price 674 | really close train stop location 675 | welcoming interaction staff 676 | tastefully trendy hotel style 677 | slightly more spacious towel bathroom 678 | extra breakfast food 679 | very social snacks food 680 | lovely hotel style 681 | decent size room size 682 | relatively spacious room size 683 | spacious views view 684 | clean room cleanliness 685 | exactly the same rooms room 686 | good entertainment facility 687 | slow room service staff 688 | reasonable rate price 689 | great part wrong 690 | small bar drink 691 | not used to city general 692 | clean rooms cleanliness 693 | close shopping location 694 | free parking parking 695 | very friendly staff staff 696 | very clean everything cleanliness 697 | clean good wrong 698 | 20 dollars valet parking parking 699 | nicely done halls style 700 | right breakfast buffet food 701 | cozy oasis vibe 702 | elegant location location 703 | short distance from hotel location 704 | great hotel general 705 | fabulous time general 706 | large room size 707 | ample bath supply bathroom 708 | great location location 709 | very nice hotel general 710 | very friendly staff staff 711 | great lobby style 712 | close quarters location 713 | less than a station location 714 | very capable guest services staff 715 | great health club gym 716 | pretty bad noise quietness 717 | good sized room size 718 | courteous staff staff 719 | hot coffee coffee 720 | opened garrison style 721 | close to everything location 722 | amazing value for money price 723 | friendly clerks staff 724 | excellent see wrong 725 | great location location 726 | pricey union street price 727 | great room room 728 | small bay room size 729 | easily accomodated rooms location 730 | couple of blocks away cable car lines location 731 | nice large windows size 732 | bad towels bathroom 733 | close to all hotel location 734 | clean room cleanliness 735 | spacious enough corner size 736 | only 7 blocks to part location 737 | close to train location 738 | plenty of room room 739 | extremely friendly desk clerks staff 740 | enough room room 741 | well kept rooms room 742 | great diner food 743 | worst motel general 744 | great value price 745 | equally comfortable beds bed 746 | polite staff staff 747 | great experience general 748 | medium size hotel size 749 | friendly staff staff 750 | really unprepossessing area general 751 | great location location 752 | good customer service staff 753 | wonderful rooms room 754 | spacious enough room size 755 | smallest room size 756 | fantastic room room 757 | like recycling general 758 | no cool air comfort 759 | fabulous toiletries bathroom 760 | marvelous stay general 761 | fantastic lighting style 762 | accomodating staff staff 763 | only 2 blocks from property location 764 | decent space size 765 | clean rooms cleanliness 766 | too much costs price 767 | great show location 768 | free internet access wifi 769 | small room size 770 | delightful staff staff 771 | open windows room 772 | cheapest , dining price 773 | nice size desk size 774 | lovely hotel style 775 | adorable colorful hotel style 776 | comfortable bed bed 777 | accommodating staff staff 778 | tile furnishings room 779 | attractive hotel style 780 | varied menu food 781 | way hotels wrong 782 | minutes from trolley line location 783 | large room size 784 | clean everything cleanliness 785 | efficient everything staff 786 | complimentary hot beverage food 787 | always clean hotels cleanliness 788 | terrible staff staff 789 | adequate property facility 790 | prime location location 791 | free wireless internet wifi 792 | right out the front door main bus route location 793 | breeze location location 794 | okay breakfast food 795 | not good rooms room 796 | great room room 797 | really comfy bed bed 798 | great hotel general 799 | very near location location 800 | friendly staff staff 801 | lovely views of the bay view 802 | completely undrinkable coffee coffee 803 | rude checkout clerk staff 804 | very nice experience general 805 | great price price 806 | very neat room room 807 | finest bed bed 808 | very well maintained hotel general 809 | plenty of room room 810 | really nice villa general 811 | great bay view view 812 | quite standard facilities facility 813 | light hotel general 814 | just a block away car line location 815 | clean tub cleanliness 816 | very helpful concierge staff 817 | good location location 818 | 1 1/2 blocks to shopping center location 819 | ghetto area location 820 | two beds bed 821 | nice club drink 822 | nice tv facility 823 | professional staff staff 824 | great room room 825 | absolutley hotel general 826 | right size artwork size 827 | good public areas style 828 | wonderful location location 829 | great dining room food 830 | modern furnishings room 831 | most upmarket shopping areas location 832 | very annoying features general 833 | high quality linens bed 834 | countless broken bathtub bathroom 835 | terrific food food 836 | right on tram line location 837 | easy walking distance places to location 838 | is really quot wrong 839 | great lighting facility 840 | very basic hotel style 841 | very nice rooms room 842 | basic room room 843 | near hotel location 844 | knowledgeable owners staff 845 | great little wrong 846 | nice continental breakfast food 847 | right service staff 848 | complimentary goodies price 849 | great place general 850 | friendly desk staff staff 851 | a few blocks up villa florence location 852 | attentive staff staff 853 | joke heating comfort 854 | so nice staff staff 855 | greater restaurants food 856 | clean rooms cleanliness 857 | a little rough location location 858 | gleaming white linens bed 859 | just next door parking parking 860 | well kept place general 861 | good information staff 862 | convenient located hotel location 863 | free glasses of wine drink 864 | quiet section quietness 865 | tiny bathroom bathroom 866 | incredibly courteous staff staff 867 | and two doors great wrong 868 | outrageous rate price 869 | strait furnishings room 870 | well appointed rooms room 871 | one diner location 872 | are very reasonably cafe food 873 | helpful staff staff 874 | close to location location 875 | an incredibly pleasant wrong 876 | really comfortable common areas comfort 877 | very pleased with place general 878 | low chair room 879 | most amazing bay view 880 | available view view 881 | are extremely the desk wrong 882 | short blocks from transit central location 883 | adjacent to shopping location 884 | gorgeous function rooms room 885 | great lobby style 886 | love bay view view 887 | geniunely staff staff 888 | older hotel style 889 | nice enough pool pool 890 | clean rooms cleanliness 891 | well located hotel location 892 | great ammenities facility 893 | a bit more alive courtyard general 894 | nice room room 895 | incredibly food food 896 | loved room room 897 | good shopping location 898 | helpful hotel staff staff 899 | really quite poor room room 900 | non smoking smell room 901 | very welcoming staff staff 902 | across restaurant food 903 | freaky hotel style 904 | great restaurants food 905 | quiet rooms quietness 906 | uncomfortable bedspros bed 907 | was practically fish dishes food 908 | plenty of clothing general 909 | friendly staff staff 910 | nice size , internet wifi 911 | nicer hotels general 912 | awesome spa services staff 913 | funky decor style 914 | very nice looking hotel style 915 | free internet access wifi 916 | complimentary breakfast buffet food 917 | big enough jacuzzi pool 918 | delicious sushi food 919 | fine front desk staff 920 | not new furniture room 921 | way nicer hotel general 922 | very comfortable bed bed 923 | love empty wrong 924 | busy porters wrong 925 | double the street wrong 926 | lovely hotel style 927 | good restaurants food 928 | complimentary breakfast food 929 | huge trees size 930 | very comfortable pillows bed 931 | balanced view view 932 | short walk to center location 933 | fluffy sized size 934 | great location location 935 | really clean staff staff 936 | close to hotel location 937 | comfortable room comfort 938 | helpful concierge staff 939 | 10 minutes walk bus stop location 940 | friendly staff staff 941 | perfect hotel general 942 | shared bath bathroom 943 | awesome view view 944 | good restaurant food 945 | clean fridge cleanliness 946 | nicely decorated rooms style 947 | perfect sights view 948 | very helpful rooms staff 949 | excellent cafe food 950 | great hotel general 951 | very hip boutique style 952 | comfortable wedding comfort 953 | dark hallways style 954 | hit rooms room 955 | slightly iffy area safety 956 | nice hotel general 957 | beautiful view view 958 | big change general 959 | nice people staff 960 | fancy hotel style 961 | clean rooms cleanliness 962 | good check in checkin 963 | crowded restaurant food 964 | helpful staff staff 965 | excellent experience general 966 | fine rooms general 967 | very nicely appointed rooms room 968 | clean staff staff 969 | well maintained hotel general 970 | very hard overhead room 971 | closed bar drink 972 | always expensive everything price 973 | quot room room 974 | 3 long blocks to museum location 975 | dark halls style 976 | helpful staff staff 977 | free wine drink 978 | far too costly cost price 979 | not huge room size 980 | willing staff staff 981 | clean staff staff 982 | decently priced honor bar drink 983 | amazing experience general 984 | second staff staff 985 | 3 times bigger room size 986 | spotless bathroom bathroom 987 | ok rooms room 988 | faced away from street view 989 | right outside everything location 990 | great location location 991 | enjoyable neighborhood location 992 | very reasonable price price 993 | good scrub down cleanliness 994 | a bit small rooms size 995 | great shops location 996 | crushed room room 997 | nice renovation style 998 | outdated carpet room 999 | so comfortable bed bed 1000 | not that great area safety 1001 | -------------------------------------------------------------------------------- /configs.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "name": "restaurant_classification", 4 | "task_type": "classification", 5 | "vocab": [ 6 | "food-quantity", 7 | "recommendation", 8 | "food -> healthiness", 9 | "location", 10 | "staff", 11 | "food -> vegetarian option", 12 | "food -> variety", 13 | "drink -> alcohol", 14 | "restaurant -> comfort", 15 | "value-for-money", 16 | "wait-time", 17 | "food -> quality", 18 | "good-for-groups", 19 | "restaurant -> atmosphere", 20 | "kid-friendliness", 21 | "drink -> quality" 22 | ], 23 | "trainset": "combined_data/restaurant/classification/train.txt", 24 | "validset": "combined_data/restaurant/classification/dev.txt", 25 | "testset": "combined_data/restaurant/classification/dev.txt" 26 | }, 27 | { 28 | "name": "restaurant_tagging", 29 | "task_type": "tagging", 30 | "vocab": [ 31 | "B-AS", 32 | "I-AS", 33 | "B-OP", 34 | "I-OP" 35 | ], 36 | "trainset": "combined_data/restaurant/train.txt", 37 | "validset": "combined_data/restaurant/dev.txt", 38 | "testset": "combined_data/restaurant/dev.txt", 39 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 40 | }, 41 | { 42 | "name": "hotel_tagging", 43 | "task_type": "tagging", 44 | "vocab": [ 45 | "B-AS", 46 | "I-AS", 47 | "B-OP", 48 | "I-OP" 49 | ], 50 | "trainset": "combined_data/hotel_tagging/train.txt.combined", 51 | "validset": "combined_data/hotel_tagging/dev.txt", 52 | "testset": "combined_data/hotel_tagging/dev.txt", 53 | "unlabeled": "combined_data/hotel_tagging/unlabeled.txt" 54 | }, 55 | { 56 | "name": "pairing", 57 | "task_type": "classification", 58 | "vocab": [ 59 | "UNPAIR", 60 | "PAIR" 61 | ], 62 | "trainset": "combined_data/hotel_pairing/train.txt.combined", 63 | "validset": "combined_data/hotel_pairing/dev.txt", 64 | "testset": "combined_data/hotel_pairing/dev.txt", 65 | "unlabeled": "combined_data/hotel_pairing/unlabeled.txt" 66 | }, 67 | { 68 | "name": "sf_hotel_classification", 69 | "task_type": "classification", 70 | "vocab": [ 71 | "room", 72 | "facility", 73 | "staff", 74 | "drink", 75 | "size", 76 | "bathroom", 77 | "view", 78 | "wrong", 79 | "price", 80 | "wifi", 81 | "parking", 82 | "cleanliness", 83 | "checkin", 84 | "pool", 85 | "gym", 86 | "style", 87 | "quietness", 88 | "location", 89 | "vibe", 90 | "comfort", 91 | "food", 92 | "bed", 93 | "general", 94 | "safety", 95 | "coffee" 96 | ], 97 | "trainset": "combined_data/hotel_classification/train.txt", 98 | "validset": "combined_data/hotel_classification/dev.txt", 99 | "testset": "combined_data/hotel_classification/dev.txt" 100 | }, 101 | { 102 | "name": "restaurant_asc", 103 | "task_type": "classification", 104 | "vocab": [ 105 | "positive", 106 | "negative", 107 | "neutral" 108 | ], 109 | "trainset": "combined_data/absa/asc/rest/train.txt", 110 | "validset": "combined_data/absa/asc/rest/dev.txt", 111 | "testset": "combined_data/absa/asc/rest/test.txt", 112 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 113 | }, 114 | { 115 | "name": "laptop_asc", 116 | "task_type": "classification", 117 | "vocab": [ 118 | "positive", 119 | "negative", 120 | "neutral" 121 | ], 122 | "trainset": "combined_data/absa/asc/laptop/train.txt", 123 | "validset": "combined_data/absa/asc/laptop/dev.txt", 124 | "testset": "combined_data/absa/asc/laptop/test.txt", 125 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 126 | }, 127 | { 128 | "name": "restaurant_ae_tagging", 129 | "task_type": "tagging", 130 | "vocab": [ 131 | "B-AS", 132 | "I-AS" 133 | ], 134 | "trainset": "combined_data/absa/ae/rest/train.txt", 135 | "validset": "combined_data/absa/ae/rest/dev.txt", 136 | "testset": "combined_data/absa/ae/rest/test.txt", 137 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 138 | }, 139 | { 140 | "name": "laptop_ae_tagging", 141 | "task_type": "tagging", 142 | "vocab": [ 143 | "B-AS", 144 | "I-AS" 145 | ], 146 | "trainset": "combined_data/absa/ae/laptop/train.txt", 147 | "validset": "combined_data/absa/ae/laptop/dev.txt", 148 | "testset": "combined_data/absa/ae/laptop/test.txt", 149 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt", 150 | "augmented": [ 151 | "combined_data/absa/ae/laptop/full/train.txt", 152 | "combined_data/absa/ae/laptop/full/train.txt" 153 | ] 154 | }, 155 | { 156 | "name": "restaurant_ae_tagging_250_1", 157 | "task_type": "tagging", 158 | "vocab": [ 159 | "B-AS", 160 | "I-AS" 161 | ], 162 | "trainset": "combined_data/absa/ae/rest16-old-sample-1/250/train.txt", 163 | "validset": "combined_data/absa/ae/rest/dev.txt", 164 | "testset": "combined_data/absa/ae/rest/test.txt", 165 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 166 | }, 167 | { 168 | "name": "restaurant_ae_tagging_500_1", 169 | "task_type": "tagging", 170 | "vocab": [ 171 | "B-AS", 172 | "I-AS" 173 | ], 174 | "trainset": "combined_data/absa/ae/rest16-old-sample-1/500/train.txt", 175 | "validset": "combined_data/absa/ae/rest/dev.txt", 176 | "testset": "combined_data/absa/ae/rest/test.txt", 177 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 178 | }, 179 | { 180 | "name": "restaurant_ae_tagging_750_1", 181 | "task_type": "tagging", 182 | "vocab": [ 183 | "B-AS", 184 | "I-AS" 185 | ], 186 | "trainset": "combined_data/absa/ae/rest16-old-sample-1/750/train.txt", 187 | "validset": "combined_data/absa/ae/rest/dev.txt", 188 | "testset": "combined_data/absa/ae/rest/test.txt", 189 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 190 | }, 191 | { 192 | "name": "restaurant_ae_tagging_1000_1", 193 | "task_type": "tagging", 194 | "vocab": [ 195 | "B-AS", 196 | "I-AS" 197 | ], 198 | "trainset": "combined_data/absa/ae/rest16-old-sample-1/1000/train.txt", 199 | "validset": "combined_data/absa/ae/rest/dev.txt", 200 | "testset": "combined_data/absa/ae/rest/test.txt", 201 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 202 | }, 203 | { 204 | "name": "restaurant_ae_tagging_full_1", 205 | "task_type": "tagging", 206 | "vocab": [ 207 | "B-AS", 208 | "I-AS" 209 | ], 210 | "trainset": "combined_data/absa/ae/rest16-old-sample-1/full/train.txt", 211 | "validset": "combined_data/absa/ae/rest/dev.txt", 212 | "testset": "combined_data/absa/ae/rest/test.txt", 213 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 214 | }, 215 | { 216 | "name": "restaurant_ae_tagging_250_2", 217 | "task_type": "tagging", 218 | "vocab": [ 219 | "B-AS", 220 | "I-AS" 221 | ], 222 | "trainset": "combined_data/absa/ae/rest16-old-sample-2/250/train.txt", 223 | "validset": "combined_data/absa/ae/rest/dev.txt", 224 | "testset": "combined_data/absa/ae/rest/test.txt", 225 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 226 | }, 227 | { 228 | "name": "restaurant_ae_tagging_500_2", 229 | "task_type": "tagging", 230 | "vocab": [ 231 | "B-AS", 232 | "I-AS" 233 | ], 234 | "trainset": "combined_data/absa/ae/rest16-old-sample-2/500/train.txt", 235 | "validset": "combined_data/absa/ae/rest/dev.txt", 236 | "testset": "combined_data/absa/ae/rest/test.txt", 237 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 238 | }, 239 | { 240 | "name": "restaurant_ae_tagging_750_2", 241 | "task_type": "tagging", 242 | "vocab": [ 243 | "B-AS", 244 | "I-AS" 245 | ], 246 | "trainset": "combined_data/absa/ae/rest16-old-sample-2/750/train.txt", 247 | "validset": "combined_data/absa/ae/rest/dev.txt", 248 | "testset": "combined_data/absa/ae/rest/test.txt", 249 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 250 | }, 251 | { 252 | "name": "restaurant_ae_tagging_1000_2", 253 | "task_type": "tagging", 254 | "vocab": [ 255 | "B-AS", 256 | "I-AS" 257 | ], 258 | "trainset": "combined_data/absa/ae/rest16-old-sample-2/1000/train.txt", 259 | "validset": "combined_data/absa/ae/rest/dev.txt", 260 | "testset": "combined_data/absa/ae/rest/test.txt", 261 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 262 | }, 263 | { 264 | "name": "restaurant_ae_tagging_full_2", 265 | "task_type": "tagging", 266 | "vocab": [ 267 | "B-AS", 268 | "I-AS" 269 | ], 270 | "trainset": "combined_data/absa/ae/rest16-old-sample-2/full/train.txt", 271 | "validset": "combined_data/absa/ae/rest/dev.txt", 272 | "testset": "combined_data/absa/ae/rest/test.txt", 273 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 274 | }, 275 | { 276 | "name": "restaurant_ae_tagging_250_3", 277 | "task_type": "tagging", 278 | "vocab": [ 279 | "B-AS", 280 | "I-AS" 281 | ], 282 | "trainset": "combined_data/absa/ae/rest16-old-sample-3/250/train.txt", 283 | "validset": "combined_data/absa/ae/rest/dev.txt", 284 | "testset": "combined_data/absa/ae/rest/test.txt", 285 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 286 | }, 287 | { 288 | "name": "restaurant_ae_tagging_500_3", 289 | "task_type": "tagging", 290 | "vocab": [ 291 | "B-AS", 292 | "I-AS" 293 | ], 294 | "trainset": "combined_data/absa/ae/rest16-old-sample-3/500/train.txt", 295 | "validset": "combined_data/absa/ae/rest/dev.txt", 296 | "testset": "combined_data/absa/ae/rest/test.txt", 297 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 298 | }, 299 | { 300 | "name": "restaurant_ae_tagging_750_3", 301 | "task_type": "tagging", 302 | "vocab": [ 303 | "B-AS", 304 | "I-AS" 305 | ], 306 | "trainset": "combined_data/absa/ae/rest16-old-sample-3/750/train.txt", 307 | "validset": "combined_data/absa/ae/rest/dev.txt", 308 | "testset": "combined_data/absa/ae/rest/test.txt", 309 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 310 | }, 311 | { 312 | "name": "restaurant_ae_tagging_1000_3", 313 | "task_type": "tagging", 314 | "vocab": [ 315 | "B-AS", 316 | "I-AS" 317 | ], 318 | "trainset": "combined_data/absa/ae/rest16-old-sample-3/1000/train.txt", 319 | "validset": "combined_data/absa/ae/rest/dev.txt", 320 | "testset": "combined_data/absa/ae/rest/test.txt", 321 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 322 | }, 323 | { 324 | "name": "restaurant_ae_tagging_full_3", 325 | "task_type": "tagging", 326 | "vocab": [ 327 | "B-AS", 328 | "I-AS" 329 | ], 330 | "trainset": "combined_data/absa/ae/rest16-old-sample-3/full/train.txt", 331 | "validset": "combined_data/absa/ae/rest/dev.txt", 332 | "testset": "combined_data/absa/ae/rest/test.txt", 333 | "unlabeled": "combined_data/absa/ae/rest/unlabeled.txt" 334 | }, 335 | { 336 | "name": "laptop_ae_tagging_250_1", 337 | "task_type": "tagging", 338 | "vocab": [ 339 | "B-AS", 340 | "I-AS" 341 | ], 342 | "trainset": "combined_data/absa/ae/laptop-sample-1/250/train.txt", 343 | "validset": "combined_data/absa/ae/laptop/dev.txt", 344 | "testset": "combined_data/absa/ae/laptop/test.txt", 345 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 346 | }, 347 | { 348 | "name": "laptop_ae_tagging_500_1", 349 | "task_type": "tagging", 350 | "vocab": [ 351 | "B-AS", 352 | "I-AS" 353 | ], 354 | "trainset": "combined_data/absa/ae/laptop-sample-1/500/train.txt", 355 | "validset": "combined_data/absa/ae/laptop/dev.txt", 356 | "testset": "combined_data/absa/ae/laptop/test.txt", 357 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 358 | }, 359 | { 360 | "name": "laptop_ae_tagging_750_1", 361 | "task_type": "tagging", 362 | "vocab": [ 363 | "B-AS", 364 | "I-AS" 365 | ], 366 | "trainset": "combined_data/absa/ae/laptop-sample-1/750/train.txt", 367 | "validset": "combined_data/absa/ae/laptop/dev.txt", 368 | "testset": "combined_data/absa/ae/laptop/test.txt", 369 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 370 | }, 371 | { 372 | "name": "laptop_ae_tagging_1000_1", 373 | "task_type": "tagging", 374 | "vocab": [ 375 | "B-AS", 376 | "I-AS" 377 | ], 378 | "trainset": "combined_data/absa/ae/laptop-sample-1/1000/train.txt", 379 | "validset": "combined_data/absa/ae/laptop/dev.txt", 380 | "testset": "combined_data/absa/ae/laptop/test.txt", 381 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 382 | }, 383 | { 384 | "name": "laptop_ae_tagging_full_1", 385 | "task_type": "tagging", 386 | "vocab": [ 387 | "B-AS", 388 | "I-AS" 389 | ], 390 | "trainset": "combined_data/absa/ae/laptop-sample-1/full/train.txt", 391 | "validset": "combined_data/absa/ae/laptop/dev.txt", 392 | "testset": "combined_data/absa/ae/laptop/test.txt", 393 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 394 | }, 395 | { 396 | "name": "laptop_ae_tagging_250_2", 397 | "task_type": "tagging", 398 | "vocab": [ 399 | "B-AS", 400 | "I-AS" 401 | ], 402 | "trainset": "combined_data/absa/ae/laptop-sample-2/250/train.txt", 403 | "validset": "combined_data/absa/ae/laptop/dev.txt", 404 | "testset": "combined_data/absa/ae/laptop/test.txt", 405 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 406 | }, 407 | { 408 | "name": "laptop_ae_tagging_500_2", 409 | "task_type": "tagging", 410 | "vocab": [ 411 | "B-AS", 412 | "I-AS" 413 | ], 414 | "trainset": "combined_data/absa/ae/laptop-sample-2/500/train.txt", 415 | "validset": "combined_data/absa/ae/laptop/dev.txt", 416 | "testset": "combined_data/absa/ae/laptop/test.txt", 417 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 418 | }, 419 | { 420 | "name": "laptop_ae_tagging_750_2", 421 | "task_type": "tagging", 422 | "vocab": [ 423 | "B-AS", 424 | "I-AS" 425 | ], 426 | "trainset": "combined_data/absa/ae/laptop-sample-2/750/train.txt", 427 | "validset": "combined_data/absa/ae/laptop/dev.txt", 428 | "testset": "combined_data/absa/ae/laptop/test.txt", 429 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 430 | }, 431 | { 432 | "name": "laptop_ae_tagging_1000_2", 433 | "task_type": "tagging", 434 | "vocab": [ 435 | "B-AS", 436 | "I-AS" 437 | ], 438 | "trainset": "combined_data/absa/ae/laptop-sample-2/1000/train.txt", 439 | "validset": "combined_data/absa/ae/laptop/dev.txt", 440 | "testset": "combined_data/absa/ae/laptop/test.txt", 441 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 442 | }, 443 | { 444 | "name": "laptop_ae_tagging_full_2", 445 | "task_type": "tagging", 446 | "vocab": [ 447 | "B-AS", 448 | "I-AS" 449 | ], 450 | "trainset": "combined_data/absa/ae/laptop-sample-2/full/train.txt", 451 | "validset": "combined_data/absa/ae/laptop/dev.txt", 452 | "testset": "combined_data/absa/ae/laptop/test.txt", 453 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 454 | }, 455 | { 456 | "name": "laptop_ae_tagging_250_3", 457 | "task_type": "tagging", 458 | "vocab": [ 459 | "B-AS", 460 | "I-AS" 461 | ], 462 | "trainset": "combined_data/absa/ae/laptop-sample-3/250/train.txt", 463 | "validset": "combined_data/absa/ae/laptop/dev.txt", 464 | "testset": "combined_data/absa/ae/laptop/test.txt", 465 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 466 | }, 467 | { 468 | "name": "laptop_ae_tagging_500_3", 469 | "task_type": "tagging", 470 | "vocab": [ 471 | "B-AS", 472 | "I-AS" 473 | ], 474 | "trainset": "combined_data/absa/ae/laptop-sample-3/500/train.txt", 475 | "validset": "combined_data/absa/ae/laptop/dev.txt", 476 | "testset": "combined_data/absa/ae/laptop/test.txt", 477 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 478 | }, 479 | { 480 | "name": "laptop_ae_tagging_750_3", 481 | "task_type": "tagging", 482 | "vocab": [ 483 | "B-AS", 484 | "I-AS" 485 | ], 486 | "trainset": "combined_data/absa/ae/laptop-sample-3/750/train.txt", 487 | "validset": "combined_data/absa/ae/laptop/dev.txt", 488 | "testset": "combined_data/absa/ae/laptop/test.txt", 489 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 490 | }, 491 | { 492 | "name": "laptop_ae_tagging_1000_3", 493 | "task_type": "tagging", 494 | "vocab": [ 495 | "B-AS", 496 | "I-AS" 497 | ], 498 | "trainset": "combined_data/absa/ae/laptop-sample-3/1000/train.txt", 499 | "validset": "combined_data/absa/ae/laptop/dev.txt", 500 | "testset": "combined_data/absa/ae/laptop/test.txt", 501 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 502 | }, 503 | { 504 | "name": "laptop_ae_tagging_full_3", 505 | "task_type": "tagging", 506 | "vocab": [ 507 | "B-AS", 508 | "I-AS" 509 | ], 510 | "trainset": "combined_data/absa/ae/laptop-sample-3/full/train.txt", 511 | "validset": "combined_data/absa/ae/laptop/dev.txt", 512 | "testset": "combined_data/absa/ae/laptop/test.txt", 513 | "unlabeled": "combined_data/absa/ae/laptop/unlabeled.txt" 514 | }, 515 | { 516 | "name": "restaurant_asc_250_1", 517 | "task_type": "classification", 518 | "vocab": [ 519 | "positive", 520 | "negative", 521 | "neutral" 522 | ], 523 | "trainset": "combined_data/absa/asc/rest-sample-1/250/train.txt", 524 | "validset": "combined_data/absa/asc/rest/dev.txt", 525 | "testset": "combined_data/absa/asc/rest/test.txt", 526 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 527 | }, 528 | { 529 | "name": "restaurant_asc_500_1", 530 | "task_type": "classification", 531 | "vocab": [ 532 | "positive", 533 | "negative", 534 | "neutral" 535 | ], 536 | "trainset": "combined_data/absa/asc/rest-sample-1/500/train.txt", 537 | "validset": "combined_data/absa/asc/rest/dev.txt", 538 | "testset": "combined_data/absa/asc/rest/test.txt", 539 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 540 | }, 541 | { 542 | "name": "restaurant_asc_750_1", 543 | "task_type": "classification", 544 | "vocab": [ 545 | "positive", 546 | "negative", 547 | "neutral" 548 | ], 549 | "trainset": "combined_data/absa/asc/rest-sample-1/750/train.txt", 550 | "validset": "combined_data/absa/asc/rest/dev.txt", 551 | "testset": "combined_data/absa/asc/rest/test.txt", 552 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 553 | }, 554 | { 555 | "name": "restaurant_asc_1000_1", 556 | "task_type": "classification", 557 | "vocab": [ 558 | "positive", 559 | "negative", 560 | "neutral" 561 | ], 562 | "trainset": "combined_data/absa/asc/rest-sample-1/1000/train.txt", 563 | "validset": "combined_data/absa/asc/rest/dev.txt", 564 | "testset": "combined_data/absa/asc/rest/test.txt", 565 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 566 | }, 567 | { 568 | "name": "restaurant_asc_full_1", 569 | "task_type": "classification", 570 | "vocab": [ 571 | "positive", 572 | "negative", 573 | "neutral" 574 | ], 575 | "trainset": "combined_data/absa/asc/rest-sample-1/full/train.txt", 576 | "validset": "combined_data/absa/asc/rest/dev.txt", 577 | "testset": "combined_data/absa/asc/rest/test.txt", 578 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 579 | }, 580 | { 581 | "name": "restaurant_asc_250_2", 582 | "task_type": "classification", 583 | "vocab": [ 584 | "positive", 585 | "negative", 586 | "neutral" 587 | ], 588 | "trainset": "combined_data/absa/asc/rest-sample-2/250/train.txt", 589 | "validset": "combined_data/absa/asc/rest/dev.txt", 590 | "testset": "combined_data/absa/asc/rest/test.txt", 591 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 592 | }, 593 | { 594 | "name": "restaurant_asc_500_2", 595 | "task_type": "classification", 596 | "vocab": [ 597 | "positive", 598 | "negative", 599 | "neutral" 600 | ], 601 | "trainset": "combined_data/absa/asc/rest-sample-2/500/train.txt", 602 | "validset": "combined_data/absa/asc/rest/dev.txt", 603 | "testset": "combined_data/absa/asc/rest/test.txt", 604 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 605 | }, 606 | { 607 | "name": "restaurant_asc_750_2", 608 | "task_type": "classification", 609 | "vocab": [ 610 | "positive", 611 | "negative", 612 | "neutral" 613 | ], 614 | "trainset": "combined_data/absa/asc/rest-sample-2/750/train.txt", 615 | "validset": "combined_data/absa/asc/rest/dev.txt", 616 | "testset": "combined_data/absa/asc/rest/test.txt", 617 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 618 | }, 619 | { 620 | "name": "restaurant_asc_1000_2", 621 | "task_type": "classification", 622 | "vocab": [ 623 | "positive", 624 | "negative", 625 | "neutral" 626 | ], 627 | "trainset": "combined_data/absa/asc/rest-sample-2/1000/train.txt", 628 | "validset": "combined_data/absa/asc/rest/dev.txt", 629 | "testset": "combined_data/absa/asc/rest/test.txt", 630 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 631 | }, 632 | { 633 | "name": "restaurant_asc_full_2", 634 | "task_type": "classification", 635 | "vocab": [ 636 | "positive", 637 | "negative", 638 | "neutral" 639 | ], 640 | "trainset": "combined_data/absa/asc/rest-sample-2/full/train.txt", 641 | "validset": "combined_data/absa/asc/rest/dev.txt", 642 | "testset": "combined_data/absa/asc/rest/test.txt", 643 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 644 | }, 645 | { 646 | "name": "restaurant_asc_250_3", 647 | "task_type": "classification", 648 | "vocab": [ 649 | "positive", 650 | "negative", 651 | "neutral" 652 | ], 653 | "trainset": "combined_data/absa/asc/rest-sample-3/250/train.txt", 654 | "validset": "combined_data/absa/asc/rest/dev.txt", 655 | "testset": "combined_data/absa/asc/rest/test.txt", 656 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 657 | }, 658 | { 659 | "name": "restaurant_asc_500_3", 660 | "task_type": "classification", 661 | "vocab": [ 662 | "positive", 663 | "negative", 664 | "neutral" 665 | ], 666 | "trainset": "combined_data/absa/asc/rest-sample-3/500/train.txt", 667 | "validset": "combined_data/absa/asc/rest/dev.txt", 668 | "testset": "combined_data/absa/asc/rest/test.txt", 669 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 670 | }, 671 | { 672 | "name": "restaurant_asc_750_3", 673 | "task_type": "classification", 674 | "vocab": [ 675 | "positive", 676 | "negative", 677 | "neutral" 678 | ], 679 | "trainset": "combined_data/absa/asc/rest-sample-3/750/train.txt", 680 | "validset": "combined_data/absa/asc/rest/dev.txt", 681 | "testset": "combined_data/absa/asc/rest/test.txt", 682 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 683 | }, 684 | { 685 | "name": "restaurant_asc_1000_3", 686 | "task_type": "classification", 687 | "vocab": [ 688 | "positive", 689 | "negative", 690 | "neutral" 691 | ], 692 | "trainset": "combined_data/absa/asc/rest-sample-3/1000/train.txt", 693 | "validset": "combined_data/absa/asc/rest/dev.txt", 694 | "testset": "combined_data/absa/asc/rest/test.txt", 695 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 696 | }, 697 | { 698 | "name": "restaurant_asc_full_3", 699 | "task_type": "classification", 700 | "vocab": [ 701 | "positive", 702 | "negative", 703 | "neutral" 704 | ], 705 | "trainset": "combined_data/absa/asc/rest-sample-3/full/train.txt", 706 | "validset": "combined_data/absa/asc/rest/dev.txt", 707 | "testset": "combined_data/absa/asc/rest/test.txt", 708 | "unlabeled": "combined_data/absa/asc/rest/unlabeled.txt" 709 | }, 710 | { 711 | "name": "laptop_asc_250_1", 712 | "task_type": "classification", 713 | "vocab": [ 714 | "positive", 715 | "negative", 716 | "neutral" 717 | ], 718 | "trainset": "combined_data/absa/asc/laptop-sample-1/250/train.txt", 719 | "validset": "combined_data/absa/asc/laptop/dev.txt", 720 | "testset": "combined_data/absa/asc/laptop/test.txt", 721 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 722 | }, 723 | { 724 | "name": "laptop_asc_500_1", 725 | "task_type": "classification", 726 | "vocab": [ 727 | "positive", 728 | "negative", 729 | "neutral" 730 | ], 731 | "trainset": "combined_data/absa/asc/laptop-sample-1/500/train.txt", 732 | "validset": "combined_data/absa/asc/laptop/dev.txt", 733 | "testset": "combined_data/absa/asc/laptop/test.txt", 734 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 735 | }, 736 | { 737 | "name": "laptop_asc_750_1", 738 | "task_type": "classification", 739 | "vocab": [ 740 | "positive", 741 | "negative", 742 | "neutral" 743 | ], 744 | "trainset": "combined_data/absa/asc/laptop-sample-1/750/train.txt", 745 | "validset": "combined_data/absa/asc/laptop/dev.txt", 746 | "testset": "combined_data/absa/asc/laptop/test.txt", 747 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 748 | }, 749 | { 750 | "name": "laptop_asc_1000_1", 751 | "task_type": "classification", 752 | "vocab": [ 753 | "positive", 754 | "negative", 755 | "neutral" 756 | ], 757 | "trainset": "combined_data/absa/asc/laptop-sample-1/1000/train.txt", 758 | "validset": "combined_data/absa/asc/laptop/dev.txt", 759 | "testset": "combined_data/absa/asc/laptop/test.txt", 760 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 761 | }, 762 | { 763 | "name": "laptop_asc_full_1", 764 | "task_type": "classification", 765 | "vocab": [ 766 | "positive", 767 | "negative", 768 | "neutral" 769 | ], 770 | "trainset": "combined_data/absa/asc/laptop-sample-1/full/train.txt", 771 | "validset": "combined_data/absa/asc/laptop/dev.txt", 772 | "testset": "combined_data/absa/asc/laptop/test.txt", 773 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 774 | }, 775 | { 776 | "name": "laptop_asc_250_2", 777 | "task_type": "classification", 778 | "vocab": [ 779 | "positive", 780 | "negative", 781 | "neutral" 782 | ], 783 | "trainset": "combined_data/absa/asc/laptop-sample-2/250/train.txt", 784 | "validset": "combined_data/absa/asc/laptop/dev.txt", 785 | "testset": "combined_data/absa/asc/laptop/test.txt", 786 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 787 | }, 788 | { 789 | "name": "laptop_asc_500_2", 790 | "task_type": "classification", 791 | "vocab": [ 792 | "positive", 793 | "negative", 794 | "neutral" 795 | ], 796 | "trainset": "combined_data/absa/asc/laptop-sample-2/500/train.txt", 797 | "validset": "combined_data/absa/asc/laptop/dev.txt", 798 | "testset": "combined_data/absa/asc/laptop/test.txt", 799 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 800 | }, 801 | { 802 | "name": "laptop_asc_750_2", 803 | "task_type": "classification", 804 | "vocab": [ 805 | "positive", 806 | "negative", 807 | "neutral" 808 | ], 809 | "trainset": "combined_data/absa/asc/laptop-sample-2/750/train.txt", 810 | "validset": "combined_data/absa/asc/laptop/dev.txt", 811 | "testset": "combined_data/absa/asc/laptop/test.txt", 812 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 813 | }, 814 | { 815 | "name": "laptop_asc_1000_2", 816 | "task_type": "classification", 817 | "vocab": [ 818 | "positive", 819 | "negative", 820 | "neutral" 821 | ], 822 | "trainset": "combined_data/absa/asc/laptop-sample-2/1000/train.txt", 823 | "validset": "combined_data/absa/asc/laptop/dev.txt", 824 | "testset": "combined_data/absa/asc/laptop/test.txt", 825 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 826 | }, 827 | { 828 | "name": "laptop_asc_full_2", 829 | "task_type": "classification", 830 | "vocab": [ 831 | "positive", 832 | "negative", 833 | "neutral" 834 | ], 835 | "trainset": "combined_data/absa/asc/laptop-sample-2/full/train.txt", 836 | "validset": "combined_data/absa/asc/laptop/dev.txt", 837 | "testset": "combined_data/absa/asc/laptop/test.txt", 838 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 839 | }, 840 | { 841 | "name": "laptop_asc_250_3", 842 | "task_type": "classification", 843 | "vocab": [ 844 | "positive", 845 | "negative", 846 | "neutral" 847 | ], 848 | "trainset": "combined_data/absa/asc/laptop-sample-3/250/train.txt", 849 | "validset": "combined_data/absa/asc/laptop/dev.txt", 850 | "testset": "combined_data/absa/asc/laptop/test.txt", 851 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 852 | }, 853 | { 854 | "name": "laptop_asc_500_3", 855 | "task_type": "classification", 856 | "vocab": [ 857 | "positive", 858 | "negative", 859 | "neutral" 860 | ], 861 | "trainset": "combined_data/absa/asc/laptop-sample-3/500/train.txt", 862 | "validset": "combined_data/absa/asc/laptop/dev.txt", 863 | "testset": "combined_data/absa/asc/laptop/test.txt", 864 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 865 | }, 866 | { 867 | "name": "laptop_asc_750_3", 868 | "task_type": "classification", 869 | "vocab": [ 870 | "positive", 871 | "negative", 872 | "neutral" 873 | ], 874 | "trainset": "combined_data/absa/asc/laptop-sample-3/750/train.txt", 875 | "validset": "combined_data/absa/asc/laptop/dev.txt", 876 | "testset": "combined_data/absa/asc/laptop/test.txt", 877 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 878 | }, 879 | { 880 | "name": "laptop_asc_1000_3", 881 | "task_type": "classification", 882 | "vocab": [ 883 | "positive", 884 | "negative", 885 | "neutral" 886 | ], 887 | "trainset": "combined_data/absa/asc/laptop-sample-3/1000/train.txt", 888 | "validset": "combined_data/absa/asc/laptop/dev.txt", 889 | "testset": "combined_data/absa/asc/laptop/test.txt", 890 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 891 | }, 892 | { 893 | "name": "laptop_asc_full_3", 894 | "task_type": "classification", 895 | "vocab": [ 896 | "positive", 897 | "negative", 898 | "neutral" 899 | ], 900 | "trainset": "combined_data/absa/asc/laptop-sample-3/full/train.txt", 901 | "validset": "combined_data/absa/asc/laptop/dev.txt", 902 | "testset": "combined_data/absa/asc/laptop/test.txt", 903 | "unlabeled": "combined_data/absa/asc/laptop/unlabeled.txt" 904 | } 905 | ] 906 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gensim==3.8.1 2 | numpy==1.19.2 3 | regex==2019.12.20 4 | spacy==2.2.3 5 | sentencepiece==0.1.85 6 | sklearn==0.0 7 | spacy==2.2.3 8 | tensorboardX==2.0 9 | torch==1.4.0 10 | tqdm==4.41.0 11 | transformers==3.1.0 12 | jsonlines==1.2.0 13 | nltk==3.4.5 14 | -------------------------------------------------------------------------------- /run_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import numpy as np 5 | import random 6 | import json 7 | import jsonlines 8 | import csv 9 | import spacy 10 | import re 11 | import time 12 | import argparse 13 | import sys 14 | 15 | from torch.utils import data 16 | from tqdm import tqdm 17 | from collections import OrderedDict 18 | 19 | from snippext.model import MultiTaskNet 20 | from snippext.dataset import SnippextDataset 21 | 22 | csv.field_size_limit(sys.maxsize) 23 | nlp = spacy.load('en_core_web_sm') 24 | 25 | def handle_punct(text): 26 | """Basic handling of punctuations 27 | 28 | Args: 29 | text (str): the input text 30 | Returns: 31 | str: the string with the bad characters replaced and 32 | new characters inserted 33 | """ 34 | text = text.replace("''", "'").replace("\\n", ' ') 35 | new_text = '' 36 | i = 0 37 | N = len(text) 38 | while i < len(text): 39 | curr_chr = text[i] 40 | new_text += curr_chr 41 | if i > 0 and i < N - 1: 42 | next_chr = text[i + 1] 43 | prev_chr = text[i - 1] 44 | if next_chr.isalnum() and prev_chr.isalnum() and curr_chr in '!?.,();:': 45 | new_text += ' ' 46 | i += 1 47 | return new_text 48 | 49 | 50 | def sent_tokenizer(text): 51 | """Tokenizer a paragraph of text into a list of sentences. 52 | 53 | Args: 54 | text (str): the input paragraph 55 | 56 | Returns: 57 | list of spacy Sentence: the tokenized sentences 58 | """ 59 | text = handle_punct(text)[:1000000] 60 | ori_sentences = [] 61 | for line in text.split('\n'): 62 | for sent in nlp(line, disable=['tagger', 'ner']).sents: 63 | if len(sent) >= 2: 64 | ori_sentences.append(sent) 65 | 66 | return ori_sentences 67 | 68 | def do_tagging(text, config, model): 69 | """Apply the tagging model. 70 | 71 | Args: 72 | text (str): the input paragraph 73 | config (dict): the model configuration 74 | model (MultiTaskNet): the model in pytorch 75 | 76 | Returns: 77 | list of list of str: the tokens in each sentences 78 | list of list of int: each token's starting position in the original text 79 | list of list of str: the tags assigned to each token 80 | """ 81 | # load data and tokenization 82 | source = [] 83 | token_pos_list = [] 84 | # print('Tokenize sentences') 85 | for sent in sent_tokenizer(text): 86 | tokens = [token.text for token in sent] 87 | token_pos = [token.idx for token in sent] 88 | source.append(tokens) 89 | token_pos_list.append(token_pos) 90 | 91 | dataset = SnippextDataset(source, config['vocab'], config['name'], 92 | lm=model.lm, 93 | max_len=64) 94 | iterator = data.DataLoader(dataset=dataset, 95 | batch_size=32, 96 | shuffle=False, 97 | num_workers=0, 98 | collate_fn=SnippextDataset.pad) 99 | 100 | # prediction 101 | model.eval() 102 | Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], [] 103 | with torch.no_grad(): 104 | # print('Tagging') 105 | for i, batch in enumerate(iterator): 106 | try: 107 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 108 | taskname = taskname[0] 109 | _, _, y_hat = model(x, y, task=taskname) # y_hat: (N, T) 110 | 111 | Words.extend(words) 112 | Is_heads.extend(is_heads) 113 | Tags.extend(tags) 114 | Y.extend(y.numpy().tolist()) 115 | Y_hat.extend(y_hat.cpu().numpy().tolist()) 116 | except: 117 | print('error @', batch) 118 | 119 | # gets results and save 120 | results = [] 121 | for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat): 122 | y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1] 123 | # remove the first and the last token 124 | preds = [dataset.idx2tag[hat] for hat in y_hat][1:-1] 125 | results.append(preds) 126 | 127 | return source, token_pos_list, results 128 | 129 | def do_pairing(all_tokens, all_tags, config, model): 130 | """Apply the pairing model. 131 | 132 | Args: 133 | all_tokens (list of list of str): the tokenized text 134 | all_tags (list of list of str): the tags assigned to each token 135 | config (dict): the model configuration 136 | model (MultiTaskNet): the model in pytorch 137 | 138 | Returns: 139 | list of dict: For each sentence, the list of extracted 140 | opinions/experiences from the sentence. Each dictionary includes 141 | an aspect term and an opinion term and the start/end 142 | position of the aspect/opinion term. 143 | """ 144 | samples = [] 145 | sent_ids = [] 146 | candidates = [] 147 | positions = [] 148 | all_spans = {} 149 | 150 | sid = 0 151 | for tokens, tags in zip(all_tokens, all_tags): 152 | aspects = [] 153 | opinions = [] 154 | # find aspects 155 | # find opinions 156 | for i, tag in enumerate(tags): 157 | if tag[0] == 'B': 158 | start = i 159 | end = i 160 | while end + 1 < len(tags) and tags[end + 1][0] == 'I': 161 | end += 1 162 | if tag == 'B-AS': 163 | aspects.append((start, end)) 164 | all_spans[(sid, start, end)] = {'aspect': ' '.join(tokens[start:end+1]), 165 | 'sid': sid, 166 | 'asp_start': start, 167 | 'asp_end': end} 168 | else: 169 | opinions.append((start, end)) 170 | all_spans[(sid, start, end)] = {'opinion': ' '.join(tokens[start:end+1]), 171 | 'sid': sid, 172 | 'op_start': start, 173 | 'op_end': end} 174 | 175 | candidate_pairs = [] 176 | for asp in aspects: 177 | for opi in opinions: 178 | candidate_pairs.append((asp, opi)) 179 | candidate_pairs.sort(key=lambda ao: abs(ao[0][0] - ao[1][0])) 180 | 181 | for asp, opi in candidate_pairs: 182 | asp_start, asp_end = asp 183 | op_start, op_end = opi 184 | token_ids = [] 185 | for i in range(asp_start, asp_end + 1): 186 | token_ids.append((sid, i)) 187 | for i in range(op_start, op_end + 1): 188 | token_ids.append((sid, i)) 189 | 190 | if op_start < asp_start: 191 | samples.append(' '.join(tokens) + ' [SEP] ' + \ 192 | ' '.join(tokens[op_start:op_end+1]) + ' ' + \ 193 | ' '.join(tokens[asp_start:asp_end+1])) 194 | else: 195 | samples.append(' '.join(tokens) + ' [SEP] ' + \ 196 | ' '.join(tokens[asp_start:asp_end+1]) + ' ' + \ 197 | ' '.join(tokens[op_start:op_end+1])) 198 | 199 | sent_ids.append(sid) 200 | candidates.append({'opinion': ' '.join(tokens[op_start:op_end+1]), 201 | 'aspect': ' '.join(tokens[asp_start:asp_end+1]), 202 | 'sid': sid, 203 | 'asp_start': asp_start, 204 | 'asp_end': asp_end, 205 | 'op_start': op_start, 206 | 'op_end': op_end}) 207 | positions.append(token_ids) 208 | sid += 1 209 | 210 | dataset = SnippextDataset(samples, config['vocab'], config['name'], 211 | lm=model.lm) 212 | iterator = data.DataLoader(dataset=dataset, 213 | batch_size=32, 214 | shuffle=False, 215 | num_workers=0, 216 | collate_fn=SnippextDataset.pad) 217 | 218 | # prediction 219 | Y_hat = [] 220 | Y = [] 221 | with torch.no_grad(): 222 | for i, batch in enumerate(iterator): 223 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 224 | taskname = taskname[0] 225 | _, y, y_hat = model(x, y, task=taskname) # y_hat: (N, T) 226 | Y_hat.extend(y_hat.cpu().numpy().tolist()) 227 | Y.extend(y.cpu().numpy().tolist()) 228 | 229 | results = [] 230 | for tokens in all_tokens: 231 | results.append({'sentence': ' '.join(tokens), 232 | 'extractions': []}) 233 | 234 | used = set([]) 235 | for i, yhat in enumerate(Y_hat): 236 | phrase = samples[i].split(' [SEP] ')[1] 237 | # print(phrase, yhat) 238 | if yhat == 1: 239 | # do some filtering 240 | assigned = False 241 | for tid in positions[i]: 242 | if tid in used: 243 | assigned = True 244 | break 245 | 246 | if not assigned: 247 | results[sent_ids[i]]['extractions'].append(candidates[i]) 248 | for tid in positions[i]: 249 | used.add(tid) 250 | # drop from all_spans 251 | sid = candidates[i]['sid'] 252 | del all_spans[(sid, 253 | candidates[i]['asp_start'], 254 | candidates[i]['asp_end'])] 255 | del all_spans[(sid, 256 | candidates[i]['op_start'], 257 | candidates[i]['op_end'])] 258 | 259 | # add aspects/opinions that are not paired 260 | for sid, start, end in all_spans: 261 | results[sid]['extractions'].append(all_spans[(sid, start, end)]) 262 | 263 | return results 264 | 265 | 266 | def classify(extractions, config, model, sents=None): 267 | """Apply the classification models (for Sentiment and Attribute Classification). 268 | 269 | Args: 270 | extractions (list of dict): the partial extraction results by the pairing model 271 | config (dict): the model configuration 272 | model (MultiTaskNet): the model in pytorch 273 | 274 | Returns: 275 | list of dict: the extraction results with attribute name and sentiment score 276 | assigned to the field "attribute" and "sentiment". 277 | """ 278 | phrases = [] 279 | index = [] 280 | # print('Prepare classification data') 281 | for sid, sent in enumerate(extractions): 282 | for eid, ext in enumerate(sent['extractions']): 283 | if 'asc' in config['name']: 284 | if 'aspect' in ext: 285 | phrase = ' '.join(sents[ext['sid']]) + '\t' + ext['aspect'] 286 | else: 287 | phrase = ' '.join(sents[ext['sid']]) + '\t' + ext['opinion'] 288 | else: 289 | if 'aspect' in ext and 'opinion' in ext: 290 | phrase = ext['opinion'] + ' ' + ext['aspect'] 291 | elif 'aspect' in ext: 292 | phrase = ext['aspect'] 293 | else: 294 | phrase = ext['opinion'] 295 | phrases.append(phrase) 296 | index.append((sid, eid)) 297 | 298 | dataset = SnippextDataset(phrases, config['vocab'], config['name'], 299 | lm=model.lm) 300 | iterator = data.DataLoader(dataset=dataset, 301 | batch_size=32, 302 | shuffle=False, 303 | num_workers=0, 304 | collate_fn=SnippextDataset.pad) 305 | 306 | # prediction 307 | Y_hat = [] 308 | with torch.no_grad(): 309 | # print('Classification') 310 | for i, batch in enumerate(iterator): 311 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 312 | taskname = taskname[0] 313 | _, _, y_hat = model(x, y, task=taskname) # y_hat: (N, T) 314 | Y_hat.extend(y_hat.cpu().numpy().tolist()) 315 | 316 | for i in range(len(phrases)): 317 | attr = dataset.idx2tag[Y_hat[i]] 318 | sid, eid = index[i] 319 | if 'asc' in config['name']: 320 | extractions[sid]['extractions'][eid]['sentiment'] = attr 321 | else: 322 | extractions[sid]['extractions'][eid]['attribute'] = attr 323 | 324 | return extractions 325 | 326 | def extract(review, config_list, models): 327 | """Extract experiences and opinions from a paragraph of text 328 | 329 | Args: 330 | review (Dictionary): a review object with a text field to be extracted 331 | config_list (list of dictionary): a list of task config dictionary 332 | models (list of MultiTaskNet): the most of models 333 | 334 | Returns: 335 | Dictionary: the same review object with a new extraction field 336 | """ 337 | text = review['content'] 338 | 339 | start_time = time.time() 340 | # tagging 341 | all_tokens, token_pos, all_tags = do_tagging(text, config_list[0], models[0]) 342 | # pairing 343 | extractions = do_pairing(all_tokens, all_tags, config_list[1], models[1]) 344 | # classification 345 | extractions = classify(extractions, config_list[2], models[2]) 346 | # asc 347 | extractions = classify(extractions, config_list[3], models[3], sents=all_tokens) 348 | 349 | review['extractions'] = [] 350 | review['sentences'] = [] 351 | for sent, tokens in zip(extractions, all_tokens): 352 | review['extractions'] += sent['extractions'] 353 | review['sentences'].append(tokens) 354 | return review 355 | 356 | 357 | def load_model(config, 358 | path, 359 | device='gpu', 360 | lm='bert', 361 | fp16=False): 362 | """Load a model for a specific task. 363 | 364 | Args: 365 | config (dictionary): the task dictionary 366 | path (string): the path to the checkpoint 367 | lm (str, optional): the language model (bert, distilbert, or albert) 368 | fp16 (boolean): whether to use fp16 optimization 369 | 370 | Returns: 371 | MultiTaskNet: the model 372 | """ 373 | model = MultiTaskNet([config], device, True, lm=lm) 374 | saved_state = torch.load(path, map_location=lambda storage, loc: storage) 375 | model.load_state_dict(saved_state) 376 | model = model.to(device) 377 | 378 | if fp16 and 'cuda' in device: 379 | from apex import amp 380 | model = amp.initialize(model, opt_level='O2') 381 | 382 | return model 383 | 384 | def predict(input_fn, output_fn, config_list, models): 385 | """Run the extraction on an input csv file. 386 | 387 | Args: 388 | input_fn (str): the input file name (.csv) 389 | output_fn (str): the output file name (.jsonl) 390 | config_list (list of dict): the list of configuration 391 | models (list of MultiTaskNet): the list of models 392 | 393 | Returns: 394 | None 395 | """ 396 | with jsonlines.open(output_fn, mode='w') as writer: 397 | with open(input_fn) as fin: 398 | reader = csv.DictReader(fin) 399 | for idx, row in tqdm(enumerate(reader)): 400 | try: 401 | review = extract(row, config_list, models) 402 | writer.write(review) 403 | except: 404 | writer.write(row) 405 | 406 | def initialize(checkpoint_path, 407 | use_gpu=False, 408 | lm='bert', 409 | fp16=False, 410 | tasks=['hotel_tagging', 411 | 'pairing', 412 | 'sf_hotel_classification', 413 | 'restaurant_asc']): 414 | """load the models from a path storing the checkpoints. 415 | 416 | Args: 417 | checkpoint_path (str): the path to the checkpoints 418 | use_gpu (boolean, optional): whether to use gpu 419 | lm (string, optional): the language model (default: bert) 420 | fp16 (boolean): whether to use fp16 421 | tasks (list of str, optional): the list of snippext tasks 422 | Returns: 423 | list of dictionary: the configuration list 424 | list of MultiTaskNet: the list of models 425 | """ 426 | # load models 427 | checkpoints = dict([(task, os.path.join(checkpoint_path, \ 428 | '%s.pt' % task)) for task in tasks]) 429 | configs = json.load(open('configs.json')) 430 | configs = {conf['name'] : conf for conf in configs} 431 | 432 | if use_gpu: 433 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 434 | else: 435 | device = 'cpu' 436 | 437 | models = [load_model(configs[task], checkpoints[task], device=device, 438 | lm=lm, fp16=fp16) for task in tasks] 439 | config_list = [configs[task] for task in tasks] 440 | 441 | return config_list, models 442 | 443 | # running the command line version 444 | if __name__ == "__main__": 445 | parser = argparse.ArgumentParser() 446 | parser.add_argument("--input_fn", type=str, default='input/trustyou_raw_review_sampled.csv') 447 | parser.add_argument("--output_fn", type=str, default='trustyou_reviews_with_extractions.jsonl') 448 | parser.add_argument("--use_gpu", dest="use_gpu", action="store_true") 449 | parser.add_argument("--fp16", dest="fp16", action="store_true") 450 | parser.add_argument("--checkpoint_path", type=str, default='checkpoints/') 451 | parser.add_argument("--lm", type=str, default='bert') 452 | parser.add_argument("--tasks", type=str, default='hotel_tagging,pairing,sf_hotel_classification,restaurant_asc') 453 | hp = parser.parse_args() 454 | 455 | config_list, models = initialize(hp.checkpoint_path, hp.use_gpu, 456 | lm=hp.lm, fp16=hp.fp16, tasks=hp.tasks.split(',')) 457 | predict(hp.input_fn, hp.output_fn, config_list, models) 458 | -------------------------------------------------------------------------------- /snippext/__init__.py: -------------------------------------------------------------------------------- 1 | # from .dataset import * 2 | # from .baseline import initialize_and_train 3 | # from .mixda import initialize_and_train 4 | # from .mixmatchnl import initialize_and_train 5 | -------------------------------------------------------------------------------- /snippext/augment.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import numpy as np 4 | 5 | class Augmenter(object): 6 | """Data augmentation for the extractor. 7 | 8 | Support both token and span level augmentation operators. 9 | 10 | Attributes: 11 | index (dict): a dictionary containing both the token and span level index 12 | all_spans (dict): a dictionary from span type to a list of all the spans in the index 13 | span_freqs (dict): a dictionary from span type to the document frequency of each span in all_spans 14 | """ 15 | 16 | def __init__(self, index_fn, valid_spans=None): 17 | self.index = json.load(open(index_fn)) 18 | self.all_spans = {} 19 | self.span_freqs = {} 20 | 21 | for span_type in self.index['span']: 22 | self.all_spans[span_type] = list(self.index['span'][span_type].keys()) 23 | span_freqs_tmp = [self.index['span'][span_type][sp]['document_freq'] \ 24 | for sp in self.all_spans[span_type]] 25 | self.span_freqs[span_type] = np.array(span_freqs_tmp) / np.sum(span_freqs_tmp) 26 | 27 | def augment(self, tokens, labels, op='token_del_tfidf'): 28 | """ Performs data augmentation on a tagging example. 29 | 30 | We support deletion (del), insertion (ins), replacement (repl), 31 | and swapping(swap) at the token level. At the span level, we support 32 | replacement with the options of replacing with random, frequent (freq), 33 | or similar spans (sim). 34 | 35 | The supported ops: 36 | ['token_del_tfidf', 37 | 'token_del', 38 | 'token_repl_tfidf', 39 | 'token_repl', 40 | 'token_swap', 41 | 'token_ins', 42 | 'span_sim', 43 | 'span_freq', 44 | 'span'] 45 | 46 | Args: 47 | tokens (list of strings): the input tokens 48 | labels (list of strings): the labels of the tokens 49 | op (str, optional): a string encoding of the operator to be applied 50 | 51 | Returns: 52 | list of strings: the augmented tokens 53 | list of strings: the augmented labels 54 | """ 55 | flags = op.split('_') 56 | if 'span' in flags: 57 | start, end = self.sample_span_position(tokens, labels) 58 | if start < 0: 59 | return tokens, labels 60 | span = ' '.join(tokens[start:end+1])# .lower() 61 | label = labels[start] 62 | if label.startswith('B-'): 63 | labelI = 'I' + labels[start][1:] 64 | else: 65 | labelI = label 66 | 67 | if 'AS' in label: 68 | span_type = 'aspect' 69 | else: 70 | span_type = 'opinion' 71 | 72 | if 'sim' in op: 73 | candidates = self.index['span'][span_type][span]['similar_spans'] 74 | new_span = random.choice(candidates)[0] 75 | elif 'freq' in op: 76 | candidates = self.all_spans[span_type] 77 | new_span = np.random.choice(candidates, 1, 78 | p=self.span_freqs[span_type])[0] 79 | else: 80 | candidates = self.all_spans[span_type] 81 | new_span = random.choice(candidates) 82 | 83 | new_span_len = len(new_span.split(' ')) 84 | new_tokens = tokens[:start] + \ 85 | new_span.split(' ') + tokens[end+1:] 86 | new_labels = labels[:start] + [label] + \ 87 | [labelI] * (new_span_len - 1) + labels[end+1:] 88 | return new_tokens, new_labels 89 | else: 90 | tfidf = 'tfidf' in op 91 | pos1 = self.sample_position(tokens, labels, tfidf) 92 | if pos1 < 0: 93 | return tokens, labels 94 | 95 | if 'del' in op: 96 | # insert padding to keep the length consistent 97 | if tokens[pos1] in self.index['token']: 98 | length = self.index['token'][tokens[pos1]]['bert_length'] 99 | else: 100 | length = 1 101 | new_tokens = tokens[:pos1] + ['[PAD]']*(length) + tokens[pos1+1:] 102 | new_labels = labels[:pos1] + ['']*(length) + labels[pos1+1:] 103 | elif 'ins' in op: 104 | ins_token = self.sample_token(tokens[pos1], same_length=False) 105 | new_tokens = tokens[:pos1] + [ins_token] + tokens[pos1:] 106 | new_labels = labels[:pos1] + ['O'] + labels[pos1:] 107 | elif 'repl' in op: 108 | ins_token = self.sample_token(tokens[pos1], same_length=False) 109 | if tokens[pos1] in self.index['token'] and \ 110 | ins_token in self.index['token']: 111 | len1 = self.index['token'][tokens[pos1]]['bert_length'] 112 | len2 = self.index['token'][ins_token]['bert_length'] 113 | if len1 < len2: 114 | # truncate the new sequence 115 | bert_tokens = self.index['token'][ins_token]['bert_token'][:len1] 116 | bert_tokens = [token.replace('##', '') for token in bert_tokens] 117 | ins_token = ''.join(bert_tokens) 118 | new_tokens = tokens[:pos1] + [ins_token] + tokens[pos1+1:] 119 | new_labels = labels[:pos1] + ['O'] + labels[pos1+1:] 120 | else: 121 | # pad the new sequence 122 | more = len1 - len2 123 | new_tokens = tokens[:pos1] + [ins_token] + ['[PAD]']*more + tokens[pos1+1:] 124 | new_labels = labels[:pos1] + ['O'] + ['']*more + labels[pos1+1:] 125 | else: 126 | # backup 127 | new_tokens = tokens[:pos1] + [ins_token] + tokens[pos1+1:] 128 | new_labels = labels[:pos1] + ['O'] + labels[pos1+1:] 129 | elif 'swap' in op: 130 | pos2 = self.sample_position(tokens, labels, tfidf) 131 | new_tokens = list(tokens) 132 | new_labels = list(labels) 133 | new_tokens[pos1], new_tokens[pos2] = tokens[pos2], tokens[pos1] 134 | else: 135 | new_tokens, new_labels = tokens, labels 136 | 137 | return new_tokens, new_labels 138 | 139 | 140 | def augment_sent(self, text, op='token_del_tfidf'): 141 | """ Performs data augmentation on a classification example. 142 | 143 | Similar to augment(tokens, labels) but works for sentences 144 | or sentence-pairs. 145 | 146 | Args: 147 | text (str): the input sentence 148 | op (str, optional): a string encoding of the operator to be applied 149 | 150 | Returns: 151 | str: the augmented sentence 152 | """ 153 | # handling sentence pairs 154 | sents = text.split(' [SEP] ') 155 | text = sents[0] 156 | target_spans = sents[1:] 157 | 158 | # tokenize the sentence 159 | current = '' 160 | tokens = [] 161 | labels = [] 162 | for ch in text: 163 | if ch.isalnum(): 164 | current += ch 165 | else: 166 | if current != '': 167 | tokens.append(current) 168 | if ch not in ' \t\r\n': 169 | tokens.append(ch) 170 | current = '' 171 | if current != '': 172 | tokens.append(current) 173 | 174 | labels = ['O'] * len(tokens) 175 | for idx, span in enumerate(target_spans): 176 | span_tokens = span.split(' ') 177 | for tid in range(len(tokens)): 178 | if tid + len(span_tokens) <= len(tokens) and \ 179 | tokens[tid:tid+len(span_tokens)] == span_tokens: 180 | for i in range(tid, tid+len(span_tokens)): 181 | labels[i] = 'AS%d' % idx 182 | 183 | # print(tokens) 184 | # print(labels) 185 | # only augment the original sentence 186 | tokens, labels = self.augment(tokens, labels, op=op) 187 | 188 | # check consistency 189 | tid = 0 190 | while tid < len(tokens): 191 | if labels[tid][:2] == 'AS': 192 | new_span = tokens[tid] 193 | idx = int(labels[tid][2:]) 194 | while tid + 1 < len(tokens) and \ 195 | labels[tid + 1] == labels[tid]: 196 | tid += 1 197 | new_span += ' ' + tokens[tid] 198 | if target_spans[idx] != new_span: 199 | target_spans[idx] = new_span 200 | tid += 1 201 | 202 | # error handling 203 | results = ' '.join(tokens) 204 | for span in target_spans: 205 | results += ' [SEP] ' + span 206 | return results 207 | 208 | 209 | def sample_position(self, tokens, labels, tfidf=False): 210 | """ Randomly sample a token's position from a training example 211 | 212 | When tfidf is turned on, the weight of each token is proportional 213 | to MaxTfIdf - Tfidf of each token. When it is off, the sampling is uniform. 214 | Only tokens with 'O' labels and at least 1 position away from a non 'O' 215 | labels will be sampled. 216 | 217 | Args: 218 | tokens (list of strings): the input tokens 219 | labels (list of strings): the labels of the tokens 220 | tfidf (bool, optional): whether the sampled position is by tfidf 221 | 222 | Returns: 223 | int: the sampled position (-1 if no such position) 224 | """ 225 | index = self.index['token'] 226 | candidates = [] 227 | for idx, token in enumerate(tokens): 228 | if labels[idx] == 'O' and \ 229 | token in index and \ 230 | index[token]['similar_words'] != None and \ 231 | len(index[token]['similar_words']) > 0: 232 | candidates.append(idx) 233 | # if token.lower() in index and \ 234 | # labels[idx] == 'O' and \ 235 | # (idx + 1 >= len(tokens) or labels[idx + 1] == 'O') and \ 236 | # (idx - 1 < 0 or labels[idx - 1] == 'O'): 237 | # candidates.append(idx) 238 | 239 | if len(candidates) <= 0: 240 | return -1 241 | if tfidf: 242 | weight = {} 243 | max_weight = 0.0 244 | for idx, token in enumerate(tokens): 245 | # token = token.lower() 246 | if token not in index: 247 | continue 248 | if token not in weight: 249 | weight[token] = 0.0 250 | weight[token] += index[token]['idf'] 251 | max_weight = max(max_weight, weight[token]) 252 | 253 | weights = [] 254 | for idx in candidates: 255 | weights.append(max_weight - weight[tokens[idx]] + 1e-6) 256 | # weights.append(max_weight - weight[tokens[idx].lower()] + 1e-6) 257 | weights = np.array(weights) / sum(weights) 258 | 259 | return np.random.choice(candidates, 1, p=weights)[0] 260 | else: 261 | return random.choice(candidates) 262 | 263 | def sample_token(self, token, same_length=True, max_candidates=10): 264 | """ Randomly sample a token's similar token stored in the index 265 | 266 | Args: 267 | token (str): the input token 268 | same_length (bool, optional): whether the return token should have the same 269 | length in BERT 270 | max_candidates (int, optional): the maximal number of candidates 271 | to be sampled 272 | 273 | Returns: 274 | str: the sampled token (unchanged if the input is not in index) 275 | """ 276 | # token = token.lower() 277 | index = self.index['token'] 278 | if token in index and index[token]['similar_words'] != None: 279 | if same_length: 280 | bert_length = index[token]['bert_length'] 281 | candidates = [] 282 | for ts, bl in zip(index[token]['similar_words'], 283 | index[token]['similar_words_length']): 284 | t, _ = ts 285 | if bl == bert_length: 286 | candidates.append(t) 287 | if len(candidates) >= max_candidates: 288 | break 289 | else: 290 | candidates = [t for t, _ in \ 291 | index[token]['similar_words'][:max_candidates]] 292 | if len(candidates) <= 0: 293 | return token 294 | else: 295 | return random.choice(candidates) 296 | else: 297 | return token 298 | 299 | def sample_span_position(self, tokens, labels): 300 | """ Uniformly sample a span from a training example and return its positions. 301 | 302 | The output is a pair (start_op, end_op) of the span. 303 | 304 | Args: 305 | tokens (list of strings): the input tokens 306 | labels (list of strings): the labels of the tokens 307 | 308 | Returns: 309 | int: the start position (-1 if no available span) 310 | int: the ending position (-1 if no available span) 311 | """ 312 | index = self.index['span'] 313 | candidates = [] 314 | idx = 0 315 | while idx < len(tokens): 316 | if labels[idx] != 'O': 317 | start = idx 318 | while idx + 1 < len(tokens) and \ 319 | labels[idx + 1][1:] == labels[idx][1:]: 320 | idx += 1 321 | end = idx 322 | 323 | span = ' '.join(tokens[start:end+1]) # .lower() 324 | if span in index['aspect'] or \ 325 | span in index['opinion']: 326 | candidates.append((start, end)) 327 | idx += 1 328 | 329 | if len(candidates) > 0: 330 | return random.choice(candidates) 331 | else: 332 | return (-1, -1) 333 | 334 | 335 | if __name__ == '__main__': 336 | from transformers import BertTokenizer 337 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 338 | 339 | ag = Augmenter('augment/laptop_index.json', []) 340 | tokens = 'this is a great drawback desktop keyboard'.split(' ') 341 | labels = 'O O O O O B-AS I-AS'.split(' ') 342 | for op in ['token_del_tfidf', 343 | 'token_del', 344 | 'token_repl_tfidf', 345 | 'token_repl', 346 | 'token_swap', 347 | 'token_ins', 348 | 'span_sim', 349 | 'span_freq', 350 | 'span']: 351 | print(op) 352 | result = ag.augment(tokens, labels, op=op) 353 | result = ' '.join(result[0]) 354 | print(result, len(tokenizer.encode(result))) 355 | original = ' '.join(tokens) 356 | print(original, len(tokenizer.encode(original))) 357 | 358 | tokens = 'I liked the macbook desktop keyboard . It is very good .'.split(' ') 359 | labels = 'O O O O B-AS I-AS O O O O O O'.split(' ') 360 | for op in ['span']: 361 | print(op) 362 | print(ag.augment(tokens, labels, op=op)) 363 | 364 | 365 | ag = Augmenter('augment/rest_index_asc.json', []) 366 | text = 'I liked the beef [SEP] beef' 367 | for op in ['token_del_tfidf', 368 | 'token_del', 369 | 'token_repl_tfidf', 370 | 'token_repl', 371 | 'token_swap', 372 | 'token_ins', 373 | 'span_sim', 374 | 'span_freq', 375 | 'span']: 376 | print(op) 377 | print(ag.augment_sent(text, op=op)) 378 | 379 | ag = Augmenter('augment/laptop_index_asc.json', []) 380 | text = 'I liked the desktop keyboard [SEP] desktop keyboard' 381 | for op in ['token_del_tfidf', 382 | 'token_del', 383 | 'token_repl_tfidf', 384 | 'token_repl', 385 | 'token_swap', 386 | 'token_ins', 387 | 'span_sim', 388 | 'span_freq', 389 | 'span']: 390 | print(op) 391 | print(ag.augment_sent(text, op=op)) 392 | -------------------------------------------------------------------------------- /snippext/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import os 6 | import numpy as np 7 | import argparse 8 | import json 9 | 10 | from torch.utils import data 11 | from .model import MultiTaskNet 12 | from .dataset import * 13 | from .train_util import * 14 | from tensorboardX import SummaryWriter 15 | from transformers import AdamW, get_linear_schedule_with_warmup 16 | from apex import amp 17 | 18 | def train(model, train_set, optimizer, scheduler=None, batch_size=32, fp16=False): 19 | """Perfrom one epoch of the training process. 20 | 21 | Args: 22 | model (MultiTaskNet): the current model state 23 | train_set (SnippextDataset): the training dataset 24 | optimizer: the optimizer for training (e.g., Adam) 25 | batch_size (int, optional): the batch size 26 | fp16 (boolean): whether to use fp16 27 | 28 | Returns: 29 | None 30 | """ 31 | iterator = data.DataLoader(dataset=train_set, 32 | batch_size=batch_size, 33 | shuffle=True, 34 | num_workers=1, 35 | collate_fn=SnippextDataset.pad) 36 | 37 | tagging_criterion = nn.CrossEntropyLoss(ignore_index=0) 38 | classifier_criterion = nn.CrossEntropyLoss() 39 | regression_criterion = nn.MSELoss() 40 | 41 | model.train() 42 | for i, batch in enumerate(iterator): 43 | # for monitoring 44 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 45 | taskname = taskname[0] 46 | _y = y 47 | 48 | if 'tagging' in taskname: 49 | criterion = tagging_criterion 50 | elif 'sts-b' in taskname: 51 | criterion = regression_criterion 52 | else: 53 | criterion = classifier_criterion 54 | 55 | # forward 56 | optimizer.zero_grad() 57 | logits, y, _ = model(x, y, task=taskname) 58 | if 'sts-b' in taskname: 59 | logits = logits.view(-1) 60 | else: 61 | logits = logits.view(-1, logits.shape[-1]) 62 | y = y.view(-1) 63 | loss = criterion(logits, y) 64 | 65 | # back propagation 66 | if fp16: 67 | with amp.scale_loss(loss, optimizer) as scaled_loss: 68 | scaled_loss.backward() 69 | else: 70 | loss.backward() 71 | optimizer.step() 72 | if scheduler: 73 | scheduler.step() 74 | 75 | if i == 0: 76 | print("=====sanity check======") 77 | print("words:", words[0]) 78 | print("x:", x.cpu().numpy()[0][:seqlens[0]]) 79 | print("tokens:", get_tokenizer().convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]]) 80 | print("is_heads:", is_heads[0]) 81 | y_sample = _y.cpu().numpy()[0] 82 | if np.isscalar(y_sample): 83 | print("y:", y_sample) 84 | else: 85 | print("y:", y_sample[:seqlens[0]]) 86 | print("tags:", tags[0]) 87 | print("mask:", mask[0]) 88 | print("seqlen:", seqlens[0]) 89 | print("task_name:", taskname) 90 | print("=======================") 91 | 92 | if i%10 == 0: # monitoring 93 | print(f"step: {i}, task: {taskname}, loss: {loss.item()}") 94 | del loss 95 | 96 | def initialize_and_train(task_config, 97 | trainset, 98 | validset, 99 | testset, 100 | hp, 101 | run_tag): 102 | """The train process. 103 | 104 | Args: 105 | task_config (dictionary): the configuration of the task 106 | trainset (SnippextDataset): the training set 107 | validset (SnippextDataset): the validation set 108 | testset (SnippextDataset): the testset 109 | hp (Namespace): the parsed hyperparameters 110 | run_tag (string): the tag of the run (for logging purpose) 111 | 112 | Returns: 113 | None 114 | """ 115 | # create iterators for validation and test 116 | padder = SnippextDataset.pad 117 | valid_iter = data.DataLoader(dataset=validset, 118 | batch_size=hp.batch_size * 4, 119 | shuffle=False, 120 | num_workers=0, 121 | collate_fn=padder) 122 | test_iter = data.DataLoader(dataset=testset, 123 | batch_size=hp.batch_size * 4, 124 | shuffle=False, 125 | num_workers=0, 126 | collate_fn=padder) 127 | 128 | # initialize model 129 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 130 | model = MultiTaskNet([task_config], 131 | device, 132 | hp.finetuning, 133 | lm=hp.lm, 134 | bert_path=hp.bert_path) 135 | if device == 'cpu': 136 | optimizer = AdamW(model.parameters(), lr=hp.lr) 137 | else: 138 | model = model.cuda() 139 | optimizer = AdamW(model.parameters(), lr=hp.lr) 140 | if hp.fp16: 141 | model, optimizer = amp.initialize(model, optimizer, opt_level='O2') 142 | 143 | # learning rate scheduler 144 | num_steps = (len(trainset) // hp.batch_size) * hp.n_epochs 145 | scheduler = get_linear_schedule_with_warmup(optimizer, 146 | num_warmup_steps=num_steps // 10, 147 | num_training_steps=num_steps) 148 | 149 | # create logging directory 150 | if not os.path.exists(hp.logdir): 151 | os.makedirs(hp.logdir) 152 | writer = SummaryWriter(log_dir=hp.logdir) 153 | 154 | # start training 155 | best_dev_f1 = best_test_f1 = 0.0 156 | epoch = 1 157 | while epoch <= hp.n_epochs: 158 | train(model, 159 | trainset, 160 | optimizer, 161 | scheduler=scheduler, 162 | batch_size=hp.batch_size, 163 | fp16=hp.fp16) 164 | 165 | print(f"=========eval at epoch={epoch}=========") 166 | dev_f1, test_f1 = eval_on_task(epoch, 167 | model, 168 | task_config['name'], 169 | valid_iter, 170 | validset, 171 | test_iter, 172 | testset, 173 | writer, 174 | run_tag) 175 | 176 | if dev_f1 > 1e-6: 177 | epoch += 1 178 | if hp.save_model: 179 | if dev_f1 > best_dev_f1: 180 | best_dev_f1 = dev_f1 181 | torch.save(model.state_dict(), run_tag + '_dev.pt') 182 | if test_f1 > best_test_f1: 183 | best_test_f1 = test_f1 184 | torch.save(model.state_dict(), run_tag + '_test.pt') 185 | 186 | writer.close() 187 | 188 | -------------------------------------------------------------------------------- /snippext/conlleval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script applies to IOB2 or IOBES tagging scheme. 3 | If you are using a different scheme, please convert to IOB2 or IOBES. 4 | 5 | IOB2: 6 | - B = begin, 7 | - I = inside but not the first, 8 | - O = outside 9 | 10 | e.g. 11 | John lives in New York City . 12 | B-PER O O B-LOC I-LOC I-LOC O 13 | 14 | IOBES: 15 | - B = begin, 16 | - E = end, 17 | - S = singleton, 18 | - I = inside but not the first or the last, 19 | - O = outside 20 | 21 | e.g. 22 | John lives in New York City . 23 | S-PER O O B-LOC I-LOC E-LOC O 24 | 25 | prefix: IOBES 26 | chunk_type: PER, LOC, etc. 27 | """ 28 | from __future__ import division, print_function, unicode_literals 29 | 30 | import sys 31 | from collections import defaultdict 32 | 33 | def split_tag(chunk_tag): 34 | """ 35 | split chunk tag into IOBES prefix and chunk_type 36 | e.g. 37 | B-PER -> (B, PER) 38 | O -> (O, None) 39 | """ 40 | if chunk_tag == 'O': 41 | return ('O', None) 42 | return chunk_tag.split('-', maxsplit=1) 43 | 44 | def is_chunk_end(prev_tag, tag): 45 | """ 46 | check if the previous chunk ended between the previous and current word 47 | e.g. 48 | (B-PER, I-PER) -> False 49 | (B-LOC, O) -> True 50 | 51 | Note: in case of contradicting tags, e.g. (B-PER, I-LOC) 52 | this is considered as (B-PER, B-LOC) 53 | """ 54 | prefix1, chunk_type1 = split_tag(prev_tag) 55 | prefix2, chunk_type2 = split_tag(tag) 56 | 57 | if prefix1 == 'O': 58 | return False 59 | if prefix2 == 'O': 60 | return prefix1 != 'O' 61 | 62 | if chunk_type1 != chunk_type2: 63 | return True 64 | 65 | return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S'] 66 | 67 | def is_chunk_start(prev_tag, tag): 68 | """ 69 | check if a new chunk started between the previous and current word 70 | """ 71 | prefix1, chunk_type1 = split_tag(prev_tag) 72 | prefix2, chunk_type2 = split_tag(tag) 73 | 74 | if prefix2 == 'O': 75 | return False 76 | if prefix1 == 'O': 77 | return prefix2 != 'O' 78 | 79 | if chunk_type1 != chunk_type2: 80 | return True 81 | 82 | return prefix2 in ['B', 'S'] or prefix1 in ['E', 'S'] 83 | 84 | 85 | def calc_metrics(tp, p, t, percent=True): 86 | """ 87 | compute overall precision, recall and FB1 (default values are 0.0) 88 | if percent is True, return 100 * original decimal value 89 | """ 90 | precision = tp / p if p else 0 91 | recall = tp / t if t else 0 92 | fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0 93 | if percent: 94 | return 100 * precision, 100 * recall, 100 * fb1 95 | else: 96 | return precision, recall, fb1 97 | 98 | 99 | def count_chunks(true_seqs, pred_seqs): 100 | """ 101 | true_seqs: a list of true tags 102 | pred_seqs: a list of predicted tags 103 | 104 | return: 105 | correct_chunks: a dict (counter), 106 | key = chunk types, 107 | value = number of correctly identified chunks per type 108 | true_chunks: a dict, number of true chunks per type 109 | pred_chunks: a dict, number of identified chunks per type 110 | 111 | correct_counts, true_counts, pred_counts: similar to above, but for tags 112 | """ 113 | correct_chunks = defaultdict(int) 114 | true_chunks = defaultdict(int) 115 | pred_chunks = defaultdict(int) 116 | 117 | correct_counts = defaultdict(int) 118 | true_counts = defaultdict(int) 119 | pred_counts = defaultdict(int) 120 | 121 | prev_true_tag, prev_pred_tag = 'O', 'O' 122 | correct_chunk = None 123 | 124 | for true_tag, pred_tag in zip(true_seqs, pred_seqs): 125 | if true_tag == pred_tag: 126 | correct_counts[true_tag] += 1 127 | true_counts[true_tag] += 1 128 | pred_counts[pred_tag] += 1 129 | 130 | _, true_type = split_tag(true_tag) 131 | _, pred_type = split_tag(pred_tag) 132 | 133 | if correct_chunk is not None: 134 | true_end = is_chunk_end(prev_true_tag, true_tag) 135 | pred_end = is_chunk_end(prev_pred_tag, pred_tag) 136 | 137 | if pred_end and true_end: 138 | correct_chunks[correct_chunk] += 1 139 | correct_chunk = None 140 | elif pred_end != true_end or true_type != pred_type: 141 | correct_chunk = None 142 | 143 | true_start = is_chunk_start(prev_true_tag, true_tag) 144 | pred_start = is_chunk_start(prev_pred_tag, pred_tag) 145 | 146 | if true_start and pred_start and true_type == pred_type: 147 | correct_chunk = true_type 148 | if true_start: 149 | true_chunks[true_type] += 1 150 | if pred_start: 151 | pred_chunks[pred_type] += 1 152 | 153 | prev_true_tag, prev_pred_tag = true_tag, pred_tag 154 | if correct_chunk is not None: 155 | correct_chunks[correct_chunk] += 1 156 | 157 | return (correct_chunks, true_chunks, pred_chunks, 158 | correct_counts, true_counts, pred_counts) 159 | 160 | def get_result(correct_chunks, true_chunks, pred_chunks, 161 | correct_counts, true_counts, pred_counts, verbose=True): 162 | """ 163 | if verbose, print overall performance, as well as preformance per chunk type; 164 | otherwise, simply return overall prec, rec, f1 scores 165 | """ 166 | # sum counts 167 | sum_correct_chunks = sum(correct_chunks.values()) 168 | sum_true_chunks = sum(true_chunks.values()) 169 | sum_pred_chunks = sum(pred_chunks.values()) 170 | 171 | sum_correct_counts = sum(correct_counts.values()) 172 | sum_true_counts = sum(true_counts.values()) 173 | 174 | nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != 'O') 175 | nonO_true_counts = sum(v for k, v in true_counts.items() if k != 'O') 176 | 177 | chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks)))) 178 | 179 | # compute overall precision, recall and FB1 (default values are 0.0) 180 | prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks) 181 | res = (prec, rec, f1) 182 | if not verbose: 183 | return res 184 | 185 | # print overall performance, and performance per chunk type 186 | 187 | print("processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), end='') 188 | print("found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), end='') 189 | 190 | print("accuracy: %6.2f%%; (non-O)" % (100*nonO_correct_counts/nonO_true_counts)) 191 | print("accuracy: %6.2f%%; " % (100*sum_correct_counts/sum_true_counts), end='') 192 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1)) 193 | 194 | # for each chunk type, compute precision, recall and FB1 (default values are 0.0) 195 | for t in chunk_types: 196 | prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t]) 197 | print("%17s: " %t , end='') 198 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % 199 | (prec, rec, f1), end='') 200 | print(" %d" % pred_chunks[t]) 201 | 202 | return res 203 | # you can generate LaTeX output for tables like in 204 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 205 | # but I'm not implementing this 206 | 207 | def evaluate(true_seqs, pred_seqs, verbose=True): 208 | (correct_chunks, true_chunks, pred_chunks, 209 | correct_counts, true_counts, pred_counts) = count_chunks(true_seqs, pred_seqs) 210 | result = get_result(correct_chunks, true_chunks, pred_chunks, 211 | correct_counts, true_counts, pred_counts, verbose=verbose) 212 | return result 213 | 214 | def evaluate_conll_file(fileIterator): 215 | true_seqs, pred_seqs = [], [] 216 | 217 | for line in fileIterator: 218 | cols = line.strip().split() 219 | # each non-empty line must contain >= 3 columns 220 | if not cols: 221 | true_seqs.append('O') 222 | pred_seqs.append('O') 223 | elif len(cols) < 3: 224 | raise IOError("conlleval: too few columns in line %s\n" % line) 225 | else: 226 | # extract tags from last 2 columns 227 | true_seqs.append(cols[-2]) 228 | pred_seqs.append(cols[-1]) 229 | return evaluate(true_seqs, pred_seqs) 230 | 231 | if __name__ == '__main__': 232 | """ 233 | usage: conlleval < file 234 | """ 235 | evaluate_conll_file(sys.stdin) 236 | -------------------------------------------------------------------------------- /snippext/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | import jsonlines 5 | 6 | from torch.utils import data 7 | from .augment import Augmenter 8 | 9 | tokenizer = None 10 | 11 | def get_tokenizer(lm='bert'): 12 | """Return the tokenizer. Intiailize it if not initialized. 13 | 14 | Args: 15 | lm (string, optional): the name of the language model 16 | (bert, albert, roberta, distilbert, etc.) 17 | 18 | Returns: 19 | Tokenizer: the tokenizer to be used 20 | """ 21 | global tokenizer 22 | if tokenizer is None: 23 | if lm == 'bert': 24 | from transformers import BertTokenizer 25 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 26 | elif lm == 'distilbert': 27 | from transformers import DistilBertTokenizer 28 | tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased') 29 | elif lm == 'albert': 30 | from transformers import AlbertTokenizer 31 | tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') 32 | elif lm == 'roberta': 33 | from transformers import RobertaTokenizer 34 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 35 | elif lm == 'xlnet': 36 | from transformers import XLNetTokenizer 37 | tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased') 38 | elif lm == 'longformer': 39 | from transformers import LongformerTokenizer 40 | tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') 41 | return tokenizer 42 | 43 | 44 | class SnippextDataset(data.Dataset): 45 | def __init__(self, 46 | source, 47 | vocab, 48 | taskname, 49 | max_len=512, 50 | lm='bert', 51 | augment_index=None, 52 | augment_op=None, 53 | size=None): 54 | """ TODO 55 | Args: 56 | 57 | """ 58 | # tokens and tags 59 | sents, tags_li = [], [] # list of lists 60 | self.max_len = max_len 61 | get_tokenizer(lm) 62 | 63 | if type(source) is str: 64 | # read from file (for training/evaluation) 65 | if '_tagging' in taskname or '_qa' in taskname: 66 | sents, tags_li = self.read_tagging_file(source) 67 | else: 68 | sents, tags_li = self.read_classification_file(source) 69 | if size is not None: 70 | sents, tags_li = sents[:size], tags_li[:size] 71 | else: 72 | # read from list of tokens (for prediction) 73 | if '_tagging' in taskname or '_qa' in taskname: 74 | for tokens in source: 75 | sents.append(["[CLS]"] + [token for token in tokens] + ["[SEP]"]) 76 | tags_li.append([""] + ['O' for token in tokens] + [""]) 77 | else: 78 | for sent in source: 79 | sents.append(sent) 80 | tags_li.append(vocab[0]) 81 | 82 | # handling QA datasets. Mark the question tokens with so that 83 | # the model does not predict those tokens. 84 | if '_qa' in taskname: 85 | for tokens, labels in zip(sents, tags_li): 86 | if "[SEP]" in tokens[:-1]: 87 | for i, token in enumerate(tokens): 88 | labels[i] = "" 89 | if token == "[SEP]": 90 | break 91 | 92 | # assign class variables 93 | self.sents, self.tags_li = sents, tags_li 94 | self.vocab = vocab 95 | 96 | # add special tags for tagging 97 | if '_tagging' in taskname: 98 | if 'O' not in self.vocab: 99 | self.vocab.append('O') 100 | if self.vocab[0] != '': 101 | self.vocab.insert(0, '') 102 | 103 | # index for tags/labels 104 | self.tag2idx = {tag: idx for idx, tag in enumerate(self.vocab)} 105 | self.idx2tag = {idx: tag for idx, tag in enumerate(self.vocab)} 106 | self.taskname = taskname 107 | 108 | # augmentation index and op 109 | self.augment_op = augment_op 110 | if augment_op == 't5': 111 | self.load_t5_examples(source) 112 | elif augment_index != None: 113 | self.augmenter = Augmenter(augment_index) 114 | else: 115 | self.augmenter = None 116 | self.augment_op = None 117 | 118 | 119 | def load_t5_examples(self, source): 120 | self.augmenter = None 121 | # read augmented examples 122 | self.augmented_examples = [] 123 | if '_tagging' in self.taskname: 124 | with jsonlines.open(source + '.augment.jsonl', mode='r') as reader: 125 | for row in reader: 126 | exms = [] 127 | for entry in row['augment']: 128 | tokens, labels = self.read_tagging_file(entry, is_file=False) 129 | exms.append((tokens[0], labels[0])) 130 | self.augmented_examples.append(exms) 131 | else: 132 | with jsonlines.open(source + '.augment.jsonl', mode='r') as reader: 133 | for row in reader: 134 | exms = [] 135 | label = row['label'] 136 | for entry in row['augment']: 137 | sent = ' [SEP] '.join(entry.split('\t')) 138 | exms.append((sent, label)) 139 | self.augmented_examples.append(exms) 140 | 141 | 142 | def read_tagging_file(self, path, is_file=True): 143 | """Read a train/eval tagging dataset from file 144 | 145 | The input file should contain multiple entries separated by empty lines. 146 | The format of each entry: 147 | 148 | The O 149 | room B-AS 150 | is O 151 | very B-OP 152 | clean I-OP 153 | . O 154 | 155 | Args: 156 | path (str): the path to the dataset file 157 | 158 | Returns: 159 | list of list of str: the tokens 160 | list of list of str: the labels 161 | """ 162 | sents, tags_li = [], [] 163 | if is_file: 164 | entries = open(path, 'r').read().strip().split("\n\n") 165 | else: 166 | entries = [path.strip()] 167 | 168 | for entry in entries: 169 | try: 170 | words = [line.split()[0] for line in entry.splitlines()] 171 | tags = [line.split()[-1] for line in entry.splitlines()] 172 | sents.append(["[CLS]"] + words[:self.max_len] + ["[SEP]"]) 173 | tags_li.append([""] + tags[:self.max_len] + [""]) 174 | except: 175 | print('error @', entry) 176 | return sents, tags_li 177 | 178 | 179 | def read_classification_file(self, path): 180 | """Read a train/eval classification dataset from file 181 | 182 | The input file should contain multiple lines where each line is an example. 183 | The format of each line: 184 | The room is clean.\troom\tpositive 185 | 186 | Args: 187 | path (str): the path to the dataset file 188 | 189 | Returns: 190 | list of str: the input sequences 191 | list of str: the labels 192 | """ 193 | sents, labels = [], [] 194 | lines = open(path).readlines() 195 | for line in lines: 196 | items = line.strip().split('\t') 197 | # only consider sentence and sentence pairs 198 | if len(items) < 2 or len(items) > 3: 199 | continue 200 | try: 201 | if len(items) == 2: 202 | sents.append(items[0]) 203 | labels.append(items[1]) 204 | else: 205 | sents.append(items[0] + ' [SEP] ' + items[1]) 206 | labels.append(items[2]) 207 | except: 208 | print('error @', line.strip()) 209 | return sents, labels 210 | 211 | 212 | def __len__(self): 213 | """Return the length of the dataset""" 214 | return len(self.sents) 215 | 216 | def get(self, idx, op=[]): 217 | ag = self.augmenter 218 | self.augmenter = None 219 | item = self.__getitem__(idx) 220 | self.augmenter = ag 221 | return item 222 | 223 | def __getitem__(self, idx): 224 | """Return the ith item of in the dataset. 225 | 226 | Args: 227 | idx (int): the element index 228 | Returns (TODO): 229 | words, x, is_heads, tags, mask, y, seqlen, self.taskname 230 | """ 231 | words, tags = self.sents[idx], self.tags_li[idx] 232 | 233 | if '_tagging' in self.taskname: 234 | # apply data augmentation if specified 235 | if self.augment_op == 't5': 236 | if len(self.augmented_examples[idx]) > 0: 237 | words, tags = random.choice(self.augmented_examples[idx]) 238 | elif self.augmenter != None: 239 | words, tags = self.augmenter.augment(words, tags, self.augment_op) 240 | 241 | # We give credits only to the first piece. 242 | x, y = [], [] # list of ids 243 | is_heads = [] # list. 1: the token is the first piece of a word 244 | 245 | for w, t in zip(words, tags): 246 | # avoid bad tokens 247 | w = w[:50] 248 | tokens = tokenizer.tokenize(w) if w not in ("[CLS]", "[SEP]") else [w] 249 | xx = tokenizer.convert_tokens_to_ids(tokens) 250 | if len(xx) == 0: 251 | continue 252 | 253 | is_head = [1] + [0]*(len(tokens) - 1) 254 | 255 | t = [t] + [""] * (len(tokens) - 1) # : no decision 256 | yy = [self.tag2idx[each] for each in t] # (T,) 257 | 258 | x.extend(xx) 259 | is_heads.extend(is_head) 260 | y.extend(yy) 261 | # make sure that the length of x is not too large 262 | if len(x) > self.max_len: 263 | break 264 | 265 | assert len(x)==len(y)==len(is_heads), \ 266 | f"len(x)={len(x)}, len(y)={len(y)}, len(is_heads)={len(is_heads)}, {' '.join(tokens)}" 267 | 268 | # seqlen 269 | seqlen = len(y) 270 | 271 | mask = [1] * seqlen 272 | # masking for QA 273 | for i, t in enumerate(tags): 274 | if t != '': 275 | break 276 | mask[i] = 0 277 | 278 | # to string 279 | words = " ".join(words) 280 | tags = " ".join(tags) 281 | else: # classification 282 | if self.augment_op == 't5': 283 | if len(self.augmented_examples[idx]) > 0: 284 | words, tags = random.choice(self.augmented_examples[idx]) 285 | elif self.augmenter != None: 286 | words = self.augmenter.augment_sent(words, self.augment_op) 287 | 288 | if ' [SEP] ' in words: 289 | sent_a, sent_b = words.split(' [SEP] ') 290 | else: 291 | sent_a, sent_b = words, None 292 | 293 | x = tokenizer.encode(sent_a, text_pair=sent_b, 294 | truncation="longest_first", 295 | max_length=self.max_len, 296 | add_special_tokens=True) 297 | 298 | y = self.tag2idx[tags] # label 299 | is_heads = [1] * len(x) 300 | mask = [1] * len(x) 301 | 302 | assert len(x)==len(mask)==len(is_heads), \ 303 | f"len(x)={len(x)}, len(y)={len(y)}, len(is_heads)={len(is_heads)}" 304 | # seqlen 305 | seqlen = len(mask) 306 | 307 | return words, x, is_heads, tags, mask, y, seqlen, self.taskname 308 | 309 | @staticmethod 310 | def pad(batch): 311 | '''Pads to the longest sample 312 | 313 | Args: 314 | batch: 315 | 316 | Returns (TODO): 317 | return words, f(x), is_heads, tags, f(mask), f(y), seqlens, name 318 | ''' 319 | f = lambda x: [sample[x] for sample in batch] 320 | g = lambda x, seqlen, val: \ 321 | [sample[x] + [val] * (seqlen - len(sample[x])) \ 322 | for sample in batch] # 0: 323 | 324 | # get maximal sequence length 325 | seqlens = f(6) 326 | maxlen = np.array(seqlens).max() 327 | 328 | # get task name 329 | name = f(7) 330 | 331 | words = f(0) 332 | x = g(1, maxlen, 0) 333 | is_heads = f(2) 334 | tags = f(3) 335 | mask = g(4, maxlen, 1) 336 | if '_tagging' in name[0]: 337 | y = g(5, maxlen, 0) 338 | else: 339 | y = f(5) 340 | 341 | f = torch.LongTensor 342 | if isinstance(y[0], float): 343 | y = torch.Tensor(y) 344 | else: 345 | y = torch.LongTensor(y) 346 | return words, f(x), is_heads, tags, f(mask), y, seqlens, name 347 | -------------------------------------------------------------------------------- /snippext/mixda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import os 6 | import numpy as np 7 | import argparse 8 | import json 9 | import copy 10 | import random 11 | 12 | from torch.utils import data 13 | from .model import MultiTaskNet 14 | from .train_util import * 15 | from .dataset import * 16 | from tensorboardX import SummaryWriter 17 | from transformers import AdamW, get_linear_schedule_with_warmup 18 | from apex import amp 19 | 20 | # criterion for tagging 21 | tagging_criterion = nn.CrossEntropyLoss(ignore_index=0) 22 | 23 | # criterion for classification 24 | classifier_criterion = nn.CrossEntropyLoss() 25 | 26 | # criterion for regression 27 | regression_criterion = nn.MSELoss() 28 | 29 | def mixda(model, batch, alpha_aug=0.4): 30 | """Perform one iteration of MixDA 31 | 32 | Args: 33 | model (MultiTaskNet): the model state 34 | batch (tuple): the input batch 35 | alpha_aug (float, Optional): the parameter for MixDA 36 | 37 | Returns: 38 | Tensor: the loss (of 0-d) 39 | """ 40 | _, x, _, _, mask, y, _, taskname = batch 41 | taskname = taskname[0] 42 | # two batches 43 | batch_size = x.size()[0] // 2 44 | 45 | # augmented 46 | aug_x = x[batch_size:] 47 | aug_y = y[batch_size:] 48 | aug_lam = np.random.beta(alpha_aug, alpha_aug) 49 | 50 | # labeled 51 | x = x[:batch_size] 52 | 53 | # back prop 54 | logits, y, _ = model(x, y, 55 | augment_batch=(aug_x, aug_lam), 56 | task=taskname) 57 | if 'sts-b' in taskname: 58 | logits = logits.view(-1) 59 | else: 60 | logits = logits.view(-1, logits.shape[-1]) 61 | 62 | aug_y = y[batch_size:] 63 | y = y[:batch_size] 64 | aug_y = y.view(-1) 65 | y = y.view(-1) 66 | 67 | # cross entropy 68 | if 'tagging' in taskname: 69 | criterion = tagging_criterion 70 | elif 'sts-b' in taskname: 71 | criterion = regression_criterion 72 | else: 73 | criterion = classifier_criterion 74 | 75 | # mix the labels 76 | loss = criterion(logits, y) * aug_lam + \ 77 | criterion(logits, aug_y) * (1 - aug_lam) 78 | 79 | return loss 80 | 81 | 82 | def create_mixda_batches(l_set, aug_set, batch_size=16): 83 | """Create batches for mixda 84 | 85 | Each batch is the concatenation of (1) a labeled batch and (2) an augmented 86 | labeled batch (having the same order of (1) ) 87 | 88 | Args: 89 | l_set (SnippextDataset): the train set 90 | aug_set (SnippextDataset): the augmented train set 91 | batch_size (int, optional): batch size (of each component) 92 | 93 | Returns: 94 | list of list: the created batches 95 | """ 96 | num_labeled = len(l_set) 97 | l_index = np.random.permutation(num_labeled) 98 | 99 | l_batch = [] 100 | l_batch_aug = [] 101 | padder = l_set.pad 102 | 103 | for i, idx in enumerate(l_index): 104 | l_batch.append(l_set[idx]) 105 | l_batch_aug.append(aug_set[idx]) 106 | 107 | if len(l_batch) == batch_size or i == len(l_index) - 1: 108 | batches = l_batch + l_batch_aug 109 | yield padder(batches) 110 | l_batch.clear() 111 | l_batch_aug.clear() 112 | 113 | if len(l_batch) > 0: 114 | batches = l_batch + l_batch_aug 115 | yield padder(batches) 116 | 117 | 118 | def train(model, l_set, aug_set, optimizer, 119 | scheduler=None, 120 | fp16=False, 121 | batch_size=32, 122 | alpha_aug=0.8): 123 | """Perform one epoch of MixDA 124 | 125 | Args: 126 | model (MultiTaskModel): the model state 127 | train_dataset (SnippextDataset): the train set 128 | augment_dataset (SnippextDataset): the augmented train set 129 | optimizer (Optimizer): Adam 130 | fp16 (boolean, Optional): whether to use fp16 131 | batch_size (int, Optional): batch size 132 | alpha_aug (float, Optional): the alpha for MixDA 133 | 134 | Returns: 135 | None 136 | """ 137 | mixda_batches = create_mixda_batches(l_set, 138 | aug_set, 139 | batch_size=batch_size) 140 | 141 | model.train() 142 | for i, batch in enumerate(mixda_batches): 143 | # for monitoring 144 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 145 | taskname = taskname[0] 146 | _y = y 147 | 148 | # perform mixmatch 149 | optimizer.zero_grad() 150 | loss = mixda(model, batch, alpha_aug) 151 | if fp16: 152 | with amp.scale_loss(loss, optimizer) as scaled_loss: 153 | scaled_loss.backward() 154 | else: 155 | loss.backward() 156 | optimizer.step() 157 | if scheduler: 158 | scheduler.step() 159 | 160 | if i == 0: 161 | print("=====sanity check======") 162 | print("words:", words[0]) 163 | print("x:", x.cpu().numpy()[0][:seqlens[0]]) 164 | print("tokens:", get_tokenizer().convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]]) 165 | print("is_heads:", is_heads[0]) 166 | y_sample = _y.cpu().numpy()[0] 167 | if np.isscalar(y_sample): 168 | print("y:", y_sample) 169 | else: 170 | print("y:", y_sample[:seqlens[0]]) 171 | print("tags:", tags[0]) 172 | print("mask:", mask[0]) 173 | print("seqlen:", seqlens[0]) 174 | print("task_name:", taskname) 175 | print("=======================") 176 | 177 | if i%10 == 0: # monitoring 178 | print(f"step: {i}, task: {taskname}, loss: {loss.item()}") 179 | del loss 180 | 181 | 182 | 183 | def initialize_and_train(task_config, 184 | trainset, 185 | augmentset, 186 | validset, 187 | testset, 188 | hp, 189 | run_tag): 190 | """The train process. 191 | 192 | Args: 193 | task_config (dictionary): the configuration of the task 194 | trainset (SnippextDataset): the training set 195 | augmentset (SnippextDataset): the augmented training set 196 | validset (SnippextDataset): the validation set 197 | testset (SnippextDataset): the testset 198 | hp (Namespace): the parsed hyperparameters 199 | run_tag (string): the tag of the run (for logging purpose) 200 | 201 | Returns: 202 | None 203 | """ 204 | padder = SnippextDataset.pad 205 | 206 | # iterators for dev/test set 207 | valid_iter = data.DataLoader(dataset=validset, 208 | batch_size=hp.batch_size * 4, 209 | shuffle=False, 210 | num_workers=0, 211 | collate_fn=padder) 212 | test_iter = data.DataLoader(dataset=testset, 213 | batch_size=hp.batch_size * 4, 214 | shuffle=False, 215 | num_workers=0, 216 | collate_fn=padder) 217 | 218 | 219 | # initialize model 220 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 221 | if device == 'cpu': 222 | model = MultiTaskNet([task_config], device, 223 | hp.finetuning, lm=hp.lm, bert_path=hp.bert_path) 224 | optimizer = AdamW(model.parameters(), lr=hp.lr) 225 | else: 226 | model = MultiTaskNet([task_config], device, 227 | hp.finetuning, lm=hp.lm, bert_path=hp.bert_path).cuda() 228 | optimizer = AdamW(model.parameters(), lr=hp.lr) 229 | if hp.fp16: 230 | model, optimizer = amp.initialize(model, optimizer, opt_level='O2') 231 | 232 | # learning rate scheduler 233 | num_steps = (len(trainset) // hp.batch_size) * hp.n_epochs 234 | scheduler = get_linear_schedule_with_warmup(optimizer, 235 | num_warmup_steps=num_steps // 10, 236 | num_training_steps=num_steps) 237 | # create logging 238 | if not os.path.exists(hp.logdir): 239 | os.makedirs(hp.logdir) 240 | writer = SummaryWriter(log_dir=hp.logdir) 241 | 242 | # start training 243 | best_dev_f1 = best_test_f1 = 0.0 244 | epoch = 1 245 | while epoch <= hp.n_epochs: 246 | train(model, 247 | trainset, 248 | augmentset, 249 | optimizer, 250 | scheduler=scheduler, 251 | fp16=hp.fp16, 252 | batch_size=hp.batch_size, 253 | alpha_aug=hp.alpha_aug) 254 | 255 | print(f"=========eval at epoch={epoch}=========") 256 | dev_f1, test_f1 = eval_on_task(epoch, 257 | model, 258 | task_config['name'], 259 | valid_iter, 260 | validset, 261 | test_iter, 262 | testset, 263 | writer, 264 | run_tag) 265 | 266 | # skip the epochs with zero f1 267 | if dev_f1 > 1e-6: 268 | epoch += 1 269 | if hp.save_model: 270 | if dev_f1 > best_dev_f1: 271 | best_dev_f1 = dev_f1 272 | torch.save(model.state_dict(), run_tag + '_dev.pt') 273 | if test_f1 > best_test_f1: 274 | best_test_f1 = test_f1 275 | torch.save(model.state_dict(), run_tag + '_test.pt') 276 | 277 | writer.close() 278 | 279 | -------------------------------------------------------------------------------- /snippext/mixmatchnl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import os 6 | import numpy as np 7 | import argparse 8 | import json 9 | import copy 10 | import random 11 | 12 | from torch.utils import data 13 | from tensorboardX import SummaryWriter 14 | from transformers import AdamW, get_linear_schedule_with_warmup 15 | from apex import amp 16 | 17 | from .model import MultiTaskNet 18 | from .train_util import * 19 | from .dataset import * 20 | 21 | def tagging_criterion(pred_labeled, y_labeled): 22 | """The loss function for tagging task (with float tensor input) 23 | 24 | Args: 25 | pred_labeled (Tensor): the predicted float tensor 26 | y_labeled (Tensor): the groundtruth float tensor 27 | 28 | Returns: 29 | Tensor: the cross entropy loss with the 0'th dimention ignored 30 | """ 31 | # cross-entropy, ignore the 0 class 32 | loss_x = torch.sum(-y_labeled[:, 1:] * pred_labeled[:, 1:].log(), -1).mean() 33 | return loss_x 34 | 35 | def classifier_criterion(pred_labeled, y_labeled): 36 | """The loss function for classification task (with float tensor input) 37 | 38 | Args: 39 | pred_labeled (Tensor): the predicted float tensor 40 | y_labeled (Tensor): the groundtruth float tensor 41 | 42 | Returns: 43 | Tensor: the cross entropy loss 44 | """ 45 | loss_x = torch.sum(-y_labeled * pred_labeled.log(), -1).mean() 46 | return loss_x 47 | 48 | 49 | def mixmatch(model, batch, num_aug=2, alpha=0.4, alpha_aug=0.4, u_lambda=0.5): 50 | """Perform one iteration of MixMatchNL 51 | 52 | Args: 53 | model (MultiTaskNet): the model state 54 | batch (tuple): the input batch 55 | num_aug (int, Optional): the number of augmented examples in the batch 56 | alpha (float, Optional): the parameter for MixUp 57 | alpha_aug (float, Optional): the parameter for MixDA 58 | u_lambda (float, Optional): the parameter controlling 59 | the weight of unlabeled data 60 | 61 | Returns: 62 | Tensor: the loss (of 0-d) 63 | """ 64 | _, x, _, _, mask, y, _, taskname = batch 65 | taskname = taskname[0] 66 | # two batches of labeled and two batches of unlabeled 67 | batch_size = x.size()[0] // (num_aug + 3) 68 | 69 | y = y[:batch_size] 70 | 71 | # the unlabeled half 72 | u0 = x[batch_size:2*batch_size] 73 | 74 | # augmented 75 | aug_x = x[2*batch_size:3*batch_size] 76 | 77 | # augmented unlabeled 78 | u_augs = [] 79 | for uid in range(num_aug): 80 | u_augs.append(x[(3+uid)*batch_size:(4+uid)*batch_size]) 81 | 82 | # labeled + original unlabeled 83 | x = torch.cat((x[:batch_size], x[3*batch_size:])) 84 | 85 | # label guessing 86 | model.eval() 87 | u_guesses = [] 88 | u_aug_enc_list = [] 89 | _, _, _, u_enc = model(u0, y, 90 | task=taskname, get_enc=True) 91 | 92 | for x_u in u_augs: 93 | if alpha_aug <= 0: 94 | u_aug_lam = 1.0 95 | else: 96 | u_aug_lam = np.random.beta(alpha_aug, alpha_aug) 97 | 98 | # it is fine to switch the order of x_u and u0 in this case 99 | u_logits, y, _, u_aug_enc = model(x_u, y, 100 | augment_batch=(u0, u_aug_lam), 101 | aug_enc=u_enc, 102 | task=taskname, 103 | get_enc=True) 104 | # softmax 105 | u_guess = F.softmax(u_logits, dim=-1) 106 | u_guess = u_guess.detach() 107 | u_guesses.append(u_guess) 108 | 109 | # save u_aug_enc 110 | u_aug_enc_list.append(u_aug_enc) 111 | 112 | # averaging 113 | u_guess = sum(u_guesses) / len(u_guesses) 114 | 115 | # temperature sharpening 116 | T = 0.5 117 | u_power = u_guess.pow(1/T) 118 | u_guess = u_power / u_power.sum(dim=-1, keepdim=True) 119 | 120 | # make duplicate of u_guess 121 | if len(u_guess.size()) == 2: 122 | u_guess = u_guess.repeat(num_aug, 1) 123 | else: 124 | u_guess = u_guess.repeat(num_aug, 1, 1) 125 | 126 | vocab = u_guess.shape[-1] 127 | # switch back to training mode 128 | model.train() 129 | 130 | # shuffle 131 | index = torch.randperm(batch_size + u_guess.size()[0]) 132 | lam = np.random.beta(alpha, alpha) 133 | lam = max(lam, 1.0 - lam) 134 | 135 | # convert y to one-hot 136 | y_onehot = F.one_hot(y, vocab).float() 137 | y_concat = torch.cat((y_onehot, u_guess)) 138 | y_mixed = y_concat[index, :] 139 | 140 | # x_aug_enc 141 | _, _, _, x_enc = model(x[:batch_size], y, 142 | task=taskname, 143 | get_enc=True) 144 | # concatenate the augmented encodings 145 | x_enc = torch.cat([x_enc] + u_aug_enc_list) 146 | 147 | # forward 148 | if alpha_aug <= 0: 149 | aug_lam = 1.0 150 | else: 151 | aug_lam = np.random.beta(alpha_aug, alpha_aug) 152 | logits, y_concat, _ = model(x, y_concat, 153 | augment_batch=(aug_x, aug_lam), 154 | x_enc=x_enc, 155 | second_batch=(index, lam), 156 | task=taskname) 157 | logits = F.softmax(logits, dim=-1) 158 | l_pred = logits[:batch_size].view(-1, vocab) 159 | u_pred = logits[batch_size:].view(-1, vocab) 160 | 161 | # mixup y's 162 | y = lam * y_concat + (1.0 - lam) * y_mixed 163 | l_y = y[:batch_size].view(-1, vocab) 164 | u_y = y[batch_size:].view(-1, vocab) 165 | 166 | # cross entropy on label data + mse on unlabeled data 167 | if 'tagging' in taskname: 168 | loss_x = tagging_criterion(l_pred, l_y) 169 | loss_u = F.mse_loss(u_pred[:, 1:], u_y[:, 1:]) 170 | else: 171 | loss_x = classifier_criterion(l_pred, l_y) 172 | loss_u = F.mse_loss(u_pred, u_y) 173 | 174 | loss = loss_x + loss_u * u_lambda 175 | return loss 176 | 177 | 178 | # global bookkeeping variables for using the unlabeled set 179 | epoch_idx = 0 180 | u_order = [] 181 | 182 | def create_mixmatch_batches(l_set, aug_set, u_set, u_set_aug, 183 | num_aug=2, 184 | batch_size=16): 185 | """Create batches for mixmatchnl 186 | 187 | Each batch is the concatenation of (1) a labeled batch, (2) an augmented 188 | labeled batch (having the same order of (1) ), (3) an unlabeled batch, 189 | and (4) multiple augmented unlabeled batches of the same order 190 | of (3). 191 | 192 | Args: 193 | l_set (SnippextDataset): the train set 194 | aug_set (SnippextDataset): the augmented train set 195 | u_set (SnippextDataset): the unlabeled set 196 | u_set_aug (SnippextDataset): the augmented unlabeled set 197 | num_aug (int, optional): number of unlabeled augmentations to be created 198 | batch_size (int, optional): batch size (of each component) 199 | 200 | Returns: 201 | list of list: the created batches 202 | """ 203 | mixed_batches = [] 204 | num_labeled = len(l_set) 205 | l_index = np.random.permutation(num_labeled) 206 | # num_unlabeled = len(u_set) 207 | # u_index = np.random.permutation(num_unlabeled) 208 | 209 | global u_order 210 | if len(u_order) == 0: 211 | u_order = list(range(len(u_set))) 212 | random.shuffle(u_order) 213 | u_order = np.array(u_order) 214 | 215 | global epoch_idx 216 | u_index = np.random.permutation(num_labeled) + num_labeled * epoch_idx 217 | u_index %= len(u_set) 218 | u_index = u_order[u_index] 219 | epoch_idx += 1 220 | 221 | l_batch = [] 222 | l_batch_aug = [] 223 | u_batch = [] 224 | u_batch_aug = [[] for _ in range(num_aug)] 225 | padder = l_set.pad 226 | 227 | for i, idx in enumerate(l_index): 228 | u_idx = u_index[i] 229 | l_batch.append(l_set[idx]) 230 | l_batch_aug.append(aug_set[idx]) 231 | # add augmented examples of unlabeled 232 | u_batch.append(u_set[u_idx]) 233 | for uid in range(num_aug): 234 | u_batch_aug[uid].append(u_set_aug[u_idx]) 235 | 236 | if len(l_batch) == batch_size or i == len(l_index) - 1: 237 | batches = l_batch + u_batch + l_batch_aug 238 | for ub in u_batch_aug: 239 | batches += ub 240 | 241 | mixed_batches.append(padder(batches)) 242 | l_batch.clear() 243 | l_batch_aug.clear() 244 | u_batch.clear() 245 | for ub in u_batch_aug: 246 | ub.clear() 247 | random.shuffle(mixed_batches) 248 | 249 | return mixed_batches 250 | 251 | 252 | def train(model, l_set, aug_set, u_set, u_set_aug, optimizer, 253 | scheduler=None, 254 | batch_size=32, 255 | num_aug=2, 256 | alpha=0.4, 257 | alpha_aug=0.8, 258 | u_lambda=1.0, 259 | fp16=False): 260 | """Perform one epoch of MixMatchNL 261 | 262 | Args: 263 | model (MultiTaskModel): the model state 264 | train_dataset (SnippextDataset): the train set 265 | augment_dataset (SnippextDataset): the augmented train set 266 | u_dataset (SnippextDataset): the unlabeled set 267 | u_dataset_aug (SnippextDataset): the augmented unlabeled set 268 | optimizer (Optimizer): Adam 269 | scheduler (Scheduler, optional): the learning rate scheduler 270 | fp16 (boolean): whether to use fp16 271 | num_aug (int, Optional): 272 | batch_size (int, Optional): batch size 273 | alpha (float, Optional): the alpha for MixUp 274 | alpha_aug (float, Optional): the alpha for MixDA 275 | u_lambda (float, Optional): the weight of unlabeled data 276 | 277 | Returns: 278 | None 279 | """ 280 | mixed_batches = create_mixmatch_batches(l_set, 281 | aug_set, 282 | u_set, 283 | u_set_aug, 284 | num_aug=num_aug, 285 | batch_size=batch_size // 2) 286 | 287 | model.train() 288 | for i, batch in enumerate(mixed_batches): 289 | # for monitoring 290 | # print('memory:', torch.cuda.memory_allocated(), 'cached:', torch.cuda.memory_cached()) 291 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 292 | taskname = taskname[0] 293 | _y = y 294 | 295 | # perform mixmatch 296 | optimizer.zero_grad() 297 | try: 298 | loss = mixmatch(model, batch, num_aug, alpha, alpha_aug, u_lambda) 299 | if fp16: 300 | with amp.scale_loss(loss, optimizer) as scaled_loss: 301 | scaled_loss.backward() 302 | else: 303 | loss.backward() 304 | 305 | optimizer.step() 306 | if scheduler: 307 | scheduler.step() 308 | 309 | if i == 0: 310 | print("=====sanity check======") 311 | print("words:", words[0]) 312 | print("x:", x.cpu().numpy()[0][:seqlens[0]]) 313 | print("tokens:", get_tokenizer().convert_ids_to_tokens(x.cpu().numpy()[0])[:seqlens[0]]) 314 | print("is_heads:", is_heads[0]) 315 | y_sample = _y.cpu().numpy()[0] 316 | if np.isscalar(y_sample): 317 | print("y:", y_sample) 318 | else: 319 | print("y:", y_sample[:seqlens[0]]) 320 | print("tags:", tags[0]) 321 | print("mask:", mask[0]) 322 | print("seqlen:", seqlens[0]) 323 | print("task_name:", taskname) 324 | print("=======================") 325 | 326 | if i%10 == 0: # monitoring 327 | print(f"step: {i}, task: {taskname}, loss: {loss.item()}") 328 | del loss 329 | except: 330 | print("debug - seqlen:", max(seqlens)) 331 | torch.cuda.empty_cache() 332 | 333 | 334 | def initialize_and_train(task_config, 335 | trainset, 336 | augmentset, 337 | validset, 338 | testset, 339 | uset, 340 | uset_aug, 341 | hp, 342 | run_tag): 343 | """The train process. 344 | 345 | Args: 346 | task_config (dictionary): the configuration of the task 347 | trainset (SnippextDataset): the training set 348 | augmentset (SnippextDataset): the augmented training set 349 | validset (SnippextDataset): the validation set 350 | testset (SnippextDataset): the testset 351 | uset (SnippextDataset): the unlabeled dataset 352 | uset_aug (SnippextDataset): the unlabeled dataset, augmented 353 | hp (Namespace): the parsed hyperparameters 354 | run_tag (string): the tag of the run (for logging purpose) 355 | 356 | Returns: 357 | None 358 | """ 359 | padder = SnippextDataset.pad 360 | 361 | # iterators for dev/test set 362 | valid_iter = data.DataLoader(dataset=validset, 363 | batch_size=hp.batch_size * 4, 364 | shuffle=False, 365 | num_workers=0, 366 | collate_fn=padder) 367 | test_iter = data.DataLoader(dataset=testset, 368 | batch_size=hp.batch_size * 4, 369 | shuffle=False, 370 | num_workers=0, 371 | collate_fn=padder) 372 | 373 | 374 | # initialize model 375 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 376 | if device == 'cpu': 377 | model = MultiTaskNet([task_config], device, 378 | hp.finetuning, bert_path=hp.bert_path) 379 | optimizer = AdamW(model.parameters(), lr = hp.lr) 380 | else: 381 | model = MultiTaskNet([task_config], device, 382 | hp.finetuning, bert_path=hp.bert_path).cuda() 383 | optimizer = AdamW(model.parameters(), lr = hp.lr) 384 | model, optimizer = amp.initialize(model, optimizer, opt_level='O2') 385 | 386 | # learning rate scheduler 387 | num_steps = (len(trainset) // hp.batch_size * 2) * hp.n_epochs 388 | scheduler = get_linear_schedule_with_warmup(optimizer, 389 | num_warmup_steps=num_steps // 10, 390 | num_training_steps=num_steps) 391 | 392 | # create logging 393 | if not os.path.exists(hp.logdir): 394 | os.makedirs(hp.logdir) 395 | writer = SummaryWriter(log_dir=hp.logdir) 396 | 397 | # start training 398 | best_dev_f1 = best_test_f1 = 0.0 399 | for epoch in range(1, hp.n_epochs+1): 400 | train(model, 401 | trainset, 402 | augmentset, 403 | uset, 404 | uset_aug, 405 | optimizer, 406 | scheduler=scheduler, 407 | batch_size=hp.batch_size, 408 | num_aug=hp.num_aug, 409 | alpha=hp.alpha, 410 | alpha_aug=hp.alpha_aug, 411 | u_lambda=hp.u_lambda, 412 | fp16=hp.fp16) 413 | 414 | print(f"=========eval at epoch={epoch}=========") 415 | dev_f1, test_f1 = eval_on_task(epoch, 416 | model, 417 | task_config['name'], 418 | valid_iter, 419 | validset, 420 | test_iter, 421 | testset, 422 | writer, 423 | run_tag) 424 | 425 | if hp.save_model: 426 | if dev_f1 > best_dev_f1: 427 | best_dev_f1 = dev_f1 428 | torch.save(model.state_dict(), run_tag + '_dev.pt') 429 | if test_f1 > best_test_f1: 430 | best_test_f1 = test_f1 431 | torch.save(model.state_dict(), run_tag + '_test.pt') 432 | 433 | writer.close() 434 | 435 | -------------------------------------------------------------------------------- /snippext/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel, AlbertModel, DistilBertModel, RobertaModel, XLNetModel, LongformerModel 4 | 5 | model_ckpts = {'bert': "bert-base-uncased", 6 | 'albert': "albert-base-v2", 7 | 'roberta': "roberta-base", 8 | 'xlnet': "xlnet-base-cased", 9 | 'distilbert': "distilbert-base-uncased", 10 | 'longformer': "allenai/longformer-base-4096"} 11 | 12 | class MultiTaskNet(nn.Module): 13 | def __init__(self, task_configs=[], 14 | device='cpu', 15 | finetuning=True, 16 | lm='bert', 17 | bert_pt=None, 18 | bert_path=None): 19 | super().__init__() 20 | 21 | assert len(task_configs) > 0 22 | 23 | # load the model or model checkpoint 24 | if bert_path == None: 25 | if lm == 'bert': 26 | self.bert = BertModel.from_pretrained(model_ckpts[lm]) 27 | elif lm == 'distilbert': 28 | self.bert = DistilBertModel.from_pretrained(model_ckpts[lm]) 29 | elif lm == 'albert': 30 | self.bert = AlbertModel.from_pretrained(model_ckpts[lm]) 31 | elif lm == 'xlnet': 32 | self.bert = XLNetModel.from_pretrained(model_ckpts[lm]) 33 | elif lm == 'roberta': 34 | self.bert = RobertaModel.from_pretrained(model_ckpts[lm]) 35 | elif lm == 'longformer': 36 | self.bert = LongformerModel.from_pretrained(model_ckpts[lm]) 37 | else: 38 | output_model_file = bert_path 39 | model_state_dict = torch.load(output_model_file, 40 | map_location=lambda storage, loc: storage) 41 | if lm == 'bert': 42 | self.bert = BertModel.from_pretrained(model_ckpts[lm], 43 | state_dict=model_state_dict) 44 | elif lm == 'distilbert': 45 | self.bert = DistilBertModel.from_pretrained(model_ckpts[lm], 46 | state_dict=model_state_dict) 47 | elif lm == 'albert': 48 | self.bert = AlbertModel.from_pretrained(model_ckpts[lm], 49 | state_dict=model_state_dict) 50 | elif lm == 'xlnet': 51 | self.bert = XLNetModel.from_pretrained(model_ckpts[lm], 52 | state_dict=model_state_dict) 53 | elif lm == 'roberta': 54 | self.bert = RobertaModel.from_pretrained(model_ckpts[lm], 55 | state_dict=model_state_dict) 56 | 57 | self.device = device 58 | self.finetuning = finetuning 59 | self.task_configs = task_configs 60 | self.module_dict = nn.ModuleDict({}) 61 | self.lm = lm 62 | 63 | # hard corded for now 64 | hidden_size = 768 65 | hidden_dropout_prob = 0.1 66 | 67 | for config in task_configs: 68 | name = config['name'] 69 | task_type = config['task_type'] 70 | vocab = config['vocab'] 71 | 72 | if task_type == 'tagging': 73 | # for tagging 74 | vocab_size = len(vocab) # 'O' and '' 75 | if 'O' not in vocab: 76 | vocab_size += 1 77 | if '' not in vocab: 78 | vocab_size += 1 79 | else: 80 | # for pairing and classification 81 | vocab_size = len(vocab) 82 | 83 | self.module_dict['%s_dropout' % name] = nn.Dropout(hidden_dropout_prob) 84 | self.module_dict['%s_fc' % name] = nn.Linear(hidden_size, vocab_size) 85 | 86 | 87 | def forward(self, x, y, 88 | augment_batch=None, 89 | aug_enc=None, 90 | second_batch=None, 91 | x_enc=None, 92 | task='hotel_tagging', 93 | get_enc=False): 94 | """Forward function of the BERT models for classification/tagging. 95 | 96 | Args: 97 | x (Tensor): 98 | y (Tensor): 99 | augment_batch (tuple of Tensor, optional): 100 | aug_enc (Tensor, optional): 101 | second_batch (Tensor, optional): 102 | task (string, optional): 103 | get_enc (boolean, optional): 104 | 105 | Returns: 106 | Tensor: logits 107 | Tensor: y 108 | Tensor: yhat 109 | Tensor (optional): enc""" 110 | 111 | # move input to GPU 112 | x = x.to(self.device) 113 | y = y.to(self.device) 114 | if second_batch != None: 115 | index, lam = second_batch 116 | lam = torch.tensor(lam).to(self.device) 117 | if augment_batch != None: 118 | aug_x, aug_lam = augment_batch 119 | aug_x = aug_x.to(self.device) 120 | aug_lam = torch.tensor(aug_lam).to(self.device) 121 | 122 | dropout = self.module_dict[task + '_dropout'] 123 | fc = self.module_dict[task + '_fc'] 124 | 125 | if 'tagging' in task: # TODO: this needs to be changed later 126 | if self.training and self.finetuning: 127 | self.bert.train() 128 | if x_enc is None: 129 | enc = self.bert(x)[0] 130 | else: 131 | enc = x_enc 132 | # Dropout 133 | enc = dropout(enc) 134 | else: 135 | self.bert.eval() 136 | with torch.no_grad(): 137 | enc = self.bert(x)[0] 138 | 139 | if augment_batch != None: 140 | if aug_enc is None: 141 | aug_enc = self.bert(aug_x)[0] 142 | enc[:aug_x.shape[0]] *= aug_lam 143 | enc[:aug_x.shape[0]] += aug_enc * (1 - aug_lam) 144 | 145 | if second_batch != None: 146 | enc = enc * lam + enc[index] * (1 - lam) 147 | enc = dropout(enc) 148 | 149 | logits = fc(enc) 150 | y_hat = logits.argmax(-1) 151 | if get_enc: 152 | return logits, y, y_hat, enc 153 | else: 154 | return logits, y, y_hat 155 | else: 156 | if self.training and self.finetuning: 157 | self.bert.train() 158 | if x_enc is None: 159 | output = self.bert(x) 160 | pooled_output = output[0][:, 0, :] 161 | pooled_output = dropout(pooled_output) 162 | else: 163 | pooled_output = x_enc 164 | else: 165 | self.bert.eval() 166 | with torch.no_grad(): 167 | output = self.bert(x) 168 | pooled_output = output[0][:, 0, :] 169 | pooled_output = dropout(pooled_output) 170 | 171 | if augment_batch != None: 172 | if aug_enc is None: 173 | aug_enc = self.bert(aug_x)[0][:, 0, :] 174 | pooled_output[:aug_x.shape[0]] *= aug_lam 175 | pooled_output[:aug_x.shape[0]] += aug_enc * (1 - aug_lam) 176 | 177 | if second_batch != None: 178 | pooled_output = pooled_output * lam + pooled_output[index] * (1 - lam) 179 | 180 | logits = fc(pooled_output) 181 | if 'sts-b' in task: 182 | y_hat = logits 183 | else: 184 | y_hat = logits.argmax(-1) 185 | if get_enc: 186 | return logits, y, y_hat, pooled_output 187 | else: 188 | return logits, y, y_hat 189 | -------------------------------------------------------------------------------- /snippext/train_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import sklearn.metrics as metrics 7 | import uuid 8 | 9 | from .conlleval import evaluate_conll_file 10 | from transformers.data import glue_processors, glue_compute_metrics 11 | 12 | def eval_tagging(model, iterator, idx2tag): 13 | """Evaluate a tagging model state on a dev/test set. 14 | 15 | Args: 16 | model (MultiTaskNet): the model state 17 | iterator (DataLoader): a batch iterator of the dev/test set 18 | idx2tag (dict): a mapping from tag indices to tag names 19 | 20 | Returns: 21 | float: precision 22 | float: recall 23 | float: f1 24 | float: loss 25 | """ 26 | model.eval() 27 | 28 | Words, Is_heads, Tags, Y, Y_hat = [], [], [], [], [] 29 | with torch.no_grad(): 30 | loss_list = [] 31 | total_size = 0 32 | for i, batch in enumerate(iterator): 33 | words, x, is_heads, tags, mask, y, seqlens, taskname = batch 34 | 35 | taskname = taskname[0] 36 | loss_fct = nn.CrossEntropyLoss(ignore_index=0) 37 | batch_size = y.shape[0] 38 | 39 | logits, y, y_hat = model(x, y, task=taskname) # y_hat: (N, T) 40 | 41 | logits = logits.view(-1, logits.shape[-1]) 42 | y = y.view(-1) 43 | loss = loss_fct(logits, y) 44 | loss_list.append(loss.item() * batch_size) 45 | total_size += batch_size 46 | 47 | Words.extend(words) 48 | Is_heads.extend(is_heads) 49 | Tags.extend(tags) 50 | Y.extend(y.cpu().numpy().tolist()) 51 | Y_hat.extend(y_hat.cpu().numpy().tolist()) 52 | 53 | # gets results and save 54 | eval_fname = "temp_da_" + taskname +'_' + uuid.uuid4().hex 55 | with open(eval_fname, 'w') as fout: 56 | for words, is_heads, tags, y_hat in zip(Words, Is_heads, Tags, Y_hat): 57 | y_hat = [hat for head, hat in zip(is_heads, y_hat) if head == 1] 58 | preds = [idx2tag[hat] for hat in y_hat] 59 | if len(preds)==len(words.split())==len(tags.split()): 60 | for w, t, p in zip(words.split()[1:-1], tags.split()[1:-1], preds[1:-1]): 61 | if p == '': 62 | p = 'O' 63 | if t == '': 64 | p = t = 'O' 65 | fout.write(f"{w} {t} {p}\n") 66 | fout.write("\n") 67 | 68 | ## calc metric 69 | precision, recall, f1 = evaluate_conll_file(open(eval_fname)) 70 | loss = sum(loss_list) / total_size 71 | os.remove(eval_fname) 72 | print("=============%s==================" % taskname) 73 | print("precision=%.3f"%precision) 74 | print("recall=%.3f"%recall) 75 | print("f1=%.3f"%f1) 76 | print("loss=%.3f"%loss) 77 | print("=====================================") 78 | return precision, recall, f1, loss 79 | 80 | def eval_classifier(model, iterator, threshold=None, get_threshold=False): 81 | """Evaluate a classification model state on a dev/test set. 82 | 83 | Args: 84 | model (MultiTaskNet): the model state 85 | iterator (DataLoader): a batch iterator of the dev/test set 86 | threshold (float, optional): the cut-off threshold for binary cls 87 | get_threshold (boolean, optional): return the selected threshold if True 88 | 89 | Returns: 90 | float: Precision (or accuracy if more than 2 classes) 91 | float: Recall (or accuracy if more than 2 classes) 92 | float: F1 (or macro F1 if more than 2 classes) 93 | float: The Loss 94 | float: The cut-off threshold 95 | """ 96 | model.eval() 97 | 98 | Y = [] 99 | Y_hat = [] 100 | Y_prob = [] 101 | loss_list = [] 102 | total_size = 0 103 | with torch.no_grad(): 104 | for i, batch in enumerate(iterator): 105 | _, x, _, _, _, y, _, taskname = batch 106 | taskname = taskname[0] 107 | logits, y1, y_hat = model(x, y, task=taskname) 108 | logits = logits.view(-1, logits.shape[-1]) 109 | y1 = y1.view(-1) 110 | if 'sts-b' in taskname.lower(): 111 | loss = nn.MSELoss()(logits, y1) 112 | else: 113 | loss = nn.CrossEntropyLoss()(logits, y1) 114 | 115 | loss_list.append(loss.item() * y.shape[0]) 116 | total_size += y.shape[0] 117 | 118 | Y.extend(y.numpy().tolist()) 119 | Y_hat.extend(y_hat.cpu().numpy().tolist()) 120 | Y_prob.extend(logits.softmax(dim=-1).max(dim=-1)[0].cpu().numpy().tolist()) 121 | 122 | loss = sum(loss_list) / total_size 123 | 124 | print("=============%s==================" % taskname) 125 | 126 | # for glue 127 | if taskname in glue_processors: 128 | Y_hat = np.array(Y_hat).squeeze() 129 | Y = np.array(Y) 130 | result = glue_compute_metrics(taskname, Y_hat, Y) 131 | result['loss'] = loss 132 | print(result) 133 | return result 134 | elif taskname[:5] == 'glue_': 135 | task = taskname.split('_')[1].lower() 136 | Y_hat = np.array(Y_hat).squeeze() 137 | Y = np.array(Y) 138 | result = glue_compute_metrics(task, Y_hat, Y) 139 | result['loss'] = loss 140 | print(result) 141 | return result 142 | else: 143 | num_classes = len(set(Y)) 144 | # Binary classification 145 | if num_classes <= 2: 146 | accuracy = metrics.accuracy_score(Y, Y_hat) 147 | precision = metrics.precision_score(Y, Y_hat) 148 | recall = metrics.recall_score(Y, Y_hat) 149 | f1 = metrics.f1_score(Y, Y_hat) 150 | if any([prefix in taskname for prefix in \ 151 | ['cleaning_', 'Structured', 'Textual', 'Dirty']]): # handle imbalance: 152 | max_f1 = f1 153 | if threshold is None: 154 | for th in np.arange(0.9, 1.0, 0.005): 155 | Y_hat = [y if p > th else 0 for (y, p) in zip(Y_hat, Y_prob)] 156 | f1 = metrics.f1_score(Y, Y_hat) 157 | if f1 > max_f1: 158 | max_f1 = f1 159 | accuracy = metrics.accuracy_score(Y, Y_hat) 160 | precision = metrics.precision_score(Y, Y_hat) 161 | recall = metrics.recall_score(Y, Y_hat) 162 | threshold = th 163 | f1 = max_f1 164 | else: 165 | Y_hat = [y if p > threshold else 0 for (y, p) in zip(Y_hat, Y_prob)] 166 | accuracy = metrics.accuracy_score(Y, Y_hat) 167 | precision = metrics.precision_score(Y, Y_hat) 168 | recall = metrics.recall_score(Y, Y_hat) 169 | f1 = metrics.f1_score(Y, Y_hat) 170 | 171 | print("accuracy=%.3f"%accuracy) 172 | print("precision=%.3f"%precision) 173 | print("recall=%.3f"%recall) 174 | print("f1=%.3f"%f1) 175 | print("======================================") 176 | if get_threshold: 177 | return accuracy, precision, recall, f1, loss, threshold 178 | else: 179 | return accuracy, precision, recall, f1, loss 180 | else: 181 | accuracy = metrics.accuracy_score(Y, Y_hat) 182 | f1 = metrics.f1_score(Y, Y_hat, average='macro') 183 | precision = recall = accuracy # We might just not return anything 184 | print("accuracy=%.3f"%accuracy) 185 | print("macro_f1=%.3f"%f1) 186 | print("======================================") 187 | return accuracy, f1, loss 188 | 189 | 190 | def eval_on_task(epoch, 191 | model, 192 | task, 193 | valid_iter, 194 | valid_dataset, 195 | test_iter, 196 | test_dataset, 197 | writer, 198 | run_tag): 199 | """Run the eval function on the dev/test datasets and log the results. 200 | 201 | Args: 202 | epoch (int): the epoch number of the training process 203 | model (MultiTaskNet): the model state 204 | task (str): the task name to be evaluated 205 | valid_iter (DataLoader): the dev set iterator 206 | valid_dataset (Dataset): the dev dataset 207 | test_iter (DataLoader): the test set iterator 208 | test_dataset (Datset): the test dataset 209 | writer (SummaryWriter): the logging writer for tensorboard 210 | run_tag (str): the tag of the run 211 | 212 | Returns: 213 | float: dev F1 214 | float: test F1 215 | """ 216 | t_prec = t_recall = t_f1 = t_loss = None 217 | if 'tagging' in task: 218 | print('Validation:') 219 | prec, recall, f1, v_loss = eval_tagging(model, 220 | valid_iter, 221 | valid_dataset.idx2tag) 222 | if test_iter is not None: 223 | print('Test:') 224 | t_prec, t_recall, t_f1, t_loss = eval_tagging(model, 225 | test_iter, 226 | test_dataset.idx2tag) 227 | scalars = {'precision': prec, 228 | 'recall': recall, 229 | 'f1': f1, 230 | 'v_loss': v_loss, 231 | 't_precision': t_prec, 232 | 't_recall': t_recall, 233 | 't_f1': t_f1, 234 | 't_loss': t_loss} 235 | elif task in glue_processors: 236 | print('Validation:') 237 | scalars = eval_classifier(model, valid_iter) 238 | f1, t_f1 = 0.0, 0.0 239 | elif task[:5] == 'glue_': 240 | print('Validation:') 241 | scalars = eval_classifier(model, valid_iter) 242 | 243 | if test_iter is not None: 244 | print('Test:') 245 | t_output = eval_classifier(model, test_iter) 246 | for key in t_output: 247 | scalars['t_' + key] = t_output[key] 248 | 249 | f1, t_f1 = 0.0, 0.0 250 | elif any([prefix in task for prefix in \ 251 | ['cleaning_', 'Structured', 'Textual', 'Dirty']]): # handle imbalance: 252 | print('Validation:') 253 | acc, prec, recall, f1, v_loss, th = eval_classifier(model, valid_iter, get_threshold=True) 254 | print('Test:') 255 | t_acc, t_prec, t_recall, t_f1, t_loss = eval_classifier(model, test_iter, threshold=th) 256 | scalars = {'acc': acc, 257 | 'precision': prec, 258 | 'recall': recall, 259 | 'f1': f1, 260 | 'v_loss': v_loss, 261 | 't_acc': t_acc, 262 | 't_precision': t_prec, 263 | 't_recall': t_recall, 264 | 't_f1': t_f1, 265 | 't_loss': t_loss} 266 | else: 267 | print('Validation:') 268 | v_output = eval_classifier(model, valid_iter) 269 | 270 | if test_iter is not None: 271 | print('Test:') 272 | t_output = eval_classifier(model, test_iter) 273 | 274 | if len(v_output) == 5: 275 | acc, prec, recall, f1, v_loss = v_output 276 | t_acc, t_prec, t_recall, t_f1, t_loss = t_output 277 | scalars = {'acc': acc, 278 | 'precision': prec, 279 | 'recall': recall, 280 | 'f1': f1, 281 | 'v_loss': v_loss, 282 | 't_acc': t_acc, 283 | 't_precision': t_prec, 284 | 't_recall': t_recall, 285 | 't_f1': t_f1, 286 | 't_loss': t_loss} 287 | else: 288 | acc, f1, v_loss = v_output 289 | t_acc, t_f1, t_loss = t_output 290 | scalars = {'acc': acc, 291 | 'f1': f1, 292 | 'v_loss': v_loss, 293 | 't_acc': t_acc, 294 | 't_f1': t_f1, 295 | 't_loss': t_loss} 296 | 297 | # logging 298 | writer.add_scalars(run_tag, scalars, epoch) 299 | return f1, t_f1 300 | -------------------------------------------------------------------------------- /train_baseline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from torch.utils import data 6 | from snippext.baseline import initialize_and_train 7 | from snippext.dataset import SnippextDataset 8 | 9 | if __name__=="__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--task", type=str, default="hotel_tagging") 12 | parser.add_argument("--lm", type=str, default="bert") 13 | parser.add_argument("--run_id", type=int, default=0) 14 | parser.add_argument("--batch_size", type=int, default=128) 15 | parser.add_argument("--lr", type=float, default=0.0001) 16 | parser.add_argument("--n_epochs", type=int, default=30) 17 | parser.add_argument("--max_len", type=int, default=64) 18 | parser.add_argument("--finetuning", dest="finetuning", action="store_true") 19 | parser.add_argument("--fp16", dest="fp16", action="store_true") 20 | parser.add_argument("--save_model", dest="save_model", action="store_true") 21 | parser.add_argument("--logdir", type=str, default="checkpoints/") 22 | parser.add_argument("--bert_path", type=str, default=None) 23 | 24 | hp = parser.parse_args() 25 | 26 | # only a single task for baseline 27 | task = hp.task 28 | 29 | # create the tag of the run 30 | run_tag = 'baseline_task_%s_lm_%s_batch_size_%d_run_id_%d' % (task, 31 | hp.lm, 32 | hp.batch_size, 33 | hp.run_id) 34 | 35 | # load task configuration 36 | configs = json.load(open('configs.json')) 37 | configs = {conf['name'] : conf for conf in configs} 38 | config = configs[task] 39 | 40 | trainset = config['trainset'] 41 | validset = config['validset'] 42 | testset = config['testset'] 43 | task_type = config['task_type'] 44 | vocab = config['vocab'] 45 | tasknames = [task] 46 | 47 | # load train/dev/test sets 48 | train_dataset = SnippextDataset(trainset, vocab, task, 49 | lm=hp.lm, 50 | max_len=hp.max_len) 51 | valid_dataset = SnippextDataset(validset, vocab, task, 52 | lm=hp.lm) 53 | test_dataset = SnippextDataset(testset, vocab, task, 54 | lm=hp.lm) 55 | 56 | # run the training process 57 | initialize_and_train(config, 58 | train_dataset, 59 | valid_dataset, 60 | test_dataset, 61 | hp, 62 | run_tag) 63 | -------------------------------------------------------------------------------- /train_mixda.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from torch.utils import data 6 | from snippext.dataset import SnippextDataset 7 | from snippext.mixda import initialize_and_train 8 | 9 | 10 | if __name__=="__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--task", type=str, default="hotel_tagging") 13 | parser.add_argument("--lm", type=str, default="bert") 14 | parser.add_argument("--run_id", type=int, default=0) 15 | parser.add_argument("--batch_size", type=int, default=128) 16 | parser.add_argument("--max_len", type=int, default=64) 17 | parser.add_argument("--lr", type=float, default=0.0001) 18 | parser.add_argument("--n_epochs", type=int, default=30) 19 | parser.add_argument("--finetuning", dest="finetuning", action="store_true") 20 | parser.add_argument("--fp16", dest="fp16", action="store_true") 21 | parser.add_argument("--save_model", dest="save_model", action="store_true") 22 | parser.add_argument("--logdir", type=str, default="checkpoints/") 23 | parser.add_argument("--bert_path", type=str, default=None) 24 | parser.add_argument("--alpha_aug", type=float, default=0.8) 25 | parser.add_argument("--augment_index", type=str, default=None) 26 | parser.add_argument("--augment_op", type=str, default=None) 27 | 28 | hp = parser.parse_args() 29 | 30 | task = hp.task # consider a single task for now 31 | 32 | # create the tag of the run 33 | run_tag = 'mixda_task_%s_lm_%s_batch_size_%d_alpha_aug_%.1f_augment_op_%s_run_id_%d' % \ 34 | (task, hp.lm, hp.batch_size, hp.alpha_aug, hp.augment_op, hp.run_id) 35 | 36 | # task config 37 | configs = json.load(open('configs.json')) 38 | configs = {conf['name'] : conf for conf in configs} 39 | config = configs[task] 40 | 41 | trainset = config['trainset'] 42 | validset = config['validset'] 43 | testset = config['testset'] 44 | task_type = config['task_type'] 45 | vocab = config['vocab'] 46 | tasknames = [task] 47 | 48 | # train dataset 49 | train_dataset = SnippextDataset(trainset, vocab, task, 50 | lm=hp.lm, 51 | max_len=hp.max_len) 52 | # train dataset augmented 53 | augment_dataset = SnippextDataset(trainset, vocab, task, 54 | lm=hp.lm, 55 | max_len=hp.max_len, 56 | augment_index=hp.augment_index, 57 | augment_op=hp.augment_op) 58 | # dev set 59 | valid_dataset = SnippextDataset(validset, vocab, task, lm=hp.lm) 60 | 61 | # test set 62 | test_dataset = SnippextDataset(testset, vocab, task, lm=hp.lm) 63 | 64 | # run the training process 65 | initialize_and_train(config, 66 | train_dataset, 67 | augment_dataset, 68 | valid_dataset, 69 | test_dataset, 70 | hp, run_tag) 71 | -------------------------------------------------------------------------------- /train_mixmatchnl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from snippext.dataset import SnippextDataset 6 | from snippext.mixmatchnl import initialize_and_train 7 | 8 | if __name__=="__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--task", type=str, default="hotel_tagging") 11 | parser.add_argument("--lm", type=str, default="bert") 12 | parser.add_argument("--run_id", type=int, default=0) 13 | parser.add_argument("--batch_size", type=int, default=128) 14 | parser.add_argument("--lr", type=float, default=0.0001) 15 | parser.add_argument("--n_epochs", type=int, default=30) 16 | parser.add_argument("--max_len", type=int, default=64) 17 | parser.add_argument("--finetuning", dest="finetuning", action="store_true") 18 | parser.add_argument("--save_model", dest="save_model", action="store_true") 19 | parser.add_argument("--fp16", dest="fp16", action="store_true") 20 | parser.add_argument("--logdir", type=str, default="checkpoints/") 21 | parser.add_argument("--bert_path", type=str, default=None) 22 | parser.add_argument("--alpha", type=float, default=0.2) 23 | parser.add_argument("--alpha_aug", type=float, default=0.8) 24 | parser.add_argument("--num_aug", type=int, default=2) 25 | parser.add_argument("--u_lambda", type=float, default=10.0) 26 | parser.add_argument("--augment_index", type=str, default=None) 27 | parser.add_argument("--augment_op", type=str, default=None) 28 | 29 | hp = parser.parse_args() 30 | 31 | task = hp.task # consider a single task for now 32 | 33 | # create the tag of the run 34 | run_tag = 'mixmatchnl_task_%s_lm_%s_batch_size_%d_alpha_%.1f_alpha_aug_%.1f_num_aug_%d_u_lambda_%.1f_augment_op_%s_run_id_%d' % \ 35 | (task, hp.lm, hp.batch_size, hp.alpha, hp.alpha_aug, \ 36 | hp.num_aug, hp.u_lambda, hp.augment_op, hp.run_id) 37 | 38 | # task config 39 | configs = json.load(open('configs.json')) 40 | configs = {conf['name'] : conf for conf in configs} 41 | config = configs[task] 42 | config_list = [config] 43 | 44 | trainset = config['trainset'] 45 | validset = config['validset'] 46 | testset = config['testset'] 47 | unlabeled = config['unlabeled'] 48 | task_type = config['task_type'] 49 | vocab = config['vocab'] 50 | tasknames = [task] 51 | 52 | # train dataset 53 | train_dataset = SnippextDataset(trainset, vocab, task, 54 | lm=hp.lm, 55 | max_len=hp.max_len) 56 | # train dataset augmented 57 | augment_dataset = SnippextDataset(trainset, vocab, task, 58 | lm=hp.lm, 59 | max_len=hp.max_len, 60 | augment_index=hp.augment_index, 61 | augment_op=hp.augment_op) 62 | # dev set 63 | valid_dataset = SnippextDataset(validset, vocab, task, lm=hp.lm) 64 | 65 | # test set 66 | test_dataset = SnippextDataset(testset, vocab, task, lm=hp.lm) 67 | 68 | # unlabeled dataset and augmented 69 | u_dataset = SnippextDataset(unlabeled, vocab, task, max_len=hp.max_len, lm=hp.lm) 70 | u_dataset_aug = SnippextDataset(unlabeled, vocab, task, 71 | lm=hp.lm, 72 | max_len=hp.max_len, 73 | augment_index=hp.augment_index, 74 | augment_op=hp.augment_op) 75 | 76 | # train the model 77 | initialize_and_train(config, 78 | train_dataset, 79 | augment_dataset, 80 | valid_dataset, 81 | test_dataset, 82 | u_dataset, 83 | u_dataset_aug, 84 | hp, run_tag) 85 | --------------------------------------------------------------------------------