├── .gitignore ├── bert_sentiment ├── __init__.py ├── data.py └── train.py ├── requirements.txt ├── run.py ├── LICENSE └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | **__pycache__** 2 | .vscode 3 | -------------------------------------------------------------------------------- /bert_sentiment/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data, train 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | pytorch-transformers 3 | pytreebank 4 | loguru 5 | tqdm 6 | click 7 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import click 4 | 5 | CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) 6 | 7 | 8 | @click.command(context_settings=CONTEXT_SETTINGS) 9 | @click.option( 10 | "-c", 11 | "--bert-config", 12 | default="bert-large-uncased", 13 | help="Pretrained BERT configuration", 14 | ) 15 | @click.option("-b", "--binary", is_flag=True, help="Use binary labels, ignore neutrals") 16 | @click.option("-r", "--root", is_flag=True, help="Use only root nodes of SST") 17 | @click.option( 18 | "-s", "--save", is_flag=True, help="Save the model files after every epoch" 19 | ) 20 | def main(bert_config, binary, root, save): 21 | """Train BERT sentiment classifier.""" 22 | from bert_sentiment.train import train 23 | 24 | train(binary=binary, root=root, bert=bert_config, save=save) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Manish Munikar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-grained Sentiment Classification using BERT 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/fine-grained-sentiment-classification-using/sentiment-analysis-on-sst-5-fine-grained)](https://paperswithcode.com/sota/sentiment-analysis-on-sst-5-fine-grained?p=fine-grained-sentiment-classification-using) 4 | 5 | This repo contains the code that was used to obtain the results of the paper [Fine-grained Sentiment Classification using BERT](https://arxiv.org/abs/1910.03474). 6 | 7 | ## Usage 8 | 9 | Experiments for various configuration can be run using the `run.py`. First of all, install the python packages (preferably in a clean virtualenv): `pip install -r requirements.txt` 10 | 11 | ``` 12 | Usage: run.py [OPTIONS] 13 | 14 | Train BERT sentiment classifier. 15 | 16 | Options: 17 | -c, --bert-config TEXT Pretrained BERT configuration 18 | -b, --binary Use binary labels, ignore neutrals 19 | -r, --root Use only root nodes of SST 20 | -s, --save Save the model files after every epoch 21 | -h, --help Show this message and exit. 22 | ``` 23 | 24 | For example, to run the experiment for binary labels and root nodes, run: 25 | 26 | python3 run.py -rb 27 | 28 | -------------------------------------------------------------------------------- /bert_sentiment/data.py: -------------------------------------------------------------------------------- 1 | """This module defines a configurable SSTDataset class.""" 2 | 3 | import pytreebank 4 | import torch 5 | from loguru import logger 6 | from pytorch_transformers import BertTokenizer 7 | from torch.utils.data import Dataset 8 | 9 | logger.info("Loading the tokenizer") 10 | tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") 11 | 12 | logger.info("Loading SST") 13 | sst = pytreebank.load_sst() 14 | 15 | 16 | def rpad(array, n=70): 17 | """Right padding.""" 18 | current_len = len(array) 19 | if current_len > n: 20 | return array[: n - 1] 21 | extra = n - current_len 22 | return array + ([0] * extra) 23 | 24 | 25 | def get_binary_label(label): 26 | """Convert fine-grained label to binary label.""" 27 | if label < 2: 28 | return 0 29 | if label > 2: 30 | return 1 31 | raise ValueError("Invalid label") 32 | 33 | 34 | class SSTDataset(Dataset): 35 | """Configurable SST Dataset. 36 | 37 | Things we can configure: 38 | - split (train / val / test) 39 | - root / all nodes 40 | - binary / fine-grained 41 | """ 42 | 43 | def __init__(self, split="train", root=True, binary=True): 44 | """Initializes the dataset with given configuration. 45 | 46 | Args: 47 | split: str 48 | Dataset split, one of [train, val, test] 49 | root: bool 50 | If true, only use root nodes. Else, use all nodes. 51 | binary: bool 52 | If true, use binary labels. Else, use fine-grained. 53 | """ 54 | logger.info(f"Loading SST {split} set") 55 | self.sst = sst[split] 56 | 57 | logger.info("Tokenizing") 58 | if root and binary: 59 | self.data = [ 60 | ( 61 | rpad( 62 | tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66 63 | ), 64 | get_binary_label(tree.label), 65 | ) 66 | for tree in self.sst 67 | if tree.label != 2 68 | ] 69 | elif root and not binary: 70 | self.data = [ 71 | ( 72 | rpad( 73 | tokenizer.encode("[CLS] " + tree.to_lines()[0] + " [SEP]"), n=66 74 | ), 75 | tree.label, 76 | ) 77 | for tree in self.sst 78 | ] 79 | elif not root and not binary: 80 | self.data = [ 81 | (rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66), label) 82 | for tree in self.sst 83 | for label, line in tree.to_labeled_lines() 84 | ] 85 | else: 86 | self.data = [ 87 | ( 88 | rpad(tokenizer.encode("[CLS] " + line + " [SEP]"), n=66), 89 | get_binary_label(label), 90 | ) 91 | for tree in self.sst 92 | for label, line in tree.to_labeled_lines() 93 | if label != 2 94 | ] 95 | 96 | def __len__(self): 97 | return len(self.data) 98 | 99 | def __getitem__(self, index): 100 | X, y = self.data[index] 101 | X = torch.tensor(X) 102 | return X, y 103 | -------------------------------------------------------------------------------- /bert_sentiment/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from loguru import logger 5 | from pytorch_transformers import BertConfig, BertForSequenceClassification 6 | from tqdm import tqdm 7 | 8 | from .data import SSTDataset 9 | 10 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 11 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | def train_one_epoch(model, lossfn, optimizer, dataset, batch_size=32): 15 | generator = torch.utils.data.DataLoader( 16 | dataset, batch_size=batch_size, shuffle=True 17 | ) 18 | model.train() 19 | train_loss, train_acc = 0.0, 0.0 20 | for batch, labels in tqdm(generator): 21 | batch, labels = batch.to(device), labels.to(device) 22 | optimizer.zero_grad() 23 | loss, logits = model(batch, labels=labels) 24 | err = lossfn(logits, labels) 25 | loss.backward() 26 | optimizer.step() 27 | 28 | train_loss += loss.item() 29 | pred_labels = torch.argmax(logits, axis=1) 30 | train_acc += (pred_labels == labels).sum().item() 31 | train_loss /= len(dataset) 32 | train_acc /= len(dataset) 33 | return train_loss, train_acc 34 | 35 | 36 | def evaluate_one_epoch(model, lossfn, optimizer, dataset, batch_size=32): 37 | generator = torch.utils.data.DataLoader( 38 | dataset, batch_size=batch_size, shuffle=True 39 | ) 40 | model.eval() 41 | loss, acc = 0.0, 0.0 42 | with torch.no_grad(): 43 | for batch, labels in tqdm(generator): 44 | batch, labels = batch.to(device), labels.to(device) 45 | logits = model(batch)[0] 46 | error = lossfn(logits, labels) 47 | loss += error.item() 48 | pred_labels = torch.argmax(logits, axis=1) 49 | acc += (pred_labels == labels).sum().item() 50 | loss /= len(dataset) 51 | acc /= len(dataset) 52 | return loss, acc 53 | 54 | 55 | def train( 56 | root=True, 57 | binary=False, 58 | bert="bert-large-uncased", 59 | epochs=30, 60 | batch_size=32, 61 | save=False, 62 | ): 63 | trainset = SSTDataset("train", root=root, binary=binary) 64 | devset = SSTDataset("dev", root=root, binary=binary) 65 | testset = SSTDataset("test", root=root, binary=binary) 66 | 67 | config = BertConfig.from_pretrained(bert) 68 | if not binary: 69 | config.num_labels = 5 70 | model = BertForSequenceClassification.from_pretrained(bert, config=config) 71 | 72 | model = model.to(device) 73 | lossfn = torch.nn.CrossEntropyLoss() 74 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) 75 | 76 | for epoch in range(1, epochs): 77 | train_loss, train_acc = train_one_epoch( 78 | model, lossfn, optimizer, trainset, batch_size=batch_size 79 | ) 80 | val_loss, val_acc = evaluate_one_epoch( 81 | model, lossfn, optimizer, devset, batch_size=batch_size 82 | ) 83 | test_loss, test_acc = evaluate_one_epoch( 84 | model, lossfn, optimizer, testset, batch_size=batch_size 85 | ) 86 | logger.info(f"epoch={epoch}") 87 | logger.info( 88 | f"train_loss={train_loss:.4f}, val_loss={val_loss:.4f}, test_loss={test_loss:.4f}" 89 | ) 90 | logger.info( 91 | f"train_acc={train_acc:.3f}, val_acc={val_acc:.3f}, test_acc={test_acc:.3f}" 92 | ) 93 | if save: 94 | label = "binary" if binary else "fine" 95 | nodes = "root" if root else "all" 96 | torch.save(model, f"{bert}__{nodes}__{label}__e{epoch}.pickle") 97 | 98 | logger.success("Done!") 99 | --------------------------------------------------------------------------------