├── README.md ├── models.py ├── utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-elmo-classification 2 | text classification using ELMO 3 | 4 | Require: allennlp, pytorch 0.4.1 5 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | import os 6 | import numpy as np 7 | from allennlp.modules.elmo import Elmo 8 | 9 | class SimpleELMOClassifier(nn.Module): 10 | def __init__(self, label_size, use_gpu, dropout=0.5): 11 | super(SimpleELMOClassifier, self).__init__() 12 | self.use_gpu = use_gpu 13 | self.dropout = dropout 14 | options_file = "elmo_2x4096_512_2048cnn_2xhighway_options.json" 15 | weight_file = "elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5" 16 | self.elmo = Elmo(options_file, weight_file, 1, dropout=dropout, do_layer_norm=False) 17 | # elmo output 18 | # Dict with keys: 19 | # ``'elmo_representations'``: ``List[torch.Tensor]`` 20 | # A ``num_output_representations`` list of ELMo representations for the input sequence. 21 | # Each representation is shape ``(batch_size, timesteps, embedding_dim)`` 22 | # ``'mask'``: ``torch.Tensor`` 23 | # Shape ``(batch_size, timesteps)`` long tensor with sequence mask. 24 | self.conv1 = nn.Conv1d(1024, 16, 3) 25 | self.p1 = nn.AdaptiveMaxPool1d(128) 26 | self.activation_func = nn.ReLU6() 27 | self.dropout_l = nn.Dropout(dropout) 28 | self.hidden2label = nn.Linear(2048, label_size) 29 | 30 | def init_weights(self): 31 | for name, param in self.hidden2label.named_parameters(): 32 | if 'bias' in name: 33 | nn.init.constant_(param, 0.0) 34 | elif 'weight' in name: 35 | nn.init.xavier_uniform_(param) 36 | for name, param in self.conv1.named_parameters(): 37 | if 'bias' in name: 38 | nn.init.constant_(param, 0.0) 39 | elif 'weight' in name: 40 | nn.init.xavier_uniform_(param) 41 | 42 | def forward(self, sentences): 43 | elmo_out = self.elmo(sentences) 44 | x = elmo_out['elmo_representations'][0] 45 | x = x.transpose(1,2) 46 | x = self.conv1(x) 47 | x = self.activation_func(x) 48 | x = self.p1(x) 49 | x = x.view(-1, 2048) 50 | x = self.dropout_l(x) 51 | y = self.hidden2label(x) 52 | return y 53 | 54 | def load_models(load_path, model_args, suffix='', on_gpu=False): 55 | classifier = SimpleELMOClassifier(model_args['label_size'], \ 56 | on_gpu, \ 57 | model_args['dropout']) 58 | 59 | print('Loading models from', load_path) 60 | cls_path = os.path.join(load_path, "classifier_model{}.pt".format(suffix)) 61 | classifier.load_state_dict(torch.load(cls_path)) 62 | return classifier 63 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | from itertools import dropwhile 6 | import collections 7 | import six 8 | from allennlp.modules.elmo import batch_to_ids 9 | 10 | def rindex(lst, item): 11 | def index_ne(x): 12 | return lst[x] != item 13 | try: 14 | return next(dropwhile(index_ne, reversed(range(len(lst))))) 15 | except StopIteration: 16 | raise ValueError("rindex(lst, item): item not in list") 17 | 18 | class Dictionary(object): 19 | def __init__(self): 20 | self.word2idx = {} 21 | self.idx2word = {} 22 | def __len__(self): 23 | return len(self.word2idx) 24 | 25 | class SpaceTokenizer(object): 26 | def __init__(self): 27 | super(SpaceTokenizer, self).__init__() 28 | def tokenize(self, sent): 29 | return sent.split(' ') 30 | 31 | class Corpus(object): 32 | def __init__(self, path, maxlen, lowercase=False, max_lines=-1, \ 33 | test_size=-1, train_path='train.txt', test_path='test.txt', tokenizer=None, \ 34 | label_dict=None): 35 | 36 | self.label_dict = label_dict 37 | self.maxlen = maxlen 38 | self.lowercase = lowercase 39 | self.train_path = os.path.join(path, train_path) 40 | self.test_path = os.path.join(path, test_path) 41 | self.max_lines = max_lines 42 | self.tokenizer = tokenizer if tokenizer else SpaceTokenizer() 43 | self.train = self.tokenize(self.train_path) 44 | if test_size > 0 and len(test_path) > 0: 45 | print("Test size and test path cannot both be present!") 46 | exit() 47 | if test_size > 0: 48 | print("Using {} in training set as test set".format(test_size)) 49 | self.train, self.test = self.train[:-test_size], self.train[-test_size:] 50 | return 51 | elif len(test_path) > 0: 52 | print("Using {} as test set".format(test_path)) 53 | self.test = self.tokenize(self.test_path) 54 | 55 | def tokenize(self, path): 56 | """Tokenizes a text file.""" 57 | if self.label_dict: 58 | print("Convert class names in label_dict") 59 | 60 | cropped = 0. 61 | 62 | with open(path, 'r') as f: 63 | linecount = 0 64 | lines = [] 65 | tags = [] 66 | for line in f.readlines(): 67 | linecount += 1 68 | if linecount % 10000 == 0: print("Read line", linecount, end='\r') 69 | if self.max_lines > 1 and linecount >= self.max_lines: 70 | break 71 | if self.lowercase: 72 | line = line.lower().strip().split('\t') 73 | else: 74 | line = line.strip().split('\t') 75 | tag, sent = line[0], line[1] 76 | sent = self.tokenizer.tokenize(sent) 77 | if len(sent) > self.maxlen: 78 | cropped += 1 79 | words = sent[:self.maxlen] 80 | if linecount == 2: print(words) 81 | lines.append(words) 82 | # Convert class label to int 83 | if self.label_dict: 84 | tag = self.label_dict[tag] 85 | tags.append(tag) 86 | oov_count = -1 87 | # oov_count = print([(1 if ii==unk_idx else 0) for l in lines for ii in l]) 88 | print("\nNumber of sentences cropped in {}: {:.0f} out of {:.0f} total, OOV {:.0f}". 89 | format(path, cropped, linecount, oov_count)) 90 | 91 | return list(zip(tags, lines)) 92 | 93 | 94 | def batchify(data, bsz, shuffle=False, gpu=False): 95 | if shuffle: 96 | random.shuffle(data) 97 | tags, sents = zip(*data) 98 | nbatch = (len(sents)+bsz-1) // bsz 99 | # downsample biggest class 100 | # sents, tags = balance_tags(sents, tags) 101 | 102 | for i in range(nbatch): 103 | 104 | batch = sents[i*bsz:(i+1)*bsz] 105 | batch_tags = tags[i*bsz:(i+1)*bsz] 106 | # lengths = [len(x) for x in batch] 107 | # sort items by length (decreasing) 108 | # batch, batch_tags, lengths = length_sort(batch, batch_tags, lengths) 109 | 110 | # Pad batches to maximum sequence length in batch 111 | # find length to pad to 112 | 113 | # maxlen = lengths[0] 114 | # for b_i in range(len(batch)): 115 | # pads = [pad_id] * (maxlen-len(batch[b_i])) 116 | # batch[b_i] = batch[b_i] + pads 117 | # batch = torch.tensor(batch).long() 118 | batch = batch_to_ids(batch) 119 | batch_tags = torch.tensor(batch_tags).long() 120 | # lengths = [torch.tensor(l).long() for l in lengths] 121 | 122 | # yield (batch, batch_tags, lengths) 123 | yield (batch, batch_tags) 124 | 125 | def filter_flip_polarity(data): 126 | flipped = [] 127 | tags, sents = zip(*data) 128 | 129 | for i in range(len(tags)): 130 | org_tag = tags[i] 131 | sent = sents[i] 132 | if org_tag == 1: new_tag = 0 133 | if org_tag == 0: new_tag = 1 134 | flipped.append((new_tag, sent)) 135 | print("Filtered and flipped {} sents from {} sents.".format(len(flipped), len(data))) 136 | return flipped 137 | 138 | def length_sort(items, tags, lengths, descending=True): 139 | """In order to use pytorch variable length sequence package""" 140 | old_items = list(zip(items, tags, lengths)) 141 | old_items.sort(key=lambda x: x[2], reverse=True) 142 | items, tags, lengths = zip(*old_items) 143 | return list(items), list(tags), list(lengths) 144 | 145 | def balance_tags(items, tags): 146 | """Downsample largest group of tags""" 147 | new_items = [] 148 | new_tags = [] 149 | 150 | biggest_class = 2 151 | drop_ratio = .666 152 | for i in range(len(items)): 153 | tag = tags[i] 154 | item = items[i] 155 | if tag == biggest_class: 156 | if random.random() < drop_ratio: 157 | continue 158 | new_items.append(item) 159 | new_tags.append(tag) 160 | return new_items, new_tags 161 | 162 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | import numpy as np 8 | import random 9 | import sys 10 | import json 11 | import collections 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | 18 | from utils import Corpus, batchify 19 | from models import SimpleELMOClassifier 20 | 21 | 22 | parser = argparse.ArgumentParser(description='Text') 23 | # Path Arguments 24 | parser.add_argument('--data_path', type=str, required=True, 25 | help='location of the corpus') 26 | parser.add_argument('--out_dir', type=str, default='output', 27 | help='output directory name') 28 | parser.add_argument('--checkpoint', type=str, default='', 29 | help='load checkpoint') 30 | 31 | # Data Processing Arguments 32 | parser.add_argument('--maxlen', type=int, default=128, 33 | help='maximum sentence length') 34 | parser.add_argument('--lowercase', action='store_true', 35 | help='lowercase all text') 36 | 37 | # Model Arguments 38 | parser.add_argument('--dropout', type=float, default=0.2, 39 | help='dropout applied to layers (0 = no dropout)') 40 | 41 | # Training Arguments 42 | parser.add_argument('--epochs', type=int, default=10, 43 | help='maximum number of epochs') 44 | parser.add_argument('--batch_size', type=int, default=200, 45 | help='batch size') 46 | parser.add_argument('--lr', type=float, default=2e-4, 47 | help='learning rate') 48 | parser.add_argument('--clip', type=float, default=5., 49 | help='gradient clipping, max norm') 50 | 51 | # Evaluation Arguments 52 | parser.add_argument('--log_interval', type=int, default=200, 53 | help='interval to log training results') 54 | 55 | # Other 56 | parser.add_argument('--seed', type=int, default=1337, 57 | help='random seed') 58 | parser.add_argument('--no_cuda', action='store_true', 59 | help='do not use CUDA') 60 | 61 | args = parser.parse_args() 62 | print(vars(args)) 63 | 64 | def save_model(model, suffix=''): 65 | print("Saving model") 66 | with open('{}/model_{}.pt'.format(args.out_dir, suffix), 'wb') as f: 67 | torch.save(model.state_dict(), f) 68 | 69 | ############################################################################### 70 | # Eval code 71 | ############################################################################### 72 | 73 | def evaluate(model, data): 74 | # Turn on evaluation mode which disables dropout. 75 | model.eval() 76 | all_accuracies = 0. 77 | nbatches = 0. 78 | 79 | for batch in data: 80 | nbatches += 1. 81 | # source, tags, lengths = batch 82 | source, tags = batch 83 | if args.cuda: 84 | source = source.to("cuda") 85 | tags = tags.to("cuda") 86 | # output = model(source, lengths) 87 | output = model(source) 88 | max_vals, max_indices = torch.max(output, -1) 89 | 90 | accuracy = torch.mean(max_indices.eq(tags).float()).item() 91 | all_accuracies += accuracy 92 | return all_accuracies/nbatches 93 | 94 | def train_classifier(args, classifier, train_batch, optimizer_, criterion_ce): 95 | classifier.train() 96 | classifier.zero_grad() 97 | # source, tags, lengths = train_batch 98 | source, tags = train_batch 99 | if args.cuda: 100 | source = source.to("cuda") 101 | tags = tags.to("cuda") 102 | 103 | # output: batch x nclasses 104 | # output = classifier(source, lengths) 105 | output = classifier(source) 106 | c_loss = criterion_ce(output, tags) 107 | 108 | c_loss.backward() 109 | 110 | # `clip_grad_norm` to prevent exploding gradient in RNNs / LSTMs 111 | torch.nn.utils.clip_grad_norm_(classifier.parameters(), args.clip) 112 | optimizer_.step() 113 | 114 | total_loss = c_loss.item() 115 | 116 | # probs = F.softmax(output, dim=-1) 117 | # max_vals, max_indices = torch.max(probs, -1) 118 | # accuracy = torch.mean(max_indices.eq(tags).float()).item() 119 | 120 | return total_loss 121 | 122 | def main(_): 123 | # make output directory if it doesn't already exist 124 | args.out_dir = os.path.join('.', args.out_dir) 125 | if not os.path.isdir(args.out_dir): 126 | os.makedirs(args.out_dir) 127 | 128 | # Set the random seed manually for reproducibility. 129 | random.seed(args.seed) 130 | np.random.seed(args.seed) 131 | torch.manual_seed(args.seed) 132 | if torch.cuda.is_available(): 133 | if args.no_cuda: 134 | print("WARNING: You have a CUDA device not used.") 135 | else: 136 | torch.cuda.manual_seed(args.seed) 137 | args.cuda = not args.no_cuda 138 | 139 | ############################################################################### 140 | # Load data 141 | ############################################################################### 142 | label_dict = dict([ 143 | ('1', 0), 144 | ('2', 1), 145 | ]) 146 | nclasses = len(label_dict) 147 | # create corpus 148 | corpus = Corpus(args.data_path, 149 | maxlen=args.maxlen, 150 | lowercase=args.lowercase, 151 | max_lines=-1, 152 | test_size=0, 153 | train_path='train.txt', 154 | test_path='test.txt', 155 | # tokenizer=tokenizer, 156 | label_dict=label_dict, 157 | ) 158 | args.nclasses = nclasses 159 | # save arguments 160 | with open('{}/args.json'.format(args.out_dir), 'w') as f: 161 | json.dump(vars(args), f) 162 | 163 | eval_batch_size = 20 164 | 165 | # Print corpus stats 166 | class_counts = collections.Counter([c[0] for c in corpus.train]) 167 | print("Train: {}".format(class_counts)) 168 | class_counts = collections.Counter([c[0] for c in corpus.test]) 169 | print("Test: {}".format(class_counts)) 170 | 171 | train_data = batchify(corpus.train, args.batch_size, shuffle=True) 172 | test_data = batchify(corpus.test, eval_batch_size, shuffle=False) 173 | 174 | print("Loaded data!") 175 | 176 | ############################################################################### 177 | # Build the models 178 | ############################################################################### 179 | 180 | classifier = SimpleELMOClassifier(label_size=args.nclasses, use_gpu=args.cuda, dropout=args.dropout,) 181 | # print(classifier) 182 | 183 | def count_parameters(model): 184 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 185 | 186 | print("Num params:", count_parameters(classifier)) 187 | # optimizer = optim.Adam(classifier.parameters(), lr=args.lr) 188 | # optimizer = optim.RMSprop(classifier.parameters(), lr=args.lr) 189 | optimizer = optim.Adam(classifier.parameters(), lr=args.lr, weight_decay=1e-4) 190 | learning_rate_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9) 191 | criterion_ce = nn.CrossEntropyLoss() 192 | 193 | if args.cuda: 194 | classifier.to("cuda") 195 | criterion_ce = criterion_ce.to("cuda") 196 | 197 | classifier.init_weights() 198 | if args.checkpoint: 199 | classifier.load_state_dict(torch.load(args.checkpoint)) 200 | 201 | print("Training...") 202 | with open("{}/logs.txt".format(args.out_dir), 'w') as f: 203 | f.write('Training...\n') 204 | niter_global = 0 205 | for epoch in range(1, args.epochs+1): 206 | print("Epoch ", epoch) 207 | 208 | # loop through all batches in training data 209 | for train_batch in train_data: 210 | loss = train_classifier(args, classifier, train_batch, optimizer, criterion_ce) 211 | niter_global += 1 212 | if niter_global % 10 == 0: 213 | msg = 'loss {:.5f}'.format(loss) 214 | print(msg, end='\r') 215 | with open("{}/logs.txt".format(args.out_dir), 'a') as f: 216 | f.write(msg) 217 | f.write('\n') 218 | f.flush() 219 | 220 | if niter_global % 1000 == 0: 221 | with torch.no_grad(): 222 | accuracy = evaluate(classifier, test_data) 223 | msg = 'test acc {:.4f}'.format(accuracy) 224 | # msg = 'test loss {:.5f} acc {:.2f}'.format(test_loss, accuracy) 225 | print('\n' + msg) 226 | with open("{}/logs.txt".format(args.out_dir), 'a') as f: 227 | f.write(msg) 228 | f.write('\n') 229 | f.flush() 230 | # save_model(classifier, suffix=niter_global) 231 | # print("Saved model step {}".format(niter_global)) 232 | # we use generator, so must re-gen test data 233 | test_data = batchify(corpus.test, eval_batch_size, shuffle=False) 234 | 235 | # end of epoch ---------------------------- 236 | # save model every epoch 237 | save_model(classifier, suffix=epoch) 238 | print("saved model epoch {}".format(epoch)) 239 | 240 | # clear cache between epoch 241 | torch.cuda.empty_cache() 242 | # decay learning rate 243 | learning_rate_scheduler.step() 244 | # shuffle between epochs 245 | train_data = batchify(corpus.train, args.batch_size, shuffle=True) 246 | 247 | 248 | if __name__ == "__main__": 249 | main(1) 250 | --------------------------------------------------------------------------------