├── models ├── __init__.py ├── TextCNN.py └── TextAttnBiLSTM.py ├── requirements.txt ├── data └── README.md ├── config.py ├── datasets.py ├── main.py ├── .gitignore ├── README.md ├── preprocess.py └── utils.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==0.24.2 2 | torch==1.1.0 3 | fire==0.1.3 4 | numpy==1.22.0 5 | gensim==3.7.3 6 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | Put data into here 2 | * pretrained word vectors 3 | * SST datasets 4 | 5 | For example 6 | ``` 7 | ├─data 8 | │ │ glove.6B.300d.txt 9 | │ │ GoogleNews-vectors-negative300.bin 10 | │ └─stanfordSentimentTreebank 11 | ``` -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-12 21:53:17 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-13 17:58:30 6 | @Description: 7 | ''' 8 | 9 | class Config(object): 10 | ''' 11 | 全局配置参数 12 | ''' 13 | status = 'train' # 执行 train_eval or test, 默认执行train_eval 14 | use_model = 'TextCNN' # 使用何种模型, 默认使用TextCNN 15 | output_folder = 'output_data/' # 已处理的数据所在文件夹 16 | data_name = 'SST-2' # SST-1(fine-grained) SST-2(binary) 17 | SST_path = 'data/stanfordSentimentTreebank/' # 数据集所在路径 18 | emb_file = 'data/glove.6B.300d.txt' # 预训练词向量所在路径 19 | emb_format = 'glove' # embedding format: word2vec/glove 20 | min_word_freq = 1 # 最小词频 21 | max_len = 40 # 采样最大长度 -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-09 14:22:50 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-13 16:40:09 6 | @Description: 7 | ''' 8 | import torch 9 | import pandas as pd 10 | from torch.utils.data import Dataset 11 | 12 | class SSTreebankDataset(Dataset): 13 | ''' 14 | 创建dataloader 15 | ''' 16 | 17 | def __init__(self, data_name, output_folder, split): 18 | ''' 19 | :param output_folder: 数据文件所在路径 20 | :param split: 'train', 'dev', or 'test' 21 | ''' 22 | self.split = split 23 | assert self.split in {'train', 'dev', 'test'} 24 | 25 | self.dataset = pd.read_csv(output_folder + data_name + '_' + split + '.csv') 26 | 27 | self.dataset_size = len(self.dataset) 28 | 29 | def __getitem__(self, i): 30 | 31 | sentence = torch.LongTensor(eval(self.dataset.iloc[i]['token_idx'])) # sentence shape [max_len] 32 | sentence_label = self.dataset.iloc[i]['sentiment_label'] 33 | 34 | return sentence, sentence_label 35 | 36 | def __len__(self): 37 | 38 | return self.dataset_size -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-12 21:05:31 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-16 16:45:07 6 | @Description: 7 | ''' 8 | import fire 9 | import models 10 | from config import Config 11 | from models.TextCNN import ModelCNN 12 | from models.TextAttnBiLSTM import ModelAttnBiLSTM 13 | 14 | 15 | def run(**kwargs): 16 | 17 | global_opt = Config() 18 | 19 | 20 | for k,v in kwargs.items(): 21 | if getattr(global_opt, k, 'KeyError') != 'KeyError': 22 | setattr(global_opt, k, v) 23 | 24 | if global_opt.use_model == 'TextCNN': 25 | 26 | model_opt = models.TextCNN.ModelConfig() 27 | 28 | for k,v in kwargs.items(): 29 | if getattr(model_opt, k,'KeyError') != 'KeyError': 30 | setattr(model_opt, k, v) 31 | 32 | if global_opt.status == 'train': 33 | models.TextCNN.train_eval(model_opt) 34 | 35 | elif global_opt.status == 'test': 36 | models.TextCNN.test(model_opt) 37 | 38 | elif global_opt.use_model == 'TextAttnBiLSTM': 39 | 40 | model_opt = models.TextAttnBiLSTM.ModelConfig() 41 | 42 | for k,v in kwargs.items(): 43 | if getattr(model_opt, k,'KeyError') != 'KeyError': 44 | setattr(model_opt, k, v) 45 | 46 | if global_opt.status == 'train': 47 | models.TextAttnBiLSTM.train_eval(model_opt) 48 | 49 | elif global_opt.status == 'test': 50 | models.TextAttnBiLSTM.test(model_opt) 51 | 52 | 53 | if __name__ == "__main__": 54 | 55 | fire.Fire() 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Text-Classification-PyTorch :whale2: 2 | 3 | Here is a new boy :bow: who wants to become a NLPer and his repository for Text Classification. Besides TextCNN and TextAttnBiLSTM, more models will be added in the near future. 4 | 5 | Thanks for you Star:star:, Fork and Watch! 6 | 7 | ## Dataset 8 | 9 | * [Stanford Sentiment Treebank(SST)](nlp.stanford.edu/sentiment/code.html) 10 | * SST-1: 5 classes(fine-grained), SST-2: 2 classes(binary) 11 | * Preprocess 12 | * Map sentiment values to labels 13 | * Remove tokens consisting of all non-alphanumeric characters, such as `...` 14 | 15 | ## Pre-trained Word Vectors 16 | 17 | * [Word2Vec](https://code.google.com/archive/p/word2vec/) : `GoogleNews-vectors-negative300.bin` 18 | * [GloVe](https://nlp.stanford.edu/projects/glove/) : `glove.840B.300d.txt` 19 | * Because the OOV Rate of *GloVe* is lower than *Word2Vec* and the experiment performance is also better than the other one, we use *GloVe* as pre-trained word vectors. 20 | * Options for different format word vectors are still preserved in the code. 21 | 22 | ## Model 23 | 24 | * TextCNN 25 | 26 | * Paper: [Convolutional Neural Networks for Sentence Classification](https://www.aclweb.org/anthology/D14-1181) 27 | * See:`models/TextCNN.py` 28 | 29 | ![](https://ws1.sinaimg.cn/large/72cf269fly1g6229o5a47j20m609c74t.jpg) 30 | 31 | * TextAttnBiLSTM 32 | 33 | * Paper: [Attention-Based Bidirection LSTM for Text Classification](https://www.aclweb.org/anthology/P16-2034) 34 | * See: `models/TextAttnBiLSTM.py` 35 | 36 | ![](https://ws1.sinaimg.cn/large/72cf269fly1g622af7rxij20la0axq3g.jpg) 37 | 38 | ## Result 39 | 40 | * Baseline from the paper 41 | 42 | | model | SST-1 | SST-2 | 43 | | ---------------- | -------- | -------- | 44 | | CNN-rand | 45.0 | 82.7 | 45 | | CNN-static | 45.5 | 86.8 | 46 | | CNN-non-static | **48.0** | 87.2 | 47 | | CNN-multichannel | 47.4 | **88.1** | 48 | 49 | * Re-Implementation 50 | 51 | | model | SST-1 | SST-2 | 52 | | ------------------ | ---------- | ---------- | 53 | | CNN-rand | 34.841 | 74.500 | 54 | | CNN-static | 45.056 | 84.125 | 55 | | CNN-non-static | 46.974 | 85.886 | 56 | | CNN-multichannel | 45.129 | **85.993** | 57 | | Attention + BiLSTM | 47.015 | 85.632 | 58 | | Attention + BiGRU | **47.854** | 85.102 | 59 | 60 | ## Requirement 61 | 62 | Please install the following library requirements first. 63 | 64 | ```markdown 65 | pandas==0.24.2 66 | torch==1.1.0 67 | fire==0.1.3 68 | numpy==1.16.2 69 | gensim==3.7.3 70 | ``` 71 | 72 | ## Structure 73 | 74 | ```python 75 | │ .gitignore 76 | │ config.py # Global Configuration 77 | │ datasets.py # Create Dataloader 78 | │ main.py 79 | │ preprocess.py 80 | │ README.md 81 | │ requirements.txt 82 | │ utils.py 83 | │ 84 | ├─checkpoints # Save checkpoint and best model 85 | │ 86 | ├─data # pretrained word vectors and datasets 87 | │ │ glove.6B.300d.txt 88 | │ │ GoogleNews-vectors-negative300.bin 89 | │ └─stanfordSentimentTreebank # datasets folder 90 | │ 91 | ├─models 92 | │ TextAttnBiLSTM.py 93 | │ TextCNN.py 94 | │ __init__.py 95 | │ 96 | └─output_data # Preprocessed data and vocabulary, etc. 97 | ``` 98 | 99 | ## Usage 100 | 101 | * Set global configuration parameters in config.py 102 | 103 | * Preprocess the datasets 104 | 105 | ```shell 106 | $python preprocess.py 107 | ``` 108 | 109 | * Train 110 | 111 | ```shell 112 | $python main.py run 113 | ``` 114 | 115 | You can set the parameters in the `config.py` and `models/TextCNN.py` or `models/TextAttnBiLSTM.py` in the command line. 116 | 117 | ```shell 118 | $python main.py run [--option=VALUE] 119 | ``` 120 | 121 | For example, 122 | 123 | ```shell 124 | $python main.py run --status='train' --use_model="TextAttnBiLSTM" 125 | ``` 126 | 127 | * Test 128 | 129 | ```shell 130 | $python main.py run --status='test' --best_model="checkpoints/BEST_checkpoint_SST-2_TextCNN.pth" 131 | ``` 132 | 133 | ## Conclusion 134 | 135 | * The `TextCNN` model uses the n-gram-like convolution kernel extraction feature, while the `TextAttnBiLSTM` model uses BiLSTM to capture semantics and long-term dependencies, combined with the attention mechanism for classification. 136 | * TextCNN Parameter tuning: 137 | * glove is better than word2vec 138 | * Use a smaller batch size 139 | * Add weight decay ($l_2$ constraint), learning rate decay, early stop, etc. 140 | * Do not set `padding_idx=0` in embedding layer 141 | * TextAttnBiLSTM 142 | * Apply dropout on embedding layer, LSTM layer, and fully-connected layer 143 | 144 | ## Acknowledge 145 | 146 | * Motivated by https://github.com/TobiasLee/Text-Classification 147 | * Thanks to https://github.com/bigboNed3/chinese_text_cnn 148 | * Thanks to https://github.com/ShawnyXiao/TextClassification-Keras 149 | 150 | ## Reference 151 | 152 | [1] [Convolutional Neural Networks for Sentence Classification](http://www.aclweb.org/anthology/D14-1181) 153 | 154 | [2] [A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1510.03820) 155 | 156 | [3] [Attention-Based Bidirection LSTM for Text Classification](https://www.aclweb.org/anthology/P16-2034) 157 | 158 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-11 19:12:09 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-13 16:38:22 6 | @Description: 7 | ''' 8 | import os 9 | import json 10 | import torch 11 | import pandas as pd 12 | import numpy as np 13 | import warnings 14 | from utils import load_embeddings 15 | from collections import Counter 16 | from config import Config 17 | 18 | warnings.filterwarnings("ignore") # 忽略输出警告 19 | 20 | def create_input_files(data_name, SST_path, emb_file, emb_format, output_folder, min_word_freq, max_len): 21 | ''' 22 | 对数据集进行预处理 23 | 24 | :param data_name: SST-1/SST-2 25 | :param SST_path: Stanford Sentiment Treebank数据集的路径 26 | :param emb_file: 预训练词向量文件路径 27 | :param emb_format: 词向量格式 glove or word2vec 28 | :param output_folder: 处理后的数据集保存路径 29 | :param min_word_freq: 最小词频 30 | :param max_len: 最大采样长度 31 | ''' 32 | 33 | # Sanity check 34 | assert data_name in {'SST-1', 'SST-2'} 35 | 36 | 37 | # 读入数据集 38 | print('Preprocess datasets...') 39 | datasetSentences = pd.read_csv(SST_path + 'datasetSentences.txt', sep='\t') 40 | dictionary = pd.read_csv(SST_path + 'dictionary.txt', sep='|', header=None, names=['sentence', 'phrase ids']) 41 | datasetSplit = pd.read_csv(SST_path + 'datasetSplit.txt', sep=',') 42 | sentiment_labels = pd.read_csv(SST_path + 'sentiment_labels.txt', sep='|') 43 | 44 | # 将多个表进行内连接合并 45 | dataset = pd.merge(pd.merge(pd.merge(datasetSentences, datasetSplit), dictionary),sentiment_labels) 46 | 47 | 48 | def labeling(data_name, sentiment_value): 49 | ''' 50 | 将情感值转为标签 51 | 52 | :param data_name: SST-1/SST-2 53 | :param sentiment_value: sentiment_value 54 | :return: label 55 | ''' 56 | if data_name == 'SST-1': 57 | if sentiment_value <= 0.2: 58 | return 0 # very negative 59 | elif sentiment_value <= 0.4: 60 | return 1 # negative 61 | elif sentiment_value <= 0.6: 62 | return 2 # neutral 63 | elif sentiment_value <= 0.8: 64 | return 3 # positive 65 | elif sentiment_value <= 1: 66 | return 4 # very positive 67 | else: 68 | if sentiment_value <= 0.4: 69 | return 0 # negative 70 | elif sentiment_value > 0.6: 71 | return 1 # positive 72 | else: 73 | return -1 # drop neutral 74 | 75 | # 将情感值转为标签 76 | dataset['sentiment_label'] = dataset['sentiment values'].apply(lambda x: labeling(data_name, x)) 77 | dataset = dataset[dataset['sentiment_label'] != -1] 78 | 79 | 80 | 81 | def check_not_punctuation(token): 82 | ''' 83 | 检查token是否完全由非数字字母字符组成,比如`` 84 | 85 | :param s: sentence 86 | :return: bool 87 | ''' 88 | for ch in token: 89 | if ch.isalnum(): return True 90 | return False 91 | 92 | def filter_punctuation(s): 93 | ''' 94 | 将句子转为小写,同时过滤标点符号等 95 | 96 | :param s: sentence 97 | :return: token list 98 | ''' 99 | s = s.lower().split(' ') 100 | return [token for token in s if check_not_punctuation(token)] 101 | 102 | # 对句子进行预处理 103 | dataset['sentence'] = dataset['sentence'].apply(lambda s: filter_punctuation(s)) 104 | 105 | 106 | 107 | # 创建词表 108 | word_freq = Counter() 109 | valid_idx = [] 110 | for i,tokens in enumerate(dataset['sentence']): 111 | word_freq.update(tokens) 112 | if len(tokens) <= max_len: # 采样长度不超过max_len 113 | valid_idx.append(i) 114 | dataset = dataset.iloc[valid_idx, :] 115 | 116 | words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq] 117 | word_map = {k: v + 1 for v, k in enumerate(words)} 118 | word_map[''] = len(word_map) + 1 119 | word_map[''] = 0 120 | 121 | 122 | 123 | def tokens_to_idx(tokens): 124 | ''' 125 | 将token转为索引 126 | 127 | :param tokens: token list 128 | :return: index list 129 | ''' 130 | return [word_map.get(word, word_map['']) for word in tokens] + [word_map['']] * (max_len - len(tokens)) 131 | 132 | # 将token转成索引 133 | dataset['token_idx'] = dataset['sentence'].apply(lambda x: tokens_to_idx(x)) 134 | 135 | 136 | 137 | # 加载并保存预训练词向量 138 | pretrain_embed, embed_dim = load_embeddings(emb_file, emb_format, word_map) 139 | embed = dict() 140 | embed['pretrain'] = pretrain_embed 141 | embed['dim'] = embed_dim 142 | torch.save(embed, output_folder + data_name + '_' + 'pretrain_embed.pth') 143 | 144 | 145 | # 保存word_map 146 | with open(os.path.join(output_folder, data_name + '_' + 'wordmap.json'), 'w') as j: 147 | json.dump(word_map, j) 148 | 149 | 150 | # 保存处理好的数据集 151 | # train 152 | dataset[dataset['splitset_label']==1][['token_idx','sentiment_label']].to_csv(output_folder + data_name + '_' + 'train.csv',index=False) 153 | # test 154 | dataset[dataset['splitset_label']==2][['token_idx','sentiment_label']].to_csv(output_folder + data_name + '_' + 'test.csv',index=False) 155 | # dev 156 | dataset[dataset['splitset_label']==3][['token_idx','sentiment_label']].to_csv(output_folder + data_name + '_' + 'dev.csv',index=False) 157 | 158 | print('Preprocess End\n') 159 | 160 | 161 | 162 | if __name__ == "__main__": 163 | opt = Config() 164 | create_input_files(data_name=opt.data_name, 165 | SST_path=opt.SST_path, 166 | emb_file=opt.emb_file, 167 | emb_format=opt.emb_format, 168 | output_folder=opt.output_folder, 169 | min_word_freq=opt.min_word_freq, 170 | max_len=opt.max_len) -------------------------------------------------------------------------------- /models/TextCNN.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-09 16:29:55 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-16 19:00:19 6 | @Description: 7 | ''' 8 | import json 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim 13 | import torch.utils.data 14 | from config import Config 15 | from datasets import SSTreebankDataset 16 | from utils import adjust_learning_rate, accuracy, save_checkpoint, AverageMeter, train, validate, testing 17 | 18 | class ModelConfig(object): 19 | ''' 20 | 模型配置参数 21 | ''' 22 | # 全局配置参数 23 | opt = Config() 24 | 25 | # 数据参数 26 | output_folder = opt.output_folder 27 | data_name = opt.data_name 28 | SST_path = opt.SST_path 29 | emb_file = opt.emb_file 30 | emb_format = opt.emb_format 31 | output_folder = opt.output_folder 32 | min_word_freq = opt.min_word_freq 33 | max_len = opt.max_len 34 | 35 | # 训练参数 36 | epochs = 120 # epoch数目,除非early stopping, 先开20个epoch不微调,再开多点epoch微调 37 | batch_size = 32 # batch_size 38 | workers = 4 # 多处理器加载数据 39 | lr = 1e-4 # 如果要微调时,学习率要小于1e-3,因为已经是很优化的了,不用这么大的学习率 40 | weight_decay = 1e-5 # 权重衰减率 41 | decay_epoch = 20 # 多少个epoch后执行学习率衰减 42 | improvement_epoch = 6 # 多少个epoch后执行early stopping 43 | is_Linux = True # 如果是Linux则设置为True,否则设置为else, 用于判断是否多处理器加载 44 | print_freq = 100 # 每隔print_freq个iteration打印状态 45 | checkpoint = None # 模型断点所在位置, 无则None 46 | best_model = None # 最优模型所在位置 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | 49 | # 模型参数 50 | model_name = 'TextCNN' # 模型名 51 | class_num = 5 if data_name == 'SST-1' else 2 # 分类类别 52 | kernel_num = 100 # kernel数量 53 | kernel_sizes = [3,4,5] # 不同尺寸的kernel 54 | dropout = 0.5 # dropout 55 | embed_dim = 128 # 未使用预训练词向量的默认值 56 | static = True # 是否使用预训练词向量, static=True, 表示使用预训练词向量 57 | non_static = True # 是否微调,non_static=True,表示微调 58 | multichannel = True # 是否多通道 59 | 60 | 61 | class ModelCNN(nn.Module): 62 | ''' 63 | TextCNN: CNN-rand, CNN-static, CNN-non-static, CNN-multichannel 64 | ''' 65 | def __init__(self, vocab_size, embed_dim, kernel_num, kernel_sizes, class_num, pretrain_embed, dropout, static, non_static, multichannel): 66 | ''' 67 | :param vocab_size: 词表大小 68 | :param embed_dim: 词向量维度 69 | :param kernel_num: kernel数目 70 | :param kernel_sizes: 不同kernel size 71 | :param class_num: 类别数 72 | :param pretrain_embed: 预训练词向量 73 | :param dropout: dropout 74 | :param static: 是否使用预训练词向量, static=True, 表示使用预训练词向量 75 | :param non_static: 是否微调,non_static=True,表示不微调 76 | :param multichannel: 是否多通道 77 | ''' 78 | super(ModelCNN, self).__init__() 79 | 80 | # 初始化为单通道 81 | channel_num = 1 82 | 83 | # 随机初始化词向量 84 | self.embedding = nn.Embedding(vocab_size, embed_dim) 85 | 86 | # 使用预训练词向量 87 | if static: 88 | self.embedding = self.embedding.from_pretrained(pretrain_embed, freeze=not non_static) 89 | 90 | # 微调+固定预训练词向量 91 | if multichannel: 92 | # defalut: freeze=True, 即默认embedding2是固定的 93 | self.embedding2 = nn.Embedding(vocab_size, embed_dim).from_pretrained(pretrain_embed) 94 | channel_num = 2 95 | else: 96 | self.embedding2 = None 97 | 98 | # 卷积层, kernel size: (size, embed_dim), output: [(batch_size, kernel_num, h,1)] 99 | self.convs = nn.ModuleList([ 100 | nn.Conv2d(channel_num, kernel_num, (size, embed_dim)) 101 | for size in kernel_sizes 102 | ]) 103 | 104 | 105 | # 1维最大池化层,因为无法确定feature map大小,所以放在forward里面 106 | 107 | # dropout 108 | self.dropout = nn.Dropout(dropout) 109 | 110 | # 全连接层 111 | self.fc = nn.Linear(len(kernel_sizes) * kernel_num, class_num) 112 | 113 | def forward(self, x): 114 | ''' 115 | :params x: (batch_size, max_len) 116 | :return x: logits 117 | ''' 118 | 119 | if self.embedding2: 120 | x = torch.stack([self.embedding(x), self.embedding2(x)], dim=1) # (batch_size, 2, max_len, word_vec) 121 | else: 122 | x = self.embedding(x).unsqueeze(1) # (batch_size, 1, max_len, word_vec) 123 | 124 | # 卷积 125 | x = [F.relu(conv(x)).squeeze(3) for conv in self.convs] # [(batch_size, kernel_num, h)] 126 | # 池化 127 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(batch_size, kernel_num)] 128 | # flatten 129 | x = torch.cat(x, 1) # (batch_size, kernel_num * len(kernel_sizes)) 130 | # dropout 131 | x = self.dropout(x) 132 | # fc 133 | x = self.fc(x) # logits, 没有softmax, (batch_size, class_num) 134 | 135 | return x 136 | 137 | 138 | 139 | def train_eval(opt): 140 | ''' 141 | 训练和验证 142 | ''' 143 | # 初始化best accuracy 144 | best_acc = 0. 145 | 146 | # epoch 147 | start_epoch = 0 148 | epochs = opt.epochs 149 | epochs_since_improvement = 0 # 跟踪训练时的验证集上的BLEU变化,每过一个epoch没提升则加1 150 | 151 | # 读入词表 152 | word_map_file = opt.output_folder + opt.data_name + '_' + 'wordmap.json' 153 | with open(word_map_file, 'r') as j: 154 | word_map = json.load(j) 155 | 156 | # 加载预训练词向量 157 | embed_file = opt.output_folder + opt.data_name + '_' + 'pretrain_embed.pth' 158 | embed_file = torch.load(embed_file) 159 | pretrain_embed, embed_dim = embed_file['pretrain'], embed_file['dim'] 160 | 161 | # 初始化/加载模型 162 | if opt.checkpoint is None: 163 | if opt.static == False: embed_dim = opt.embed_dim 164 | model = ModelCNN(vocab_size=len(word_map), 165 | embed_dim=embed_dim, 166 | kernel_num=opt.kernel_num, 167 | kernel_sizes=opt.kernel_sizes, 168 | class_num=opt.class_num, 169 | pretrain_embed=pretrain_embed, 170 | dropout=opt.dropout, 171 | static=opt.static, 172 | non_static=opt.non_static, 173 | multichannel=opt.multichannel) 174 | 175 | optimizer = torch.optim.Adam(params=model.parameters(), 176 | lr=opt.lr, 177 | weight_decay=opt.weight_decay) 178 | 179 | else: 180 | # 载入checkpoint 181 | checkpoint = torch.load(opt.checkpoint, map_location='cpu') 182 | start_epoch = checkpoint['epoch'] + 1 183 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 184 | best_acc = checkpoint['acc'] 185 | model = checkpoint['model'] 186 | optimizer = checkpoint['optimizer'] 187 | 188 | # 移动到GPU 189 | model = model.to(opt.device) 190 | 191 | # loss function 192 | criterion = nn.CrossEntropyLoss().to(opt.device) 193 | 194 | # dataloader 195 | train_loader = torch.utils.data.DataLoader( 196 | SSTreebankDataset(opt.data_name, opt.output_folder, 'train'), 197 | batch_size=opt.batch_size, 198 | shuffle=True, 199 | num_workers = opt.workers if opt.is_Linux else 0, 200 | pin_memory=True) 201 | val_loader = torch.utils.data.DataLoader( 202 | SSTreebankDataset(opt.data_name, opt.output_folder, 'dev'), 203 | batch_size=opt.batch_size, 204 | shuffle=True, 205 | num_workers = opt.workers if opt.is_Linux else 0, 206 | pin_memory=True) 207 | 208 | # Epochs 209 | for epoch in range(start_epoch, epochs): 210 | 211 | # 学习率衰减 212 | if epoch > opt.decay_epoch: 213 | adjust_learning_rate(optimizer, epoch) 214 | 215 | # early stopping 如果dev上的acc在6个连续epoch上没有提升 216 | if epochs_since_improvement == opt.improvement_epoch: 217 | break 218 | 219 | # 一个epoch的训练 220 | train(train_loader=train_loader, 221 | model=model, 222 | criterion=criterion, 223 | optimizer=optimizer, 224 | epoch=epoch, 225 | vocab_size=len(word_map), 226 | print_freq=opt.print_freq, 227 | device=opt.device) 228 | 229 | # 一个epoch的验证 230 | recent_acc = validate(val_loader=val_loader, 231 | model=model, 232 | criterion=criterion, 233 | print_freq=opt.print_freq, 234 | device=opt.device) 235 | 236 | # 检查是否有提升 237 | is_best = recent_acc > best_acc 238 | best_acc = max(recent_acc, best_acc) 239 | if not is_best: 240 | epochs_since_improvement += 1 241 | print("Epochs since last improvement: %d\n" % (epochs_since_improvement,)) 242 | else: 243 | epochs_since_improvement = 0 244 | 245 | # 保存模型 246 | save_checkpoint(opt.model_name, opt.data_name, epoch, epochs_since_improvement, model, optimizer, recent_acc, is_best) 247 | 248 | def test(opt): 249 | 250 | # 载入best model 251 | best_model = torch.load(opt.best_model, map_location='cpu') 252 | model = best_model['model'] 253 | 254 | # 移动到GPU 255 | model = model.to(opt.device) 256 | 257 | # loss function 258 | criterion = nn.CrossEntropyLoss().to(opt.device) 259 | 260 | # dataloader 261 | test_loader = torch.utils.data.DataLoader( 262 | SSTreebankDataset(opt.data_name, opt.output_folder, 'test'), 263 | batch_size=opt.batch_size, 264 | shuffle=True, 265 | num_workers = opt.workers if opt.is_Linux else 0, 266 | pin_memory=True) 267 | 268 | # test 269 | testing(test_loader, model, criterion, opt.print_freq, opt.device) 270 | 271 | -------------------------------------------------------------------------------- /models/TextAttnBiLSTM.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-16 13:34:15 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-17 01:49:03 6 | @Description: 7 | ''' 8 | import json 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim 13 | import torch.utils.data 14 | from config import Config 15 | from datasets import SSTreebankDataset 16 | from utils import adjust_learning_rate, accuracy, save_checkpoint, AverageMeter, train, validate, testing 17 | 18 | class ModelConfig(): 19 | ''' 20 | 模型配置参数 21 | ''' 22 | # 全局配置参数 23 | opt = Config() 24 | 25 | # 数据参数 26 | output_folder = opt.output_folder 27 | data_name = opt.data_name 28 | SST_path = opt.SST_path 29 | emb_file = opt.emb_file 30 | emb_format = opt.emb_format 31 | output_folder = opt.output_folder 32 | min_word_freq = opt.min_word_freq 33 | max_len = opt.max_len 34 | 35 | # 训练参数 36 | epochs = 120 # epoch数目,除非early stopping, 先开20个epoch不微调,再开多点epoch微调 37 | batch_size = 16 # batch_size 38 | workers = 4 # 多处理器加载数据 39 | lr = 1e-4 # 如果要微调时,学习率要小于1e-3,因为已经是很优化的了,不用这么大的学习率 40 | weight_decay = 1e-5 # 权重衰减率 41 | decay_epoch = 15 # 多少个epoch后执行学习率衰减 42 | improvement_epoch = 30 # 多少个epoch后执行early stopping 43 | is_Linux = True # 如果是Linux则设置为True,否则设置为else, 用于判断是否多处理器加载 44 | print_freq = 100 # 每隔print_freq个iteration打印状态 45 | checkpoint = None # 模型断点所在位置, 无则None 46 | best_model = None # 最优模型所在位置 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | 49 | # 模型参数 50 | model_name = 'TextAttnBiLSTM' # 模型名 51 | class_num = 5 if data_name == 'SST-1' else 2 # 分类类别 52 | embed_dropout = 0.3 # dropout 53 | model_dropout = 0.5 # dropout 54 | fc_dropout = 0.5 # dropout 55 | num_layers = 2 # LSTM层数 56 | embed_dim = 128 # 未使用预训练词向量的默认值 57 | use_embed = True # 是否使用预训练 58 | use_gru = True # 是否使用GRU 59 | grad_clip = 4. # 梯度裁剪阈值 60 | 61 | class Attn(nn.Module): 62 | ''' 63 | Attention Layer 64 | ''' 65 | def __init__(self, hidden_size): 66 | super(Attn, self).__init__() 67 | self.attn = nn.Linear(hidden_size, 1) 68 | 69 | def forward(self, x): 70 | ''' 71 | :param x: (batch_size, max_len, hidden_size) 72 | :return alpha: (batch_size, max_len) 73 | ''' 74 | x = torch.tanh(x) # (batch_size, max_len, hidden_size) 75 | x = self.attn(x).squeeze(2) # (batch_size, max_len) 76 | alpha = F.softmax(x, dim=1).unsqueeze(1) # (batch_size, 1, max_len) 77 | return alpha 78 | 79 | class ModelAttnBiLSTM(nn.Module): 80 | ''' 81 | BiLSTM: BiLSTM, BiGRU 82 | ''' 83 | def __init__(self, vocab_size, embed_dim, hidden_size, pretrain_embed, use_gru, embed_dropout, fc_dropout, model_dropout, num_layers, class_num, use_embed): 84 | 85 | super(ModelAttnBiLSTM, self).__init__() 86 | 87 | self.hidden_size = hidden_size 88 | 89 | if use_embed: 90 | self.embedding = nn.Embedding(vocab_size, embed_dim).from_pretrained(pretrain_embed, freeze=False) 91 | else: 92 | self.embedding = nn.Embedding(vocab_size, embed_dim) 93 | 94 | self.embed_dropout = nn.Dropout(embed_dropout) 95 | 96 | if use_gru: 97 | self.bilstm = nn.GRU(embed_dim, hidden_size, num_layers, dropout=(0 if num_layers == 1 else model_dropout), bidirectional=True, batch_first=True) 98 | else: 99 | self.bilstm = nn.LSTM(embed_dim, hidden_size, num_layers, dropout=(0 if num_layers == 1 else model_dropout), bidirectional=True, batch_first=True) 100 | 101 | self.fc = nn.Linear(hidden_size, class_num) 102 | 103 | self.fc_dropout = nn.Dropout(fc_dropout) 104 | 105 | self.attn = Attn(hidden_size) 106 | 107 | def forward(self, x): 108 | ''' 109 | :param x: [batch_size, max_len] 110 | :return logits: logits 111 | ''' 112 | x = self.embedding(x) # (batch_size, max_len, word_vec) 113 | x = self.embed_dropout(x) 114 | # 输入的x是所有time step的输入, 输出的y实际每个time step的hidden输出 115 | # _是最后一个time step的hidden输出 116 | # 因为双向,y的shape为(batch_size, max_len, hidden_size*num_directions), 其中[:,:,:hidden_size]是前向的结果,[:,:,hidden_size:]是后向的结果 117 | y, _ = self.bilstm(x) # (batch_size, max_len, hidden_size*num_directions) 118 | y = y[:,:,:self.hidden_size] + y[:,:,self.hidden_size:] # (batch_size, max_len, hidden_size) 119 | alpha = self.attn(y) # (batch_size, 1, max_len) 120 | r = alpha.bmm(y).squeeze(1) # (batch_size, hidden_size) 121 | h = torch.tanh(r) # (batch_size, hidden_size) 122 | logits = self.fc(h) # (batch_size, class_num) 123 | logits = self.fc_dropout(logits) 124 | return logits 125 | 126 | 127 | def train_eval(opt): 128 | ''' 129 | 训练和验证 130 | ''' 131 | # 初始化best accuracy 132 | best_acc = 0. 133 | 134 | # epoch 135 | start_epoch = 0 136 | epochs = opt.epochs 137 | epochs_since_improvement = 0 # 跟踪训练时的验证集上的BLEU变化,每过一个epoch没提升则加1 138 | 139 | # 读入词表 140 | word_map_file = opt.output_folder + opt.data_name + '_' + 'wordmap.json' 141 | with open(word_map_file, 'r') as j: 142 | word_map = json.load(j) 143 | 144 | # 加载预训练词向量 145 | embed_file = opt.output_folder + opt.data_name + '_' + 'pretrain_embed.pth' 146 | embed_file = torch.load(embed_file) 147 | pretrain_embed, embed_dim = embed_file['pretrain'], embed_file['dim'] 148 | 149 | # 初始化/加载模型 150 | if opt.checkpoint is None: 151 | if opt.use_embed == False: embed_dim = opt.embed_dim 152 | model = ModelAttnBiLSTM(vocab_size=len(word_map), 153 | embed_dim=embed_dim, 154 | hidden_size=embed_dim, 155 | class_num=opt.class_num, 156 | pretrain_embed=pretrain_embed, 157 | num_layers=opt.num_layers, 158 | model_dropout=opt.model_dropout, 159 | fc_dropout=opt.fc_dropout, 160 | embed_dropout=opt.embed_dropout, 161 | use_gru=opt.use_gru, 162 | use_embed=opt.use_embed) 163 | 164 | optimizer = torch.optim.Adam(params=model.parameters(), 165 | lr=opt.lr, 166 | weight_decay=opt.weight_decay) 167 | 168 | else: 169 | # 载入checkpoint 170 | checkpoint = torch.load(opt.checkpoint, map_location='cpu') 171 | start_epoch = checkpoint['epoch'] + 1 172 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 173 | best_acc = checkpoint['acc'] 174 | model = checkpoint['model'] 175 | optimizer = checkpoint['optimizer'] 176 | 177 | # 移动到GPU 178 | model = model.to(opt.device) 179 | 180 | # loss function 181 | criterion = nn.CrossEntropyLoss().to(opt.device) 182 | 183 | # dataloader 184 | train_loader = torch.utils.data.DataLoader( 185 | SSTreebankDataset(opt.data_name, opt.output_folder, 'train'), 186 | batch_size=opt.batch_size, 187 | shuffle=True, 188 | num_workers = opt.workers if opt.is_Linux else 0, 189 | pin_memory=True) 190 | val_loader = torch.utils.data.DataLoader( 191 | SSTreebankDataset(opt.data_name, opt.output_folder, 'dev'), 192 | batch_size=opt.batch_size, 193 | shuffle=True, 194 | num_workers = opt.workers if opt.is_Linux else 0, 195 | pin_memory=True) 196 | 197 | # Epochs 198 | for epoch in range(start_epoch, epochs): 199 | 200 | # 学习率衰减 201 | if epoch > opt.decay_epoch: 202 | adjust_learning_rate(optimizer, epoch) 203 | 204 | # early stopping 如果dev上的acc在6个连续epoch上没有提升 205 | if epochs_since_improvement == opt.improvement_epoch: 206 | break 207 | 208 | # 一个epoch的训练 209 | train(train_loader=train_loader, 210 | model=model, 211 | criterion=criterion, 212 | optimizer=optimizer, 213 | epoch=epoch, 214 | vocab_size=len(word_map), 215 | print_freq=opt.print_freq, 216 | device=opt.device, 217 | grad_clip=opt.grad_clip) 218 | 219 | # 一个epoch的验证 220 | recent_acc = validate(val_loader=val_loader, 221 | model=model, 222 | criterion=criterion, 223 | print_freq=opt.print_freq, 224 | device=opt.device) 225 | 226 | # 检查是否有提升 227 | is_best = recent_acc > best_acc 228 | best_acc = max(recent_acc, best_acc) 229 | if not is_best: 230 | epochs_since_improvement += 1 231 | print("Epochs since last improvement: %d\n" % (epochs_since_improvement,)) 232 | else: 233 | epochs_since_improvement = 0 234 | 235 | # 保存模型 236 | save_checkpoint(opt.model_name, opt.data_name, epoch, epochs_since_improvement, model, optimizer, recent_acc, is_best) 237 | 238 | def test(opt): 239 | 240 | # 载入best model 241 | best_model = torch.load(opt.best_model, map_location='cpu') 242 | model = best_model['model'] 243 | 244 | # 移动到GPU 245 | model = model.to(opt.device) 246 | 247 | # loss function 248 | criterion = nn.CrossEntropyLoss().to(opt.device) 249 | 250 | # dataloader 251 | test_loader = torch.utils.data.DataLoader( 252 | SSTreebankDataset(opt.data_name, opt.output_folder, 'test'), 253 | batch_size=opt.batch_size, 254 | shuffle=True, 255 | num_workers = opt.workers if opt.is_Linux else 0, 256 | pin_memory=True) 257 | 258 | # test 259 | testing(test_loader, model, criterion, opt.print_freq, opt.device) 260 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | @Author: Gordon Lee 3 | @Date: 2019-08-09 13:48:17 4 | @LastEditors: Gordon Lee 5 | @LastEditTime: 2019-08-16 16:29:07 6 | @Description: 7 | ''' 8 | import math 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from gensim.models import KeyedVectors as Vectors 13 | 14 | class AverageMeter(object): 15 | ''' 16 | 跟踪指标的最新值,平均值,和,count 17 | ''' 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0. #value 24 | self.avg = 0. #average 25 | self.sum = 0. #sum 26 | self.count = 0 #count 27 | 28 | def update(self, val, n=1): 29 | self.val = val # 当前batch的val 30 | self.sum += val * n # 从第一个batch到现在的累加值 31 | self.count += n # 累加数目加1 32 | self.avg = self.sum / self.count # 从第一个batch到现在的平均值 33 | 34 | 35 | def init_embeddings(embeddings): 36 | ''' 37 | 使用均匀分布U(-bias, bias)来随机初始化 38 | 39 | :param embeddings: 词向量矩阵 40 | ''' 41 | bias = np.sqrt(3.0 / embeddings.size(1)) 42 | torch.nn.init.uniform_(embeddings, -bias, bias) 43 | 44 | 45 | def load_embeddings(emb_file, emb_format, word_map): 46 | ''' 47 | 加载预训练词向量 48 | 49 | :param emb_file: 词向量文件路径 50 | :param emb_format: 词向量格式: 'glove' or 'word2vec' 51 | :param word_map: 词表 52 | :return: 词向量矩阵, 词向量维度 53 | ''' 54 | assert emb_format in {'glove', 'word2vec'} 55 | 56 | vocab = set(word_map.keys()) 57 | 58 | print("Loading embedding...") 59 | cnt = 0 # 记录读入的词数 60 | 61 | if emb_format == 'glove': 62 | 63 | with open(emb_file, 'r', encoding='utf-8') as f: 64 | emb_dim = len(f.readline().split(' ')) - 1 65 | 66 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 67 | #初始化词向量(对OOV进行随机初始化,即对那些在词表上的词但不在预训练词向量中的词) 68 | init_embeddings(embeddings) 69 | 70 | 71 | # 读入词向量文件 72 | for line in open(emb_file, 'r', encoding='utf-8'): 73 | line = line.split(' ') 74 | emb_word = line[0] 75 | 76 | # 过滤空值并转为float型 77 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:]))) 78 | 79 | # 如果不在词表上 80 | if emb_word not in vocab: 81 | continue 82 | else: 83 | cnt+=1 84 | 85 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 86 | 87 | print("Number of words read: ", cnt) 88 | print("Number of OOV: ", len(vocab)-cnt) 89 | 90 | return embeddings, emb_dim 91 | 92 | else: 93 | 94 | vectors = Vectors.load_word2vec_format(emb_file,binary=True) 95 | print("Load successfully") 96 | emb_dim = 300 97 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 98 | #初始化词向量(对OOV进行随机初始化,即对那些在词表上的词但不在预训练词向量中的词) 99 | init_embeddings(embeddings) 100 | 101 | for emb_word in vocab: 102 | 103 | if emb_word in vectors.index2word: 104 | 105 | embedding = vectors[emb_word] 106 | cnt += 1 107 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 108 | 109 | else: 110 | continue 111 | 112 | print("Number of words read: ", cnt) 113 | print("Number of OOV: ", len(vocab)-cnt) 114 | 115 | return embeddings, emb_dim 116 | 117 | def clip_gradient(optimizer, grad_clip): 118 | """ 119 | 梯度裁剪防止梯度爆炸 120 | 121 | :param optimizer: 需要梯度裁剪的优化器 122 | :param grad_clip: 裁剪阈值 123 | """ 124 | for group in optimizer.param_groups: 125 | for param in group['params']: 126 | if param.grad is not None: 127 | # inplace操作,直接修改这个tensor,而不是返回新的 128 | # 将梯度限制在(-grad_clip, grad_clip)间 129 | param.grad.data.clamp_(-grad_clip, grad_clip) 130 | 131 | def accuracy(logits, targets): 132 | ''' 133 | 计算单个batch的正确率 134 | :param logits: (batch_size, class_num) 135 | :param targets: (batch_size) 136 | :return: 137 | ''' 138 | corrects = (torch.max(logits, 1)[1].view(targets.size()).data == targets.data).sum() 139 | return corrects.item() * (100.0 / targets.size(0)) 140 | 141 | def adjust_learning_rate(optimizer, current_epoch): 142 | ''' 143 | 学习率衰减 144 | ''' 145 | frac = float(current_epoch - 20) / 50 146 | shrink_factor = math.pow(0.5, frac) 147 | 148 | print("DECAYING learning rate.") 149 | for param_group in optimizer.param_groups: 150 | param_group['lr'] = param_group['lr'] * shrink_factor 151 | 152 | print("The new learning rate is {}".format(optimizer.param_groups[0]['lr'])) 153 | 154 | 155 | def save_checkpoint(model_name, data_name, epoch, epochs_since_improvement, model, optimizer, acc, is_best): 156 | ''' 157 | 保存模型 158 | 159 | :param model_name: model name 160 | :param data_name: SST-1 or SST-2, 161 | :param epoch: epoch number 162 | :param epochs_since_improvement: 自上次提升正确率后经过的epoch数 163 | :param model: model 164 | :param optimizer: optimizer 165 | :param acc: 每个epoch的验证集上的acc 166 | :param is_best: 该模型参数是否是目前最优的 167 | ''' 168 | state = {'epoch': epoch, 169 | 'epochs_since_improvement': epochs_since_improvement, 170 | 'acc': acc, 171 | 'model': model, 172 | 'optimizer': optimizer} 173 | filename = 'checkpoint_' + data_name + '_' + model_name + '.pth' 174 | torch.save(state, 'checkpoints/' + filename) 175 | # 如果目前的checkpoint是最优的,添加备份以防被重写 176 | if is_best: 177 | torch.save(state, 'checkpoints/' + 'BEST_' + filename) 178 | 179 | 180 | def train(train_loader, model, criterion, optimizer, epoch, vocab_size, print_freq, device, grad_clip=None): 181 | ''' 182 | 执行一个epoch的训练 183 | 184 | :param train_loader: DataLoader 185 | :param model: model 186 | :param criterion: 交叉熵loss 187 | :param optimizer: optimizer 188 | :param epoch: 执行到第几个epoch 189 | :param vocab_size: 词表大小 190 | :param print_freq: 打印频率 191 | :param device: device 192 | :param grad_clip: 梯度裁剪阈值 193 | ''' 194 | # 切换模式(使用dropout) 195 | model.train() 196 | 197 | losses = AverageMeter() # 一个batch的平均loss 198 | accs = AverageMeter() # 一个batch的平均正确率 199 | 200 | for i, (sents, labels) in enumerate(train_loader): 201 | 202 | # 移动到GPU 203 | sents = sents.to(device) 204 | targets = labels.to(device) 205 | 206 | # 前向计算 207 | logits = model(sents) 208 | 209 | # 计算整个batch上的平均loss 210 | loss = criterion(logits, targets) 211 | 212 | # 反向传播 213 | optimizer.zero_grad() 214 | loss.backward() 215 | 216 | # 梯度裁剪 217 | if grad_clip is not None: 218 | clip_gradient(optimizer, grad_clip) 219 | 220 | # 更新参数 221 | optimizer.step() 222 | 223 | # 计算准确率 224 | accs.update(accuracy(logits, targets)) 225 | losses.update(loss.item()) 226 | 227 | # 打印状态 228 | if i % print_freq == 0: 229 | print('Epoch: [{0}][{1}/{2}]\t' 230 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 231 | 'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(epoch, i, len(train_loader), 232 | loss=losses, 233 | acc=accs)) 234 | 235 | 236 | 237 | def validate(val_loader, model, criterion, print_freq, device): 238 | ''' 239 | 执行一个epoch的验证(跑完整个验证集) 240 | 241 | :param val_loader: 验证集的DataLoader 242 | :param model: model 243 | :param criterion: 交叉熵loss 244 | :param print_freq: 打印频率 245 | :param device: device 246 | :return: accuracy 247 | ''' 248 | 249 | #切换模式 250 | model = model.eval() 251 | 252 | losses = AverageMeter() # 一个batch的平均loss 253 | accs = AverageMeter() # 一个batch的平均正确率 254 | 255 | # 设置不计算梯度 256 | with torch.no_grad(): 257 | # 迭代每个batch 258 | for i, (sents, labels) in enumerate(val_loader): 259 | 260 | # 移动到GPU 261 | sents = sents.to(device) 262 | targets = labels.to(device) 263 | 264 | # 前向计算 265 | logits = model(sents) 266 | 267 | # 计算整个batch上的平均loss 268 | loss = criterion(logits, targets) 269 | 270 | # 计算准确率 271 | accs.update(accuracy(logits, targets)) 272 | losses.update(loss.item()) 273 | 274 | if i % print_freq == 0: 275 | print('Validation: [{0}/{1}]\t' 276 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 277 | 'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'.format(i, len(val_loader), 278 | loss=losses, acc=accs)) 279 | # 计算整个验证集上的正确率 280 | print('LOSS - {loss.avg:.3f}, ACCURACY - {acc.avg:.3f}\n'.format(loss=losses, acc=accs)) 281 | 282 | return accs.avg 283 | 284 | 285 | def testing(test_loader, model, criterion, print_freq, device): 286 | ''' 287 | 执行测试 288 | 289 | :param test_loader: 测试集的DataLoader 290 | :param model: model 291 | :param criterion: 交叉熵loss 292 | :param print_freq: 打印频率 293 | :param device: device 294 | :return: accuracy 295 | ''' 296 | 297 | #切换模式 298 | model = model.eval() 299 | 300 | losses = AverageMeter() # 一个batch的平均loss 301 | accs = AverageMeter() # 一个batch的平均正确率 302 | 303 | # 设置不计算梯度 304 | with torch.no_grad(): 305 | # 迭代每个batch 306 | for i, (sents, labels) in enumerate(test_loader): 307 | 308 | # 移动到GPU 309 | sents = sents.to(device) 310 | targets = labels.to(device) 311 | 312 | # 前向计算 313 | logits = model(sents) 314 | 315 | # 计算整个batch上的平均loss 316 | loss = criterion(logits, targets) 317 | 318 | # 计算准确率 319 | accs.update(accuracy(logits, targets)) 320 | losses.update(loss.item()) 321 | 322 | if i % print_freq == 0: 323 | print('Test: [{0}/{1}]\t' 324 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 325 | 'Accuracy {acc.val:.3f} ({acc.avg:.3f})\t'.format(i, len(test_loader), 326 | loss=losses, acc=accs)) 327 | 328 | # 计算整个测试集上的正确率 329 | print('LOSS - {loss.avg:.3f}, ACCURACY - {acc.avg:.3f}'.format(loss=losses, acc=accs)) 330 | 331 | return accs.avg 332 | 333 | 334 | 335 | 336 | --------------------------------------------------------------------------------