├── models ├── __init__.py ├── config.py ├── fast_text.py ├── text_lstm.py ├── text_cnn.py └── text_bert.py ├── LICENSE ├── scripts └── train_trec.sh ├── .gitignore ├── utils.py ├── README.md ├── preprocessing.py ├── datasets.py ├── train_bert.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .text_cnn import TextCNN 2 | from .text_lstm import TextLSTM 3 | from .fast_text import FastText 4 | from .text_bert import TextBERT 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 xashru 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/train_trec.sh: -------------------------------------------------------------------------------- 1 | # Random 2 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-none 3 | 4 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-embed --method=embed 5 | 6 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-sent --method=sent 7 | 8 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-dense --method=dense 9 | 10 | # nonstatic 11 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-nonstatic-none --w2v-file=data/glove.pickle --fine-tune=True 12 | 13 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-nonstatic-embed --method=embed --w2v-file=data/glove.pickle --fine-tune=True 14 | 15 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-nonstatic-sent --method=sent --w2v-file=data/glove.pickle --fine-tune=True 16 | 17 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-nonstatic-dense --method=dense --w2v-file=data/glove.pickle --fine-tune=True 18 | 19 | # static 20 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-static-none --w2v-file=data/glove.pickle --fine-tune=False 21 | 22 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-static-embed --method=embed --w2v-file=data/glove.pickle --fine-tune=False 23 | 24 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-static-sent --method=sent --w2v-file=data/glove.pickle --fine-tune=False 25 | 26 | python train.py --model=TextCNN --task=trec --save-path='logs-trec' --name=trec-textcnn-static-dense --method=dense --w2v-file=data/glove.pickle --fine-tune=False 27 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | from transformers import * 2 | 3 | # special tokens indices in different models available in transformers 4 | TOKEN_IDX = { 5 | 'bert': { 6 | 'START_SEQ': 101, 7 | 'PAD': 0, 8 | 'END_SEQ': 102, 9 | 'UNK': 100 10 | }, 11 | 'xlm': { 12 | 'START_SEQ': 0, 13 | 'PAD': 2, 14 | 'END_SEQ': 1, 15 | 'UNK': 3 16 | }, 17 | 'roberta': { 18 | 'START_SEQ': 0, 19 | 'PAD': 1, 20 | 'END_SEQ': 2, 21 | 'UNK': 3 22 | }, 23 | 'albert': { 24 | 'START_SEQ': 2, 25 | 'PAD': 0, 26 | 'END_SEQ': 3, 27 | 'UNK': 1 28 | }, 29 | } 30 | 31 | # pretrained model name: (model class, model tokenizer, output dimension, token style) 32 | # only BERT variants are implemented for mixup 33 | MODELS = { 34 | 'bert-base-uncased': (BertModel, BertTokenizer, 768, 'bert'), 35 | 'bert-large-uncased': (BertModel, BertTokenizer, 1024, 'bert'), 36 | 'bert-base-multilingual-cased': (BertModel, BertTokenizer, 768, 'bert'), 37 | 'bert-base-multilingual-uncased': (BertModel, BertTokenizer, 768, 'bert'), 38 | 'xlm-mlm-en-2048': (XLMModel, XLMTokenizer, 2048, 'xlm'), 39 | 'xlm-mlm-100-1280': (XLMModel, XLMTokenizer, 1280, 'xlm'), 40 | 'roberta-base': (RobertaModel, RobertaTokenizer, 768, 'roberta'), 41 | 'roberta-large': (RobertaModel, RobertaTokenizer, 1024, 'roberta'), 42 | 'distilbert-base-uncased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'), 43 | 'distilbert-base-multilingual-cased': (DistilBertModel, DistilBertTokenizer, 768, 'bert'), 44 | 'xlm-roberta-base': (XLMRobertaModel, XLMRobertaTokenizer, 768, 'roberta'), 45 | 'xlm-roberta-large': (XLMRobertaModel, XLMRobertaTokenizer, 1024, 'roberta'), 46 | 'albert-base-v1': (AlbertModel, AlbertTokenizer, 768, 'albert'), 47 | 'albert-base-v2': (AlbertModel, AlbertTokenizer, 768, 'albert'), 48 | 'albert-large-v2': (AlbertModel, AlbertTokenizer, 1024, 'albert'), 49 | } 50 | -------------------------------------------------------------------------------- /models/fast_text.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | embed_size = 300 4 | hidden_size = 10 5 | 6 | 7 | class FastText(nn.Module): 8 | def __init__(self, vocab_size, sequence_len, num_class, word_embeddings=None, fine_tune=True, dropout=0.5): 9 | super(FastText, self).__init__() 10 | 11 | # Embedding Layer 12 | self.embeddings = nn.Embedding(vocab_size, embed_size) 13 | self.sequence_len = sequence_len 14 | if word_embeddings is not None: 15 | self.embeddings.weight = nn.Parameter(word_embeddings, requires_grad=fine_tune) 16 | 17 | # Hidden Layer 18 | self.fc1 = nn.Linear(embed_size, hidden_size) 19 | self.dropout = nn.Dropout(dropout) 20 | 21 | # Output Layer 22 | self.fc2 = nn.Linear(hidden_size, num_class) 23 | 24 | def forward(self, x): 25 | # (batch, seq_len, embed) 26 | x = self.embeddings(x).permute(1, 0, 2) 27 | 28 | # (batch, hidden_size) 29 | x = self.fc1(x.mean(1)) 30 | x = self.dropout(x) 31 | 32 | # (batch, num_class) 33 | x = self.fc2(x) 34 | return x 35 | 36 | def _forward_dense(self, x): 37 | x = self.embeddings(x).permute(1, 0, 2) 38 | x = self.fc1(x.mean(1)) 39 | return x 40 | 41 | def forward_mix_embed(self, x1, x2, lam): 42 | x1 = self.embeddings(x1).permute(1, 0, 2) 43 | x2 = self.embeddings(x2).permute(1, 0, 2) 44 | x = lam * x1 + (1.0 - lam) * x2 45 | 46 | x = self.fc1(x.mean(1)) 47 | x = self.fc2(x) 48 | return x 49 | 50 | def forward_mix_sent(self, x1, x2, lam): 51 | y1 = self.forward(x1) 52 | y2 = self.forward(x2) 53 | y = lam * y1 + (1.0 - lam) * y2 54 | return y 55 | 56 | def forward_mix_encoder(self, x1, x2, lam): 57 | y1 = self._forward_dense(x1) 58 | y2 = self._forward_dense(x2) 59 | y = lam * y1 + (1.0 - lam) * y2 60 | y = self.fc2(y) 61 | return y 62 | -------------------------------------------------------------------------------- /models/text_lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | embed_size = 300 5 | hidden_size = 100 6 | hidden_layers = 3 7 | bidirectional = True 8 | 9 | 10 | class TextLSTM(nn.Module): 11 | def __init__(self, vocab_size, sequence_len, num_class, word_embeddings=None, fine_tune=True, dropout=0.5): 12 | super(TextLSTM, self).__init__() 13 | 14 | # Embedding Layer 15 | self.embeddings = nn.Embedding(vocab_size, embed_size) 16 | self.sequence_len = sequence_len 17 | 18 | if word_embeddings is not None: 19 | self.embeddings.weight = nn.Parameter(word_embeddings, requires_grad=fine_tune) 20 | 21 | # LSTM layer 22 | self.lstm = nn.LSTM(input_size=embed_size, hidden_size=hidden_size, num_layers=hidden_layers, dropout=dropout, 23 | bidirectional=bidirectional) 24 | self.fc = nn.Linear(hidden_size * hidden_layers * (1 + bidirectional), num_class) 25 | 26 | def forward(self, x): 27 | # (seq_len, batch, embed) 28 | x = self.embeddings(x) 29 | 30 | _, (x, _) = self.lstm(x) 31 | # (num_layers * num_directions, batch, hidden) 32 | 33 | x = torch.cat([x[i, :, :] for i in range(x.shape[0])], dim=1) 34 | # (batch, num_layers * num_directions * hidden_size) 35 | x = self.fc(x) 36 | return x 37 | 38 | def _forward_dense(self, x): 39 | x = self.embeddings(x) 40 | _, (x, _) = self.lstm(x) 41 | x = torch.cat([x[i, :, :] for i in range(x.shape[0])], dim=1) 42 | return x 43 | 44 | def forward_mix_embed(self, x1, x2, lam): 45 | x1 = self.embeddings(x1) 46 | x2 = self.embeddings(x2) 47 | x = lam * x1 + (1.0-lam) * x2 48 | _, (x, _) = self.lstm(x) 49 | x = torch.cat([x[i, :, :] for i in range(x.shape[0])], dim=1) 50 | x = self.fc(x) 51 | return x 52 | 53 | def forward_mix_sent(self, x1, x2, lam): 54 | y1 = self.forward(x1) 55 | y2 = self.forward(x2) 56 | y = lam * y1 + (1.0-lam) * y2 57 | return y 58 | 59 | def forward_mix_encoder(self, x1, x2, lam): 60 | y1 = self._forward_dense(x1) 61 | y2 = self._forward_dense(x2) 62 | y = lam * y1 + (1.0-lam) * y2 63 | y = self.fc(y) 64 | return y 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | **/.idea/ 132 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | class TaskConfig: 2 | num_class = None 3 | train_file = None 4 | val_file = None 5 | test_file = None 6 | sequence_len = None 7 | eval_interval = None 8 | patience = None 9 | 10 | 11 | def get_task_config(task_name): 12 | config = TaskConfig() 13 | if task_name == 'trec': 14 | config.num_class = 6 15 | config.train_file = 'data/trec/train.csv' 16 | config.val_file = None 17 | config.test_file = 'data/trec/test.csv' 18 | config.sequence_len = 30 19 | config.eval_interval = 20 20 | config.patience = 50 21 | elif task_name == 'sst1': 22 | config.num_class = 5 23 | config.train_file = 'data/sst1/train_sent.csv' 24 | config.val_file = 'data/sst1/val.csv' 25 | config.test_file = 'data/sst1/test.csv' 26 | config.sequence_len = 50 27 | config.eval_interval = 30 28 | config.patience = 30 29 | elif task_name == 'imdb': 30 | config.num_class = 2 31 | config.train_file = 'data/imdb/train.csv' 32 | config.val_file = None 33 | config.test_file = 'data/imdb/test.csv' 34 | config.sequence_len = 400 35 | config.eval_interval = 50 36 | config.patience = 30 37 | elif task_name == 'agnews': 38 | config.num_class = 4 39 | config.train_file = 'data/agnews/train.csv' 40 | config.val_file = None 41 | config.test_file = 'data/agnews/test.csv' 42 | config.sequence_len = 80 43 | config.eval_interval = 50 44 | config.patience = 50 45 | elif task_name == 'dbpedia': 46 | config.num_class = 14 47 | config.train_file = 'data/dbpedia/train.csv' 48 | config.val_file = None 49 | config.test_file = 'data/dbpedia/test.csv' 50 | config.sequence_len = 100 51 | config.eval_interval = 100 52 | config.patience = 20 53 | elif task_name == 'yahoo': 54 | config.num_class = 10 55 | config.train_file = 'data/yahoo/train.csv' 56 | config.val_file = None 57 | config.test_file = 'data/yahoo/test.csv' 58 | config.sequence_len = 200 59 | config.eval_interval = 200 60 | config.patience = 20 61 | elif task_name == 'yelp-polar': 62 | config.num_class = 2 63 | config.train_file = 'data/yelp-polar/train.csv' 64 | config.val_file = None 65 | config.test_file = 'data/yelp-polar/test.csv' 66 | config.sequence_len = 400 67 | config.eval_interval = 200 68 | config.patience = 20 69 | else: 70 | raise ValueError('Task not supported') 71 | return config 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mixup-text 2 | This repository contains implementation of mixup strategy for text classification. The implementation is primarily based on the paper [Augmenting Data with Mixup for Sentence Classification: An Empirical Study 3 | ](https://arxiv.org/abs/1905.08941), although there is some difference. 4 | 5 | Three variants of mixup are considered for text classification 6 | 1. Embedding mixup: Texts are mixed immediately after word embeedding 7 | 2. Hidden/Encoder mixup: Mixup is done prior to the last fully connected layer 8 | 3. Sentence mixup: Mixup is done before softmax 9 | 10 | #### Results 11 | 12 | 13 | Some experimental results on TREC, SST-1, IMDB, AG's News and DBPedia datasets. *rand* referes to models initialized randomly. *finetune* is models initialized with pretrained word vector (GloVe or BERT). 14 | 15 | | Model | TREC | SST-1 | IMDB | AG's News | DBPedia | 16 | |------------------------------|-------|-------|-------|-----------|---------| 17 | | CNN-rand | 88.58 | 37.00 | 86.74 | 91.07 | 98.03 | 18 | | CNN-rand + embed mixup | 88.38 | 35.93 | 87.34 | 91.67 | 97.85 | 19 | | CNN-rand + hidden mixup | 88.78 | 35.24 | 87.06 | 91.49 | 98.34 | 20 | | CNN-rand + sent mixup | 88.92 | 35.40 | 87.25 | 91.46 | 98.23 | 21 | | CNN-finetune | 90.50 | 46.38 | 88.57 | 92.67 | 98.81 | 22 | | CNN-finetune + embed mixup | 91.62 | 45.81 | 89.13 | 92.78 | 98.55 | 23 | | CNN-finetune + hidden-mixup | 91.74 | 45.70 | 89.66 | 93.11 | 98.83 | 24 | | CNN-fine-tune + sent mixup | 91.70 | 46.10 | 89.60 | 93.12 | 98.83 | 25 | | LSTM-finetune | 89.26 | 44.38 | 86.04 | 92.87 | 98.95 | 26 | | LSTM-finetune + embed mixup | 89.82 | 44.04 | 85.82 | 92.76 | 98.98 | 27 | | LSTM finetune + hidden mixup | 89.72 | 43.87 | 85.23 | 92.67 | 98.92 | 28 | | LSTM finetune + sent mixup | 89.70 | 43.86 | 85.02 | 92.65 | 98.87 | 29 | | fastText-finetune | 86.88 | 43.26 | 88.33 | 91.93 | 97.85 | 30 | | fastText-finetune + mixup | 86.2 | 43.81 | 88.05 | 91.99 | 97.99 | 31 | | BERT-finetune | 97.04 | 53.05 | - | - | - | 32 | | BERT-finetune + embed mixup | 97.20 | 53.12 | - | - | - | 33 | | BERT-finetune + hidden mixup | 96.92 | 53.13 | - | - | - | 34 | | BERT-finetune + sent mixup | 96.86 | 53.32 | - | - | - | 35 | 36 | Results are mean accuracy of 10 runs for all datasets, except for DBPedia where it is average of 3 runs. 37 | Note that for fastText model there is only one variant of mixup as it is a linear model. 38 | 39 | #### TO-DO 40 | - [ ] Manifold mixup implementation 41 | - [ ] Result for BERT on IMDB, AG's News and DBPedia datasets -------------------------------------------------------------------------------- /preprocessing.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | import spacy 3 | import unidecode 4 | import contractions as cont 5 | from word2number import w2n 6 | nlp = spacy.load('en_core_web_md') 7 | 8 | # exclude words from spacy stopwords list 9 | deselect_stop_words = ['no', 'not'] 10 | for w in deselect_stop_words: 11 | nlp.vocab[w].is_stop = False 12 | 13 | 14 | def strip_html_tags(text): 15 | """remove html tags from text""" 16 | soup = BeautifulSoup(text, "html.parser") 17 | stripped_text = soup.get_text(separator=" ") 18 | return stripped_text 19 | 20 | 21 | def remove_whitespace(text): 22 | """remove extra whitespaces from text""" 23 | text = text.strip() 24 | return " ".join(text.split()) 25 | 26 | 27 | def remove_accented_chars(text): 28 | """remove accented characters from text, e.g. café""" 29 | text = unidecode.unidecode(text) 30 | return text 31 | 32 | 33 | def expand_contractions(text): 34 | """expand shortened words, e.g. don't to do not""" 35 | return cont.fix(text, slang=False) 36 | 37 | 38 | def preprocess_text(text, accented_chars=True, contractions=True, convert_num=False, extra_whitespace=True, 39 | lemmatization=False, lowercase=True, punctuations=False, remove_html=True, remove_num=False, 40 | special_chars=True, stop_words=False): 41 | """preprocess text with default option set to true for all steps""" 42 | if remove_html: 43 | text = strip_html_tags(text) 44 | if extra_whitespace: 45 | text = remove_whitespace(text) 46 | if accented_chars: 47 | text = remove_accented_chars(text) 48 | if contractions: 49 | text = expand_contractions(text) 50 | if lowercase: 51 | text = text.lower() 52 | 53 | doc = nlp(text) 54 | 55 | clean_text = [] 56 | 57 | for token in doc: 58 | flag = True 59 | edit = token.text 60 | # remove stop words 61 | if stop_words and token.is_stop and token.pos_ != 'NUM': 62 | flag = False 63 | # remove punctuations 64 | if punctuations and token.pos_ == 'PUNCT' and flag: 65 | flag = False 66 | # remove special characters 67 | if special_chars and token.pos_ == 'SYM' and flag: 68 | flag = False 69 | # remove numbers 70 | if remove_num and (token.pos_ == 'NUM' or token.text.isnumeric()) and flag: 71 | flag = False 72 | # convert number words to numeric numbers 73 | if convert_num and token.pos_ == 'NUM' and flag: 74 | edit = w2n.word_to_num(token.text) 75 | # convert tokens to base form 76 | elif lemmatization and token.lemma_ != "-PRON-" and flag: 77 | edit = token.lemma_ 78 | # append tokens edited and not removed to list 79 | if edit != "" and flag: 80 | clean_text.append(edit) 81 | 82 | clean_text = ' '.join(clean_text) 83 | return clean_text 84 | -------------------------------------------------------------------------------- /models/text_cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | embed_size = 300 7 | kernel_size = [3, 4, 5] 8 | num_channels = 100 9 | 10 | 11 | class TextCNN(nn.Module): 12 | def __init__(self, vocab_size, sequence_len, num_class, word_embeddings=None, fine_tune=True, dropout=0.5): 13 | super(TextCNN, self).__init__() 14 | 15 | # Embedding Layer 16 | self.embeddings = nn.Embedding(vocab_size, embed_size) 17 | self.sequence_len = sequence_len 18 | if word_embeddings is not None: 19 | self.embeddings.weight = nn.Parameter(word_embeddings, requires_grad=fine_tune) 20 | 21 | # Conv layers 22 | self.convs = nn.ModuleList([nn.Conv2d(1, num_channels, [k, embed_size]) for k in kernel_size]) 23 | 24 | self.dropout = nn.Dropout(dropout) 25 | self.fc = nn.Linear(num_channels * len(kernel_size), num_class) 26 | 27 | def forward(self, x): 28 | # (batch, seq_len, embed) 29 | x = self.embeddings(x).permute(1, 0, 2) 30 | # (batch, channel, seq_len, embed) 31 | x = torch.unsqueeze(x, 1) 32 | 33 | # (batch, channel, seq_len-k+1) 34 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] 35 | 36 | # (batch, channel) 37 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] 38 | 39 | # (batch, #filters * channel) 40 | x = torch.cat(x, 1) 41 | 42 | x = self.dropout(x) 43 | 44 | # (batch, #class) 45 | x = self.fc(x) 46 | return x 47 | 48 | def _forward_dense(self, x): 49 | x = self.embeddings(x).permute(1, 0, 2) 50 | x = torch.unsqueeze(x, 1) 51 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] 52 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] 53 | x = torch.cat(x, 1) 54 | return x 55 | 56 | @staticmethod 57 | def mix_embed_nonlinear(x1, x2, lam): 58 | # x.shape: (batch, seq_len, embed) 59 | embed = x1.shape[2] 60 | stride = int(round(embed * (1 - lam))) 61 | mixed_x = x1 62 | aug_type = np.random.randint(2) 63 | if aug_type == 0: 64 | mixed_x[:, :, :stride] = x2[:, :, :stride] 65 | else: 66 | mixed_x[:, :, embed-stride:] = x2[:, :, embed-stride:] 67 | return mixed_x 68 | 69 | def forward_mix_embed(self, x1, x2, lam): 70 | # (seq_len, batch) -> (batch, seq_len, embed) 71 | x1 = self.embeddings(x1).permute(1, 0, 2) 72 | x2 = self.embeddings(x2).permute(1, 0, 2) 73 | x = lam * x1 + (1.0-lam) * x2 74 | # x = self.mix_embed_nonlinear(x1, x2, lam) 75 | 76 | x = torch.unsqueeze(x, 1) 77 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] 78 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] 79 | x = torch.cat(x, 1) 80 | x = self.dropout(x) 81 | x = self.fc(x) 82 | return x 83 | 84 | def forward_mix_sent(self, x1, x2, lam): 85 | y1 = self.forward(x1) 86 | y2 = self.forward(x2) 87 | y = lam * y1 + (1.0-lam) * y2 88 | return y 89 | 90 | def forward_mix_encoder(self, x1, x2, lam): 91 | y1 = self._forward_dense(x1) 92 | y2 = self._forward_dense(x2) 93 | y = lam * y1 + (1.0-lam) * y2 94 | y = self.fc(y) 95 | return y 96 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | from torchtext import data 2 | # from torchtext.vocab import Vectors 3 | import spacy 4 | import pandas as pd 5 | import pickle 6 | from torch.utils.data import Dataset 7 | import torch 8 | from transformers import BertTokenizer 9 | 10 | 11 | class WordDataset(object): 12 | def __init__(self, sequence_len, batch_size): 13 | self.sequence_len = sequence_len 14 | self.batch_size = batch_size 15 | self.train_iterator = None 16 | self.test_iterator = None 17 | self.val_iterator = None 18 | self.vocab = [] 19 | self.word_embeddings = {} 20 | 21 | @staticmethod 22 | def get_pandas_df(filename, text_col, label_col): 23 | """ 24 | Load the data into Pandas.DataFrame object 25 | This will be used to convert data to torchtext object 26 | """ 27 | df = pd.read_csv(filename, sep='\t') 28 | texts = [] 29 | labels = [] 30 | for index, row in df.iterrows(): 31 | texts.append(row[text_col]) 32 | labels.append(int(row[label_col])) 33 | df = pd.DataFrame({"text": texts, "label": labels}) 34 | return df 35 | 36 | def load_data(self, train_file, test_file, val_file=None, w2v_file=None, text_col='text', label_col='label'): 37 | """ 38 | Loads the data from files 39 | Sets up iterators for training, validation and test data 40 | Also create vocabulary and word embeddings based on the data 41 | 42 | Inputs: 43 | w2v_file (String): absolute path to file containing word embeddings (GloVe/Word2Vec) 44 | train_file (String): absolute path to training file 45 | test_file (String): absolute path to test file 46 | val_file (String): absolute path to validation file 47 | """ 48 | 49 | nlp = spacy.load('en') 50 | tokenizer = lambda sent: [x.text for x in nlp.tokenizer(sent) if x.text != " "] 51 | 52 | # Creating Field for data 53 | text = data.Field(sequential=True, tokenize=tokenizer, lower=True, fix_length=self.sequence_len) 54 | label = data.Field(sequential=False, use_vocab=False) 55 | datafields = [("text", text), ("label", label)] 56 | 57 | # Load data from pd.DataFrame into torchtext.data.Dataset 58 | train_df = self.get_pandas_df(train_file, text_col, label_col) 59 | train_examples = [data.Example.fromlist(i, datafields) for i in train_df.values.tolist()] 60 | train_data = data.Dataset(train_examples, datafields) 61 | 62 | test_df = self.get_pandas_df(test_file, text_col, label_col) 63 | test_examples = [data.Example.fromlist(i, datafields) for i in test_df.values.tolist()] 64 | test_data = data.Dataset(test_examples, datafields) 65 | 66 | # If validation file exists, load it. Otherwise get validation data from training data 67 | if val_file: 68 | val_df = self.get_pandas_df(val_file, text_col, label_col) 69 | val_examples = [data.Example.fromlist(i, datafields) for i in val_df.values.tolist()] 70 | val_data = data.Dataset(val_examples, datafields) 71 | else: 72 | train_data, val_data = train_data.split(split_ratio=0.9) 73 | 74 | vectors = None 75 | if w2v_file is not None: 76 | with open(w2v_file, 'rb') as handle: 77 | vectors = pickle.load(handle) 78 | # vectors = Vectors(w2v_file) 79 | # with open('glove.pickle', 'wb') as handle: 80 | # pickle.dump(vectors, handle) 81 | text.build_vocab(train_data, vectors=vectors) 82 | self.word_embeddings = text.vocab.vectors 83 | self.vocab = text.vocab 84 | 85 | self.train_iterator = data.BucketIterator( 86 | train_data, 87 | batch_size=self.batch_size, 88 | sort_key=lambda x: len(x.text), 89 | repeat=False, 90 | shuffle=True) 91 | 92 | self.val_iterator, self.test_iterator = data.BucketIterator.splits( 93 | (val_data, test_data), 94 | batch_size=self.batch_size, 95 | sort_key=lambda x: len(x.text), 96 | repeat=False, 97 | shuffle=False) 98 | 99 | print("Loaded {} training examples".format(len(train_data))) 100 | print("Loaded {} test examples".format(len(test_data))) 101 | print("Loaded {} validation examples".format(len(val_data))) 102 | 103 | 104 | class BertDataset(Dataset): 105 | def __init__(self, filename, sequence_len): 106 | self.df = pd.read_csv(filename, delimiter='\t') 107 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 108 | self.sequence_len = sequence_len 109 | 110 | def __len__(self): 111 | return len(self.df) 112 | 113 | def __getitem__(self, index): 114 | sentence = self.df.loc[index, 'text'] 115 | label = self.df.loc[index, 'label'] 116 | tokens = self.tokenizer.tokenize(sentence) 117 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 118 | if len(tokens) < self.sequence_len: 119 | tokens = tokens + ['[PAD]' for _ in range(self.sequence_len - len(tokens))] 120 | else: 121 | tokens = tokens[:self.sequence_len - 1] + ['[SEP]'] 122 | 123 | tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens) 124 | tokens_ids_tensor = torch.tensor(tokens_ids) 125 | attn_mask = (tokens_ids_tensor != 0).long() 126 | return tokens_ids_tensor, attn_mask, label 127 | -------------------------------------------------------------------------------- /models/text_bert.py: -------------------------------------------------------------------------------- 1 | from transformers.modeling_bert import * 2 | from .config import MODELS 3 | 4 | 5 | class MyBertModel(BertPreTrainedModel): 6 | """ 7 | 8 | The model can behave as an encoder (with only self-attention) as well 9 | as a decoder, in which case a layer of cross-attention is added between 10 | the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani, 11 | Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. 12 | 13 | To behave as an decoder the model needs to be initialized with the 14 | :obj:`is_decoder` argument of the configuration set to :obj:`True`; an 15 | :obj:`encoder_hidden_states` is expected as an input to the forward pass. 16 | 17 | .. _`Attention is all you need`: 18 | https://arxiv.org/abs/1706.03762 19 | 20 | """ 21 | 22 | def __init__(self, config): 23 | super().__init__(config) 24 | self.config = config 25 | 26 | self.embeddings = BertEmbeddings(config) 27 | self.encoder = BertEncoder(config) 28 | self.pooler = BertPooler(config) 29 | 30 | self.init_weights() 31 | 32 | def get_input_embeddings(self): 33 | return self.embeddings.word_embeddings 34 | 35 | def set_input_embeddings(self, value): 36 | self.embeddings.word_embeddings = value 37 | 38 | def _prune_heads(self, heads_to_prune): 39 | """ Prunes heads of the model. 40 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 41 | See base class PreTrainedModel 42 | """ 43 | for layer, heads in heads_to_prune.items(): 44 | self.encoder.layer[layer].attention.prune_heads(heads) 45 | 46 | def _forward_init( 47 | self, 48 | input_ids=None, 49 | attention_mask=None, 50 | token_type_ids=None, 51 | position_ids=None, 52 | head_mask=None, 53 | inputs_embeds=None, 54 | encoder_hidden_states=None, 55 | encoder_attention_mask=None, 56 | ): 57 | if input_ids is not None and inputs_embeds is not None: 58 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 59 | elif input_ids is not None: 60 | input_shape = input_ids.size() 61 | elif inputs_embeds is not None: 62 | input_shape = inputs_embeds.size()[:-1] 63 | else: 64 | raise ValueError("You have to specify either input_ids or inputs_embeds") 65 | 66 | device = input_ids.device if input_ids is not None else inputs_embeds.device 67 | 68 | if attention_mask is None: 69 | attention_mask = torch.ones(input_shape, device=device) 70 | if token_type_ids is None: 71 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 72 | 73 | # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] 74 | # ourselves in which case we just need to make it broadcastable to all heads. 75 | extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) 76 | 77 | # If a 2D ou 3D attention mask is provided for the cross-attention 78 | # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] 79 | if self.config.is_decoder and encoder_hidden_states is not None: 80 | encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() 81 | encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) 82 | if encoder_attention_mask is None: 83 | encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) 84 | encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) 85 | else: 86 | encoder_extended_attention_mask = None 87 | 88 | # Prepare head mask if needed 89 | # 1.0 in head_mask indicate we keep the head 90 | # attention_probs has shape bsz x n_heads x N x N 91 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] 92 | # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] 93 | head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) 94 | 95 | return input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, \ 96 | encoder_hidden_states, encoder_attention_mask, extended_attention_mask, encoder_extended_attention_mask 97 | 98 | def forward( 99 | self, 100 | input_ids=None, 101 | attention_mask=None, 102 | token_type_ids=None, 103 | position_ids=None, 104 | head_mask=None, 105 | inputs_embeds=None, 106 | encoder_hidden_states=None, 107 | encoder_attention_mask=None, 108 | ): 109 | r""" 110 | Return: 111 | :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: 112 | last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): 113 | Sequence of hidden-states at the output of the last layer of the model. 114 | pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`): 115 | Last layer hidden-state of the first token of the sequence (classification token) 116 | further processed by a Linear layer and a Tanh activation function. The Linear 117 | layer weights are trained from the next sentence prediction (classification) 118 | objective during pre-training. 119 | 120 | This output is usually *not* a good summary 121 | of the semantic content of the input, you're often better with averaging or pooling 122 | the sequence of hidden-states for the whole input sequence. 123 | hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): 124 | Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) 125 | of shape :obj:`(batch_size, sequence_length, hidden_size)`. 126 | 127 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 128 | attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): 129 | Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape 130 | :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. 131 | 132 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 133 | heads. 134 | 135 | Examples:: 136 | 137 | from transformers import BertModel, BertTokenizer 138 | import torch 139 | 140 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 141 | model = BertModel.from_pretrained('bert-base-uncased') 142 | 143 | input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1 144 | outputs = model(input_ids) 145 | 146 | last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple 147 | 148 | """ 149 | input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, \ 150 | encoder_attention_mask, extended_attention_mask, encoder_extended_attention_mask = self._forward_init( 151 | input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, 152 | encoder_hidden_states, encoder_attention_mask 153 | ) 154 | 155 | embedding_output = self.embeddings( 156 | input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds 157 | ) 158 | encoder_outputs = self.encoder( 159 | embedding_output, 160 | attention_mask=extended_attention_mask, 161 | head_mask=head_mask, 162 | encoder_hidden_states=encoder_hidden_states, 163 | encoder_attention_mask=encoder_extended_attention_mask, 164 | ) 165 | sequence_output = encoder_outputs[0] 166 | pooled_output = self.pooler(sequence_output) 167 | 168 | outputs = (sequence_output, pooled_output,) + encoder_outputs[ 169 | 1: 170 | ] # add hidden_states and attentions if they are here 171 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 172 | 173 | def forward_mix_embed(self, x1, att1, x2, att2, lam): 174 | x1, attention_mask1, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, \ 175 | encoder_attention_mask, extended_attention_mask1, encoder_extended_attention_mask = self._forward_init( 176 | input_ids=x1, attention_mask=att1) 177 | embedding_output1 = self.embeddings( 178 | input_ids=x1, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds 179 | ) 180 | 181 | x2, attention_mask2, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, \ 182 | encoder_attention_mask, extended_attention_mask2, encoder_extended_attention_mask = self._forward_init( 183 | input_ids=x2, attention_mask=att2) 184 | 185 | embedding_output2 = self.embeddings( 186 | input_ids=x2, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds 187 | ) 188 | 189 | embedding_output = lam * embedding_output1 + (1.0 - lam) * embedding_output2 190 | 191 | # need to take max of both to ensure we don't miss attending to any value 192 | extended_attention_mask = torch.max(extended_attention_mask1, extended_attention_mask2) 193 | encoder_outputs = self.encoder( 194 | embedding_output, 195 | attention_mask=extended_attention_mask, 196 | head_mask=head_mask, 197 | encoder_hidden_states=encoder_hidden_states, 198 | encoder_attention_mask=encoder_extended_attention_mask, 199 | ) 200 | 201 | sequence_output = encoder_outputs[0] 202 | pooled_output = self.pooler(sequence_output) 203 | 204 | outputs = (sequence_output, pooled_output,) + encoder_outputs[ 205 | 1: 206 | ] # add hidden_states and attentions if they are here 207 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 208 | 209 | 210 | class TextBERT(nn.Module): 211 | def __init__(self, pretrained_model, num_class, fine_tune, dropout): 212 | super(TextBERT, self).__init__() 213 | self.output_dim = num_class 214 | self.bert = MyBertModel.from_pretrained(pretrained_model) 215 | # Freeze bert layers 216 | if not fine_tune: 217 | for p in self.bert.parameters(): 218 | p.requires_grad = False 219 | 220 | bert_dim = MODELS[pretrained_model][2] 221 | self.dropout = nn.Dropout(dropout) 222 | self.classifier = nn.Linear(bert_dim, num_class) 223 | 224 | def forward(self, x, attn_masks): 225 | outputs = self.bert(x, attention_mask=attn_masks) 226 | pooled_output = outputs[1] 227 | pooled_output = self.dropout(pooled_output) 228 | logits = self.classifier(pooled_output) 229 | return logits 230 | 231 | def forward_mix_embed(self, x1, att1, x2, att2, lam): 232 | outputs = self.bert.forward_mix_embed(x1, att1, x2, att2, lam) 233 | pooled_output = outputs[1] 234 | pooled_output = self.dropout(pooled_output) 235 | logits = self.classifier(pooled_output) 236 | return logits 237 | 238 | def forward_mix_sent(self, x1, att1, x2, att2, lam): 239 | logits1 = self.forward(x1, att1) 240 | logits2 = self.forward(x2, att2) 241 | y = lam * logits1 + (1.0-lam) * logits2 242 | return y 243 | 244 | def forward_mix_encoder(self, x1, att1, x2, att2, lam): 245 | outputs1 = self.bert(x1, att1) 246 | outputs2 = self.bert(x2, att2) 247 | pooled_output1 = self.dropout(outputs1[1]) 248 | pooled_output2 = self.dropout(outputs2[1]) 249 | pooled_output = lam * pooled_output1 + (1.0-lam) * pooled_output2 250 | y = self.classifier(pooled_output) 251 | return y 252 | -------------------------------------------------------------------------------- /train_bert.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | 12 | import models 13 | from datasets import BertDataset 14 | from utils import get_task_config 15 | from models.config import MODELS 16 | from transformers import * 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Mixup for text classification') 21 | parser.add_argument('--task', default='trec', type=str, help='Task name') 22 | parser.add_argument('--name', default='cnn-text-fine-tune', type=str, help='name of the experiment') 23 | parser.add_argument('--text-column', default='text', type=str, help='text column name of csv file') 24 | parser.add_argument('--label-column', default='label', type=str, help='column column name of csv file') 25 | parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available') 26 | parser.add_argument('--lr', default=1e-5, type=float, help='learning rate') 27 | parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate') 28 | parser.add_argument('--decay', default=0., type=float, help='weight decay') 29 | parser.add_argument('--model', default="bert-base-uncased", type=str, help='pretrained BERT model name') 30 | parser.add_argument('--seed', default=1, type=int, help='random seed') 31 | parser.add_argument('--batch-size', default=50, type=int, help='batch size (default: 128)') 32 | parser.add_argument('--epoch', default=20, type=int, help='total epochs (default: 200)') 33 | parser.add_argument('--fine-tune', default=True, type=lambda x: (str(x).lower() == 'true'), 34 | help='whether to fine-tune embedding or not') 35 | parser.add_argument('--save-path', default='out', type=str, help='output log/result directory') 36 | parser.add_argument('--method', default='none', type=str, help='which mixing method to use (default: none)') 37 | parser.add_argument('--alpha', default=1., type=float, help='mixup interpolation coefficient (default: 1)') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | class Classification: 43 | def __init__(self, args): 44 | self.args = args 45 | 46 | self.use_cuda = args.cuda and torch.cuda.is_available() 47 | 48 | # for reproducibility 49 | torch.manual_seed(args.seed) 50 | torch.backends.cudnn.deterministic = True 51 | torch.backends.cudnn.benchmark = False 52 | np.random.seed(args.seed) 53 | random.seed(args.seed) 54 | 55 | self.config = get_task_config(args.task) 56 | 57 | # data loaders 58 | train_dataset = BertDataset(self.config.train_file, self.config.sequence_len) 59 | test_dataset = BertDataset(self.config.test_file, self.config.sequence_len) 60 | 61 | if self.config.val_file is None: 62 | train_samples = int(len(train_dataset) * 0.9) 63 | val_samples = len(train_dataset) - train_samples 64 | train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [train_samples, val_samples]) 65 | else: 66 | val_dataset = BertDataset(self.config.val_file, self.config.sequence_len) 67 | 68 | self.train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 69 | self.val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) 70 | self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) 71 | 72 | # model 73 | if MODELS[args.model][0] == BertModel: 74 | self.model = models.TextBERT(pretrained_model=args.model, num_class=self.config.num_class, 75 | fine_tune=args.fine_tune, dropout=args.dropout) 76 | 77 | self.device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu') 78 | self.model.to(self.device) 79 | 80 | # logs 81 | os.makedirs(args.save_path, exist_ok=True) 82 | self.model_save_path = os.path.join(args.save_path, args.name + '_weights.pt') 83 | self.log_path = os.path.join(args.save_path, args.name + '_logs.csv') 84 | print(str(args)) 85 | with open(self.log_path, 'a') as f: 86 | f.write(str(args) + '\n') 87 | with open(self.log_path, 'a', newline='') as out: 88 | writer = csv.writer(out) 89 | writer.writerow(['mode', 'epoch', 'step', 'loss', 'acc']) 90 | 91 | # optimizer 92 | self.criterion = nn.CrossEntropyLoss() 93 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.decay) 94 | 95 | # for early stopping 96 | self.best_val_acc = 0 97 | self.early_stop = False 98 | self.val_patience = 0 # successive iteration when validation acc did not improve 99 | 100 | self.iteration_number = 0 101 | 102 | def get_perm(self, x): 103 | """get random permutation""" 104 | batch_size = x.size()[0] 105 | if self.use_cuda: 106 | index = torch.randperm(batch_size).cuda() 107 | else: 108 | index = torch.randperm(batch_size) 109 | return index 110 | 111 | def mixup_criterion_cross_entropy(self, pred, y_a, y_b, lam): 112 | return lam * self.criterion(pred, y_a) + (1 - lam) * self.criterion(pred, y_b) 113 | 114 | def test(self, loader): 115 | self.model.eval() 116 | test_loss = 0 117 | total = 0 118 | correct = 0 119 | with torch.no_grad(): 120 | for x, att, y in loader: 121 | x, y, att = x.to(self.device), y.to(self.device), att.to(self.device) 122 | y_pred = self.model(x, att) 123 | loss = self.criterion(y_pred, y) 124 | test_loss += loss.item() * y.shape[0] 125 | total += y.shape[0] 126 | correct += torch.sum(torch.argmax(y_pred, dim=1) == y).item() 127 | 128 | avg_loss = test_loss / total 129 | acc = 100.0 * correct / total 130 | return avg_loss, acc 131 | 132 | def train(self, epoch): 133 | self.model.train() 134 | train_loss = 0 135 | total = 0 136 | correct = 0 137 | for x, att, y in self.train_loader: 138 | x, y, att = x.to(self.device), y.to(self.device), att.to(self.device) 139 | y_pred = self.model(x, att) 140 | loss = self.criterion(y_pred, y) 141 | train_loss += loss.item() * y.shape[0] 142 | total += y.shape[0] 143 | correct += torch.sum(torch.argmax(y_pred, dim=1) == y).item() 144 | 145 | self.optimizer.zero_grad() 146 | loss.backward() 147 | self.optimizer.step() 148 | 149 | # eval 150 | self.iteration_number += 1 151 | if self.iteration_number % self.config.eval_interval == 0: 152 | avg_loss = train_loss / total 153 | acc = 100.0 * correct / total 154 | # print('Train loss: {}, Train acc: {}'.format(avg_loss, acc)) 155 | train_loss = 0 156 | total = 0 157 | correct = 0 158 | 159 | val_loss, val_acc = self.test(self.val_loader) 160 | # print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) 161 | if val_acc > self.best_val_acc: 162 | torch.save(self.model.state_dict(), self.model_save_path) 163 | self.best_val_acc = val_acc 164 | self.val_patience = 0 165 | else: 166 | self.val_patience += 1 167 | if self.val_patience == self.config.patience: 168 | self.early_stop = True 169 | return 170 | with open(self.log_path, 'a', newline='') as out: 171 | writer = csv.writer(out) 172 | writer.writerow(['train', epoch, self.iteration_number, avg_loss, acc]) 173 | writer.writerow(['val', epoch, self.iteration_number, val_loss, val_acc]) 174 | self.model.train() 175 | 176 | def train_mixup(self, epoch): 177 | self.model.train() 178 | train_loss = 0 179 | total = 0 180 | correct = 0 181 | for x, att, y in self.train_loader: 182 | x, y, att = x.to(self.device), y.to(self.device), att.to(self.device) 183 | lam = np.random.beta(self.args.alpha, self.args.alpha) 184 | index = self.get_perm(x) 185 | x1 = x[index] 186 | y1 = y[index] 187 | att1 = att[index] 188 | 189 | if self.args.method == 'embed': 190 | y_pred = self.model.forward_mix_embed(x, att, x1, att1, lam) 191 | elif self.args.method == 'sent': 192 | y_pred = self.model.forward_mix_sent(x, att, x1, att1, lam) 193 | elif self.args.method == 'encoder': 194 | y_pred = self.model.forward_mix_encoder(x, att, x1, att1, lam) 195 | else: 196 | raise ValueError('invalid method name') 197 | 198 | loss = self.mixup_criterion_cross_entropy(y_pred, y, y1, lam) 199 | train_loss += loss.item() * y.shape[0] 200 | total += y.shape[0] 201 | _, predicted = torch.max(y_pred.data, 1) 202 | correct += ((lam * predicted.eq(y.data).cpu().sum().float() 203 | + (1 - lam) * predicted.eq(y1.data).cpu().sum().float())).item() 204 | 205 | self.optimizer.zero_grad() 206 | loss.backward() 207 | self.optimizer.step() 208 | 209 | # eval 210 | self.iteration_number += 1 211 | if self.iteration_number % self.config.eval_interval == 0: 212 | avg_loss = train_loss / total 213 | acc = 100.0 * correct / total 214 | # print('Train loss: {}, Train acc: {}'.format(avg_loss, acc)) 215 | train_loss = 0 216 | total = 0 217 | correct = 0 218 | 219 | val_loss, val_acc = self.test(self.val_loader) 220 | # print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) 221 | if val_acc > self.best_val_acc: 222 | torch.save(self.model.state_dict(), self.model_save_path) 223 | self.best_val_acc = val_acc 224 | self.val_patience = 0 225 | else: 226 | self.val_patience += 1 227 | if self.val_patience == self.config.patience: 228 | self.early_stop = True 229 | return 230 | with open(self.log_path, 'a', newline='') as out: 231 | writer = csv.writer(out) 232 | writer.writerow(['train', epoch, self.iteration_number, avg_loss, acc]) 233 | writer.writerow(['val', epoch, self.iteration_number, val_loss, val_acc]) 234 | self.model.train() 235 | 236 | def run(self): 237 | for epoch in range(self.args.epoch): 238 | print('------------------------------------- Epoch {} -------------------------------------'.format(epoch)) 239 | if self.args.method == 'none': 240 | self.train(epoch) 241 | else: 242 | self.train_mixup(epoch) 243 | if self.early_stop: 244 | break 245 | print('Training complete!') 246 | print('Best Validation Acc: ', self.best_val_acc) 247 | 248 | self.model.load_state_dict(torch.load(self.model_save_path)) 249 | train_loss, train_acc = self.test(self.train_loader) 250 | val_loss, val_acc = self.test(self.val_loader) 251 | test_loss, test_acc = self.test(self.test_loader) 252 | 253 | with open(self.log_path, 'a', newline='') as out: 254 | writer = csv.writer(out) 255 | writer.writerow(['train', -1, -1, train_loss, train_acc]) 256 | writer.writerow(['val', -1, -1, val_loss, val_acc]) 257 | writer.writerow(['test', -1, -1, test_loss, test_acc]) 258 | 259 | print('Train loss: {}, Train acc: {}'.format(train_loss, train_acc)) 260 | print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) 261 | print('Test loss: {}, Test acc: {}'.format(test_loss, test_acc)) 262 | 263 | return val_acc, test_acc 264 | 265 | 266 | if __name__ == '__main__': 267 | args = parse_args() 268 | num_runs = args.num_runs 269 | 270 | test_acc = [] 271 | val_acc = [] 272 | 273 | for i in range(num_runs): 274 | cls = Classification(args) 275 | val, test = cls.run() 276 | val_acc.append(val) 277 | test_acc.append(test) 278 | args.seed += 1 279 | 280 | with open(os.path.join(args.save_path, args.name + '_result.txt', 'a')) as f: 281 | f.write(str(args)) 282 | f.write('val acc:' + str(val_acc) + '\n') 283 | f.write('test acc:' + str(test_acc) + '\n') 284 | f.write('mean val acc:' + str(np.mean(val_acc)) + '\n') 285 | f.write('std val acc:' + str(np.std(val_acc, ddof=1)) + '\n') 286 | f.write('mean test acc:' + str(np.mean(test_acc)) + '\n') 287 | f.write('std test acc:' + str(np.std(test_acc, ddof=1)) + '\n\n\n') 288 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.nn as nn 10 | from tqdm import tqdm 11 | 12 | import models 13 | from datasets import WordDataset 14 | from utils import get_task_config 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='Mixup for text classification') 19 | parser.add_argument('--task', default='trec', type=str, help='Task name') 20 | parser.add_argument('--name', default='cnn-text-fine-tune', type=str, help='name of the experiment') 21 | parser.add_argument('--text-column', default='text', type=str, help='text column name of csv file') 22 | parser.add_argument('--label-column', default='label', type=str, help='column column name of csv file') 23 | parser.add_argument('--w2v-file', default=None, type=str, help='word embedding file') 24 | parser.add_argument('--cuda', default=True, type=lambda x: (str(x).lower() == 'true'), help='use cuda if available') 25 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 26 | parser.add_argument('--dropout', default=0.5, type=float, help='dropout rate') 27 | parser.add_argument('--decay', default=0., type=float, help='weight decay') 28 | parser.add_argument('--model', default="TextCNN", type=str, help='model type (default: TextCNN)') 29 | parser.add_argument('--seed', default=1, type=int, help='random seed') 30 | parser.add_argument('--batch-size', default=50, type=int, help='batch size (default: 128)') 31 | parser.add_argument('--epoch', default=50, type=int, help='total epochs (default: 200)') 32 | parser.add_argument('--fine-tune', default=True, type=lambda x: (str(x).lower() == 'true'), 33 | help='whether to fine-tune embedding or not') 34 | parser.add_argument('--method', default='embed', type=str, help='which mixing method to use (default: none)') 35 | parser.add_argument('--alpha', default=1., type=float, help='mixup interpolation coefficient (default: 1)') 36 | parser.add_argument('--save-path', default='out', type=str, help='output log/result directory') 37 | parser.add_argument('--num-runs', default=10, type=int, help='number of runs') 38 | args = parser.parse_args() 39 | return args 40 | 41 | 42 | def mixup_criterion_cross_entropy(criterion, pred, y_a, y_b, lam): 43 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 44 | 45 | 46 | class Classification: 47 | def __init__(self, args): 48 | self.args = args 49 | 50 | self.use_cuda = args.cuda and torch.cuda.is_available() 51 | 52 | # for reproducibility 53 | torch.manual_seed(args.seed) 54 | torch.backends.cudnn.deterministic = True 55 | torch.backends.cudnn.benchmark = False 56 | np.random.seed(args.seed) 57 | random.seed(args.seed) 58 | 59 | self.config = get_task_config(args.task) 60 | 61 | # data loaders 62 | dataset = WordDataset(self.config.sequence_len, args.batch_size) 63 | dataset.load_data(self.config.train_file, self.config.test_file, self.config.val_file, args.w2v_file, 64 | args.text_column, args.label_column) 65 | self.train_iterator = dataset.train_iterator 66 | self.val_iterator = dataset.val_iterator 67 | self.test_iterator = dataset.test_iterator 68 | 69 | # model 70 | vocab_size = len(dataset.vocab) 71 | self.model = models.__dict__[args.model](vocab_size=vocab_size, sequence_len=self.config.sequence_len, 72 | num_class=self.config.num_class, 73 | word_embeddings=dataset.word_embeddings, fine_tune=args.fine_tune, 74 | dropout=args.dropout) 75 | self.device = torch.device('cuda' if (args.cuda and torch.cuda.is_available()) else 'cpu') 76 | self.model.to(self.device) 77 | 78 | # logs 79 | os.makedirs(args.save_path, exist_ok=True) 80 | self.model_save_path = os.path.join(args.save_path, args.name + '_weights.pt') 81 | self.log_path = os.path.join(args.save_path, args.name + '_logs.csv') 82 | print(str(args)) 83 | with open(self.log_path, 'a') as f: 84 | f.write(str(args) + '\n') 85 | with open(self.log_path, 'a', newline='') as out: 86 | writer = csv.writer(out) 87 | writer.writerow(['mode', 'epoch', 'step', 'loss', 'acc']) 88 | 89 | # optimizer 90 | self.criterion = nn.CrossEntropyLoss() 91 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.decay) 92 | 93 | # for early stopping 94 | self.best_val_acc = 0 95 | self.early_stop = False 96 | self.val_patience = 0 # successive iteration when validation acc did not improve 97 | 98 | self.iteration_number = 0 99 | 100 | def get_perm(self, x): 101 | """get random permutation""" 102 | batch_size = x.size()[0] 103 | if self.use_cuda: 104 | index = torch.randperm(batch_size).cuda() 105 | else: 106 | index = torch.randperm(batch_size) 107 | return index 108 | 109 | def test(self, iterator): 110 | self.model.eval() 111 | test_loss = 0 112 | total = 0 113 | correct = 0 114 | with torch.no_grad(): 115 | # for _, batch in tqdm(enumerate(iterator), total=len(iterator), desc='test'): 116 | for _, batch in enumerate(iterator): 117 | x = batch.text 118 | y = batch.label 119 | x, y = x.to(self.device), y.to(self.device) 120 | y_pred = self.model(x) 121 | loss = self.criterion(y_pred, y) 122 | test_loss += loss.item() * y.shape[0] 123 | total += y.shape[0] 124 | correct += torch.sum(torch.argmax(y_pred, dim=1) == y).item() 125 | 126 | avg_loss = test_loss / total 127 | acc = 100.0 * correct / total 128 | return avg_loss, acc 129 | 130 | def train(self, epoch): 131 | self.model.train() 132 | train_loss = 0 133 | total = 0 134 | correct = 0 135 | # for _, batch in tqdm(enumerate(self.train_iterator), total=len(self.train_iterator), desc='train'): 136 | for _, batch in enumerate(self.train_iterator): 137 | x = batch.text 138 | y = batch.label 139 | x, y = x.to(self.device), y.to(self.device) 140 | y_pred = self.model(x) 141 | loss = self.criterion(y_pred, y) 142 | train_loss += loss.item() * y.shape[0] 143 | total += y.shape[0] 144 | correct += torch.sum(torch.argmax(y_pred, dim=1) == y).item() 145 | 146 | self.optimizer.zero_grad() 147 | loss.backward() 148 | self.optimizer.step() 149 | 150 | # eval 151 | self.iteration_number += 1 152 | if self.iteration_number % self.config.eval_interval == 0: 153 | avg_loss = train_loss / total 154 | acc = 100.0 * correct / total 155 | # print('Train loss: {}, Train acc: {}'.format(avg_loss, acc)) 156 | train_loss = 0 157 | total = 0 158 | correct = 0 159 | 160 | val_loss, val_acc = self.test(iterator=self.val_iterator) 161 | # print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) 162 | if val_acc > self.best_val_acc: 163 | torch.save(self.model.state_dict(), self.model_save_path) 164 | self.best_val_acc = val_acc 165 | self.val_patience = 0 166 | else: 167 | self.val_patience += 1 168 | if self.val_patience == self.config.patience: 169 | self.early_stop = True 170 | return 171 | with open(self.log_path, 'a', newline='') as out: 172 | writer = csv.writer(out) 173 | writer.writerow(['train', epoch, self.iteration_number, avg_loss, acc]) 174 | writer.writerow(['val', epoch, self.iteration_number, val_loss, val_acc]) 175 | self.model.train() 176 | 177 | def train_mixup(self, epoch): 178 | self.model.train() 179 | train_loss = 0 180 | total = 0 181 | correct = 0 182 | # for _, batch in tqdm(enumerate(self.train_iterator), total=len(self.train_iterator), desc='train'): 183 | for _, batch in enumerate(self.train_iterator): 184 | x = batch.text 185 | y = batch.label 186 | x, y = x.to(self.device), y.to(self.device) 187 | lam = np.random.beta(self.args.alpha, self.args.alpha) 188 | index = self.get_perm(x) 189 | x1 = x[:, index] 190 | y1 = y[index] 191 | 192 | if self.args.method == 'embed': 193 | y_pred = self.model.forward_mix_embed(x, x1, lam) 194 | elif self.args.method == 'sent': 195 | y_pred = self.model.forward_mix_sent(x, x1, lam) 196 | elif self.args.method == 'encoder': 197 | y_pred = self.model.forward_mix_encoder(x, x1, lam) 198 | else: 199 | raise ValueError('invalid method name') 200 | 201 | loss = mixup_criterion_cross_entropy(self.criterion, y_pred, y, y1, lam) 202 | train_loss += loss.item() * y.shape[0] 203 | total += y.shape[0] 204 | _, predicted = torch.max(y_pred.data, 1) 205 | correct += ((lam * predicted.eq(y.data).cpu().sum().float() 206 | + (1 - lam) * predicted.eq(y1.data).cpu().sum().float())).item() 207 | 208 | self.optimizer.zero_grad() 209 | loss.backward() 210 | self.optimizer.step() 211 | 212 | # eval 213 | self.iteration_number += 1 214 | if self.iteration_number % self.config.eval_interval == 0: 215 | avg_loss = train_loss / total 216 | acc = 100.0 * correct / total 217 | # print('Train loss: {}, Train acc: {}'.format(avg_loss, acc)) 218 | train_loss = 0 219 | total = 0 220 | correct = 0 221 | 222 | val_loss, val_acc = self.test(iterator=self.val_iterator) 223 | # print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) 224 | if val_acc > self.best_val_acc: 225 | torch.save(self.model.state_dict(), self.model_save_path) 226 | self.best_val_acc = val_acc 227 | self.val_patience = 0 228 | else: 229 | self.val_patience += 1 230 | if self.val_patience == self.config.patience: 231 | self.early_stop = True 232 | return 233 | with open(self.log_path, 'a', newline='') as out: 234 | writer = csv.writer(out) 235 | writer.writerow(['train', epoch, self.iteration_number, avg_loss, acc]) 236 | writer.writerow(['val', epoch, self.iteration_number, val_loss, val_acc]) 237 | self.model.train() 238 | 239 | def run(self): 240 | for epoch in range(self.args.epoch): 241 | print('------------------------------------- Epoch {} -------------------------------------'.format(epoch)) 242 | if self.args.method == 'none': 243 | self.train(epoch) 244 | else: 245 | self.train_mixup(epoch) 246 | if self.early_stop: 247 | break 248 | print('Training complete!') 249 | print('Best Validation Acc: ', self.best_val_acc) 250 | 251 | self.model.load_state_dict(torch.load(self.model_save_path)) 252 | # train_loss, train_acc = self.test(self.train_iterator) 253 | val_loss, val_acc = self.test(self.val_iterator) 254 | test_loss, test_acc = self.test(self.test_iterator) 255 | 256 | with open(self.log_path, 'a', newline='') as out: 257 | writer = csv.writer(out) 258 | # writer.writerow(['train', -1, -1, train_loss, train_acc]) 259 | writer.writerow(['val', -1, -1, val_loss, val_acc]) 260 | writer.writerow(['test', -1, -1, test_loss, test_acc]) 261 | 262 | # print('Train loss: {}, Train acc: {}'.format(train_loss, train_acc)) 263 | print('Val loss: {}, Val acc: {}'.format(val_loss, val_acc)) 264 | print('Test loss: {}, Test acc: {}'.format(test_loss, test_acc)) 265 | 266 | return val_acc, test_acc 267 | 268 | 269 | if __name__ == '__main__': 270 | args = parse_args() 271 | num_runs = args.num_runs 272 | 273 | test_acc = [] 274 | val_acc = [] 275 | 276 | for i in range(num_runs): 277 | cls = Classification(args) 278 | val, test = cls.run() 279 | val_acc.append(val) 280 | test_acc.append(test) 281 | args.seed += 1 282 | 283 | with open(os.path.join(args.save_path, args.name + '_result.txt', 'a')) as f: 284 | f.write(str(args)) 285 | f.write('val acc:' + str(val_acc) + '\n') 286 | f.write('test acc:' + str(test_acc) + '\n') 287 | f.write('mean val acc:' + str(np.mean(val_acc)) + '\n') 288 | f.write('std val acc:' + str(np.std(val_acc, ddof=1)) + '\n') 289 | f.write('mean test acc:' + str(np.mean(test_acc)) + '\n') 290 | f.write('std test acc:' + str(np.std(test_acc, ddof=1)) + '\n\n\n') 291 | --------------------------------------------------------------------------------