├── .gitignore ├── README.md ├── config.py ├── data ├── readme.md ├── rt-polarity.neg └── rt-polarity.pos ├── data_pro.py ├── encoder.py ├── main.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | *pyc 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 测试Bert/ELMo词向量 2 | - 任务: 文本分类 (text classification) 3 | - 数据集: 电影评论情感分类 4 | - 模型 word embeddings + encoder: 5 | - word embeddings: 6 | - Bert 7 | - ELMo 8 | - GloVe 9 | - encoder: 10 | - CNN+MaxPooling 11 | - RNN+Last Hidden States 12 | - PositionEncoding+Transformer+Average Pooling 13 | - Average all words 14 | 15 | 博客总结:[Bert/ELMo文本分类](http://shomy.top/2020/07/06/bert-elmo-cls/) 16 | 17 | 18 | ## 使用方法 Usage 19 | - 环境: 20 | - python3.6+ 21 | - pytorch 1.4+ 22 | - transformers 23 | - AllenNLP 24 | - sklearn 25 | - fire 26 | - 克隆代码到本地, 依据`data/readme.md`说明 下载Bert/ELMo/GloVe的词向量文件 27 | - 运行代码: 28 | ``` 29 | python main.py train --emb_method='elmo' --enc_method='cnn' 30 | ``` 31 | - 可配置项: 32 | - emb_method: [elmo, glove, bert] 33 | - enc_method: [cnn, rnn, transformer, mean] 34 | 35 | 其余可参看`config.py`, 比如使用`--gpu_id=1`来指定使用的GPU 36 | 37 | ## 结果 38 | 39 | 运行环境: 40 | - GPU: 1080Ti 41 | - CPU: E5 2680 V4 42 | 43 | 此外,实验中我们将Bert与ELMo的参数固定住. 44 | 45 | | Embedding | Encoder | Acc | Second | 46 | | - | - | - | - | 47 | | **Bert** | MEAN | 0.8031 | 17.98s | 48 | | **Bert** | CNN | 0.8397 | 18.35s | 49 | | **Bert** | RNN | 0.8444 | 18.93s | 50 | | **Bert** | Transformer | **0.8472** | 20.95s | 51 | | *ELMo* | Mean | 0.7572 | 25.05s | 52 | | *ELMo* | CNN | 0.8172 | 25.53s | 53 | | *ELMo* | RNN | **0.8219** | 27.18s | 54 | | *ELMo* | Transformer | 0.8051 | 26.209 | 55 | | GloVe | Mean | 0.8003 | 0.60s | 56 | | GloVe | CNN | 0.8031 | 0.76s | 57 | | GloVe | RNN | **0.8219** | 1.45s | 58 | | GloVe | Transformer | 0.8153 | 1.71s | 59 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | class Config(): 5 | 6 | # ELMo 7 | elmo_options_file = "./data/elmo/elmo_2x2048_256_2048cnn_1xhighway_options.json" 8 | elmo_weight_file = "./data/elmo/elmo_2x2048_256_2048cnn_1xhighway_weights.hdf5" 9 | elmo_dim = 512 10 | 11 | # Bert 12 | bert_path = './data/bert/' 13 | bert_dim = 768 14 | 15 | # glove 16 | vocab_size = 18766 17 | glove_dim = 300 18 | glove_file = "./data/glove/glove_300d.npy" 19 | word2id_file = "./data/glove/word2id.npy" 20 | 21 | emb_method = 'glove' # bert/elmo/glove/ 22 | enc_method = 'CNN' # CNN/RNN/Transformer/mean 23 | hidden_size = 200 24 | out_size = 64 25 | num_labels = 2 26 | 27 | use_gpu = True 28 | seed = 2020 29 | gpu_id = 0 30 | 31 | dropout = 0.5 32 | epochs = 20 33 | 34 | test_size = 0.1 35 | lr = 1e-3 36 | weight_decay = 1e-4 37 | batch_size = 64 38 | device = "cuda:0" 39 | 40 | 41 | def parse(self, kwargs): 42 | ''' 43 | user can update the default hyperparamter 44 | ''' 45 | for k, v in kwargs.items(): 46 | if not hasattr(self, k): 47 | raise Exception('opt has No key: {}'.format(k)) 48 | setattr(self, k, v) 49 | 50 | print('*************************************************') 51 | print('user config:') 52 | for k, v in self.__class__.__dict__.items(): 53 | if not k.startswith('__'): 54 | print("{} => {}".format(k, getattr(self, k))) 55 | 56 | print('*************************************************') 57 | 58 | 59 | Config.parse = parse 60 | opt = Config() 61 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | ## datset 2 | 3 | - Please first download the Bert/ELMo/GloVe pretrained weightfile from the links: 4 | - [google drive](https://drive.google.com/file/d/1qlteRE6grcOw53EooMvewOwZH4I9I2ca/view?usp=sharing) 5 | - [百度云盘](https://pan.baidu.com/s/1BIF37GnTPpfTxGIha4oQQg), 提取码:2ljr 6 | - unzip in the dir: 7 | ``` 8 | data/ 9 | ELMo/ 10 | Bert/ 11 | GloVe/ 12 | readme.md 13 | rt-polarity.pos 14 | rt-polarity.neg 15 | ``` 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /data_pro.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import re 4 | import os 5 | import sys 6 | import numpy as np 7 | import pickle 8 | 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class Data(Dataset): 13 | def __init__(self, x, y): 14 | self.data = list(zip(x, y)) 15 | 16 | def __len__(self): 17 | return len(self.data) 18 | 19 | def __getitem__(self, idx): 20 | assert idx < len(self) 21 | return self.data[idx] 22 | 23 | 24 | def clean_str(string): 25 | """ 26 | Tokenization/string cleaning for all datasets except for SST. 27 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 28 | """ 29 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 30 | string = re.sub(r"\'s", " \'s", string) 31 | string = re.sub(r"\'ve", " \'ve", string) 32 | string = re.sub(r"n\'t", " n\'t", string) 33 | string = re.sub(r"\'re", " \'re", string) 34 | string = re.sub(r"\'d", " \'d", string) 35 | string = re.sub(r"\'ll", " \'ll", string) 36 | string = re.sub(r",", " , ", string) 37 | string = re.sub(r"!", " ! ", string) 38 | string = re.sub(r"\(", " \( ", string) 39 | string = re.sub(r"\)", " \) ", string) 40 | string = re.sub(r"\?", " \? ", string) 41 | string = re.sub(r"\s{2,}", " ", string) 42 | return string.strip().lower() 43 | 44 | 45 | def extract_vocab(positive_data_file, negative_data_file): 46 | ''' 47 | extract vocab from txt 48 | ''' 49 | positive_examples = list(open(positive_data_file, "r", encoding='utf-8').readlines()) 50 | positive_examples = [s.strip() for s in positive_examples] 51 | negative_examples = list(open(negative_data_file, "r", encoding='utf-8').readlines()) 52 | negative_examples = [s.strip() for s in negative_examples] 53 | x_text = positive_examples + negative_examples 54 | x_text = [clean_str(sent) for sent in x_text] 55 | x_text = list(map(lambda x: x.split(), x_text)) 56 | 57 | vocab = [] 58 | for line in x_text: 59 | vocab.extend(line) 60 | 61 | vocab = list(set(vocab)) 62 | print("vocab size: {}.".format(len(vocab))) 63 | open("./data/glove/vocab.txt", "w").write("\n".join(vocab)) 64 | 65 | 66 | def get_glove(w2v_path, vocab_path): 67 | 68 | vocab = {j.strip(): i for i, j in enumerate(open(vocab_path), 0)} 69 | id2word = {vocab[i]: i for i in vocab} 70 | 71 | dim = 0 72 | w2v = {} 73 | for line in open(w2v_path): 74 | line = line.strip().split() 75 | word = line[0] 76 | vec = list(map(float, line[1:])) 77 | dim = len(vec) 78 | w2v[word] = vec 79 | 80 | vecs = [] 81 | vecs.append(np.random.uniform(low=-1.0, high=1.0, size=dim)) 82 | 83 | hit = 0 84 | for i in range(1, len(vocab) - 1): 85 | if id2word[i] in w2v: 86 | hit += 1 87 | vecs.append(w2v[id2word[i]]) 88 | else: 89 | vecs.append(vecs[0]) 90 | vecs.append(np.zeros(dim)) 91 | assert(len(vecs) == len(vocab)) 92 | 93 | print("vocab size: {}, dim: {}; hit in glove:{}".format(len(vocab), dim, hit)) 94 | np.save("./data/glove/glove_{}d.npy".format(dim), np.array(vecs, dtype=np.float32)) 95 | np.save("./data/glove/word2id.npy", vocab) 96 | np.save("./data/glove/id2word.npy", id2word) 97 | 98 | 99 | def load_data_and_labels(positive_data_file, negative_data_file): 100 | """ 101 | Loads MR polarity data from files, splits the data into words and generates labels. 102 | Returns split sentences and labels. 103 | """ 104 | # Load data from files 105 | positive_examples = list(open(positive_data_file, "r", encoding='utf-8').readlines()) 106 | positive_examples = [s.strip() for s in positive_examples] 107 | negative_examples = list(open(negative_data_file, "r", encoding='utf-8').readlines()) 108 | negative_examples = [s.strip() for s in negative_examples] 109 | # Split by words 110 | x_text = positive_examples + negative_examples 111 | x_text = [clean_str(sent) for sent in x_text] 112 | x_text = list(map(lambda x: x.split(), x_text)) 113 | # Generate labels 114 | positive_labels = [1 for _ in positive_examples] 115 | negative_labels = [0 for _ in negative_examples] 116 | y = np.array(positive_labels + negative_labels) 117 | return [x_text, y] 118 | 119 | 120 | if __name__ == "__main__": 121 | import fire 122 | fire.Fire() 123 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class Encoder(nn.Module): 10 | 11 | def __init__(self, enc_method, input_size, hidden_size, out_size): 12 | ''' 13 | input_size 14 | hidden_size: the output size of CNN/RNN/TR 15 | outpu_size: the final size of the encoder (after pooling) 16 | w 17 | CNN: 18 | - filters_num: feature_dim 19 | - filter_size: 3 20 | - pooling: max_pooling 21 | RNN: 22 | - hidden_size: feature_dim // 2 23 | - pooling: last hidden status 24 | Transformer 25 | - nhead: 2 26 | - nlayer: 1 27 | - pooling: average 28 | ------- 29 | ''' 30 | super(Encoder, self).__init__() 31 | self.enc_method = enc_method.lower() 32 | if self.enc_method == 'cnn': 33 | self.conv = nn.Conv2d(1, hidden_size, (3, input_size)) 34 | nn.init.xavier_uniform_(self.conv.weight) 35 | nn.init.constant_(self.conv.bias, 0.0) 36 | f_dim = hidden_size 37 | elif self.enc_method == 'rnn': 38 | self.rnn = nn.GRU(input_size, hidden_size//2, batch_first=True, bidirectional=True) 39 | f_dim = hidden_size 40 | elif self.enc_method == 'transformer': 41 | self.pe = PositionEmbedding(input_size, 512) 42 | self.layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=1) 43 | self.tr = nn.TransformerEncoder(self.layer, num_layers=1) 44 | f_dim = input_size 45 | else: 46 | f_dim = input_size 47 | 48 | self.fc = nn.Linear(f_dim, out_size) 49 | nn.init.uniform_(self.fc.weight, -0.5, 0.5) 50 | nn.init.uniform_(self.fc.bias, -0.1, 0.1) 51 | 52 | def forward(self, inputs): 53 | if self.enc_method == 'cnn': 54 | x = inputs.unsqueeze(1) 55 | x = F.relu(self.conv(x).squeeze(3)) 56 | out = x.permute(0, 2, 1) 57 | elif self.enc_method == 'rnn': 58 | out, _ = self.rnn(inputs) 59 | elif self.enc_method == 'transformer': 60 | inputs = self.pe(inputs) 61 | out = self.tr(inputs.permute(1, 0, 2)).permute(1, 0, 2) 62 | else: 63 | out = inputs 64 | return self.fc(out.mean(1)) 65 | 66 | 67 | class PositionEmbedding(nn.Module): 68 | def __init__(self, d_model, max_len): 69 | super(PositionEmbedding, self).__init__() 70 | self.pe = nn.Embedding(max_len, d_model) 71 | nn.init.uniform_(self.pe.weight, -0.1, 0.1) 72 | 73 | def forward(self, x): 74 | b, l, d = x.size() 75 | seq_len = torch.arange(l).to(x.device) 76 | return x + self.pe(seq_len).unsqueeze(0) 77 | 78 | 79 | # performance poor 80 | class PositionalEncoding(nn.Module): 81 | def __init__(self, d_model, max_len): 82 | super(PositionalEncoding, self).__init__() 83 | pe = torch.zeros(max_len, d_model) 84 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 85 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 86 | pe[:, 0::2] = torch.sin(position * div_term) 87 | pe[:, 1::2] = torch.cos(position * div_term) 88 | pe = pe.unsqueeze(0) 89 | self.register_buffer('pe', pe) 90 | 91 | def forward(self, x): 92 | x = x + self.pe[:, :x.size(1)] 93 | return x 94 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | import random 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from sklearn.model_selection import train_test_split 10 | from torch.utils.data import DataLoader 11 | 12 | from data_pro import load_data_and_labels, Data 13 | from model import Model 14 | from config import opt 15 | 16 | 17 | def now(): 18 | return str(time.strftime('%Y-%m-%d %H:%M:%S')) 19 | 20 | 21 | def collate_fn(batch): 22 | data, label = zip(*batch) 23 | return data, label 24 | 25 | 26 | def train(**kwargs): 27 | 28 | opt.parse(kwargs) 29 | device = torch.device("cuda:{}".format(opt.gpu_id) if torch.cuda.is_available() else "cpu") 30 | opt.device = device 31 | 32 | random.seed(opt.seed) 33 | np.random.seed(opt.seed) 34 | torch.manual_seed(opt.seed) 35 | if opt.use_gpu: 36 | torch.cuda.manual_seed_all(opt.seed) 37 | 38 | x_text, y = load_data_and_labels("./data/rt-polarity.pos", "./data/rt-polarity.neg") 39 | x_train, x_test, y_train, y_test = train_test_split(x_text, y, test_size=opt.test_size) 40 | 41 | train_data = Data(x_train, y_train) 42 | test_data = Data(x_test, y_test) 43 | train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True, collate_fn=collate_fn) 44 | test_loader = DataLoader(test_data, batch_size=opt.batch_size, shuffle=False, collate_fn=collate_fn) 45 | 46 | print(f"{now()} train data: {len(train_data)}, test data: {len(test_data)}") 47 | 48 | model = Model(opt) 49 | print(f"{now()} {opt.emb_method} init model finished") 50 | 51 | if opt.use_gpu: 52 | model.to(device) 53 | 54 | criterion = nn.CrossEntropyLoss() 55 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay) 56 | lr_sheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.7) 57 | best_acc = -0.1 58 | best_epoch = -1 59 | start_time = time.time() 60 | for epoch in range(1, opt.epochs): 61 | total_loss = 0.0 62 | model.train() 63 | for step, batch_data in enumerate(train_loader): 64 | x, labels = batch_data 65 | labels = torch.LongTensor(labels) 66 | if opt.use_gpu: 67 | labels = labels.to(device) 68 | optimizer.zero_grad() 69 | output = model(x) 70 | loss = criterion(output, labels) 71 | loss.backward() 72 | optimizer.step() 73 | 74 | total_loss += loss.item() 75 | acc = test(model, test_loader) 76 | if acc > best_acc: 77 | best_acc = acc 78 | best_epoch = epoch 79 | print(f"{now()} Epoch{epoch}: loss: {total_loss}, test_acc: {acc}") 80 | lr_sheduler.step() 81 | 82 | end_time = time.time() 83 | print("*"*20) 84 | print(f"{now()} finished; epoch {best_epoch} best_acc: {best_acc}, time/epoch: {(end_time-start_time)/opt.epochs}") 85 | 86 | 87 | def test(model, test_loader): 88 | correct = 0 89 | num = 0 90 | model.eval() 91 | with torch.no_grad(): 92 | for data in test_loader: 93 | x, labels = data 94 | num += len(labels) 95 | output = model(x) 96 | labels = torch.LongTensor(labels) 97 | if opt.use_gpu: 98 | output = output.cpu() 99 | predict = torch.max(output.data, 1)[1] 100 | correct += (predict == labels).sum().item() 101 | model.train() 102 | return correct * 1.0 / num 103 | 104 | 105 | if __name__ == "__main__": 106 | import fire 107 | fire.Fire() 108 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from allennlp.modules.elmo import Elmo, batch_to_ids 4 | import torch 5 | import torch.nn as nn 6 | from transformers import AutoModel, AutoTokenizer 7 | import numpy as np 8 | 9 | from encoder import Encoder 10 | 11 | 12 | class Model(nn.Module): 13 | def __init__(self, opt): 14 | 15 | super(Model, self).__init__() 16 | self.opt = opt 17 | self.use_gpu = self.opt.use_gpu 18 | 19 | if opt.emb_method == 'elmo': 20 | self.init_elmo() 21 | elif self.opt.emb_method == 'glove': 22 | self.init_glove() 23 | elif self.opt.emb_method == 'bert': 24 | self.init_bert() 25 | 26 | self.encoder = Encoder(opt.enc_method, self.word_dim, opt.hidden_size, opt.out_size) 27 | self.cls = nn.Linear(opt.out_size, opt.num_labels) 28 | nn.init.uniform_(self.cls.weight, -0.1, 0.1) 29 | nn.init.uniform_(self.cls.bias, -0.1, 0.1) 30 | self.dropout = nn.Dropout(self.opt.dropout) 31 | 32 | def forward(self, x): 33 | if self.opt.emb_method == 'elmo': 34 | word_embs = self.get_elmo(x) 35 | elif self.opt.emb_method == 'glove': 36 | word_embs = self.get_glove(x) 37 | elif self.opt.emb_method == 'bert': 38 | word_embs = self.get_bert(x) 39 | 40 | x = self.encoder(word_embs) 41 | x = self.dropout(x) 42 | x = self.cls(x) # batch_size * num_label 43 | return x 44 | 45 | def init_bert(self): 46 | ''' 47 | initilize the Bert model 48 | ''' 49 | self.bert_tokenizer = AutoTokenizer.from_pretrained(self.opt.bert_path) 50 | self.bert = AutoModel.from_pretrained(self.opt.bert_path) 51 | for param in self.bert.parameters(): 52 | param.requires_grad = False 53 | self.word_dim = self.opt.bert_dim 54 | 55 | def init_elmo(self): 56 | ''' 57 | initilize the ELMo model 58 | ''' 59 | self.elmo = Elmo(self.opt.elmo_options_file, self.opt.elmo_weight_file, 1) 60 | for param in self.elmo.parameters(): 61 | param.requires_grad = False 62 | self.word_dim = self.opt.elmo_dim 63 | 64 | def init_glove(self): 65 | ''' 66 | load the GloVe model 67 | ''' 68 | self.word2id = np.load(self.opt.word2id_file, allow_pickle=True).tolist() 69 | self.glove = nn.Embedding(self.opt.vocab_size, self.opt.glove_dim) 70 | emb = torch.from_numpy(np.load(self.opt.glove_file, allow_pickle=True)) 71 | if self.use_gpu: 72 | emb = emb.to(self.opt.device) 73 | self.glove.weight.data.copy_(emb) 74 | self.word_dim = self.opt.glove_dim 75 | 76 | def get_bert(self, sentence_lists): 77 | ''' 78 | get the ELMo word embedding vectors for a sentences 79 | ''' 80 | sentence_lists = [' '.join(x) for x in sentence_lists] 81 | ids = self.bert_tokenizer(sentence_lists, padding=True, return_tensors="pt") 82 | inputs = ids['input_ids'] 83 | if self.opt.use_gpu: 84 | inputs = inputs.to(self.opt.device) 85 | 86 | embeddings = self.bert(inputs) 87 | return embeddings[0] 88 | 89 | def get_elmo(self, sentence_lists): 90 | ''' 91 | get the ELMo word embedding vectors for a sentences 92 | ''' 93 | character_ids = batch_to_ids(sentence_lists) 94 | if self.opt.use_gpu: 95 | character_ids = character_ids.to(self.opt.device) 96 | embeddings = self.elmo(character_ids) 97 | return embeddings['elmo_representations'][0] 98 | 99 | def get_glove(self, sentence_lists): 100 | ''' 101 | get the glove word embedding vectors for a sentences 102 | ''' 103 | max_len = max(map(lambda x: len(x), sentence_lists)) 104 | sentence_lists = list(map(lambda x: list(map(lambda w: self.word2id.get(w, 0), x)), sentence_lists)) 105 | sentence_lists = list(map(lambda x: x + [self.opt.vocab_size-1] * (max_len - len(x)), sentence_lists)) 106 | sentence_lists = torch.LongTensor(sentence_lists) 107 | if self.use_gpu: 108 | sentence_lists = sentence_lists.to(self.opt.device) 109 | embeddings = self.glove(sentence_lists) 110 | 111 | return embeddings 112 | 113 | --------------------------------------------------------------------------------