├── .gitignore ├── Bert_Script └── extract_bert_word_features_speed.py ├── Config ├── README.md ├── config.cfg └── config.py ├── Data ├── Bert_Features │ └── stsa_binary_train_bert_dim300_16.json ├── README.md └── sst_binary │ ├── stsa.binary.dev │ ├── stsa.binary.test │ └── stsa.binary.trai ├── DataUtils ├── Alphabet.py ├── Batch_Iterator.py ├── Common.py ├── Embed.py ├── Optim.py ├── Pickle.py ├── README.md ├── __init__.py ├── mainHelp.py └── utils.py ├── Dataloader ├── DataLoader_SST_Binary.py ├── Instance.py └── README.md ├── LICENSE ├── README.md ├── __init__.py ├── main.py ├── models ├── Text_Classification │ ├── BiLSTM.py │ ├── CNN.py │ ├── Text_Classification.py │ ├── __init__.py │ ├── initialize.py │ └── modelHelp.py ├── Text_Classification_BertFeature │ ├── Bert_Encoder.py │ ├── Bert_Encoder_Pool.py │ ├── BiLSTM.py │ ├── CNN.py │ ├── Text_Classification.py │ ├── __init__.py │ ├── initialize.py │ └── modelHelp.py └── __init__.py ├── run_train_p.sh ├── test.py └── trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /Bert_Script/extract_bert_word_features_speed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Extract pre-computed feature vectors from a PyTorch BERT model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import sys 23 | import argparse 24 | import collections 25 | import logging 26 | import json 27 | import re 28 | import numpy as np 29 | 30 | import torch 31 | import torch.nn.functional as F 32 | from torch.utils.data import TensorDataset, DataLoader, SequentialSampler 33 | from torch.utils.data.distributed import DistributedSampler 34 | 35 | from pytorch_pretrained_bert.tokenization import BertTokenizer 36 | from pytorch_pretrained_bert.modeling import BertModel 37 | 38 | logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', 39 | datefmt = '%m/%d/%Y %H:%M:%S', 40 | level = logging.INFO) 41 | logger = logging.getLogger(__name__) 42 | 43 | 44 | class InputExample(object): 45 | 46 | def __init__(self, unique_id, text_a, text_b): 47 | self.unique_id = unique_id 48 | self.text_a = text_a 49 | self.text_b = text_b 50 | 51 | 52 | class InputFeatures(object): 53 | """A single set of features of data.""" 54 | 55 | def __init__(self, unique_id, tokens, input_ids, input_mask, input_type_ids): 56 | self.unique_id = unique_id 57 | self.tokens = tokens 58 | self.input_ids = input_ids 59 | self.input_mask = input_mask 60 | self.input_type_ids = input_type_ids 61 | 62 | 63 | def convert_examples_to_features(examples, seq_length, tokenizer): 64 | """Loads a data file into a list of `InputBatch`s.""" 65 | 66 | features = [] 67 | for (ex_index, example) in enumerate(examples): 68 | # print(example.text_a) 69 | tokens_a = tokenizer.tokenize(example.text_a) 70 | 71 | tokens_b = None 72 | if example.text_b: 73 | tokens_b = tokenizer.tokenize(example.text_b) 74 | 75 | if tokens_b: 76 | # Modifies `tokens_a` and `tokens_b` in place so that the total 77 | # length is less than the specified length. 78 | # Account for [CLS], [SEP], [SEP] with "- 3" 79 | _truncate_seq_pair(tokens_a, tokens_b, seq_length - 3) 80 | else: 81 | # Account for [CLS] and [SEP] with "- 2" 82 | if len(tokens_a) > seq_length - 2: 83 | tokens_a = tokens_a[0:(seq_length - 2)] 84 | 85 | # The convention in BERT is: 86 | # (a) For sequence pairs: 87 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 88 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 89 | # (b) For single sequences: 90 | # tokens: [CLS] the dog is hairy . [SEP] 91 | # type_ids: 0 0 0 0 0 0 0 92 | # 93 | # Where "type_ids" are used to indicate whether this is the first 94 | # sequence or the second sequence. The embedding vectors for `type=0` and 95 | # `type=1` were learned during pre-training and are added to the wordpiece 96 | # embedding vector (and position vector). This is not *strictly* necessary 97 | # since the [SEP] token unambigiously separates the sequences, but it makes 98 | # it easier for the model to learn the concept of sequences. 99 | # 100 | # For classification tasks, the first vector (corresponding to [CLS]) is 101 | # used as as the "sentence vector". Note that this only makes sense because 102 | # the entire model is fine-tuned. 103 | tokens = [] 104 | input_type_ids = [] 105 | tokens.append("[CLS]") 106 | input_type_ids.append(0) 107 | for token in tokens_a: 108 | tokens.append(token) 109 | input_type_ids.append(0) 110 | tokens.append("[SEP]") 111 | input_type_ids.append(0) 112 | 113 | if tokens_b: 114 | for token in tokens_b: 115 | tokens.append(token) 116 | input_type_ids.append(1) 117 | tokens.append("[SEP]") 118 | input_type_ids.append(1) 119 | 120 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 121 | 122 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 123 | # tokens are attended to. 124 | input_mask = [1] * len(input_ids) 125 | 126 | # Zero-pad up to the sequence length. 127 | while len(input_ids) < seq_length: 128 | input_ids.append(0) 129 | input_mask.append(0) 130 | input_type_ids.append(0) 131 | 132 | assert len(input_ids) == seq_length 133 | assert len(input_mask) == seq_length 134 | assert len(input_type_ids) == seq_length 135 | 136 | if ex_index < 2: 137 | logger.info("*** Example ***") 138 | logger.info("unique_id: %s" % (example.unique_id)) 139 | logger.info("tokens: %s" % " ".join([str(x) for x in tokens])) 140 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 141 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 142 | logger.info( 143 | "input_type_ids: %s" % " ".join([str(x) for x in input_type_ids])) 144 | 145 | features.append( 146 | InputFeatures( 147 | unique_id=example.unique_id, 148 | tokens=tokens, 149 | input_ids=input_ids, 150 | input_mask=input_mask, 151 | input_type_ids=input_type_ids)) 152 | return features 153 | 154 | 155 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 156 | """Truncates a sequence pair in place to the maximum length.""" 157 | 158 | # This is a simple heuristic which will always truncate the longer sequence 159 | # one token at a time. This makes more sense than truncating an equal percent 160 | # of tokens from each, since if one sequence is very short then each token 161 | # that's truncated likely contains more information than a longer sequence. 162 | while True: 163 | total_length = len(tokens_a) + len(tokens_b) 164 | if total_length <= max_length: 165 | break 166 | if len(tokens_a) > len(tokens_b): 167 | tokens_a.pop() 168 | else: 169 | tokens_b.pop() 170 | 171 | 172 | def cut_text_by_len(text, length): 173 | """ 174 | :param text: 175 | :param length: 176 | :return: 177 | """ 178 | textArr = re.findall('.{' + str(length) + '}', text) 179 | textArr.append(text[(len(textArr) * length):]) 180 | return textArr 181 | 182 | 183 | def _clean_str(string): 184 | """ 185 | Tokenization/string cleaning for all datasets except for SST. 186 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 187 | """ 188 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 189 | string = re.sub(r"\'s", " \'s", string) 190 | string = re.sub(r"\'ve", " \'ve", string) 191 | string = re.sub(r"n\'t", " n\'t", string) 192 | string = re.sub(r"\'re", " \'re", string) 193 | string = re.sub(r"\'d", " \'d", string) 194 | string = re.sub(r"\'ll", " \'ll", string) 195 | string = re.sub(r",", " , ", string) 196 | string = re.sub(r"!", " ! ", string) 197 | string = re.sub(r"\(", " \( ", string) 198 | string = re.sub(r"\)", " \) ", string) 199 | string = re.sub(r"\?", " \? ", string) 200 | string = re.sub(r"\s{2,}", " ", string) 201 | return string.strip().lower() 202 | 203 | 204 | def read_examples(input_file, max_seq_length): 205 | """Read a list of `InputExample`s from an input file.""" 206 | examples = [] 207 | unique_id = 0 208 | line_index = 0 209 | uniqueid_to_line = collections.OrderedDict() 210 | with open(input_file, "r", encoding='utf-8') as reader: 211 | while True: 212 | line = reader.readline() 213 | if not line: 214 | break 215 | line = line.strip().split() 216 | line = " ".join(line[1:]) 217 | line = _clean_str(line) 218 | # print(line) 219 | # exit() 220 | # line = "".join(json.loads(line)["fact"].split()) 221 | # line_cut = cut_text_by_len(line, max_seq_length) 222 | line_cut = [line] 223 | for l in line_cut: 224 | uniqueid_to_line[str(unique_id)] = line_index 225 | text_a = None 226 | text_b = None 227 | m = re.match(r"^(.*) \|\|\| (.*)$", l) 228 | if m is None: 229 | text_a = l 230 | else: 231 | text_a = m.group(1) 232 | text_b = m.group(2) 233 | examples.append( 234 | InputExample(unique_id=unique_id, text_a=text_a, text_b=text_b)) 235 | unique_id += 1 236 | line_index += 1 237 | # print(uniqueid_to_line) 238 | return examples, uniqueid_to_line 239 | 240 | 241 | def to_json(args, output_file, model, eval_dataloader, device, features, layer_indexes, uniqueid_to_line): 242 | """ 243 | :param args: 244 | :param output_file: 245 | :param model: 246 | :param eval_dataloader: 247 | :param device: 248 | :param features: 249 | :param layer_indexes: 250 | :return: 251 | """ 252 | model.eval() 253 | batch_count = len(eval_dataloader) 254 | batch_num = 0 255 | line_index_exist = [] 256 | result = [] 257 | file = open(output_file, mode="w", encoding="utf-8") 258 | # with open(output_file, "w", encoding='utf-8') as writer: 259 | for input_ids, input_mask, example_indices in eval_dataloader: 260 | batch_num += 1 261 | sys.stdout.write("\rBert Model For the {} Batch, All {} batch.".format(batch_num, batch_count)) 262 | input_ids = input_ids.to(device) 263 | input_mask = input_mask.to(device) 264 | all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask) 265 | all_encoder_layers = all_encoder_layers 266 | layer_index = int(-1) 267 | layer_output_all = all_encoder_layers[layer_index].detach().cpu().numpy()[:, :, :args.bert_dim] 268 | 269 | for b, example_index in enumerate(example_indices): 270 | feature = features[example_index.item()] 271 | tokens = feature.tokens 272 | token_length = len(tokens) 273 | layer_output = np.round(layer_output_all[b][:token_length].tolist(), 6).tolist() 274 | 275 | out_features = collections.OrderedDict() 276 | out_features["tokens"] = tokens 277 | out_features["values"] = layer_output 278 | 279 | unique_id = int(feature.unique_id) 280 | line_index = uniqueid_to_line[str(unique_id)] 281 | if line_index in line_index_exist: 282 | output_json["features"]["tokens"].extend(tokens) 283 | output_json["features"]["values"].extend(layer_output) 284 | continue 285 | else: 286 | if len(line_index_exist) != 0: 287 | result.append(output_json) 288 | if len(result) % 10000 == 0: 289 | to_file(file=file, result=result, output_file=output_file) 290 | result.clear() 291 | # writer.write(json.dumps(output_json, ensure_ascii=False) + "\n") 292 | line_index_exist.clear() 293 | line_index_exist.append(line_index) 294 | output_json = collections.OrderedDict() 295 | output_json["linex_index"] = line_index 296 | output_json["layer_index"] = layer_index 297 | output_json["features"] = out_features 298 | # continue 299 | # writer.write(json.dumps(output_json, ensure_ascii=False) + "\n") 300 | result.append(output_json) 301 | 302 | to_file(file, result, output_file) 303 | # print("\nTo Json File {}".format(output_file)) 304 | # line_num = 0 305 | # file = open(output_file, mode="w", encoding="utf-8") 306 | # for js in result: 307 | # line_num += 1 308 | # if line_num % 1000 == 0: 309 | # sys.stdout.write("\rBert Model Result For the {} line, All {} lines.".format(line_num, len(result))) 310 | # file.write(json.dumps(js, ensure_ascii=False) + "\n") 311 | file.close() 312 | 313 | 314 | def to_file(file, result, output_file): 315 | """ 316 | :param file: 317 | :param result: 318 | :param output_file: 319 | :return: 320 | """ 321 | print("\nAdd To Json File {}".format(output_file)) 322 | line_num = 0 323 | # file = open(output_file, mode="w", encoding="utf-8") 324 | for js in result: 325 | line_num += 1 326 | if line_num % 1000 == 0: 327 | sys.stdout.write("\rBert Model Result For the {} line, All {} lines.".format(line_num, len(result))) 328 | file.write(json.dumps(js, ensure_ascii=False) + "\n") 329 | 330 | 331 | def main(args): 332 | """ 333 | :param args: 334 | :return: 335 | """ 336 | # np.set_printoptions(precision=6) 337 | if args.local_rank == -1 or args.no_cuda: 338 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 339 | n_gpu = torch.cuda.device_count() 340 | else: 341 | device = torch.device("cuda", args.local_rank) 342 | n_gpu = 1 343 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 344 | torch.distributed.init_process_group(backend='nccl') 345 | logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1))) 346 | # exit() 347 | layer_indexes = [int(x) for x in args.layers.split(",")] 348 | 349 | # tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) 350 | tokenizer = BertTokenizer.from_pretrained(os.path.join(args.bert_model, args.vocab), do_lower_case=args.do_lower_case) 351 | examples, uniqueid_to_line = read_examples(args.input_file, args.max_seq_length) 352 | # print(max_seq_length) 353 | # exit() 354 | 355 | features = convert_examples_to_features( 356 | examples=examples, seq_length=args.max_seq_length, tokenizer=tokenizer) 357 | 358 | unique_id_to_feature = {} 359 | for feature in features: 360 | unique_id_to_feature[feature.unique_id] = feature 361 | 362 | model = BertModel.from_pretrained(args.bert_model) 363 | model.to(device) 364 | 365 | if args.local_rank != -1: 366 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 367 | output_device=args.local_rank) 368 | elif n_gpu > 1: 369 | model = torch.nn.DataParallel(model) 370 | 371 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 372 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 373 | all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) 374 | 375 | eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index) 376 | if args.local_rank == -1: 377 | eval_sampler = SequentialSampler(eval_data) 378 | else: 379 | eval_sampler = DistributedSampler(eval_data) 380 | eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.batch_size) 381 | 382 | to_json(args, args.output_file, model, eval_dataloader, device, features, layer_indexes, uniqueid_to_line) 383 | 384 | 385 | if __name__ == "__main__": 386 | 387 | max_seq_length = 60 388 | batch_size = 2 389 | bert_dim = 3 390 | input = "../Data/sst_binary/stsa.binary-t.test" 391 | output = "../sst_bert_features/stsa_binary_test_bert_dim{}.json".format(bert_dim) 392 | bert_model = "../bert-base-uncased" 393 | vocab = "bert-base-uncased-vocab.txt" 394 | do_lower_case = True 395 | 396 | parser = argparse.ArgumentParser() 397 | 398 | # Required parameters 399 | parser.add_argument("--input_file", default=input, type=str) 400 | parser.add_argument("--output_file", default=output, type=str) 401 | parser.add_argument("--bert_model", default=bert_model, type=str, 402 | help="Bert pre-trained model selected in the list: bert-base-uncased, " 403 | "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") 404 | parser.add_argument("--vocab", default=vocab, type=str) 405 | 406 | # Other parameters 407 | parser.add_argument("--do_lower_case", default=do_lower_case, action='store_true', help="Set this flag if you are using an uncased model.") 408 | parser.add_argument("--layers", default="-1", type=str) 409 | parser.add_argument("--max_seq_length", default=max_seq_length, type=int, 410 | help="The maximum total input sequence length after WordPiece tokenization. Sequences longer " 411 | "than this will be truncated, and sequences shorter than this will be padded.") 412 | parser.add_argument("--batch_size", default=batch_size, type=int, help="Batch size for predictions.") 413 | parser.add_argument("--bert_dim", default=bert_dim, type=int, help="bert_dim.") 414 | parser.add_argument("--local_rank", 415 | type=int, 416 | default=-1, 417 | help="local_rank for distributed training on gpus") 418 | parser.add_argument("--no_cuda", 419 | action='store_true', 420 | help="Whether not to use CUDA when available") 421 | 422 | args = parser.parse_args() 423 | # print(args.no_cuda) 424 | # exit() 425 | main(args) 426 | -------------------------------------------------------------------------------- /Config/README.md: -------------------------------------------------------------------------------- 1 | ## Config ## 2 | 3 | - Use `ConfigParser` to config parameter 4 | - `from configparser import ConfigParser` . 5 | - Detail see `config.py` and `config.cfg`, please. 6 | 7 | - Following is `config.cfg` Parameter details. 8 | 9 | - [Embed] 10 | - `pretrained_embed` (True or False) ------ whether to use pretrained embedding. 11 | - `zeros`(True or False) ------ OOV by zeros . 12 | - `avg`(True or False) ------ OOV by avg . 13 | - `uniform`(True or False) ------ OOV by uniform . 14 | - `nnembed`(True or False) ------ OOV by nn.Embedding . 15 | - `pretrained_embed_file` --- pre train path 16 | 17 | - [Data] 18 | - `train-file/dev-file/test-file`(path) ------ train/dev/test data path(`Data`). 19 | - `min_freq` (integer number) ------ The smallest Word frequency when build vocab. 20 | - `max_count ` (integer number) ------ The maximum instance for `debug`. 21 | - `shuffle/epochs-shuffle`(True or False) ------ shuffle data . 22 | 23 | - [Save] 24 | - `save_pkl` (True or False) ------ save pkl file for test opition. 25 | - `pkl_directory` (path) ------ save pkl directory path. 26 | - `pkl_data` (path) ------ save pkl data path. 27 | - `pkl_alphabet` (path) ------ save pkl alphabet path. 28 | - `pkl_iter` (path) ------ save pkl batch iterator path. 29 | - `pkl_embed` (path) ------ save pkl pre-train embedding path. 30 | - `save_dict` (True or False) ------ save dict to file. 31 | - `dict_directory ` (path) ------ save dict directory path. 32 | - `word_dict ` (path) ------ save word dict path. 33 | - `label_dict ` (path) ------ save label dict path. 34 | - `save_model ` (True or False) ------ save model to file. 35 | - `save_all_model` (True or False) ------ save all model to file. 36 | - `save_best_model` (True or False) ------ save best model to file. 37 | - `model_name ` (str) ------ model name. 38 | - `rm_model` (True or False) ------ remove model to save space(now not use). 39 | 40 | - [Model] 41 | - `model-***`(True or False) ------ *** model. 42 | - `wide_conv`(True or False) ------ wide model. 43 | - `lstm-layers` (integer) ------ number layers of lstm. 44 | - `embed-dim` (integer) ------ embedding dim = pre-trained embedding dim. 45 | - `embed-finetune` (True or False) ------ word embedding finetune or no-finetune. 46 | - `lstm-hiddens` (integer) ------numbers of lstm hidden. 47 | - `dropout-emb/dropout `(float) ------ dropout for prevent overfitting. 48 | - `conv_filter_sizes` (str(3,4,5)) ------ conv filter sizes split by a comma in English. 49 | - `conv_filter_nums` (int(200)) ------ conv filter nums. 50 | 51 | - [Optimizer] 52 | - `adam` (True or False) ------ `torch.optim.Adam` 53 | - `sgd` (True or False) ------ `torch.optim.SGD` 54 | - `learning-rate`(float) ------ learning rate. 55 | - `weight-decay` (float) ------ L2. 56 | - `momentum ` (float) ------ SGD momentum. 57 | - `clip_max_norm_use` (True or False) ------ use util.clip. 58 | - `clip-max-norm` (Integer number) ------ clip-max-norm value. 59 | - `use_lr_decay` (True or False) ------ use lr decay. 60 | - `lr_rate_decay`(float) ------ lr decay value. 61 | - `min_lrate `(float) ------ minimum lr. 62 | - `max_patience `(Integer) ------ patience to decay. 63 | 64 | - [Train] 65 | - `num-threads` (Integer) ------ threads. 66 | - `use-cuda` (True or False) ------ support `cuda` speed up. 67 | - `epochs` (Integer) ------ maximum train epochs 68 | - `early_max_patience` (Integer) ------ maximum dev no improvment times for early stop. 69 | - `backward_batch_size` (Integer) ------ multiple-batch to update parameters. 70 | - `batch-size/dev-batch-size/test-batch-size` (Integer) ------ number of batch 71 | - `log-interval`(Integer) ------ steps of print log. -------------------------------------------------------------------------------- /Config/config.cfg: -------------------------------------------------------------------------------- 1 | [Embed] 2 | pretrained_embed = False 3 | zeros = False 4 | avg = False 5 | uniform = False 6 | nnembed = True 7 | pretrained_embed_file = ./Data/embed/glove.sentiment.conj.pretrained.txt 8 | 9 | [Data] 10 | train_file = ./Data/sst_binary/stsa.binary.trai 11 | dev_file = ./Data/sst_binary/stsa.binary.dev 12 | test_file = ./Data/sst_binary/stsa.binary.test 13 | max_count = -1 14 | min_freq = 1 15 | shuffle = True 16 | epochs_shuffle = True 17 | 18 | [Bert] 19 | use_bert = True 20 | bert_dim = 300 21 | bert_train_file = ./Data/Bert_Features/stsa_binary_train_bert_dim300_16.json 22 | bert_dev_file = ./Data/Bert_Features/stsa_binary_train_bert_dim300_16.json 23 | bert_test_file = ./Data/Bert_Features/stsa_binary_train_bert_dim300_16.json 24 | 25 | [BertModel] 26 | use_bert_model = False 27 | bert_model_path = ./Data/bert-base-chinese 28 | bert_model_vocab = bert-base-chinese-vocab.txt 29 | bert_max_char_length = 300 30 | bert_model_max_seq_length = 302 31 | bert_model_batch_size = 2 32 | extract_dim = 3 33 | layers = -1,-2,-3,-4 34 | local_rank = -1 35 | no_cuda = False 36 | do_lower_case = True 37 | 38 | [Save] 39 | save_pkl = False 40 | pkl_directory = ./Save_pkl 41 | pkl_data = pkl_data.pkl 42 | pkl_alphabet = pkl_alphabet.pkl 43 | pkl_iter = pkl_iter.pkl 44 | pkl_embed = pkl_embed.pkl 45 | save_dict = True 46 | dict_directory = ./Save_dictionary 47 | word_dict = dictionary_word.txt 48 | label_dict = dictionary_label.txt 49 | save_direction = ./Save_model 50 | save_best_model_dir = ./Save_BModel 51 | save_model = True 52 | save_all_model = False 53 | save_best_model = True 54 | model_name = text_model 55 | rm_model = True 56 | 57 | [Model] 58 | wide_conv = True 59 | model_cnn = False 60 | model_bilstm = True 61 | lstm_layers = 1 62 | embed_dim = 300 63 | embed_finetune = True 64 | lstm_hiddens = 150 65 | dropout_emb = 0.5 66 | dropout = 0.5 67 | conv_filter_sizes = 1,2,3,4 68 | conv_filter_nums = 200 69 | 70 | [Optimizer] 71 | adam = True 72 | sgd = False 73 | learning_rate = 0.001 74 | weight_decay = 1.0e-8 75 | momentum = 0.0 76 | clip_max_norm_use = True 77 | clip_max_norm = 10 78 | use_lr_decay = False 79 | lr_rate_decay = 0.05 80 | min_lrate = 0.000005 81 | max_patience = 1 82 | 83 | [Train] 84 | num_threads = 1 85 | epochs = 1000 86 | early_max_patience = 30 87 | backward_batch_size = 1 88 | batch_size = 10 89 | dev_batch_size = 10 90 | test_batch_size = 10 91 | log_interval = 1 92 | 93 | -------------------------------------------------------------------------------- /Config/config.py: -------------------------------------------------------------------------------- 1 | 2 | from configparser import ConfigParser 3 | import os 4 | 5 | 6 | class myconf(ConfigParser): 7 | """ 8 | MyConf 9 | """ 10 | def __init__(self, defaults=None): 11 | ConfigParser.__init__(self, defaults=defaults) 12 | self.add_sec = "Additional" 13 | 14 | def optionxform(self, optionstr): 15 | return optionstr 16 | 17 | 18 | class Configurable(myconf): 19 | def __init__(self, config_file): 20 | # config = ConfigParser() 21 | super().__init__() 22 | 23 | self.test = None 24 | self.train = None 25 | config = myconf() 26 | config.read(config_file) 27 | # if config.has_section(self.add_sec) is False: 28 | # config.add_section(self.add_sec) 29 | self._config = config 30 | self.config_file = config_file 31 | 32 | print('Loaded config file sucessfully.') 33 | for section in config.sections(): 34 | for k, v in config.items(section): 35 | print(k, ":", v) 36 | if not os.path.isdir(self.save_direction): 37 | os.mkdir(self.save_direction) 38 | config.write(open(config_file, 'w')) 39 | 40 | def add_args(self, key, value): 41 | self._config.set(self.add_sec, key, value) 42 | self._config.write(open(self.config_file, 'w')) 43 | 44 | # Embed 45 | @property 46 | def pretrained_embed(self): 47 | return self._config.getboolean('Embed', 'pretrained_embed') 48 | 49 | @property 50 | def zeros(self): 51 | return self._config.getboolean('Embed', 'zeros') 52 | 53 | @property 54 | def avg(self): 55 | return self._config.getboolean('Embed', 'avg') 56 | 57 | @property 58 | def uniform(self): 59 | return self._config.getboolean('Embed', 'uniform') 60 | 61 | @property 62 | def nnembed(self): 63 | return self._config.getboolean('Embed', 'nnembed') 64 | 65 | @property 66 | def pretrained_embed_file(self): 67 | return self._config.get('Embed', 'pretrained_embed_file') 68 | 69 | # Data 70 | @property 71 | def train_file(self): 72 | return self._config.get('Data', 'train_file') 73 | 74 | @property 75 | def dev_file(self): 76 | return self._config.get('Data', 'dev_file') 77 | 78 | @property 79 | def test_file(self): 80 | return self._config.get('Data', 'test_file') 81 | 82 | @property 83 | def max_count(self): 84 | return self._config.getint('Data', 'max_count') 85 | 86 | @property 87 | def min_freq(self): 88 | return self._config.getint('Data', 'min_freq') 89 | 90 | @property 91 | def shuffle(self): 92 | return self._config.getboolean('Data', 'shuffle') 93 | 94 | @property 95 | def epochs_shuffle(self): 96 | return self._config.getboolean('Data', 'epochs_shuffle') 97 | 98 | # Bert 99 | @property 100 | def use_bert(self): 101 | return self._config.getboolean('Bert', 'use_bert') 102 | 103 | @property 104 | def bert_dim(self): 105 | return self._config.getint('Bert', 'bert_dim') 106 | 107 | @property 108 | def bert_train_file(self): 109 | return self._config.get('Bert', 'bert_train_file') 110 | 111 | @property 112 | def bert_dev_file(self): 113 | return self._config.get('Bert', 'bert_dev_file') 114 | 115 | @property 116 | def bert_test_file(self): 117 | return self._config.get('Bert', 'bert_test_file') 118 | 119 | # BertModel 120 | @property 121 | def use_bert_model(self): 122 | return self._config.getboolean('BertModel', 'use_bert_model') 123 | 124 | @property 125 | def bert_model_path(self): 126 | return self._config.get('BertModel', 'bert_model_path') 127 | 128 | @property 129 | def bert_model_vocab(self): 130 | return self._config.get('BertModel', 'bert_model_vocab') 131 | 132 | @property 133 | def bert_max_char_length(self): 134 | return self._config.getint('BertModel', 'bert_max_char_length') 135 | 136 | @property 137 | def bert_model_max_seq_length(self): 138 | return self._config.getint('BertModel', 'bert_model_max_seq_length') 139 | 140 | @property 141 | def bert_model_batch_size(self): 142 | return self._config.getint('BertModel', 'bert_model_batch_size') 143 | 144 | @property 145 | def extract_dim(self): 146 | return self._config.getint('BertModel', 'extract_dim') 147 | 148 | @property 149 | def layers(self): 150 | return self._config.get('BertModel', 'layers') 151 | 152 | @property 153 | def local_rank(self): 154 | return self._config.getint('BertModel', 'local_rank') 155 | 156 | @property 157 | def no_cuda(self): 158 | return self._config.getboolean('BertModel', 'no_cuda') 159 | 160 | @property 161 | def do_lower_case(self): 162 | return self._config.getboolean('BertModel', 'do_lower_case') 163 | 164 | # Save 165 | @property 166 | def save_pkl(self): 167 | return self._config.getboolean('Save', 'save_pkl') 168 | 169 | @property 170 | def pkl_directory(self): 171 | return self._config.get('Save', 'pkl_directory') 172 | 173 | @property 174 | def pkl_data(self): 175 | return self._config.get('Save', 'pkl_data') 176 | 177 | @property 178 | def pkl_alphabet(self): 179 | return self._config.get('Save', 'pkl_alphabet') 180 | 181 | @property 182 | def pkl_iter(self): 183 | return self._config.get('Save', 'pkl_iter') 184 | 185 | @property 186 | def pkl_embed(self): 187 | return self._config.get('Save', 'pkl_embed') 188 | 189 | @property 190 | def save_dict(self): 191 | return self._config.getboolean('Save', 'save_dict') 192 | 193 | @property 194 | def save_direction(self): 195 | return self._config.get('Save', 'save_direction') 196 | 197 | @property 198 | def dict_directory(self): 199 | return self._config.get('Save', 'dict_directory') 200 | 201 | @property 202 | def word_dict(self): 203 | return self._config.get('Save', 'word_dict') 204 | 205 | @property 206 | def label_dict(self): 207 | return self._config.get('Save', 'label_dict') 208 | 209 | @property 210 | def model_name(self): 211 | return self._config.get('Save', 'model_name') 212 | 213 | @property 214 | def save_best_model_dir(self): 215 | return self._config.get('Save', 'save_best_model_dir') 216 | 217 | @property 218 | def save_model(self): 219 | return self._config.getboolean('Save', 'save_model') 220 | 221 | @property 222 | def save_all_model(self): 223 | return self._config.getboolean('Save', 'save_all_model') 224 | 225 | @property 226 | def save_best_model(self): 227 | return self._config.getboolean('Save', 'save_best_model') 228 | 229 | @property 230 | def rm_model(self): 231 | return self._config.getboolean('Save', 'rm_model') 232 | 233 | # Model 234 | @property 235 | def wide_conv(self): 236 | return self._config.getboolean("Model", "wide_conv") 237 | 238 | @property 239 | def model_cnn(self): 240 | return self._config.getboolean("Model", "model_cnn") 241 | 242 | @property 243 | def model_bilstm(self): 244 | return self._config.getboolean("Model", "model_bilstm") 245 | 246 | @property 247 | def lstm_layers(self): 248 | return self._config.getint("Model", "lstm_layers") 249 | 250 | @property 251 | def embed_dim(self): 252 | return self._config.getint("Model", "embed_dim") 253 | 254 | @property 255 | def embed_finetune(self): 256 | return self._config.getboolean("Model", "embed_finetune") 257 | 258 | @property 259 | def lstm_hiddens(self): 260 | return self._config.getint("Model", "lstm_hiddens") 261 | 262 | @property 263 | def dropout_emb(self): 264 | return self._config.getfloat("Model", "dropout_emb") 265 | 266 | @property 267 | def dropout(self): 268 | return self._config.getfloat("Model", "dropout") 269 | 270 | @property 271 | def conv_filter_sizes(self): 272 | return self._config.get("Model", "conv_filter_sizes") 273 | 274 | @property 275 | def conv_filter_nums(self): 276 | return self._config.getint("Model", "conv_filter_nums") 277 | 278 | # Optimizer 279 | @property 280 | def adam(self): 281 | return self._config.getboolean("Optimizer", "adam") 282 | 283 | @property 284 | def sgd(self): 285 | return self._config.getboolean("Optimizer", "sgd") 286 | 287 | @property 288 | def learning_rate(self): 289 | return self._config.getfloat("Optimizer", "learning_rate") 290 | 291 | @property 292 | def weight_decay(self): 293 | return self._config.getfloat("Optimizer", "weight_decay") 294 | 295 | @property 296 | def momentum(self): 297 | return self._config.getfloat("Optimizer", "momentum") 298 | 299 | @property 300 | def clip_max_norm_use(self): 301 | return self._config.getboolean("Optimizer", "clip_max_norm_use") 302 | 303 | @property 304 | def clip_max_norm(self): 305 | return self._config.get("Optimizer", "clip_max_norm") 306 | 307 | @property 308 | def use_lr_decay(self): 309 | return self._config.getboolean("Optimizer", "use_lr_decay") 310 | 311 | @property 312 | def lr_rate_decay(self): 313 | return self._config.getfloat("Optimizer", "lr_rate_decay") 314 | 315 | @property 316 | def min_lrate(self): 317 | return self._config.getfloat("Optimizer", "min_lrate") 318 | 319 | @property 320 | def max_patience(self): 321 | return self._config.getint("Optimizer", "max_patience") 322 | 323 | # Train 324 | @property 325 | def num_threads(self): 326 | return self._config.getint("Train", "num_threads") 327 | 328 | @property 329 | def epochs(self): 330 | return self._config.getint("Train", "epochs") 331 | 332 | @property 333 | def early_max_patience(self): 334 | return self._config.getint("Train", "early_max_patience") 335 | 336 | @property 337 | def backward_batch_size(self): 338 | return self._config.getint("Train", "backward_batch_size") 339 | 340 | @property 341 | def batch_size(self): 342 | return self._config.getint("Train", "batch_size") 343 | 344 | @property 345 | def dev_batch_size(self): 346 | return self._config.getint("Train", "dev_batch_size") 347 | 348 | @property 349 | def test_batch_size(self): 350 | return self._config.getint("Train", "test_batch_size") 351 | 352 | @property 353 | def log_interval(self): 354 | return self._config.getint("Train", "log_interval") 355 | 356 | 357 | 358 | 359 | -------------------------------------------------------------------------------- /Data/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch_Bert_Text_Classification 2 | 3 | - Datasets: SST-Binary 4 | 5 | 6 | -------------------------------------------------------------------------------- /DataUtils/Alphabet.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/1/30 15:54 3 | # @File : Alphabet.py 4 | # @Last Modify Time : 2018/1/30 15:54 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Alphabet.py 9 | FUNCTION : None 10 | """ 11 | 12 | import os 13 | import sys 14 | import torch 15 | import random 16 | import collections 17 | from DataUtils.Common import seed_num, unkkey, paddingkey 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class CreateAlphabet: 23 | """ 24 | Class: Create_Alphabet 25 | Function: Build Alphabet By Alphabet Class 26 | Notice: The Class Need To Change So That Complete All Kinds Of Tasks 27 | """ 28 | def __init__(self, min_freq=1, train_data=None, dev_data=None, test_data=None, config=None): 29 | 30 | # minimum vocab size 31 | self.min_freq = min_freq 32 | self.config = config 33 | self.train_data = train_data 34 | self.dev_data = dev_data 35 | self.test_data = test_data 36 | 37 | # storage word and label 38 | self.word_state = collections.OrderedDict() 39 | self.label_state = collections.OrderedDict() 40 | 41 | # unk and pad 42 | self.word_state[unkkey] = self.min_freq 43 | self.word_state[paddingkey] = self.min_freq 44 | 45 | # word and label Alphabet 46 | self.word_alphabet = Alphabet(min_freq=self.min_freq) 47 | self.label_alphabet = Alphabet() 48 | 49 | # unk key 50 | self.word_unkId = 0 51 | self.label_unkId = 0 52 | 53 | # padding key 54 | self.word_paddingId = 0 55 | self.label_paddingId = 0 56 | 57 | @staticmethod 58 | def _build_data(train_data=None, dev_data=None, test_data=None): 59 | """ 60 | :param train_data: 61 | :param dev_data: 62 | :param test_data: 63 | :return: 64 | """ 65 | assert train_data is not None, "The Train Data Is Not Allow Empty." 66 | datasets = [] 67 | datasets.extend(train_data) 68 | print("the length of train data {}".format(len(datasets))) 69 | if dev_data is not None: 70 | print("the length of dev data {}".format(len(dev_data))) 71 | datasets.extend(dev_data) 72 | if test_data is not None: 73 | print("the length of test data {}".format(len(test_data))) 74 | datasets.extend(test_data) 75 | print("the length of data that create Alphabet {}".format(len(datasets))) 76 | return datasets 77 | 78 | def build_vocab(self): 79 | """ 80 | :param train_data: 81 | :param dev_data: 82 | :param test_data: 83 | :param debug_index: 84 | :return: 85 | """ 86 | train_data = self.train_data 87 | dev_data = self.dev_data 88 | test_data = self.test_data 89 | print("Build Vocab Start...... ") 90 | datasets = self._build_data(train_data=train_data, dev_data=dev_data, test_data=test_data) 91 | # create the word Alphabet 92 | 93 | for index, data in enumerate(datasets): 94 | # word 95 | for word in data.words: 96 | if word not in self.word_state: 97 | self.word_state[word] = 1 98 | else: 99 | self.word_state[word] += 1 100 | 101 | # label 102 | for label in data.labels: 103 | if label not in self.label_state: 104 | self.label_state[label] = 1 105 | else: 106 | self.label_state[label] += 1 107 | 108 | # Create id2words and words2id by the Alphabet Class 109 | self.word_alphabet.initial(self.word_state) 110 | self.label_alphabet.initial(self.label_state) 111 | 112 | # unkId and paddingId 113 | self.word_unkId = self.word_alphabet.from_string(unkkey) 114 | self.word_paddingId = self.word_alphabet.from_string(paddingkey) 115 | 116 | # fix the vocab 117 | self.word_alphabet.set_fixed_flag(True) 118 | self.label_alphabet.set_fixed_flag(True) 119 | 120 | 121 | class Alphabet: 122 | """ 123 | Class: Alphabet 124 | Function: Build vocab 125 | Params: 126 | ****** id2words: type(list), 127 | ****** word2id: type(dict) 128 | ****** vocab_size: vocab size 129 | ****** min_freq: vocab minimum freq 130 | ****** fixed_vocab: fix the vocab after build vocab 131 | ****** max_cap: max vocab size 132 | """ 133 | def __init__(self, min_freq=1): 134 | self.id2words = [] 135 | self.words2id = collections.OrderedDict() 136 | self.vocab_size = 0 137 | self.min_freq = min_freq 138 | self.max_cap = 1e8 139 | self.fixed_vocab = False 140 | 141 | def initial(self, data): 142 | """ 143 | :param data: 144 | :return: 145 | """ 146 | for key in data: 147 | if data[key] >= self.min_freq: 148 | self.from_string(key) 149 | self.set_fixed_flag(True) 150 | 151 | def set_fixed_flag(self, bfixed): 152 | """ 153 | :param bfixed: 154 | :return: 155 | """ 156 | self.fixed_vocab = bfixed 157 | if (not self.fixed_vocab) and (self.vocab_size >= self.max_cap): 158 | self.fixed_vocab = True 159 | 160 | def from_string(self, string): 161 | """ 162 | :param string: 163 | :return: 164 | """ 165 | if string in self.words2id: 166 | return self.words2id[string] 167 | else: 168 | if not self.fixed_vocab: 169 | newid = self.vocab_size 170 | self.id2words.append(string) 171 | self.words2id[string] = newid 172 | self.vocab_size += 1 173 | if self.vocab_size >= self.max_cap: 174 | self.fixed_vocab = True 175 | return newid 176 | else: 177 | return -1 178 | 179 | def from_id(self, qid, defineStr=""): 180 | """ 181 | :param qid: 182 | :param defineStr: 183 | :return: 184 | """ 185 | if int(qid) < 0 or self.vocab_size <= qid: 186 | return defineStr 187 | else: 188 | return self.id2words[qid] 189 | 190 | def initial_from_pretrain(self, pretrain_file, unk, padding): 191 | """ 192 | :param pretrain_file: 193 | :param unk: 194 | :param padding: 195 | :return: 196 | """ 197 | print("initial alphabet from {}".format(pretrain_file)) 198 | self.initial(unk) 199 | self.initial(padding) 200 | now_line = 0 201 | with open(pretrain_file, encoding="UTF-8") as f: 202 | for line in f.readlines(): 203 | now_line += 1 204 | sys.stdout.write("\rhandling with {} line".format(now_line)) 205 | info = line.split(" ") 206 | self.from_string(info[0]) 207 | f.close() 208 | print("\nHandle Finished.") 209 | 210 | 211 | 212 | -------------------------------------------------------------------------------- /DataUtils/Batch_Iterator.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/1/30 15:55 3 | # @File : Batch_Iterator.py.py 4 | # @Last Modify Time : 2018/1/30 15:55 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Batch_Iterator.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch 13 | from torch.autograd import Variable 14 | import random 15 | import numpy as np 16 | 17 | from DataUtils.Common import * 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class Batch_Features: 23 | """ 24 | Batch_Features 25 | """ 26 | def __init__(self): 27 | 28 | self.batch_length = 0 29 | self.inst = None 30 | self.word_features = None 31 | self.bert_features = None 32 | self.label_features = None 33 | self.sentence_length = [] 34 | 35 | @staticmethod 36 | def cuda(features, device): 37 | """ 38 | :param features: 39 | :param device: 40 | :return: 41 | """ 42 | features.word_features = features.word_features.to(device) 43 | # features.bert_features = features.bert_features.to(device) 44 | features.label_features = features.label_features.to(device) 45 | 46 | 47 | class Iterators: 48 | """ 49 | Iterators 50 | """ 51 | def __init__(self, batch_size=None, data=None, alphabet=None, config=None): 52 | self.config = config 53 | self.device = self.config.device 54 | self.batch_size = batch_size 55 | self.data = data 56 | self.alphabet = alphabet 57 | self.operator_static = None 58 | self.iterator = [] 59 | self.batch = [] 60 | self.features = [] 61 | self.data_iter = [] 62 | 63 | def createIterator(self): 64 | """ 65 | :param batch_size: batch size 66 | :param data: data 67 | :param operator: 68 | :param config: 69 | :return: 70 | """ 71 | assert isinstance(self.data, list), "ERROR: data must be in list [train_data,dev_data]" 72 | assert isinstance(self.batch_size, list), "ERROR: batch_size must be in list [16,1,1]" 73 | for id_data in range(len(self.data)): 74 | print("***************** create {} iterator **************".format(id_data + 1)) 75 | self._convert_word2id(self.data[id_data], self.alphabet) 76 | self.features = self._Create_Each_Iterator(insts=self.data[id_data], batch_size=self.batch_size[id_data], 77 | alphabet=self.alphabet) 78 | self.data_iter.append(self.features) 79 | self.features = [] 80 | if len(self.data_iter) == 2: 81 | return self.data_iter[0], self.data_iter[1] 82 | if len(self.data_iter) == 3: 83 | return self.data_iter[0], self.data_iter[1], self.data_iter[2] 84 | 85 | @staticmethod 86 | def _convert_word2id(insts, operator): 87 | """ 88 | :param insts: 89 | :param operator: 90 | :return: 91 | """ 92 | for inst in insts: 93 | # word 94 | for index in range(inst.words_size): 95 | word = inst.words[index] 96 | wordId = operator.word_alphabet.from_string(word) 97 | if wordId == -1: 98 | wordId = operator.word_unkId 99 | inst.words_index.append(wordId) 100 | 101 | # label 102 | label = inst.labels[0] 103 | labelId = operator.label_alphabet.from_string(label) 104 | inst.label_index.append(labelId) 105 | 106 | def _Create_Each_Iterator(self, insts, batch_size, alphabet): 107 | """ 108 | :param insts: 109 | :param batch_size: 110 | :param operator: 111 | :return: 112 | """ 113 | batch = [] 114 | count_inst = 0 115 | for index, inst in enumerate(insts): 116 | batch.append(inst) 117 | count_inst += 1 118 | if len(batch) == batch_size or count_inst == len(insts): 119 | one_batch = self._Create_Each_Batch(insts=batch, alphabet=alphabet) 120 | self.features.append(one_batch) 121 | batch = [] 122 | print("The all data has created iterator.") 123 | return self.features 124 | 125 | def _Create_Each_Batch(self, insts, alphabet): 126 | """ 127 | :param insts: 128 | :param batch_size: 129 | :param operator: 130 | :return: 131 | """ 132 | # print("create one batch......") 133 | batch_length = len(insts) 134 | # copy with the max length for padding 135 | max_word_size = -1 136 | max_bert_tokens_size = -1 137 | sentence_length = [] 138 | for inst in insts: 139 | sentence_length.append(inst.words_size) 140 | word_size = inst.words_size 141 | if word_size > max_word_size: 142 | max_word_size = word_size 143 | 144 | bert_token_size = len(inst.bert_tokens) 145 | if bert_token_size > max_bert_tokens_size: 146 | max_bert_tokens_size = bert_token_size 147 | 148 | # create with the Tensor/Variable 149 | # word/label features 150 | bert_dim = self.config.bert_dim 151 | batch_word_features = np.zeros((batch_length, max_word_size)) 152 | batch_bert_features = np.zeros((batch_length, max_bert_tokens_size, bert_dim)) 153 | batch_label_features = np.zeros((batch_length * 1)) 154 | 155 | for id_inst in range(batch_length): 156 | inst = insts[id_inst] 157 | # copy with the word features 158 | for id_word_index in range(max_word_size): 159 | if id_word_index < inst.words_size: 160 | batch_word_features[id_inst][id_word_index] = inst.words_index[id_word_index] 161 | else: 162 | batch_word_features[id_inst][id_word_index] = alphabet.word_paddingId 163 | length = len(inst.bert_tokens) 164 | batch_bert_features[id_inst][:length] = np.array(inst.bert_feature) 165 | 166 | # label 167 | batch_label_features[id_inst] = inst.label_index[0] 168 | 169 | # numpy to embedding 170 | batch_word_features = torch.from_numpy(batch_word_features).long() 171 | batch_label_features = torch.from_numpy(batch_label_features).long() 172 | 173 | # batch 174 | features = Batch_Features() 175 | features.inst = insts 176 | features.batch_length = batch_length 177 | features.word_features = batch_word_features 178 | features.bert_features = batch_bert_features 179 | features.label_features = batch_label_features 180 | features.sentence_length = sentence_length 181 | 182 | if self.config.use_bert: 183 | batch_bert_features = torch.from_numpy(batch_bert_features).float() 184 | features.bert_features = batch_bert_features 185 | features.bert_features.to(self.device) 186 | 187 | if self.config.device != cpu_device: 188 | features.cuda(features, self.device) 189 | return features 190 | 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /DataUtils/Common.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/1/30 15:55 3 | # @File : common.py 4 | # @Last Modify Time : 2018/1/30 15:55 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : common.py 9 | FUNCTION : Common File 10 | """ 11 | 12 | seed_num = 66 13 | unkkey = "" 14 | paddingkey = "" 15 | cpu_device = "cpu" 16 | -------------------------------------------------------------------------------- /DataUtils/Embed.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/27 15:34 3 | # @File : Embed.py 4 | # @Last Modify Time : 2018/8/27 15:34 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Embed.py 9 | FUNCTION : None 10 | """ 11 | 12 | import os 13 | import sys 14 | import time 15 | import tqdm 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.init as init 20 | from collections import OrderedDict 21 | from DataUtils.Common import * 22 | torch.manual_seed(seed_num) 23 | np.random.seed(seed_num) 24 | 25 | 26 | class Embed(object): 27 | """ 28 | Embed 29 | """ 30 | def __init__(self, path, words_dict, embed_type, pad): 31 | self.embed_type_enum = ["zero", "avg", "uniform", "nn"] 32 | self.path = path 33 | self.words_dict = words_dict 34 | self.embed_type = embed_type 35 | self.pad = pad 36 | # print(self.words_dict) 37 | if not isinstance(self.words_dict, dict): 38 | self.words_dict, self.words_list = self._list2dict(self.words_dict) 39 | if pad is not None: self.padID = self.words_dict[pad] 40 | # print(self.words_dict) 41 | self.dim, self.words_count = self._get_dim(path=self.path), len(self.words_dict) 42 | self.exact_count, self.fuzzy_count, self.oov_count = 0, 0, 0 43 | self.use_time = 0 44 | 45 | def get_embed(self): 46 | """ 47 | :return: 48 | """ 49 | start_time = time.time() 50 | embed_dict = None 51 | if self.embed_type in self.embed_type_enum: 52 | embed_dict = self._read_file(path=self.path) 53 | else: 54 | print("embed_type illegal, must be in {}".format(self.embed_type_enum)) 55 | exit() 56 | # print(embed_dict) 57 | embed = None 58 | if self.embed_type == "nn": 59 | embed = self._nn_embed(embed_dict=embed_dict, words_dict=self.words_dict) 60 | elif self.embed_type == "zero": 61 | embed = self._zeros_embed(embed_dict=embed_dict, words_dict=self.words_dict) 62 | elif self.embed_type == "uniform": 63 | embed = self._uniform_embed(embed_dict=embed_dict, words_dict=self.words_dict) 64 | elif self.embed_type == "avg": 65 | embed = self._avg_embed(embed_dict=embed_dict, words_dict=self.words_dict) 66 | # print(embed) 67 | end_time = time.time() 68 | self.use_time = end_time - start_time 69 | self.info() 70 | return embed 71 | 72 | def _zeros_embed(self, embed_dict, words_dict): 73 | """ 74 | :param embed_dict: 75 | :param words_dict: 76 | """ 77 | print("loading pre_train embedding by zeros for out of vocabulary.") 78 | embeddings = np.zeros((int(self.words_count), int(self.dim))) 79 | for word in words_dict: 80 | if word in embed_dict: 81 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word]], dtype='float32') 82 | self.exact_count += 1 83 | elif word.lower() in embed_dict: 84 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word.lower()]], dtype='float32') 85 | self.fuzzy_count += 1 86 | else: 87 | self.oov_count += 1 88 | final_embed = torch.from_numpy(embeddings).float() 89 | return final_embed 90 | 91 | def _nn_embed(self, embed_dict, words_dict): 92 | """ 93 | :param embed_dict: 94 | :param words_dict: 95 | """ 96 | print("loading pre_train embedding by nn.Embedding for out of vocabulary.") 97 | embed = nn.Embedding(int(self.words_count), int(self.dim)) 98 | init.xavier_uniform_(embed.weight.data) 99 | embeddings = np.array(embed.weight.data) 100 | for word in words_dict: 101 | if word in embed_dict: 102 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word]], dtype='float32') 103 | self.exact_count += 1 104 | elif word.lower() in embed_dict: 105 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word.lower()]], dtype='float32') 106 | self.fuzzy_count += 1 107 | else: 108 | self.oov_count += 1 109 | embeddings[self.padID] = 0 110 | final_embed = torch.from_numpy(embeddings).float() 111 | return final_embed 112 | 113 | def _uniform_embed(self, embed_dict, words_dict): 114 | """ 115 | :param embed_dict: 116 | :param words_dict: 117 | """ 118 | print("loading pre_train embedding by uniform for out of vocabulary.") 119 | embeddings = np.zeros((int(self.words_count), int(self.dim))) 120 | inword_list = {} 121 | for word in words_dict: 122 | if word in embed_dict: 123 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word]], dtype='float32') 124 | inword_list[words_dict[word]] = 1 125 | self.exact_count += 1 126 | elif word.lower() in embed_dict: 127 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word.lower()]], dtype='float32') 128 | inword_list[words_dict[word]] = 1 129 | self.fuzzy_count += 1 130 | else: 131 | self.oov_count += 1 132 | uniform_col = np.random.uniform(-0.25, 0.25, int(self.dim)).round(6) # uniform 133 | for i in range(len(words_dict)): 134 | if i not in inword_list and i != self.padID: 135 | embeddings[i] = uniform_col 136 | final_embed = torch.from_numpy(embeddings).float() 137 | return final_embed 138 | 139 | def _avg_embed(self, embed_dict, words_dict): 140 | """ 141 | :param embed_dict: 142 | :param words_dict: 143 | """ 144 | print("loading pre_train embedding by avg for out of vocabulary.") 145 | embeddings = np.zeros((int(self.words_count), int(self.dim))) 146 | inword_list = {} 147 | for word in words_dict: 148 | if word in embed_dict: 149 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word]], dtype='float32') 150 | inword_list[words_dict[word]] = 1 151 | self.exact_count += 1 152 | elif word.lower() in embed_dict: 153 | embeddings[words_dict[word]] = np.array([float(i) for i in embed_dict[word.lower()]], dtype='float32') 154 | inword_list[words_dict[word]] = 1 155 | self.fuzzy_count += 1 156 | else: 157 | self.oov_count += 1 158 | sum_col = np.sum(embeddings, axis=0) / len(inword_list) # avg 159 | for i in range(len(words_dict)): 160 | if i not in inword_list and i != self.padID: 161 | embeddings[i] = sum_col 162 | final_embed = torch.from_numpy(embeddings).float() 163 | return final_embed 164 | 165 | @staticmethod 166 | def _read_file(path): 167 | """ 168 | :param path: embed file path 169 | :return: 170 | """ 171 | embed_dict = {} 172 | with open(path, encoding='utf-8') as f: 173 | lines = f.readlines() 174 | lines = tqdm.tqdm(lines) 175 | for line in lines: 176 | values = line.strip().split(' ') 177 | if len(values) == 1 or len(values) == 2 or len(values) == 3: 178 | continue 179 | w, v = values[0], values[1:] 180 | embed_dict[w] = v 181 | return embed_dict 182 | 183 | def info(self): 184 | """ 185 | :return: 186 | """ 187 | total_count = self.exact_count + self.fuzzy_count 188 | print("Words count {}, Embed dim {}.".format(self.words_count, self.dim)) 189 | print("Exact count {} / {}".format(self.exact_count, self.words_count)) 190 | print("Fuzzy count {} / {}".format(self.fuzzy_count, self.words_count)) 191 | print(" INV count {} / {}".format(total_count, self.words_count)) 192 | print(" OOV count {} / {}".format(self.oov_count, self.words_count)) 193 | print(" OOV radio ===> {}%".format(np.round((self.oov_count / self.words_count) * 100, 2))) 194 | print("Pre-Train Embed Time {:.4f}".format(self.use_time)) 195 | print(40 * "*") 196 | 197 | @staticmethod 198 | def _get_dim(path): 199 | """ 200 | :param path: 201 | :return: 202 | """ 203 | embedding_dim = -1 204 | with open(path, encoding='utf-8') as f: 205 | for line in f: 206 | line_split = line.strip().split(' ') 207 | if len(line_split) == 1: 208 | embedding_dim = line_split[0] 209 | break 210 | elif len(line_split) == 2: 211 | embedding_dim = line_split[1] 212 | break 213 | else: 214 | embedding_dim = len(line_split) - 1 215 | break 216 | return embedding_dim 217 | 218 | @staticmethod 219 | def _list2dict(convert_list): 220 | """ 221 | :param convert_list: 222 | :return: 223 | """ 224 | list_dict = OrderedDict() 225 | list_lower = [] 226 | for index, word in enumerate(convert_list): 227 | list_lower.append(word.lower()) 228 | list_dict[word] = index 229 | assert len(list_lower) == len(list_dict) 230 | return list_dict, list_lower 231 | 232 | -------------------------------------------------------------------------------- /DataUtils/Optim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch.optim 3 | from torch.nn.utils.clip_grad import clip_grad_norm_ 4 | 5 | # Setup optimizer (should always come after model.cuda()) 6 | # iterable of dicts for per-param options where each dict 7 | # is {'params' : [p1, p2, p3...]}.update(generic optimizer args) 8 | # Example: 9 | # optim.SGD([ 10 | # {'params': model.base.parameters()}, 11 | # {'params': model.classifier.parameters(), 'lr': 1e-3} 12 | # ], lr=1e-2, momentum=0.9) 13 | 14 | 15 | def decay_learning_rate(optimizer, epoch, init_lr, lr_decay): 16 | """衰减学习率 17 | 18 | Args: 19 | epoch: int, 迭代次数 20 | init_lr: 初始学习率 21 | """ 22 | lr = init_lr / (1 + lr_decay * epoch) 23 | print('learning rate: {0}'.format(lr)) 24 | for param_group in optimizer.param_groups: 25 | param_group['lr'] = lr 26 | return optimizer 27 | 28 | 29 | class Optimizer(object): 30 | # Class dict to map lowercase identifiers to actual classes 31 | methods = { 32 | 'Adadelta': torch.optim.Adadelta, 33 | 'Adagrad': torch.optim.Adagrad, 34 | 'Adam': torch.optim.Adam, 35 | 'SGD': torch.optim.SGD, 36 | 'ASGD': torch.optim.ASGD, 37 | 'Rprop': torch.optim.Rprop, 38 | 'RMSprop': torch.optim.RMSprop, 39 | } 40 | 41 | @staticmethod 42 | def get_params(model): 43 | """Returns all name, parameter pairs with requires_grad=True.""" 44 | return list( 45 | filter(lambda p: p[1].requires_grad, model.named_parameters())) 46 | 47 | def __init__(self, 48 | name, 49 | model, 50 | lr=0, 51 | weight_decay=0, 52 | grad_clip=None, 53 | optim_args=None, 54 | momentum=None, 55 | **kwargs): 56 | """ 57 | :param decay_method: Method of learning rate decay. 58 | 59 | """ 60 | 61 | self.name = name 62 | self.model = model 63 | self.init_lr = lr 64 | self.weight_decay = weight_decay 65 | self.momentum = momentum 66 | # self.gclip = grad_clip 67 | self.gclip = None if grad_clip == "None" else float(grad_clip) 68 | # print(self.gclip) 69 | 70 | self._count = 0 71 | 72 | # TODO: 73 | # pass external optimizer configs 74 | if optim_args is None: 75 | optim_args = {} 76 | 77 | self.optim_args = optim_args 78 | 79 | # If an explicit lr given, pass it to torch optimizer 80 | if self.init_lr > 0: 81 | self.optim_args['lr'] = self.init_lr 82 | 83 | if self.name == "SGD" and self.momentum is not None: 84 | self.optim_args['momentum'] = self.momentum 85 | 86 | # Get all parameters that require grads 87 | self.named_params = self.get_params(self.model) 88 | 89 | # Filter out names for gradient clipping 90 | self.params = [param for (name, param) in self.named_params] 91 | 92 | if self.weight_decay > 0: 93 | weight_group = { 94 | 'params': [p for n, p in self.named_params if 'bias' not in n], 95 | 'weight_decay': self.weight_decay, 96 | } 97 | bias_group = { 98 | 'params': [p for n, p in self.named_params if 'bias' in n], 99 | } 100 | self.param_groups = [weight_group, bias_group] 101 | 102 | # elif self.name == "SGD" and self.momentum is not None: 103 | 104 | 105 | else: 106 | self.param_groups = [{'params': self.params}] 107 | 108 | # Safety check 109 | n_params = len(self.params) 110 | for group in self.param_groups: 111 | n_params -= len(group['params']) 112 | assert n_params == 0, "Not all params are passed to the optimizer." 113 | 114 | # Create the actual optimizer 115 | self.optim = self.methods[self.name](self.param_groups, 116 | **self.optim_args) 117 | 118 | # Assign shortcuts 119 | self.zero_grad = self.optim.zero_grad 120 | 121 | # Skip useless if evaluation logic if gradient_clip not requested 122 | if self.gclip == 0 or self.gclip is None: 123 | self.step = self.optim.step 124 | 125 | def zero_grad(self): 126 | self.optim.zero_grad() 127 | 128 | def step(self, closure=None): 129 | """Gradient clipping aware step().""" 130 | if self.gclip is not None and self.gclip > 0: 131 | # print("aaaa") 132 | clip_grad_norm_(self.params, self.gclip) 133 | self.optim.step(closure) 134 | 135 | def rescale_lrate(self, scale, min_lrate=-1.0): 136 | if isinstance(scale, list): 137 | for scale_, group in zip(scale, self.optim.param_groups): 138 | group['lr'] = max(group['lr'] * scale_, min_lrate) 139 | else: 140 | for group in self.optim.param_groups: 141 | group['lr'] = max(group['lr'] * scale, min_lrate) 142 | 143 | def get_lrate(self): 144 | for group in self.optim.param_groups: 145 | yield group['lr'] 146 | 147 | def set_lrate(self, lr): 148 | if isinstance(lr, list): 149 | for lr_, group in zip(lr, self.optim.param_groups): 150 | group['lr'] = lr_ 151 | else: 152 | for group in self.optim.param_groups: 153 | group['lr'] = lr 154 | 155 | def __repr__(self): 156 | s = "Optimizer => {} (lr: {}, weight_decay: {}, g_clip: {})".format( 157 | self.name, self.init_lr, self.weight_decay, self.gclip) 158 | return s -------------------------------------------------------------------------------- /DataUtils/Pickle.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/23 12:26 3 | # @File : Pickle.py 4 | # @Last Modify Time : 2018/8/23 12:26 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Pickle.py 9 | FUNCTION : None 10 | """ 11 | 12 | # Introduce python packages 13 | import sys 14 | import os 15 | import pickle 16 | 17 | # Introduce missing packages in here 18 | 19 | 20 | class Pickle(object): 21 | def __init__(self): 22 | print("Pickle") 23 | self.obj_count = 0 24 | 25 | @staticmethod 26 | def save(obj, path, mode="wb"): 27 | """ 28 | :param obj: obj dict to dump 29 | :param path: save path 30 | :param mode: file mode 31 | """ 32 | print("save obj to {}".format(path)) 33 | # print("obj", obj) 34 | assert isinstance(obj, dict), "The type of obj must be a dict type." 35 | if os.path.exists(path): 36 | os.remove(path) 37 | pkl_file = open(path, mode=mode) 38 | pickle.dump(obj, pkl_file) 39 | pkl_file.close() 40 | 41 | @staticmethod 42 | def load(path, mode="rb"): 43 | """ 44 | :param path: pkl path 45 | :param mode: file mode 46 | :return: data dict 47 | """ 48 | print("load obj from {}".format(path)) 49 | if os.path.exists(path) is False: 50 | print("Path {} illegal.".format(path)) 51 | pkl_file = open(path, mode=mode) 52 | data = pickle.load(pkl_file) 53 | pkl_file.close() 54 | return data 55 | 56 | 57 | pcl = Pickle 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /DataUtils/README.md: -------------------------------------------------------------------------------- 1 | ## Role of document in `DataUtils` directory ## 2 | - DataUtils 3 | - `Alphabet.py` ------ Build vocab by train data or dev/test data 4 | 5 | - `Batch_Iterator.py` ------ Build batch and iterator for train/dev/test data, get train/dev/test iterator 6 | 7 | - `Common.py` ------ The file contains some common attribute, like random seeds, padding, unk and others 8 | 9 | - `Embed.py` ------ Loading Pre-trained word embedding( `glove` or `word2vec` ), `zeros,avg, uniform, nn.Embedding for OOV`. 10 | 11 | - `Optim.py` ------ Encapsulate the `optimizer`. 12 | 13 | - `Pickle.py` ------ Encapsulate the `pickle`. 14 | 15 | - `utils.py` ------ common function. 16 | 17 | - `mainHelp.py` ------ main help file. 18 | -------------------------------------------------------------------------------- /DataUtils/__init__.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/23 17:39 3 | # @File : __init__.py 4 | # @Last Modify Time : 2018/8/23 17:39 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : __init__.py 9 | FUNCTION : None 10 | """ -------------------------------------------------------------------------------- /DataUtils/mainHelp.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/3 10:50 3 | # @File : mainHelp.py 4 | # @Last Modify Time : 2018/9/3 10:50 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : mainHelp.py 9 | FUNCTION : None 10 | """ 11 | 12 | import shutil 13 | import time 14 | from DataUtils.Alphabet import * 15 | from DataUtils.Batch_Iterator import * 16 | from DataUtils.Embed import Embed 17 | from Dataloader.DataLoader_SST_Binary import DataLoader 18 | from models.Text_Classification.Text_Classification import Text_Classification 19 | from models.Text_Classification_BertFeature.Text_Classification import Text_Classification_BertFeature 20 | from test import load_test_model 21 | 22 | # solve default encoding problem 23 | from imp import reload 24 | defaultencoding = 'utf-8' 25 | if sys.getdefaultencoding() != defaultencoding: 26 | reload(sys) 27 | sys.setdefaultencoding(defaultencoding) 28 | 29 | # random seed 30 | torch.manual_seed(seed_num) 31 | random.seed(seed_num) 32 | 33 | 34 | def get_learning_algorithm(config): 35 | """ 36 | :param config: config 37 | :return: optimizer algorithm 38 | """ 39 | algorithm = None 40 | if config.adam is True: 41 | algorithm = "Adam" 42 | elif config.sgd is True: 43 | algorithm = "SGD" 44 | print("learning algorithm is {}.".format(algorithm)) 45 | return algorithm 46 | 47 | 48 | def get_params(config, alphabet): 49 | """ 50 | :param config: config 51 | :param alphabet: alphabet dict 52 | :return: 53 | """ 54 | # get algorithm 55 | config.learning_algorithm = get_learning_algorithm(config) 56 | 57 | # save best model path 58 | config.save_best_model_path = config.save_best_model_dir 59 | if config.test is False: 60 | if os.path.exists(config.save_best_model_path): 61 | shutil.rmtree(config.save_best_model_path) 62 | 63 | # get params 64 | config.embed_num = alphabet.word_alphabet.vocab_size # word number 65 | config.label_num = alphabet.label_alphabet.vocab_size # label number 66 | config.paddingId = alphabet.word_paddingId 67 | config.alphabet = alphabet 68 | print("embed_num : {}, class_num : {}".format(config.embed_num, config.label_num)) 69 | print("PaddingID {}".format(config.paddingId)) 70 | 71 | 72 | def save_dict2file(dict, path): 73 | """ 74 | :param dict: dict 75 | :param path: path to save dict 76 | :return: 77 | """ 78 | print("Saving dictionary") 79 | if os.path.exists(path): 80 | print("path {} is exist, deleted.".format(path)) 81 | file = open(path, encoding="UTF-8", mode="w") 82 | for word, index in dict.items(): 83 | # print(word, index) 84 | file.write(str(word) + "\t" + str(index) + "\n") 85 | file.close() 86 | print("Save dictionary finished.") 87 | 88 | 89 | def save_dictionary(config): 90 | """ 91 | :param config: config 92 | :return: 93 | """ 94 | if config.save_dict is True: 95 | if os.path.exists(config.dict_directory): 96 | shutil.rmtree(config.dict_directory) 97 | if not os.path.isdir(config.dict_directory): 98 | os.makedirs(config.dict_directory) 99 | 100 | config.word_dict_path = "/".join([config.dict_directory, config.word_dict]) 101 | config.label_dict_path = "/".join([config.dict_directory, config.label_dict]) 102 | print("word_dict_path : {}".format(config.word_dict_path)) 103 | print("label_dict_path : {}".format(config.label_dict_path)) 104 | save_dict2file(config.alphabet.word_alphabet.words2id, config.word_dict_path) 105 | save_dict2file(config.alphabet.label_alphabet.words2id, config.label_dict_path) 106 | # copy to mulu 107 | print("copy dictionary to {}".format(config.save_dir)) 108 | shutil.copytree(config.dict_directory, "/".join([config.save_dir, config.dict_directory])) 109 | 110 | 111 | # load data / create alphabet / create iterator 112 | def preprocessing(config): 113 | """ 114 | :param config: config 115 | :return: 116 | """ 117 | print("Processing Data......") 118 | # read file 119 | data_loader = DataLoader(path=[config.train_file, config.dev_file, config.test_file], shuffle=True, config=config) 120 | train_data, dev_data, test_data = data_loader.dataLoader() 121 | print("train sentence {}, dev sentence {}, test sentence {}.".format(len(train_data), len(dev_data), len(test_data))) 122 | data_dict = {"train_data": train_data, "dev_data": dev_data, "test_data": test_data} 123 | if config.save_pkl: 124 | torch.save(obj=data_dict, f=os.path.join(config.pkl_directory, config.pkl_data)) 125 | 126 | # create the alphabet 127 | alphabet = None 128 | if config.embed_finetune is False: 129 | alphabet = CreateAlphabet(min_freq=config.min_freq, train_data=train_data, dev_data=dev_data, test_data=test_data, config=config) 130 | alphabet.build_vocab() 131 | if config.embed_finetune is True: 132 | alphabet = CreateAlphabet(min_freq=config.min_freq, train_data=train_data, config=config) 133 | alphabet.build_vocab() 134 | alphabet_dict = {"alphabet": alphabet} 135 | if config.save_pkl: 136 | torch.save(obj=alphabet_dict, f=os.path.join(config.pkl_directory, config.pkl_alphabet)) 137 | 138 | # create iterator 139 | create_iter = Iterators(batch_size=[config.batch_size, config.dev_batch_size, config.test_batch_size], 140 | data=[train_data, dev_data, test_data], alphabet=alphabet, config=config) 141 | train_iter, dev_iter, test_iter = create_iter.createIterator() 142 | iter_dict = {"train_iter": train_iter, "dev_iter": dev_iter, "test_iter": test_iter} 143 | if config.save_pkl: 144 | torch.save(obj=iter_dict, f=os.path.join(config.pkl_directory, config.pkl_iter)) 145 | return train_iter, dev_iter, test_iter, alphabet 146 | 147 | 148 | def pre_embed(config, alphabet): 149 | """ 150 | :param config: config 151 | :param alphabet: alphabet dict 152 | :return: pre-train embed 153 | """ 154 | print("***************************************") 155 | pretrain_embed = None 156 | embed_types = "" 157 | if config.pretrained_embed and config.zeros: 158 | embed_types = "zero" 159 | elif config.pretrained_embed and config.avg: 160 | embed_types = "avg" 161 | elif config.pretrained_embed and config.uniform: 162 | embed_types = "uniform" 163 | elif config.pretrained_embed and config.nnembed: 164 | embed_types = "nn" 165 | if config.pretrained_embed is True: 166 | p = Embed(path=config.pretrained_embed_file, words_dict=alphabet.word_alphabet.id2words, embed_type=embed_types, 167 | pad=paddingkey) 168 | pretrain_embed = p.get_embed() 169 | 170 | embed_dict = {"pretrain_embed": pretrain_embed} 171 | torch.save(obj=embed_dict, f=os.path.join(config.pkl_directory, config.pkl_embed)) 172 | 173 | return pretrain_embed 174 | 175 | 176 | def load_model(config): 177 | """ 178 | :param config: config 179 | :return: nn model 180 | """ 181 | print("***************************************") 182 | if config.use_bert: 183 | model = Text_Classification_BertFeature(config) 184 | else: 185 | model = Text_Classification(config) 186 | if config.device != cpu_device: 187 | model = model.to(config.device) 188 | if config.test is True: 189 | model = load_test_model(model, config) 190 | print(model) 191 | return model 192 | 193 | 194 | def load_data(config): 195 | """ 196 | :param config: config 197 | :return: batch data iterator and alphabet 198 | """ 199 | print("load data for process or pkl data.") 200 | train_iter, dev_iter, test_iter = None, None, None 201 | alphabet = None 202 | start_time = time.time() 203 | if (config.train is True) and (config.process is True): 204 | print("process data") 205 | if os.path.exists(config.pkl_directory): shutil.rmtree(config.pkl_directory) 206 | if not os.path.isdir(config.pkl_directory): os.makedirs(config.pkl_directory) 207 | train_iter, dev_iter, test_iter, alphabet = preprocessing(config) 208 | config.pretrained_weight = pre_embed(config=config, alphabet=alphabet) 209 | elif ((config.train is True) and (config.process is False)) or (config.test is True): 210 | print("load data from pkl file") 211 | # load alphabet from pkl 212 | alphabet_dict = torch.load(f=os.path.join(config.pkl_directory, config.pkl_alphabet)) 213 | print(alphabet_dict.keys()) 214 | alphabet = alphabet_dict["alphabet"] 215 | # load iter from pkl 216 | iter_dict = torch.load(f=os.path.join(config.pkl_directory, config.pkl_iter)) 217 | print(iter_dict.keys()) 218 | train_iter, dev_iter, test_iter = iter_dict.values() 219 | # train_iter, dev_iter, test_iter = iter_dict["train_iter"], iter_dict["dev_iter"], iter_dict["test_iter"] 220 | # load embed from pkl 221 | config.pretrained_weight = None 222 | if os.path.exists(os.path.join(config.pkl_directory, config.pkl_embed)): 223 | embed_dict = torch.load(f=os.path.join(config.pkl_directory, config.pkl_embed)) 224 | print(embed_dict.keys()) 225 | embed = embed_dict["pretrain_embed"] 226 | config.pretrained_weight = embed 227 | end_time = time.time() 228 | print("All Data/Alphabet/Iterator Use Time {:.4f}".format(end_time - start_time)) 229 | print("****************************************") 230 | return train_iter, dev_iter, test_iter, alphabet 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /DataUtils/utils.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/24 9:58 3 | # @File : utils.py 4 | # @Last Modify Time : 2018/8/24 9:58 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : utils.py 9 | FUNCTION : None 10 | """ 11 | import sys 12 | import os 13 | import torch 14 | import numpy as np 15 | 16 | 17 | class Best_Result: 18 | """ 19 | Best_Result 20 | """ 21 | def __init__(self): 22 | self.current_dev_score = -1 23 | self.best_dev_score = -1 24 | self.best_score = -1 25 | self.best_epoch = 1 26 | self.best_test = False 27 | self.early_current_patience = 0 28 | self.p = -1 29 | self.r = -1 30 | self.f = -1 31 | 32 | 33 | def getMaxindex(model_out, label_size, args): 34 | """ 35 | :param model_out: model output 36 | :param label_size: label size 37 | :param args: argument 38 | :return: max index for predict 39 | """ 40 | max = model_out.data[0] 41 | maxIndex = 0 42 | for idx in range(1, label_size): 43 | if model_out.data[idx] > max: 44 | max = model_out.data[idx] 45 | maxIndex = idx 46 | return maxIndex 47 | 48 | 49 | def getMaxindex_np(model_out): 50 | """ 51 | :param model_out: model output 52 | :return: max index for predict 53 | """ 54 | model_out_list = model_out.data.tolist() 55 | maxIndex = model_out_list.index(np.max(model_out_list)) 56 | return maxIndex 57 | 58 | 59 | def getMaxindex_batch(model_out): 60 | """ 61 | :param model_out: model output 62 | :return: max index for predict 63 | """ 64 | model_out_list = model_out.data.tolist() 65 | maxIndex_batch = [] 66 | for l in model_out_list: 67 | maxIndex_batch.append(l.index(np.max(l))) 68 | 69 | return maxIndex_batch 70 | 71 | 72 | def torch_max(output): 73 | """ 74 | :param output: batch * seq_len * label_num 75 | :return: 76 | """ 77 | # print(output) 78 | batch_size = output.size(0) 79 | _, arg_max = torch.max(output, dim=2) 80 | # print(arg_max) 81 | label = [] 82 | for i in range(batch_size): 83 | label.append(arg_max[i].cpu().data.numpy()) 84 | return label 85 | 86 | 87 | def save_model_all(model, save_dir, model_name, epoch): 88 | """ 89 | :param model: nn model 90 | :param save_dir: save model direction 91 | :param model_name: model name 92 | :param epoch: epoch 93 | :return: None 94 | """ 95 | if not os.path.isdir(save_dir): 96 | os.makedirs(save_dir) 97 | save_prefix = os.path.join(save_dir, model_name) 98 | save_path = '{}_epoch_{}.pt'.format(save_prefix, epoch) 99 | print("save all model to {}".format(save_path)) 100 | output = open(save_path, mode="wb") 101 | torch.save(model.state_dict(), output) 102 | # torch.save(model.state_dict(), save_path) 103 | output.close() 104 | 105 | 106 | def save_best_model(model, save_dir, model_name, best_eval): 107 | """ 108 | :param model: nn model 109 | :param save_dir: save model direction 110 | :param model_name: model name 111 | :param best_eval: eval best 112 | :return: None 113 | """ 114 | if best_eval.current_dev_score >= best_eval.best_dev_score: 115 | if not os.path.isdir(save_dir): os.makedirs(save_dir) 116 | model_name = "{}.pt".format(model_name) 117 | save_path = os.path.join(save_dir, model_name) 118 | print("save best model to {}".format(save_path)) 119 | # if os.path.exists(save_path): os.remove(save_path) 120 | output = open(save_path, mode="wb") 121 | torch.save(model.state_dict(), output) 122 | # torch.save(model.state_dict(), save_path) 123 | output.close() 124 | best_eval.early_current_patience = 0 125 | 126 | 127 | # adjust lr 128 | def get_lrate(optim): 129 | """ 130 | :param optim: optimizer 131 | :return: 132 | """ 133 | for group in optim.param_groups: 134 | yield group['lr'] 135 | 136 | 137 | def set_lrate(optim, lr): 138 | """ 139 | :param optim: optimizer 140 | :param lr: learning rate 141 | :return: 142 | """ 143 | for group in optim.param_groups: 144 | group['lr'] = lr 145 | 146 | -------------------------------------------------------------------------------- /Dataloader/DataLoader_SST_Binary.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/1/30 15:58 3 | # @File : DataConll2003_Loader.py 4 | # @Last Modify Time : 2018/1/30 15:58 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : 9 | FUNCTION : 10 | """ 11 | import sys 12 | import os 13 | import re 14 | import random 15 | import torch 16 | import json 17 | from Dataloader.Instance import Instance 18 | 19 | from DataUtils.Common import * 20 | torch.manual_seed(seed_num) 21 | random.seed(seed_num) 22 | 23 | 24 | class DataLoaderHelp(object): 25 | """ 26 | DataLoaderHelp 27 | """ 28 | 29 | @staticmethod 30 | def _clean_str(string): 31 | """ 32 | Tokenization/string cleaning for all datasets except for SST. 33 | Original taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 34 | """ 35 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 36 | string = re.sub(r"\'s", " \'s", string) 37 | string = re.sub(r"\'ve", " \'ve", string) 38 | string = re.sub(r"n\'t", " n\'t", string) 39 | string = re.sub(r"\'re", " \'re", string) 40 | string = re.sub(r"\'d", " \'d", string) 41 | string = re.sub(r"\'ll", " \'ll", string) 42 | string = re.sub(r",", " , ", string) 43 | string = re.sub(r"!", " ! ", string) 44 | string = re.sub(r"\(", " \( ", string) 45 | string = re.sub(r"\)", " \) ", string) 46 | string = re.sub(r"\?", " \? ", string) 47 | string = re.sub(r"\s{2,}", " ", string) 48 | return string.strip().lower() 49 | 50 | @staticmethod 51 | def _normalize_word(word): 52 | """ 53 | :param word: 54 | :return: 55 | """ 56 | new_word = "" 57 | for char in word: 58 | if char.isdigit(): 59 | new_word += '0' 60 | else: 61 | new_word += char 62 | return new_word 63 | 64 | @staticmethod 65 | def _sort(insts): 66 | """ 67 | :param insts: 68 | :return: 69 | """ 70 | sorted_insts = [] 71 | sorted_dict = {} 72 | for id_inst, inst in enumerate(insts): 73 | sorted_dict[id_inst] = inst.words_size 74 | dict = sorted(sorted_dict.items(), key=lambda d: d[1], reverse=True) 75 | for key, value in dict: 76 | sorted_insts.append(insts[key]) 77 | print("Sort Finished.") 78 | return sorted_insts 79 | 80 | 81 | class DataLoader(DataLoaderHelp): 82 | """ 83 | DataLoader 84 | """ 85 | def __init__(self, path, shuffle, config): 86 | """ 87 | :param path: data path list 88 | :param shuffle: shuffle bool 89 | :param config: config 90 | """ 91 | # 92 | print("Loading Data......") 93 | self.data_list = [] 94 | self.max_count = config.max_count 95 | self.path = path 96 | self.shuffle = shuffle 97 | 98 | # BERT 99 | self.bert_path = [config.bert_train_file, 100 | config.bert_dev_file, 101 | config.bert_test_file] 102 | 103 | self.use_bert = config.use_bert 104 | 105 | def dataLoader(self): 106 | """ 107 | :return: 108 | """ 109 | path = self.path 110 | shuffle = self.shuffle 111 | assert isinstance(path, list), "Path Must Be In List" 112 | print("Data Path {}".format(path)) 113 | for id_data in range(len(path)): 114 | print("Loading Data Form {}".format(path[id_data])) 115 | insts = self._Load_Each_Data(path=path[id_data], path_id=id_data) 116 | print("shuffle train data......") 117 | random.shuffle(insts) 118 | self.data_list.append(insts) 119 | # return train/dev/test data 120 | if len(self.data_list) == 3: 121 | return self.data_list[0], self.data_list[1], self.data_list[2] 122 | elif len(self.data_list) == 2: 123 | return self.data_list[0], self.data_list[1] 124 | 125 | def _Load_Each_Data(self, path=None, path_id=None): 126 | """ 127 | :param path: 128 | :param shuffle: 129 | :return: 130 | """ 131 | assert path is not None, "The Data Path Is Not Allow Empty." 132 | insts = [] 133 | now_lines = 0 134 | with open(path, encoding="UTF-8") as f: 135 | inst = Instance() 136 | for line in f.readlines(): 137 | line = line.strip() 138 | now_lines += 1 139 | if now_lines % 200 == 0: 140 | sys.stdout.write("\rreading the {} line\t".format(now_lines)) 141 | if line == "\n": 142 | print("empty line") 143 | 144 | inst = Instance() 145 | line = line.split() 146 | label = line[0] 147 | word = " ".join(line[1:]) 148 | if label not in ["0", "1"]: 149 | print("Error line: ", " ".join(line)) 150 | continue 151 | inst.words = self._clean_str(word).split() 152 | inst.labels.append(label) 153 | inst.words_size = len(inst.words) 154 | insts.append(inst) 155 | 156 | if len(insts) == self.max_count: 157 | break 158 | # print("\n") 159 | if self.use_bert: 160 | insts = self._read_bert_file(insts, path=self.bert_path[path_id]) 161 | return insts 162 | 163 | def _read_bert_file(self, insts, path): 164 | """ 165 | :param insts: 166 | :param path: 167 | :return: 168 | """ 169 | print("\nRead BERT Features File From {}".format(path)) 170 | now_lines = 0 171 | with open(path, encoding="utf-8") as f: 172 | for inst, bert_line in zip(insts, f.readlines()): 173 | now_lines += 1 174 | if now_lines % 2000 == 0: 175 | sys.stdout.write("\rreading the {} line\t".format(now_lines)) 176 | bert_fea = json.loads(bert_line) 177 | inst.bert_tokens = bert_fea["features"]["tokens"] 178 | inst.bert_feature = bert_fea["features"]["values"] 179 | # print(inst.bert_feature) 180 | sys.stdout.write("\rReading the {} line\t".format(now_lines)) 181 | return insts 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /Dataloader/Instance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # @Author : bamtercelboo 3 | # @Datetime : 2018/1/30 15:56 4 | # @File : Instance.py 5 | # @Last Modify Time : 2018/1/30 15:56 6 | # @Contact : bamtercelboo@{gmail.com, 163.com} 7 | 8 | """ 9 | FILE : Instance.py 10 | FUNCTION : Data Instance 11 | """ 12 | 13 | import torch 14 | import random 15 | 16 | from DataUtils.Common import * 17 | torch.manual_seed(seed_num) 18 | random.seed(seed_num) 19 | 20 | 21 | class Instance: 22 | """ 23 | Instance 24 | """ 25 | def __init__(self): 26 | self.words = [] 27 | self.labels = [] 28 | self.words_size = 0 29 | 30 | self.words_index = [] 31 | self.label_index = [] 32 | 33 | self.bert_tokens = [] 34 | self.bert_feature = None 35 | 36 | 37 | -------------------------------------------------------------------------------- /Dataloader/README.md: -------------------------------------------------------------------------------- 1 | ## Role of document in `Dataloader` directory ## 2 | - Dataloader 3 | - `DataLoader` ------ Load Data 4 | 5 | - `Instance.py` ------ It is a instance file to storage data that you read from train/dev/test file 6 | 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch_Bert_Text_Classification 2 | - Bert For Text Classification in SST 3 | 4 | ## Requirement ## 5 | 6 | PyTorch : 1.0.1 7 | Python : 3.6 8 | Cuda : 9.0 (support cuda speed up, can chose) 9 | 10 | ## Usage ## 11 | 12 | modify the config file, see the Config directory. 13 | 14 | 1、sh run_train_p.sh 15 | 2、python -u main.py --config ./Config/config.cfg --device cuda:0 --train -p 16 | 17 | 18 | 19 | ## Bert ## 20 | 21 | - use [Bert_Script](https://github.com/bamtercelboo/PyTorch_Bert_Text_Classification/tree/master/Bert_Script) to extract feature from **bert-base-uncased** bert model. 22 | 23 | ## Model ## 24 | 25 | - CNN 26 | - BiLSTM 27 | - BiLSTM + BertFeature 28 | - updating 29 | 30 | ## Data ## 31 | 32 | - SST-Binary 33 | 34 | ## Result ## 35 | The following test set accuracy are based on the best dev set accuracy. 36 | 37 | | Model |Bert-Encoder |% SST-Binary | 38 | | ------------ | ------------ | ------------ | 39 | | Bi-LSTM | - | 86.4360 | 40 | | Bi-LSTM | AvgPooling | 86.3811 | 41 | | Bi-LSTM | MaxPooling | 86.9303 | 42 | | Bi-LSTM | BiLSTM+MaxPool | 89.7309 | 43 | 44 | ## Reference ## 45 | 46 | - [cnn-lstm-bilstm-deepcnn-clstm-in-pytorch](https://github.com/bamtercelboo/cnn-lstm-bilstm-deepcnn-clstm-in-pytorch) 47 | 48 | - https://github.com/huggingface/pytorch-pretrained-BERT 49 | 50 | - https://github.com/google-research/bert 51 | 52 | ## Question ## 53 | 54 | - if you have any question, you can open a issue or email **bamtercelboo@{gmail.com, 163.com}**. 55 | 56 | - if you have any good suggestions, you can PR or email me. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/24 15:30 3 | # @File : __init__.py 4 | # @Last Modify Time : 2018/8/24 15:30 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : __init__.py 9 | FUNCTION : None 10 | """ -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/1/30 19:50 3 | # @File : main.py 4 | # @Last Modify Time : 2018/1/30 19:50 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : main.py 9 | FUNCTION : main 10 | """ 11 | 12 | import argparse 13 | import datetime 14 | import Config.config as configurable 15 | from DataUtils.mainHelp import * 16 | from DataUtils.Alphabet import * 17 | from test import load_test_data 18 | from test import T_Inference 19 | from trainer import Train 20 | import random 21 | 22 | # solve default encoding problem 23 | from imp import reload 24 | defaultencoding = 'utf-8' 25 | if sys.getdefaultencoding() != defaultencoding: 26 | reload(sys) 27 | sys.setdefaultencoding(defaultencoding) 28 | 29 | # random seed 30 | torch.manual_seed(seed_num) 31 | random.seed(seed_num) 32 | 33 | 34 | def start_train(train_iter, dev_iter, test_iter, model, config): 35 | """ 36 | :param train_iter: train batch data iterator 37 | :param dev_iter: dev batch data iterator 38 | :param test_iter: test batch data iterator 39 | :param model: nn model 40 | :param config: config 41 | :return: None 42 | """ 43 | t = Train(train_iter=train_iter, dev_iter=dev_iter, test_iter=test_iter, model=model, config=config) 44 | t.train() 45 | print("Finish Train.") 46 | 47 | 48 | def start_test(train_iter, dev_iter, test_iter, model, alphabet, config): 49 | """ 50 | :param train_iter: train batch data iterator 51 | :param dev_iter: dev batch data iterator 52 | :param test_iter: test batch data iterator 53 | :param model: nn model 54 | :param alphabet: alphabet dict 55 | :param config: config 56 | :return: None 57 | """ 58 | print("\nTesting Start......") 59 | print("Test is not Complete Now, Future will update, Sorry.") 60 | exit() 61 | data, path_source, path_result = load_test_data(train_iter, dev_iter, test_iter, config) 62 | infer = T_Inference(model=model, data=data, path_source=path_source, path_result=path_result, alphabet=alphabet, 63 | use_crf=config.use_crf, config=config) 64 | infer.infer2file() 65 | print("Finished Test.") 66 | 67 | 68 | def main(): 69 | """ 70 | main() 71 | :return: 72 | """ 73 | # save file 74 | config.mulu = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 75 | # config.add_args(key="mulu", value=datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 76 | config.save_dir = os.path.join(config.save_direction, config.mulu) 77 | if not os.path.isdir(config.save_dir): os.makedirs(config.save_dir) 78 | 79 | # get data, iter, alphabet 80 | train_iter, dev_iter, test_iter, alphabet = load_data(config=config) 81 | 82 | # get params 83 | get_params(config=config, alphabet=alphabet) 84 | 85 | # save dictionary 86 | save_dictionary(config=config) 87 | 88 | model = load_model(config) 89 | 90 | # print("Training Start......") 91 | if config.train is True: 92 | start_train(train_iter, dev_iter, test_iter, model, config) 93 | exit() 94 | elif config.test is True: 95 | start_test(train_iter, dev_iter, test_iter, model, alphabet, config) 96 | exit() 97 | 98 | 99 | def parse_argument(): 100 | """ 101 | :argument 102 | :return: 103 | """ 104 | parser = argparse.ArgumentParser(description="BERT For Text Classification") 105 | parser.add_argument("-c", "--config", dest="config_file", type=str, default="./Config/config.cfg", 106 | help="config path") 107 | parser.add_argument("-device", "--device", dest="device", type=str, default="cpu", 108 | help="device[‘cpu’,‘cuda:0’,‘cuda:1’,......]") 109 | parser.add_argument("--train", dest="train", action="store_true", default=True, help="train model") 110 | parser.add_argument("-p", "--process", dest="process", action="store_true", default=True, help="data process") 111 | parser.add_argument("-t", "--test", dest="test", action="store_true", default=False, help="test model") 112 | parser.add_argument("--t_model", dest="t_model", type=str, default=None, help="model for test") 113 | parser.add_argument("--t_data", dest="t_data", type=str, default=None, 114 | help="data[train dev test None] for test model") 115 | parser.add_argument("--predict", dest="predict", action="store_true", default=False, help="predict model") 116 | args = parser.parse_args() 117 | # print(vars(args)) 118 | config = configurable.Configurable(config_file=args.config_file) 119 | config.device = args.device 120 | config.train = args.train 121 | config.process = args.process 122 | config.test = args.test 123 | config.t_model = args.t_model 124 | config.t_data = args.t_data 125 | config.predict = args.predict 126 | # config 127 | if config.test is True: 128 | config.train = False 129 | if config.t_data not in [None, "train", "dev", "test"]: 130 | print("\nUsage") 131 | parser.print_help() 132 | print("t_data : {}, not in [None, 'train', 'dev', 'test']".format(config.t_data)) 133 | exit() 134 | print("***************************************") 135 | print("Device : {}".format(config.device)) 136 | print("Data Process : {}".format(config.process)) 137 | print("Train model : {}".format(config.train)) 138 | print("Test model : {}".format(config.test)) 139 | print("t_model : {}".format(config.t_model)) 140 | print("t_data : {}".format(config.t_data)) 141 | print("predict : {}".format(config.predict)) 142 | print("***************************************") 143 | 144 | return config 145 | 146 | 147 | if __name__ == "__main__": 148 | 149 | print("Process ID {}, Process Parent ID {}".format(os.getpid(), os.getppid())) 150 | config = parse_argument() 151 | if config.device != cpu_device: 152 | print("Using GPU To Train......") 153 | device_number = config.device[-1] 154 | torch.cuda.set_device(int(device_number)) 155 | print("Current Cuda Device {}".format(torch.cuda.current_device())) 156 | # torch.backends.cudnn.enabled = True 157 | # torch.backends.cudnn.deterministic = True 158 | torch.cuda.manual_seed(seed_num) 159 | torch.cuda.manual_seed_all(seed_num) 160 | print("torch.cuda.initial_seed", torch.cuda.initial_seed()) 161 | 162 | main() 163 | 164 | -------------------------------------------------------------------------------- /models/Text_Classification/BiLSTM.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/17 16:06 3 | # @File : BiLSTM.py 4 | # @Last Modify Time : 2018/8/17 16:06 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : BiLSTM.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | import random 15 | from DataUtils.Common import * 16 | from models.Text_Classification.initialize import * 17 | from models.Text_Classification.modelHelp import prepare_pack_padded_sequence 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class BiLSTM(nn.Module): 23 | """ 24 | BiLSTM 25 | """ 26 | 27 | def __init__(self, **kwargs): 28 | super(BiLSTM, self).__init__() 29 | for k in kwargs: 30 | self.__setattr__(k, kwargs[k]) 31 | 32 | V = self.embed_num 33 | D = self.embed_dim 34 | C = self.label_num 35 | paddingId = self.paddingId 36 | 37 | self.embed = nn.Embedding(V, D, padding_idx=paddingId) 38 | 39 | if self.pretrained_embed: 40 | self.embed.weight.data.copy_(self.pretrained_weight) 41 | else: 42 | init_embedding(self.embed.weight) 43 | 44 | self.dropout_embed = nn.Dropout(self.dropout_emb) 45 | self.dropout = nn.Dropout(self.dropout) 46 | 47 | self.bilstm = nn.LSTM(input_size=D, hidden_size=self.lstm_hiddens, num_layers=self.lstm_layers, 48 | bidirectional=True, batch_first=True, bias=True) 49 | 50 | self.linear = nn.Linear(in_features=self.lstm_hiddens * 2, out_features=C, bias=True) 51 | # init_linear_weight_bias(self.linear) 52 | init_linear(self.linear) 53 | 54 | def forward(self, word, sentence_length): 55 | """ 56 | :param word: 57 | :param sentence_length: 58 | :param desorted_indices: 59 | :return: 60 | """ 61 | word, sentence_length, desorted_indices = prepare_pack_padded_sequence(word, sentence_length, device=self.device) 62 | x = self.embed(word) # (N,W,D) 63 | x = self.dropout_embed(x) 64 | packed_embed = pack_padded_sequence(x, sentence_length, batch_first=True) 65 | x, _ = self.bilstm(packed_embed) 66 | x, _ = pad_packed_sequence(x, batch_first=True) 67 | x = x[desorted_indices] 68 | x = x.permute(0, 2, 1) 69 | x = self.dropout(x) 70 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 71 | x = torch.tanh(x) 72 | logit = self.linear(x) 73 | return logit 74 | -------------------------------------------------------------------------------- /models/Text_Classification/CNN.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/10/15 9:52 3 | # @File : CNN.py 4 | # @Last Modify Time : 2018/10/15 9:52 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : CNN.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch.nn.functional as F 13 | import random 14 | from DataUtils.Common import * 15 | from models.Text_Classification.initialize import * 16 | 17 | torch.manual_seed(seed_num) 18 | random.seed(seed_num) 19 | 20 | 21 | class CNN(nn.Module): 22 | """ 23 | BiLSTM 24 | """ 25 | 26 | def __init__(self, **kwargs): 27 | super(CNN, self).__init__() 28 | for k in kwargs: 29 | self.__setattr__(k, kwargs[k]) 30 | 31 | V = self.embed_num 32 | D = self.embed_dim 33 | C = self.label_num 34 | Ci = 1 35 | kernel_nums = self.conv_filter_nums 36 | kernel_sizes = self.conv_filter_sizes 37 | paddingId = self.paddingId 38 | 39 | self.embed = nn.Embedding(V, D, padding_idx=paddingId) 40 | 41 | if self.pretrained_embed: 42 | self.embed.weight.data.copy_(self.pretrained_weight) 43 | else: 44 | init_embedding(self.embed.weight) 45 | 46 | self.dropout_embed = nn.Dropout(self.dropout_emb) 47 | self.dropout = nn.Dropout(self.dropout) 48 | 49 | # cnn 50 | if self.wide_conv: 51 | print("Using Wide Convolution") 52 | self.conv = [nn.Conv2d(in_channels=Ci, out_channels=kernel_nums, kernel_size=(K, D), stride=(1, 1), 53 | padding=(K // 2, 0), bias=False) for K in kernel_sizes] 54 | else: 55 | print("Using Narrow Convolution") 56 | self.conv = [nn.Conv2d(in_channels=Ci, out_channels=kernel_nums, kernel_size=(K, D), bias=True) for K in kernel_sizes] 57 | 58 | for conv in self.conv: 59 | if self.device != cpu_device: 60 | conv.to(self.device) 61 | 62 | in_fea = len(kernel_sizes) * kernel_nums 63 | self.linear = nn.Linear(in_features=in_fea, out_features=C, bias=True) 64 | init_linear_weight_bias(self.linear) 65 | 66 | def forward(self, word, sentence_length): 67 | """ 68 | :param word: 69 | :param sentence_length: 70 | :return: 71 | """ 72 | x = self.embed(word) # (N,W,D) 73 | x = self.dropout_embed(x) 74 | x = x.unsqueeze(1) # (N,Ci,W,D) 75 | x = [F.relu(conv(x)).squeeze(3) for conv in self.conv] #[(N,Co,W), ...]*len(Ks) 76 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks) 77 | x = torch.cat(x, 1) 78 | x = self.dropout(x) # (N,len(Ks)*Co) 79 | logit = self.linear(x) 80 | return logit -------------------------------------------------------------------------------- /models/Text_Classification/Text_Classification.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/14 8:43 3 | # @File : Sequence_Label.py 4 | # @Last Modify Time : 2018/9/14 8:43 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Sequence_Label.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import random 15 | from models.Text_Classification.BiLSTM import BiLSTM 16 | from models.Text_Classification.CNN import CNN 17 | from DataUtils.Common import * 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class Text_Classification(nn.Module): 23 | """ 24 | Text_Classification 25 | """ 26 | 27 | def __init__(self, config): 28 | super(Text_Classification, self).__init__() 29 | self.config = config 30 | # embed 31 | self.embed_num = config.embed_num 32 | self.embed_dim = config.embed_dim 33 | self.label_num = config.label_num 34 | self.paddingId = config.paddingId 35 | # dropout 36 | self.dropout_emb = config.dropout_emb 37 | self.dropout = config.dropout 38 | # lstm 39 | self.lstm_hiddens = config.lstm_hiddens 40 | self.lstm_layers = config.lstm_layers 41 | # pre train 42 | self.pretrained_embed = config.pretrained_embed 43 | self.pretrained_weight = config.pretrained_weight 44 | # cnn param 45 | self.wide_conv = config.wide_conv 46 | self.conv_filter_sizes = self._conv_filter(config.conv_filter_sizes) 47 | self.conv_filter_nums = config.conv_filter_nums 48 | # self.use_cuda = config.use_cuda 49 | self.device = config.device 50 | 51 | if self.config.model_bilstm: 52 | self.model = BiLSTM(embed_num=self.embed_num, embed_dim=self.embed_dim, label_num=self.label_num, 53 | paddingId=self.paddingId, dropout_emb=self.dropout_emb, dropout=self.dropout, 54 | lstm_hiddens=self.lstm_hiddens, lstm_layers=self.lstm_layers, 55 | pretrained_embed=self.pretrained_embed, pretrained_weight=self.pretrained_weight, 56 | device=self.device) 57 | elif self.config.model_cnn: 58 | self.model = CNN(embed_num=self.embed_num, embed_dim=self.embed_dim, label_num=self.label_num, 59 | paddingId=self.paddingId, dropout_emb=self.dropout_emb, dropout=self.dropout, 60 | conv_filter_nums=self.conv_filter_nums, conv_filter_sizes=self.conv_filter_sizes, 61 | wide_conv=self.wide_conv, 62 | pretrained_embed=self.pretrained_embed, pretrained_weight=self.pretrained_weight, 63 | device=self.device) 64 | 65 | @staticmethod 66 | def _conv_filter(str_list): 67 | """ 68 | :param str_list: 69 | :return: 70 | """ 71 | int_list = [] 72 | str_list = str_list.split(",") 73 | for str in str_list: 74 | int_list.append(int(str)) 75 | return int_list 76 | 77 | @staticmethod 78 | def _get_model_args(batch_features): 79 | """ 80 | :param batch_features: Batch Instance 81 | :return: 82 | """ 83 | word = batch_features.word_features 84 | mask = word > 0 85 | sentence_length = batch_features.sentence_length 86 | labels = batch_features.label_features 87 | batch_size = batch_features.batch_length 88 | bert_feature = batch_features.bert_features 89 | return word, bert_feature, mask, sentence_length, labels, batch_size 90 | 91 | def forward(self, batch_features, train=False): 92 | """ 93 | :param batch_features: 94 | :param train: 95 | :return: 96 | """ 97 | word, bert_feature, mask, sentence_length, labels, batch_size = self._get_model_args(batch_features) 98 | model_output = self.model(word, sentence_length) 99 | return model_output 100 | 101 | 102 | -------------------------------------------------------------------------------- /models/Text_Classification/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalinvip/PyTorch_Bert_Text_Classification/a7afd2188c4968df3830aebedb0d7293f620dcfa/models/Text_Classification/__init__.py -------------------------------------------------------------------------------- /models/Text_Classification/initialize.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/25 9:15 3 | # @File : initialize.py 4 | # @Last Modify Time : 2018/8/25 9:15 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : initialize.py 9 | FUNCTION : None 10 | """ 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | 17 | 18 | def init_cnn_weight(cnn_layer, seed=1337): 19 | """初始化cnn层权重 20 | Args: 21 | cnn_layer: weight.size() == [nb_filter, in_channels, [kernel_size]] 22 | seed: int 23 | """ 24 | filter_nums = cnn_layer.weight.size(0) 25 | kernel_size = cnn_layer.weight.size()[2:] 26 | scope = np.sqrt(2. / (filter_nums * np.prod(kernel_size))) 27 | torch.manual_seed(seed) 28 | nn.init.normal_(cnn_layer.weight, -scope, scope) 29 | cnn_layer.bias.data.zero_() 30 | 31 | 32 | def init_lstm_weight(lstm, num_layer=1, seed=1337): 33 | """初始化lstm权重 34 | Args: 35 | lstm: torch.nn.LSTM 36 | num_layer: int, lstm层数 37 | seed: int 38 | """ 39 | for i in range(num_layer): 40 | weight_h = getattr(lstm, 'weight_hh_l{0}'.format(i)) 41 | scope = np.sqrt(6.0 / (weight_h.size(0)/4. + weight_h.size(1))) 42 | torch.manual_seed(seed) 43 | nn.init.uniform_(getattr(lstm, 'weight_hh_l{0}'.format(i)), -scope, scope) 44 | 45 | weight_i = getattr(lstm, 'weight_ih_l{0}'.format(i)) 46 | scope = np.sqrt(6.0 / (weight_i.size(0)/4. + weight_i.size(1))) 47 | torch.manual_seed(seed) 48 | nn.init.uniform_(getattr(lstm, 'weight_ih_l{0}'.format(i)), -scope, scope) 49 | 50 | if lstm.bias: 51 | for i in range(num_layer): 52 | weight_h = getattr(lstm, 'bias_hh_l{0}'.format(i)) 53 | weight_h.data.zero_() 54 | weight_h.data[lstm.hidden_size: 2*lstm.hidden_size] = 1 55 | weight_i = getattr(lstm, 'bias_ih_l{0}'.format(i)) 56 | weight_i.data.zero_() 57 | weight_i.data[lstm.hidden_size: 2*lstm.hidden_size] = 1 58 | 59 | 60 | def init_linear(input_linear, seed=1337): 61 | """初始化全连接层权重 62 | """ 63 | torch.manual_seed(seed) 64 | scope = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1))) 65 | nn.init.uniform_(input_linear.weight, -scope, scope) 66 | # nn.init.uniform(input_linear.bias, -scope, scope) 67 | if input_linear.bias is not None: 68 | input_linear.bias.data.zero_() 69 | 70 | 71 | def init_linear_weight_bias(input_linear, seed=1337): 72 | """ 73 | :param input_linear: 74 | :param seed: 75 | :return: 76 | """ 77 | # torch.manual_seed(seed) 78 | init.xavier_uniform_(input_linear.weight) 79 | scope = np.sqrt(6.0 / (input_linear.weight.size(0) + 1)) 80 | if input_linear.bias is not None: 81 | input_linear.bias.data.uniform_(-scope, scope) 82 | 83 | 84 | def init_embedding(input_embedding, seed=666): 85 | """初始化embedding层权重 86 | """ 87 | torch.manual_seed(seed) 88 | scope = np.sqrt(3.0 / input_embedding.size(1)) 89 | nn.init.uniform_(input_embedding, -scope, scope) 90 | 91 | -------------------------------------------------------------------------------- /models/Text_Classification/modelHelp.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/15 19:09 3 | # @File : modelHelp.py 4 | # @Last Modify Time : 2018/9/15 19:09 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : modelHelp.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch 13 | import random 14 | from DataUtils.Common import * 15 | torch.manual_seed(seed_num) 16 | random.seed(seed_num) 17 | 18 | 19 | def prepare_pack_padded_sequence(inputs_words, seq_lengths, device=cpu_device, descending=True): 20 | """ 21 | :param use_cuda: 22 | :param inputs_words: 23 | :param seq_lengths: 24 | :param descending: 25 | :return: 26 | """ 27 | sorted_seq_lengths, indices = torch.sort(torch.LongTensor(seq_lengths), descending=descending) 28 | if device != cpu_device: 29 | sorted_seq_lengths, indices = sorted_seq_lengths.to(device), indices.to(device) 30 | _, desorted_indices = torch.sort(indices, descending=False) 31 | sorted_inputs_words = inputs_words[indices] 32 | return sorted_inputs_words, sorted_seq_lengths.cpu().numpy(), desorted_indices 33 | 34 | 35 | -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/Bert_Encoder.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2019/3/16 15:07 3 | # @File : Bert_Encoder.py 4 | # @Last Modify Time : 2019/3/16 15:07 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Bert_Encoder.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | import random 15 | from DataUtils.Common import * 16 | from models.Text_Classification.initialize import * 17 | from models.Text_Classification.modelHelp import prepare_pack_padded_sequence 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class Bert_Encoder(nn.Module): 23 | """ 24 | Bert_Encoder 25 | """ 26 | 27 | def __init__(self, **kwargs): 28 | super(Bert_Encoder, self).__init__() 29 | for k in kwargs: 30 | self.__setattr__(k, kwargs[k]) 31 | 32 | self.dropout_bert = nn.Dropout(self.dropout) 33 | 34 | self.bert_bilstm = nn.LSTM(input_size=self.bert_dim, hidden_size=200, 35 | bidirectional=True, batch_first=True, bias=True) 36 | 37 | self.bert_linear = nn.Linear(in_features=200 * 2, out_features=self.out_dim, 38 | bias=True) 39 | init_linear_weight_bias(self.bert_linear) 40 | 41 | def forward(self, bert_fea): 42 | """ 43 | :param bert_fea: 44 | :return: 45 | """ 46 | bert_fea = bert_fea.to(self.device) 47 | bert_fea, _ = self.bert_bilstm(bert_fea) 48 | bert_fea = bert_fea.permute(0, 2, 1) 49 | bert_fea = F.max_pool1d(bert_fea, bert_fea.size(2)).squeeze(2) 50 | bert_fea = self.bert_linear(bert_fea) 51 | return bert_fea 52 | 53 | -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/Bert_Encoder_Pool.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2019/3/16 15:07 3 | # @File : Bert_Encoder.py 4 | # @Last Modify Time : 2019/3/16 15:07 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Bert_Encoder.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | import random 15 | from DataUtils.Common import * 16 | from models.Text_Classification.initialize import * 17 | from models.Text_Classification.modelHelp import prepare_pack_padded_sequence 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class Bert_Encoder(nn.Module): 23 | """ 24 | Bert_Encoder 25 | """ 26 | 27 | def __init__(self, **kwargs): 28 | super(Bert_Encoder, self).__init__() 29 | for k in kwargs: 30 | self.__setattr__(k, kwargs[k]) 31 | 32 | self.dropout_bert = nn.Dropout(self.dropout) 33 | self.bert_linear = nn.Linear(in_features=self.bert_dim, out_features=self.out_dim, 34 | bias=True) 35 | init_linear_weight_bias(self.bert_linear) 36 | 37 | def forward(self, bert_fea): 38 | """ 39 | :param bert_fea: 40 | :return: 41 | """ 42 | bert_fea = bert_fea.to(self.device) 43 | bert_fea = bert_fea.permute(0, 2, 1) 44 | bert_fea = F.max_pool1d(bert_fea, bert_fea.size(2)).squeeze(2) 45 | # bert_fea = F.avg_pool1d(bert_fea, bert_fea.size(2)).squeeze(2) 46 | bert_fea = self.bert_linear(bert_fea) 47 | return bert_fea 48 | 49 | -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/BiLSTM.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/17 16:06 3 | # @File : BiLSTM.py 4 | # @Last Modify Time : 2018/8/17 16:06 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : BiLSTM.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch.nn.functional as F 13 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 14 | import random 15 | from DataUtils.Common import * 16 | from models.Text_Classification.initialize import * 17 | from models.Text_Classification.modelHelp import prepare_pack_padded_sequence 18 | torch.manual_seed(seed_num) 19 | random.seed(seed_num) 20 | 21 | 22 | class BiLSTM(nn.Module): 23 | """ 24 | BiLSTM 25 | """ 26 | 27 | def __init__(self, **kwargs): 28 | super(BiLSTM, self).__init__() 29 | for k in kwargs: 30 | self.__setattr__(k, kwargs[k]) 31 | 32 | V = self.embed_num 33 | D = self.embed_dim 34 | C = self.label_num 35 | paddingId = self.paddingId 36 | 37 | self.embed = nn.Embedding(V, D, padding_idx=paddingId) 38 | 39 | if self.pretrained_embed: 40 | self.embed.weight.data.copy_(self.pretrained_weight) 41 | else: 42 | init_embedding(self.embed.weight) 43 | 44 | self.dropout_embed = nn.Dropout(self.dropout_emb) 45 | self.dropout = nn.Dropout(self.dropout) 46 | 47 | self.bilstm = nn.LSTM(input_size=D, hidden_size=self.lstm_hiddens, num_layers=self.lstm_layers, 48 | bidirectional=True, batch_first=True, bias=True) 49 | 50 | self.linear = nn.Linear(in_features=self.lstm_hiddens * 2 + self.bert_out_dim, 51 | out_features=C, bias=True) 52 | # init_linear_weight_bias(self.linear) 53 | init_linear(self.linear) 54 | 55 | def forward(self, word, bert_fea, sentence_length): 56 | """ 57 | :param word: 58 | :param bert_fea: 59 | :param sentence_length: 60 | :return: 61 | """ 62 | word, sentence_length, desorted_indices = prepare_pack_padded_sequence(word, sentence_length, device=self.device) 63 | x = self.embed(word) # (N,W,D) 64 | x = self.dropout_embed(x) 65 | packed_embed = pack_padded_sequence(x, sentence_length, batch_first=True) 66 | x, _ = self.bilstm(packed_embed) 67 | x, _ = pad_packed_sequence(x, batch_first=True) 68 | x = x[desorted_indices] 69 | x = x.permute(0, 2, 1) 70 | x = self.dropout(x) 71 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 72 | x = torch.cat((x, bert_fea), 1) 73 | x = torch.tanh(x) 74 | logit = self.linear(x) 75 | return logit 76 | -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/CNN.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/10/15 9:52 3 | # @File : CNN.py 4 | # @Last Modify Time : 2018/10/15 9:52 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : CNN.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch.nn.functional as F 13 | import random 14 | from DataUtils.Common import * 15 | from models.Text_Classification.initialize import * 16 | 17 | torch.manual_seed(seed_num) 18 | random.seed(seed_num) 19 | 20 | 21 | class CNN(nn.Module): 22 | """ 23 | BiLSTM 24 | """ 25 | 26 | def __init__(self, **kwargs): 27 | super(CNN, self).__init__() 28 | for k in kwargs: 29 | self.__setattr__(k, kwargs[k]) 30 | 31 | V = self.embed_num 32 | D = self.embed_dim 33 | C = self.label_num 34 | Ci = 1 35 | kernel_nums = self.conv_filter_nums 36 | kernel_sizes = self.conv_filter_sizes 37 | paddingId = self.paddingId 38 | 39 | self.embed = nn.Embedding(V, D, padding_idx=paddingId) 40 | 41 | if self.pretrained_embed: 42 | self.embed.weight.data.copy_(self.pretrained_weight) 43 | else: 44 | init_embedding(self.embed.weight) 45 | 46 | self.dropout_embed = nn.Dropout(self.dropout_emb) 47 | self.dropout = nn.Dropout(self.dropout) 48 | 49 | # cnn 50 | if self.wide_conv: 51 | print("Using Wide Convolution") 52 | self.conv = [nn.Conv2d(in_channels=Ci, out_channels=kernel_nums, kernel_size=(K, D), stride=(1, 1), 53 | padding=(K // 2, 0), bias=False) for K in kernel_sizes] 54 | else: 55 | print("Using Narrow Convolution") 56 | self.conv = [nn.Conv2d(in_channels=Ci, out_channels=kernel_nums, kernel_size=(K, D), bias=True) for K in kernel_sizes] 57 | 58 | for conv in self.conv: 59 | if self.device != cpu_device: 60 | conv.to(self.device) 61 | 62 | in_fea = len(kernel_sizes) * kernel_nums 63 | self.linear = nn.Linear(in_features=in_fea, out_features=C, bias=True) 64 | init_linear_weight_bias(self.linear) 65 | 66 | def forward(self, word, sentence_length): 67 | """ 68 | :param word: 69 | :param sentence_length: 70 | :return: 71 | """ 72 | x = self.embed(word) # (N,W,D) 73 | x = self.dropout_embed(x) 74 | x = x.unsqueeze(1) # (N,Ci,W,D) 75 | x = [F.relu(conv(x)).squeeze(3) for conv in self.conv] #[(N,Co,W), ...]*len(Ks) 76 | x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] #[(N,Co), ...]*len(Ks) 77 | x = torch.cat(x, 1) 78 | x = self.dropout(x) # (N,len(Ks)*Co) 79 | logit = self.linear(x) 80 | return logit -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/Text_Classification.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/14 8:43 3 | # @File : Sequence_Label.py 4 | # @Last Modify Time : 2018/9/14 8:43 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : Sequence_Label.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch 13 | import torch.nn as nn 14 | import random 15 | from models.Text_Classification_BertFeature.BiLSTM import BiLSTM 16 | from models.Text_Classification_BertFeature.Bert_Encoder import Bert_Encoder 17 | from models.Text_Classification.CNN import CNN 18 | from DataUtils.Common import * 19 | torch.manual_seed(seed_num) 20 | random.seed(seed_num) 21 | 22 | 23 | class Text_Classification_BertFeature(nn.Module): 24 | """ 25 | Text_Classification 26 | """ 27 | 28 | def __init__(self, config): 29 | super(Text_Classification_BertFeature, self).__init__() 30 | self.config = config 31 | # embed 32 | self.embed_num = config.embed_num 33 | self.embed_dim = config.embed_dim 34 | self.label_num = config.label_num 35 | self.paddingId = config.paddingId 36 | # dropout 37 | self.dropout_emb = config.dropout_emb 38 | self.dropout = config.dropout 39 | # lstm 40 | self.lstm_hiddens = config.lstm_hiddens 41 | self.lstm_layers = config.lstm_layers 42 | # pre train 43 | self.pretrained_embed = config.pretrained_embed 44 | self.pretrained_weight = config.pretrained_weight 45 | # cnn param 46 | self.wide_conv = config.wide_conv 47 | self.conv_filter_sizes = self._conv_filter(config.conv_filter_sizes) 48 | self.conv_filter_nums = config.conv_filter_nums 49 | # self.use_cuda = config.use_cuda 50 | self.device = config.device 51 | 52 | self.bert_out_dim = 200 53 | 54 | self.model = BiLSTM(embed_num=self.embed_num, embed_dim=self.embed_dim, label_num=self.label_num, 55 | paddingId=self.paddingId, dropout_emb=self.dropout_emb, dropout=self.dropout, 56 | lstm_hiddens=self.lstm_hiddens, lstm_layers=self.lstm_layers, 57 | pretrained_embed=self.pretrained_embed, pretrained_weight=self.pretrained_weight, 58 | device=self.device, bert_out_dim=self.bert_out_dim) 59 | 60 | self.Bert_Encoder = Bert_Encoder(dropout=0.5, bert_dim=config.bert_dim, 61 | out_dim=self.bert_out_dim, device=self.device) 62 | 63 | @staticmethod 64 | def _conv_filter(str_list): 65 | """ 66 | :param str_list: 67 | :return: 68 | """ 69 | int_list = [] 70 | str_list = str_list.split(",") 71 | for str in str_list: 72 | int_list.append(int(str)) 73 | return int_list 74 | 75 | @staticmethod 76 | def _get_model_args(batch_features): 77 | """ 78 | :param batch_features: Batch Instance 79 | :return: 80 | """ 81 | word = batch_features.word_features 82 | mask = word > 0 83 | sentence_length = batch_features.sentence_length 84 | labels = batch_features.label_features 85 | batch_size = batch_features.batch_length 86 | bert_feature = batch_features.bert_features 87 | return word, bert_feature, mask, sentence_length, labels, batch_size 88 | 89 | def forward(self, batch_features, train=False): 90 | """ 91 | :param batch_features: 92 | :param train: 93 | :return: 94 | """ 95 | word, bert_feature, mask, sentence_length, labels, batch_size = self._get_model_args(batch_features) 96 | bert_fea = self.Bert_Encoder(bert_feature) 97 | model_output = self.model(word, bert_fea, sentence_length) 98 | return model_output 99 | 100 | 101 | -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dalinvip/PyTorch_Bert_Text_Classification/a7afd2188c4968df3830aebedb0d7293f620dcfa/models/Text_Classification_BertFeature/__init__.py -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/initialize.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/25 9:15 3 | # @File : initialize.py 4 | # @Last Modify Time : 2018/8/25 9:15 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : initialize.py 9 | FUNCTION : None 10 | """ 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.init as init 16 | 17 | 18 | def init_cnn_weight(cnn_layer, seed=1337): 19 | """初始化cnn层权重 20 | Args: 21 | cnn_layer: weight.size() == [nb_filter, in_channels, [kernel_size]] 22 | seed: int 23 | """ 24 | filter_nums = cnn_layer.weight.size(0) 25 | kernel_size = cnn_layer.weight.size()[2:] 26 | scope = np.sqrt(2. / (filter_nums * np.prod(kernel_size))) 27 | torch.manual_seed(seed) 28 | nn.init.normal_(cnn_layer.weight, -scope, scope) 29 | cnn_layer.bias.data.zero_() 30 | 31 | 32 | def init_lstm_weight(lstm, num_layer=1, seed=1337): 33 | """初始化lstm权重 34 | Args: 35 | lstm: torch.nn.LSTM 36 | num_layer: int, lstm层数 37 | seed: int 38 | """ 39 | for i in range(num_layer): 40 | weight_h = getattr(lstm, 'weight_hh_l{0}'.format(i)) 41 | scope = np.sqrt(6.0 / (weight_h.size(0)/4. + weight_h.size(1))) 42 | torch.manual_seed(seed) 43 | nn.init.uniform_(getattr(lstm, 'weight_hh_l{0}'.format(i)), -scope, scope) 44 | 45 | weight_i = getattr(lstm, 'weight_ih_l{0}'.format(i)) 46 | scope = np.sqrt(6.0 / (weight_i.size(0)/4. + weight_i.size(1))) 47 | torch.manual_seed(seed) 48 | nn.init.uniform_(getattr(lstm, 'weight_ih_l{0}'.format(i)), -scope, scope) 49 | 50 | if lstm.bias: 51 | for i in range(num_layer): 52 | weight_h = getattr(lstm, 'bias_hh_l{0}'.format(i)) 53 | weight_h.data.zero_() 54 | weight_h.data[lstm.hidden_size: 2*lstm.hidden_size] = 1 55 | weight_i = getattr(lstm, 'bias_ih_l{0}'.format(i)) 56 | weight_i.data.zero_() 57 | weight_i.data[lstm.hidden_size: 2*lstm.hidden_size] = 1 58 | 59 | 60 | def init_linear(input_linear, seed=1337): 61 | """初始化全连接层权重 62 | """ 63 | torch.manual_seed(seed) 64 | scope = np.sqrt(6.0 / (input_linear.weight.size(0) + input_linear.weight.size(1))) 65 | nn.init.uniform_(input_linear.weight, -scope, scope) 66 | # nn.init.uniform(input_linear.bias, -scope, scope) 67 | if input_linear.bias is not None: 68 | input_linear.bias.data.zero_() 69 | 70 | 71 | def init_linear_weight_bias(input_linear, seed=1337): 72 | """ 73 | :param input_linear: 74 | :param seed: 75 | :return: 76 | """ 77 | # torch.manual_seed(seed) 78 | init.xavier_uniform_(input_linear.weight) 79 | scope = np.sqrt(6.0 / (input_linear.weight.size(0) + 1)) 80 | if input_linear.bias is not None: 81 | input_linear.bias.data.uniform_(-scope, scope) 82 | 83 | 84 | def init_embedding(input_embedding, seed=666): 85 | """初始化embedding层权重 86 | """ 87 | torch.manual_seed(seed) 88 | scope = np.sqrt(3.0 / input_embedding.size(1)) 89 | nn.init.uniform_(input_embedding, -scope, scope) 90 | 91 | -------------------------------------------------------------------------------- /models/Text_Classification_BertFeature/modelHelp.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/15 19:09 3 | # @File : modelHelp.py 4 | # @Last Modify Time : 2018/9/15 19:09 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : modelHelp.py 9 | FUNCTION : None 10 | """ 11 | 12 | import torch 13 | import random 14 | from DataUtils.Common import * 15 | torch.manual_seed(seed_num) 16 | random.seed(seed_num) 17 | 18 | 19 | def prepare_pack_padded_sequence(inputs_words, seq_lengths, device=cpu_device, descending=True): 20 | """ 21 | :param use_cuda: 22 | :param inputs_words: 23 | :param seq_lengths: 24 | :param descending: 25 | :return: 26 | """ 27 | sorted_seq_lengths, indices = torch.sort(torch.LongTensor(seq_lengths), descending=descending) 28 | if device != cpu_device: 29 | sorted_seq_lengths, indices = sorted_seq_lengths.to(device), indices.to(device) 30 | _, desorted_indices = torch.sort(indices, descending=False) 31 | sorted_inputs_words = inputs_words[indices] 32 | return sorted_inputs_words, sorted_seq_lengths.cpu().numpy(), desorted_indices 33 | 34 | 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/9/14 8:46 3 | # @File : __init__.py 4 | # @Last Modify Time : 2018/9/14 8:46 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : __init__.py 9 | FUNCTION : None 10 | """ -------------------------------------------------------------------------------- /run_train_p.sh: -------------------------------------------------------------------------------- 1 | export OMP_NUM_THREADS=1 2 | export MKL_NUM_THREADS=1 3 | config=./Config/config.cfg 4 | device=cuda:0 5 | log_name=log 6 | # device ["cpu", "cuda:0", "cuda:1", ......] 7 | nohup python -u main.py --config $config --device $device --train -p > $log_name 2>&1 & 8 | tail -f $log_name 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/24 15:27 3 | # @File : test.py 4 | # @Last Modify Time : 2018/8/24 15:27 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : test.py 9 | FUNCTION : None 10 | """ 11 | import os 12 | import sys 13 | import torch 14 | from DataUtils.utils import * 15 | 16 | 17 | def load_test_model(model, config): 18 | """ 19 | :param model: initial model 20 | :param config: config 21 | :return: loaded model 22 | """ 23 | if config.t_model is None: 24 | test_model_dir = config.save_best_model_dir 25 | test_model_name = "{}.pt".format(config.model_name) 26 | test_model_path = os.path.join(test_model_dir, test_model_name) 27 | print("load default model from {}".format(test_model_path)) 28 | else: 29 | test_model_path = config.t_model 30 | print("load user model from {}".format(test_model_path)) 31 | model.load_state_dict(torch.load(test_model_path)) 32 | return model 33 | 34 | 35 | def load_test_data(train_iter=None, dev_iter=None, test_iter=None, config=None): 36 | """ 37 | :param train_iter: train data 38 | :param dev_iter: dev data 39 | :param test_iter: test data 40 | :param config: config 41 | :return: data for test 42 | """ 43 | pass 44 | data, path_source, path_result = None, None, None 45 | if config.t_data is None: 46 | print("default[test] for model test.") 47 | data = test_iter 48 | path_source = config.test_file 49 | path_result = "{}.out".format(config.test_file) 50 | elif config.t_data == "train": 51 | print("train data for model test.") 52 | data = train_iter 53 | path_source = config.train_file 54 | path_result = "{}.out".format(config.train_file) 55 | elif config.t_data == "dev": 56 | print("dev data for model test.") 57 | data = dev_iter 58 | path_source = config.dev_file 59 | path_result = "{}.out".format(config.dev_file) 60 | elif config.t_data == "test": 61 | print("test data for model test.") 62 | data = test_iter 63 | path_source = config.test_file 64 | path_result = "{}.out".format(config.test_file) 65 | else: 66 | print("Error value --- t_data = {}, must in [None, 'train', 'dev', 'test'].".format(config.t_data)) 67 | exit() 68 | return data, path_source, path_result 69 | 70 | 71 | class T_Inference(object): 72 | """ 73 | Test Inference 74 | """ 75 | def __init__(self, model, data, path_source, path_result, alphabet, use_crf, config): 76 | """ 77 | :param model: nn model 78 | :param data: infer data 79 | :param path_source: source data path 80 | :param path_result: result data path 81 | :param alphabet: alphabet 82 | :param config: config 83 | """ 84 | print("Initialize T_Inference") 85 | self.model = model 86 | self.data = data 87 | self.path_source = path_source 88 | self.path_result = path_result 89 | self.alphabet = alphabet 90 | self.config = config 91 | self.use_crf = use_crf 92 | 93 | def infer2file(self): 94 | """ 95 | :return: None 96 | """ 97 | pass 98 | print("infer.....") 99 | self.model.eval() 100 | predict_labels = [] 101 | predict_label = [] 102 | all_count = len(self.data) 103 | now_count = 0 104 | for data in self.data: 105 | now_count += 1 106 | sys.stdout.write("\rinfer with batch number {}/{} .".format(now_count, all_count)) 107 | word, char, mask, sentence_length, tags = self._get_model_args(data) 108 | logit = self.model(word, char, sentence_length, train=False) 109 | if self.use_crf is False: 110 | predict_ids = torch_max(logit) 111 | for id_batch in range(data.batch_length): 112 | inst = data.inst[id_batch] 113 | label_ids = predict_ids[id_batch] 114 | # maxId_batch = getMaxindex_batch(logit[id_batch]) 115 | for id_word in range(inst.words_size): 116 | predict_label.append(self.alphabet.label_alphabet.from_id(label_ids[id_word])) 117 | else: 118 | path_score, best_paths = self.model.crf_layer(logit, mask) 119 | for id_batch in range(data.batch_length): 120 | inst = data.inst[id_batch] 121 | label_ids = best_paths[id_batch].cpu().data.numpy()[:inst.words_size] 122 | for i in label_ids: 123 | predict_label.append(self.alphabet.label_alphabet.from_id(i)) 124 | 125 | print("\ninfer finished.") 126 | self.write2file(result=predict_label, path_source=self.path_source, path_result=self.path_result) 127 | 128 | @staticmethod 129 | def write2file(result, path_source, path_result): 130 | """ 131 | :param result: 132 | :param path_source: 133 | :param path_result: 134 | :return: 135 | """ 136 | print("write result to file {}".format(path_result)) 137 | if os.path.exists(path_source) is False: 138 | print("source data path[path_source] is not exist.") 139 | if os.path.exists(path_result): 140 | os.remove(path_result) 141 | file_out = open(path_result, encoding="UTF-8", mode="w") 142 | 143 | with open(path_source, encoding="UTF-8") as file: 144 | id = 0 145 | for line in file.readlines(): 146 | sys.stdout.write("\rwrite with {}/{} .".format(id+1, len(result))) 147 | if line == "\n": 148 | file_out.write("\n") 149 | continue 150 | line = line.strip().split() 151 | line.append(result[id]) 152 | id += 1 153 | file_out.write(" ".join(line) + "\n") 154 | if id >= len(result): 155 | break 156 | 157 | file_out.close() 158 | print("\nfinished.") 159 | 160 | @staticmethod 161 | def _get_model_args(batch_features): 162 | """ 163 | :param batch_features: Batch Instance 164 | :return: 165 | """ 166 | word = batch_features.word_features 167 | char = batch_features.char_features 168 | mask = word > 0 169 | sentence_length = batch_features.sentence_length 170 | # desorted_indices = batch_features.desorted_indices 171 | tags = batch_features.label_features 172 | return word, char, mask, sentence_length, tags 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # @Author : bamtercelboo 2 | # @Datetime : 2018/8/26 8:30 3 | # @File : trainer.py 4 | # @Last Modify Time : 2018/8/26 8:30 5 | # @Contact : bamtercelboo@{gmail.com, 163.com} 6 | 7 | """ 8 | FILE : trainer.py 9 | FUNCTION : None 10 | """ 11 | 12 | import os 13 | import sys 14 | import time 15 | import numpy as np 16 | import random 17 | import torch 18 | import torch.optim as optim 19 | import torch.nn as nn 20 | import torch.nn.utils as utils 21 | from DataUtils.Optim import Optimizer 22 | from DataUtils.utils import * 23 | from DataUtils.Common import * 24 | torch.manual_seed(seed_num) 25 | random.seed(seed_num) 26 | 27 | 28 | class Train(object): 29 | """ 30 | Train 31 | """ 32 | def __init__(self, **kwargs): 33 | """ 34 | :param kwargs: 35 | Args of data: 36 | train_iter : train batch data iterator 37 | dev_iter : dev batch data iterator 38 | test_iter : test batch data iterator 39 | Args of train: 40 | model : nn model 41 | config : config 42 | """ 43 | print("Training Start......") 44 | # for k, v in kwargs.items(): 45 | # self.__setattr__(k, v) 46 | self.train_iter = kwargs["train_iter"] 47 | self.dev_iter = kwargs["dev_iter"] 48 | self.test_iter = kwargs["test_iter"] 49 | self.model = kwargs["model"] 50 | self.config = kwargs["config"] 51 | self.early_max_patience = self.config.early_max_patience 52 | self.optimizer = Optimizer(name=self.config.learning_algorithm, model=self.model, lr=self.config.learning_rate, 53 | weight_decay=self.config.weight_decay, grad_clip=self.config.clip_max_norm) 54 | self.loss_function = self._loss(learning_algorithm=self.config.learning_algorithm) 55 | print(self.optimizer) 56 | print(self.loss_function) 57 | self.best_score = Best_Result() 58 | self.train_iter_len = len(self.train_iter) 59 | self.train_iter_len = len(self.train_iter) 60 | 61 | @staticmethod 62 | def _loss(learning_algorithm): 63 | """ 64 | :param learning_algorithm: 65 | :return: 66 | """ 67 | if learning_algorithm == "SGD": 68 | loss_function = nn.CrossEntropyLoss(reduction="sum") 69 | return loss_function 70 | else: 71 | loss_function = nn.CrossEntropyLoss(reduction="mean") 72 | return loss_function 73 | 74 | def _clip_model_norm(self, clip_max_norm_use, clip_max_norm): 75 | """ 76 | :param clip_max_norm_use: whether to use clip max norm for nn model 77 | :param clip_max_norm: clip max norm max values [float or None] 78 | :return: 79 | """ 80 | if clip_max_norm_use is True: 81 | gclip = None if clip_max_norm == "None" else float(clip_max_norm) 82 | assert isinstance(gclip, float) 83 | utils.clip_grad_norm_(self.model.parameters(), max_norm=gclip) 84 | 85 | def _dynamic_lr(self, config, epoch, new_lr): 86 | """ 87 | :param config: config 88 | :param epoch: epoch 89 | :param new_lr: learning rate 90 | :return: 91 | """ 92 | if config.use_lr_decay is True and epoch > config.max_patience and ( 93 | epoch - 1) % config.max_patience == 0 and new_lr > config.min_lrate: 94 | new_lr = max(new_lr * config.lr_rate_decay, config.min_lrate) 95 | set_lrate(self.optimizer, new_lr) 96 | return new_lr 97 | 98 | def _decay_learning_rate(self, config, epoch, init_lr): 99 | """lr decay 100 | 101 | Args: 102 | epoch: int, epoch 103 | init_lr: initial lr 104 | """ 105 | if config.use_lr_decay: 106 | lr = init_lr / (1 + self.config.lr_rate_decay * epoch) 107 | for param_group in self.optimizer.param_groups: 108 | param_group['lr'] = lr 109 | return self.optimizer 110 | 111 | def _optimizer_batch_step(self, config, backward_count): 112 | """ 113 | :return: 114 | """ 115 | if backward_count % config.backward_batch_size == 0 or backward_count == self.train_iter_len: 116 | self.optimizer.step() 117 | self.optimizer.zero_grad() 118 | 119 | def _early_stop(self, epoch): 120 | """ 121 | :param epoch: 122 | :return: 123 | """ 124 | best_epoch = self.best_score.best_epoch 125 | if epoch > best_epoch: 126 | self.best_score.early_current_patience += 1 127 | print("Dev Has Not Promote {} / {}".format(self.best_score.early_current_patience, self.early_max_patience)) 128 | if self.best_score.early_current_patience >= self.early_max_patience: 129 | print("Early Stop Train. Best Score Locate on {} Epoch.".format(self.best_score.best_epoch)) 130 | exit() 131 | 132 | @staticmethod 133 | def _get_model_args(batch_features): 134 | """ 135 | :param batch_features: Batch Instance 136 | :return: 137 | """ 138 | word = batch_features.word_features 139 | mask = word > 0 140 | sentence_length = batch_features.sentence_length 141 | labels = batch_features.label_features 142 | batch_size = batch_features.batch_length 143 | bert_feature = batch_features.bert_features 144 | return word, bert_feature, mask, sentence_length, labels, batch_size 145 | 146 | def _calculate_loss(self, feats, labels): 147 | """ 148 | Args: 149 | feats: size = (batch_size, seq_len, tag_size) 150 | labels: size = (batch_size, seq_len) 151 | """ 152 | loss_value = self.loss_function(feats, labels) 153 | return loss_value 154 | 155 | def train(self): 156 | """ 157 | :return: 158 | """ 159 | epochs = self.config.epochs 160 | clip_max_norm_use = self.config.clip_max_norm_use 161 | clip_max_norm = self.config.clip_max_norm 162 | new_lr = self.config.learning_rate 163 | 164 | for epoch in range(1, epochs + 1): 165 | print("\n## The {} Epoch, All {} Epochs ! ##".format(epoch, epochs)) 166 | new_lr = self._dynamic_lr(config=self.config, epoch=epoch, new_lr=new_lr) 167 | # self.optimizer = self._decay_learning_rate(config=self.config, epoch=epoch - 1, init_lr=self.config.learning_rate) 168 | print("now lr is {}".format(self.optimizer.param_groups[0].get("lr")), end="") 169 | start_time = time.time() 170 | random.shuffle(self.train_iter) 171 | self.model.train() 172 | steps = 1 173 | backward_count = 0 174 | self.optimizer.zero_grad() 175 | for batch_count, batch_features in enumerate(self.train_iter): 176 | backward_count += 1 177 | # self.optimizer.zero_grad() 178 | batch_size = batch_features.batch_length 179 | labels = batch_features.label_features 180 | logit = self.model(batch_features, train=True) 181 | loss = self._calculate_loss(logit, labels) 182 | loss.backward() 183 | self._clip_model_norm(clip_max_norm_use, clip_max_norm) 184 | self._optimizer_batch_step(config=self.config, backward_count=backward_count) 185 | # self.optimizer.step() 186 | steps += 1 187 | if (steps - 1) % self.config.log_interval == 0: 188 | accuracy = self.getAcc(logit, labels, batch_size) 189 | sys.stdout.write( 190 | "\nbatch_count = [{}/{}] , loss is {:.6f}, [accuracy is {:.6f}%]".format( 191 | batch_count + 1, self.train_iter_len, loss.item(), accuracy)) 192 | end_time = time.time() 193 | print("\nTrain Time {:.3f}".format(end_time - start_time), end="") 194 | self.eval(model=self.model, epoch=epoch, config=self.config) 195 | self._model2file(model=self.model, config=self.config, epoch=epoch) 196 | self._early_stop(epoch=epoch) 197 | 198 | def eval(self, model, epoch, config): 199 | """ 200 | :param model: nn model 201 | :param epoch: epoch 202 | :param config: config 203 | :return: 204 | """ 205 | eval_start_time = time.time() 206 | self.eval_batch(self.dev_iter, model, self.best_score, epoch, config, test=False) 207 | eval_end_time = time.time() 208 | print("Dev Time {:.3f}".format(eval_end_time - eval_start_time)) 209 | 210 | eval_start_time = time.time() 211 | self.eval_batch(self.test_iter, model, self.best_score, epoch, config, test=True) 212 | eval_end_time = time.time() 213 | print("Test Time {:.3f}".format(eval_end_time - eval_start_time)) 214 | 215 | def _model2file(self, model, config, epoch): 216 | """ 217 | :param model: nn model 218 | :param config: config 219 | :param epoch: epoch 220 | :return: 221 | """ 222 | if config.save_model and config.save_all_model: 223 | save_model_all(model, config.save_dir, config.model_name, epoch) 224 | elif config.save_model and config.save_best_model: 225 | save_best_model(model, config.save_best_model_path, config.model_name, self.best_score) 226 | else: 227 | print() 228 | 229 | def eval_batch(self, data_iter, model, best_score, epoch, config, test=False): 230 | """ 231 | :param data_iter: eval batch data iterator 232 | :param model: eval model 233 | :param best_score: 234 | :param epoch: 235 | :param config: config 236 | :param test: whether to test 237 | :return: None 238 | """ 239 | model.eval() 240 | # eval time 241 | corrects = 0 242 | size = 0 243 | for batch_features in data_iter: 244 | # word, mask, sentence_length, labels, batch_size = self._get_model_args(batch_features) 245 | batch_size = batch_features.batch_length 246 | labels = batch_features.label_features 247 | logit = self.model(batch_features, train=False) 248 | size += batch_size 249 | corrects += (torch.max(logit, 1)[1].view(labels.size()).data == labels.data).sum() 250 | 251 | assert size is not 0, print("Error") 252 | accuracy = float(corrects) / size * 100.0 253 | 254 | test_flag = "Test" 255 | if test is False: 256 | print() 257 | test_flag = "Dev" 258 | best_score.current_dev_score = accuracy 259 | if accuracy >= best_score.best_dev_score: 260 | best_score.best_dev_score = accuracy 261 | best_score.best_epoch = epoch 262 | best_score.best_test = True 263 | if test is True and best_score.best_test is True: 264 | best_score.p = accuracy 265 | print("{} eval: Accuracy = {:.6f}%".format(test_flag, accuracy)) 266 | if test is True: 267 | print("The Current Best Dev Accuracy: {:.6f}, Locate on {} Epoch.".format(best_score.best_dev_score, 268 | best_score.best_epoch)) 269 | print("The Current Best Test Accuracy: accuracy = {:.6f}%".format(best_score.p)) 270 | if test is True: 271 | best_score.best_test = False 272 | 273 | @staticmethod 274 | def getAcc(logit, target, batch_size): 275 | """ 276 | :param logit: model predict 277 | :param target: gold label 278 | :param batch_size: batch size 279 | :param config: config 280 | :return: 281 | """ 282 | corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum() 283 | accuracy = float(corrects) / batch_size * 100.0 284 | return accuracy 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | --------------------------------------------------------------------------------