├── README.md ├── bert_cnews.py ├── loss.png ├── main.py ├── model.py ├── utils.py ├── vocab.json └── vocab.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi_Model_Classification 2 | 多模型中文新闻文本分类 3 | 4 | 5 | 本文构建了RNN、CNN、AVG、BERT模型做中文新闻cnews文本分类任务,各模型的结果汇总如下: 6 | 7 | 模型|acc|f1-score|acc_and_f1 8 | :----:|:----:|:----:|:----: 9 | AVG|0.9391|0.9385|0.9388 10 | CNN|0.979|0.9789|0.9790 11 | RNN|0.9676|0.9672|0.9674 12 | BERT|0.9656|0.9654|0.9655 13 | 14 | 15 | 模型loss曲线如下: 16 | 17 | ![avatar](loss.png) 18 | 19 | 20 | 综合比较分析: 21 | 22 | * 对词向量取平均的AVG模型准确率也能达到0.9391,说明神经网络的拟合能力确实超强,同时AVG模型简单直接性能也不错,可以作为一个baseline选择 23 | * 训练数据集可能比较简单 24 | * CNN模型在数据集上表现最佳,且训练时间最短,说明在简单任务或复杂模型架构中可以加入CNN网络,没必要万事BERT起步 25 | * BERT模型未能达到最佳性能,可能得推测为数据集相对简单。且BERT简单fine-tuning一个epoch指标已经很高,说明BERT在某些简单任务上不做fine-tuning都可以。参考论文To Tune or Not to Tune(https://arxiv.org/abs/1903.05987) 26 | -------------------------------------------------------------------------------- /bert_cnews.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa).""" 17 | 18 | from __future__ import absolute_import, division, print_function 19 | 20 | import argparse 21 | import glob 22 | import logging 23 | import os 24 | import random 25 | 26 | import numpy as np 27 | import torch 28 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 29 | TensorDataset) 30 | from torch.utils.data.distributed import DistributedSampler 31 | from transformers import DataProcessor, InputExample, InputFeatures 32 | 33 | try: 34 | from torch.utils.tensorboard import SummaryWriter 35 | except: 36 | from tensorboardX import SummaryWriter 37 | 38 | from tqdm import tqdm, trange 39 | 40 | from transformers import (WEIGHTS_NAME, BertConfig, 41 | BertForSequenceClassification, BertTokenizer, 42 | RobertaConfig, 43 | RobertaForSequenceClassification, 44 | RobertaTokenizer, 45 | XLMConfig, XLMForSequenceClassification, 46 | XLMTokenizer, XLNetConfig, 47 | XLNetForSequenceClassification, 48 | XLNetTokenizer, 49 | DistilBertConfig, 50 | DistilBertForSequenceClassification, 51 | DistilBertTokenizer, 52 | AlbertConfig, 53 | AlbertForSequenceClassification, 54 | AlbertTokenizer, 55 | ) 56 | 57 | from transformers import AdamW, get_linear_schedule_with_warmup 58 | 59 | from transformers import glue_compute_metrics as compute_metrics 60 | from transformers import glue_output_modes as output_modes 61 | from transformers import glue_processors as processors 62 | from transformers import glue_convert_examples_to_features as convert_examples_to_features 63 | 64 | from sklearn.metrics import f1_score 65 | 66 | logger = logging.getLogger(__name__) 67 | 68 | ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, 69 | RobertaConfig, DistilBertConfig)), ()) 70 | 71 | MODEL_CLASSES = { 72 | 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 73 | 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), 74 | 'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer), 75 | 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), 76 | 'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer), 77 | 'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer) 78 | } 79 | 80 | 81 | def set_seed(args): 82 | random.seed(args.seed) 83 | np.random.seed(args.seed) 84 | torch.manual_seed(args.seed) 85 | if args.n_gpu > 0: 86 | torch.cuda.manual_seed_all(args.seed) 87 | 88 | 89 | def simple_accuracy(preds, labels): 90 | return (preds == labels).mean() 91 | 92 | 93 | def acc_and_f1(preds, labels): 94 | acc = simple_accuracy(preds, labels) 95 | f1 = f1_score(y_true=labels, y_pred=preds,average='weighted') 96 | return { 97 | "acc": acc, 98 | "f1": f1, 99 | "acc_and_f1": (acc + f1) / 2, 100 | } 101 | 102 | class CnesProcessor(DataProcessor): 103 | """Processor for the cnews data set (GLUE version).""" 104 | 105 | def get_example_from_tensor_dict(self, tensor_dict): 106 | """See base class.""" 107 | return InputExample(tensor_dict['idx'].numpy(), 108 | tensor_dict['sentence'].numpy().decode('utf-8'), 109 | None, 110 | str(tensor_dict['label'].numpy())) 111 | 112 | def get_train_examples(self, data_dir): 113 | """See base class.""" 114 | return self._create_examples( 115 | self._read_tsv(os.path.join(data_dir, "cnews.train.txt")), "train") 116 | 117 | def get_dev_examples(self, data_dir): 118 | """See base class.""" 119 | return self._create_examples( 120 | self._read_tsv(os.path.join(data_dir, "cnews.test.txt")), "dev") 121 | 122 | def get_labels(self): 123 | """See base class.""" 124 | return ["体育", 125 | "娱乐", 126 | "家居", 127 | "房产", 128 | "教育", 129 | "时尚", 130 | "时政", 131 | "游戏", 132 | "科技", 133 | "财经"] 134 | 135 | def _create_examples(self, lines, set_type): 136 | """Creates examples for the training and dev sets.""" 137 | examples = [] 138 | for (i, line) in enumerate(lines): 139 | guid = "%s-%s" % (set_type, i) 140 | text_a = line[1] 141 | label = line[0] 142 | examples.append( 143 | InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) 144 | return examples 145 | 146 | 147 | processors["cnews"] = CnesProcessor 148 | output_modes["cnews"] = "classification" 149 | 150 | 151 | def train(args, train_dataset, model, tokenizer): 152 | """ Train the model """ 153 | if args.local_rank in [-1, 0]: 154 | tb_writer = SummaryWriter('./runs/bert') 155 | 156 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 157 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 158 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 159 | 160 | if args.max_steps > 0: 161 | t_total = args.max_steps 162 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 163 | else: 164 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 165 | 166 | # Prepare optimizer and schedule (linear warmup and decay) 167 | no_decay = ['bias', 'LayerNorm.weight'] 168 | optimizer_grouped_parameters = [ 169 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 170 | 'weight_decay': args.weight_decay}, 171 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 172 | ] 173 | 174 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 175 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, 176 | num_training_steps=t_total) 177 | if args.fp16: 178 | try: 179 | from apex import amp 180 | except ImportError: 181 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 182 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 183 | 184 | # multi-gpu training (should be after apex fp16 initialization) 185 | if args.n_gpu > 1: 186 | model = torch.nn.DataParallel(model) 187 | 188 | # Distributed training (should be after apex fp16 initialization) 189 | if args.local_rank != -1: 190 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 191 | output_device=args.local_rank, 192 | find_unused_parameters=True) 193 | # Train! 194 | logger.info("***** Running training *****") 195 | logger.info(" Num examples = %d", len(train_dataset)) 196 | logger.info(" Num Epochs = %d", args.num_train_epochs) 197 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 198 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 199 | args.train_batch_size * args.gradient_accumulation_steps * ( 200 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 201 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 202 | logger.info(" Total optimization steps = %d", t_total) 203 | 204 | global_step = 0 205 | tr_loss, logging_loss = 0.0, 0.0 206 | model.zero_grad() 207 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 208 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 209 | for _ in train_iterator: 210 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) 211 | for step, batch in enumerate(epoch_iterator): 212 | model.train() 213 | batch = tuple(t.to(args.device) for t in batch) 214 | inputs = {'input_ids': batch[0], 215 | 'attention_mask': batch[1], 216 | 'labels': batch[3]} 217 | if args.model_type != 'distilbert': 218 | inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 219 | 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids 220 | outputs = model(**inputs) 221 | loss = outputs[0] # model outputs are always tuple in transformers (see doc) 222 | 223 | if args.n_gpu > 1: 224 | loss = loss.mean() # mean() to average on multi-gpu parallel training 225 | if args.gradient_accumulation_steps > 1: 226 | loss = loss / args.gradient_accumulation_steps 227 | 228 | if args.fp16: 229 | with amp.scale_loss(loss, optimizer) as scaled_loss: 230 | scaled_loss.backward() 231 | else: 232 | loss.backward() 233 | 234 | tr_loss += loss.item() 235 | if (step + 1) % args.gradient_accumulation_steps == 0: 236 | if args.fp16: 237 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 238 | else: 239 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 240 | 241 | optimizer.step() 242 | scheduler.step() # Update learning rate schedule 243 | model.zero_grad() 244 | global_step += 1 245 | 246 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 247 | # Log metrics 248 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 249 | results = evaluate(args, model, tokenizer) 250 | for key, value in results.items(): 251 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 252 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 253 | tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step) 254 | logging_loss = tr_loss 255 | 256 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 257 | # Save model checkpoint 258 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 259 | if not os.path.exists(output_dir): 260 | os.makedirs(output_dir) 261 | model_to_save = model.module if hasattr(model, 262 | 'module') else model # Take care of distributed/parallel training 263 | model_to_save.save_pretrained(output_dir) 264 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 265 | logger.info("Saving model checkpoint to %s", output_dir) 266 | 267 | if args.max_steps > 0 and global_step > args.max_steps: 268 | epoch_iterator.close() 269 | break 270 | if args.max_steps > 0 and global_step > args.max_steps: 271 | train_iterator.close() 272 | break 273 | 274 | if args.local_rank in [-1, 0]: 275 | tb_writer.close() 276 | 277 | return global_step, tr_loss / global_step 278 | 279 | 280 | def evaluate(args, model, tokenizer, prefix=""): 281 | # Loop to handle MNLI double evaluation (matched, mis-matched) 282 | eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,) 283 | eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,) 284 | 285 | results = {} 286 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 287 | eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True) 288 | 289 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 290 | os.makedirs(eval_output_dir) 291 | 292 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 293 | # Note that DistributedSampler samples randomly 294 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 295 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 296 | 297 | # multi-gpu eval 298 | if args.n_gpu > 1: 299 | model = torch.nn.DataParallel(model) 300 | 301 | # Eval! 302 | logger.info("***** Running evaluation {} *****".format(prefix)) 303 | logger.info(" Num examples = %d", len(eval_dataset)) 304 | logger.info(" Batch size = %d", args.eval_batch_size) 305 | eval_loss = 0.0 306 | nb_eval_steps = 0 307 | preds = None 308 | out_label_ids = None 309 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 310 | model.eval() 311 | batch = tuple(t.to(args.device) for t in batch) 312 | 313 | with torch.no_grad(): 314 | inputs = {'input_ids': batch[0], 315 | 'attention_mask': batch[1], 316 | 'labels': batch[3]} 317 | if args.model_type != 'distilbert': 318 | inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 319 | 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids 320 | outputs = model(**inputs) 321 | tmp_eval_loss, logits = outputs[:2] 322 | 323 | eval_loss += tmp_eval_loss.mean().item() 324 | nb_eval_steps += 1 325 | if preds is None: 326 | preds = logits.detach().cpu().numpy() 327 | out_label_ids = inputs['labels'].detach().cpu().numpy() 328 | else: 329 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 330 | out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0) 331 | 332 | eval_loss = eval_loss / nb_eval_steps 333 | if args.output_mode == "classification": 334 | preds = np.argmax(preds, axis=1) 335 | elif args.output_mode == "regression": 336 | preds = np.squeeze(preds) 337 | 338 | if eval_task == "cnews": 339 | result = acc_and_f1(preds,out_label_ids) 340 | else: 341 | result = compute_metrics(eval_task, preds, out_label_ids) 342 | results.update(result) 343 | 344 | output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") 345 | with open(output_eval_file, "w") as writer: 346 | logger.info("***** Eval results {} *****".format(prefix)) 347 | for key in sorted(result.keys()): 348 | logger.info(" %s = %s", key, str(result[key])) 349 | writer.write("%s = %s\n" % (key, str(result[key]))) 350 | 351 | return results 352 | 353 | 354 | def load_and_cache_examples(args, task, tokenizer, evaluate=False): 355 | if args.local_rank not in [-1, 0] and not evaluate: 356 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 357 | 358 | processor = processors[task]() 359 | output_mode = output_modes[task] 360 | # Load data features from cache or dataset file 361 | cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format( 362 | 'dev' if evaluate else 'train', 363 | 'bert', 364 | str(args.max_seq_length), 365 | str(task))) 366 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 367 | logger.info("Loading features from cached file %s", cached_features_file) 368 | features = torch.load(cached_features_file) 369 | else: 370 | logger.info("Creating features from dataset file at %s", args.data_dir) 371 | label_list = processor.get_labels() 372 | if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']: 373 | # HACK(label indices are swapped in RoBERTa pretrained model) 374 | label_list[1], label_list[2] = label_list[2], label_list[1] 375 | examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples( 376 | args.data_dir) 377 | features = convert_examples_to_features(examples, 378 | tokenizer, 379 | label_list=label_list, 380 | max_length=args.max_seq_length, 381 | output_mode=output_mode, 382 | pad_on_left=bool(args.model_type in ['xlnet']), 383 | # pad on the left for xlnet 384 | pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0], 385 | pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0, 386 | ) 387 | if args.local_rank in [-1, 0]: 388 | logger.info("Saving features into cached file %s", cached_features_file) 389 | torch.save(features, cached_features_file) 390 | 391 | if args.local_rank == 0 and not evaluate: 392 | torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache 393 | 394 | # Convert to Tensors and build dataset 395 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 396 | all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long) 397 | all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long) 398 | if output_mode == "classification": 399 | all_labels = torch.tensor([f.label for f in features], dtype=torch.long) 400 | elif output_mode == "regression": 401 | all_labels = torch.tensor([f.label for f in features], dtype=torch.float) 402 | 403 | dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels) 404 | return dataset 405 | 406 | 407 | def main(): 408 | parser = argparse.ArgumentParser() 409 | 410 | ## Required parameters 411 | parser.add_argument("--data_dir", default='./cnews', type=str, required=False, 412 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 413 | parser.add_argument("--model_type", default='bert', type=str, required=False, 414 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 415 | parser.add_argument("--model_name_or_path", default='D:\\NLP\\my-wholes-models\\chinese_wwm_pytorch', type=str, required=False, 416 | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join( 417 | ALL_MODELS)) 418 | parser.add_argument("--task_name", default='cnews', type=str, required=False, 419 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 420 | parser.add_argument("--output_dir", default='./outs', type=str, required=False, 421 | help="The output directory where the model predictions and checkpoints will be written.") 422 | 423 | ## Other parameters 424 | parser.add_argument("--config_name", default="", type=str, 425 | help="Pretrained config name or path if not the same as model_name") 426 | parser.add_argument("--tokenizer_name", default="", type=str, 427 | help="Pretrained tokenizer name or path if not the same as model_name") 428 | parser.add_argument("--cache_dir", default="", type=str, 429 | help="Where do you want to store the pre-trained models downloaded from s3") 430 | parser.add_argument("--max_seq_length", default=128, type=int, 431 | help="The maximum total input sequence length after tokenization. Sequences longer " 432 | "than this will be truncated, sequences shorter will be padded.") 433 | parser.add_argument("--do_train", action='store_true', 434 | help="Whether to run training.") 435 | parser.add_argument("--do_eval",default=True, action='store_true', 436 | help="Whether to run eval on the dev set.") 437 | parser.add_argument("--evaluate_during_training", action='store_true', 438 | help="Rul evaluation during training at each logging step.") 439 | parser.add_argument("--do_lower_case", action='store_true', 440 | help="Set this flag if you are using an uncased model.") 441 | 442 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 443 | help="Batch size per GPU/CPU for training.") 444 | parser.add_argument("--per_gpu_eval_batch_size", default=16, type=int, 445 | help="Batch size per GPU/CPU for evaluation.") 446 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 447 | help="Number of updates steps to accumulate before performing a backward/update pass.") 448 | parser.add_argument("--learning_rate", default=2e-5, type=float, 449 | help="The initial learning rate for Adam.") 450 | parser.add_argument("--weight_decay", default=0.0, type=float, 451 | help="Weight deay if we apply some.") 452 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, 453 | help="Epsilon for Adam optimizer.") 454 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 455 | help="Max gradient norm.") 456 | parser.add_argument("--num_train_epochs", default=10.0, type=float, 457 | help="Total number of training epochs to perform.") 458 | parser.add_argument("--max_steps", default=-1, type=int, 459 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 460 | parser.add_argument("--warmup_steps", default=0, type=int, 461 | help="Linear warmup over warmup_steps.") 462 | 463 | parser.add_argument('--logging_steps', type=int, default=100, 464 | help="Log every X updates steps.") 465 | parser.add_argument('--save_steps', type=int, default=6000, 466 | help="Save checkpoint every X updates steps.") 467 | parser.add_argument("--eval_all_checkpoints",default=True, action='store_true', 468 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 469 | parser.add_argument("--no_cuda", action='store_true', 470 | help="Avoid using CUDA when available") 471 | parser.add_argument('--overwrite_output_dir', action='store_true', 472 | help="Overwrite the content of the output directory") 473 | parser.add_argument('--overwrite_cache', action='store_true', 474 | help="Overwrite the cached training and evaluation sets") 475 | parser.add_argument('--seed', type=int, default=42, 476 | help="random seed for initialization") 477 | 478 | parser.add_argument('--fp16', action='store_true', 479 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 480 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 481 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 482 | "See details at https://nvidia.github.io/apex/amp.html") 483 | parser.add_argument("--local_rank", type=int, default=-1, 484 | help="For distributed training: local_rank") 485 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 486 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 487 | args = parser.parse_args() 488 | 489 | if os.path.exists(args.output_dir) and os.listdir( 490 | args.output_dir) and args.do_train and not args.overwrite_output_dir: 491 | raise ValueError( 492 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 493 | args.output_dir)) 494 | 495 | # Setup distant debugging if needed 496 | if args.server_ip and args.server_port: 497 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 498 | import ptvsd 499 | print("Waiting for debugger attach") 500 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 501 | ptvsd.wait_for_attach() 502 | 503 | # Setup CUDA, GPU & distributed training 504 | if args.local_rank == -1 or args.no_cuda: 505 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 506 | args.n_gpu = torch.cuda.device_count() 507 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 508 | torch.cuda.set_device(args.local_rank) 509 | device = torch.device("cuda", args.local_rank) 510 | torch.distributed.init_process_group(backend='nccl') 511 | args.n_gpu = 1 512 | args.device = device 513 | 514 | # Setup logging 515 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 516 | datefmt='%m/%d/%Y %H:%M:%S', 517 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 518 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 519 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 520 | 521 | # Set seed 522 | set_seed(args) 523 | 524 | # Prepare GLUE task 525 | args.task_name = args.task_name.lower() 526 | if args.task_name not in processors: 527 | raise ValueError("Task not found: %s" % (args.task_name)) 528 | processor = processors[args.task_name]() 529 | args.output_mode = output_modes[args.task_name] 530 | label_list = processor.get_labels() 531 | num_labels = len(label_list) 532 | 533 | # Load pretrained model and tokenizer 534 | if args.local_rank not in [-1, 0]: 535 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 536 | 537 | args.model_type = args.model_type.lower() 538 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 539 | config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path, 540 | num_labels=num_labels, 541 | finetuning_task=args.task_name, 542 | cache_dir=args.cache_dir if args.cache_dir else None) 543 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 544 | do_lower_case=args.do_lower_case, 545 | cache_dir=args.cache_dir if args.cache_dir else None) 546 | model = model_class.from_pretrained(args.model_name_or_path, 547 | from_tf=bool('.ckpt' in args.model_name_or_path), 548 | config=config, 549 | cache_dir=args.cache_dir if args.cache_dir else None) 550 | 551 | if args.local_rank == 0: 552 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 553 | 554 | model.to(args.device) 555 | 556 | logger.info("Training/evaluation parameters %s", args) 557 | 558 | # Training 559 | if args.do_train: 560 | train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 561 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 562 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 563 | 564 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 565 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 566 | # Create output directory if needed 567 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 568 | os.makedirs(args.output_dir) 569 | 570 | logger.info("Saving model checkpoint to %s", args.output_dir) 571 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 572 | # They can then be reloaded using `from_pretrained()` 573 | model_to_save = model.module if hasattr(model, 574 | 'module') else model # Take care of distributed/parallel training 575 | model_to_save.save_pretrained(args.output_dir) 576 | tokenizer.save_pretrained(args.output_dir) 577 | 578 | # Good practice: save your training arguments together with the trained model 579 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 580 | 581 | # Load a trained model and vocabulary that you have fine-tuned 582 | model = model_class.from_pretrained(args.output_dir) 583 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 584 | model.to(args.device) 585 | 586 | # Evaluation 587 | results = {} 588 | if args.do_eval and args.local_rank in [-1, 0]: 589 | tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) 590 | checkpoints = [args.output_dir] 591 | if args.eval_all_checkpoints: 592 | checkpoints = list( 593 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 594 | logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 595 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 596 | for checkpoint in checkpoints: 597 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 598 | prefix = "checkpoint-" + str(checkpoint.split("-")[-1]) if checkpoint.find('checkpoint') != -1 else "" 599 | 600 | model = model_class.from_pretrained(checkpoint) 601 | model.to(args.device) 602 | result = evaluate(args, model, tokenizer, prefix=prefix) 603 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 604 | results.update(result) 605 | 606 | return results 607 | 608 | 609 | if __name__ == "__main__": 610 | main() 611 | -------------------------------------------------------------------------------- /loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeHappyForMe/Multi_Model_Classification/67dfd01c309c09f0de50134e1d779fc5223868a1/loss.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | import time 6 | import numpy as np 7 | import pkuseg 8 | import argparse 9 | from tqdm import trange,tqdm 10 | import os 11 | from utils import read_corpus,batch_iter 12 | from vocab import Vocab 13 | from model import RNN,CNN,WordAVGModel 14 | import math 15 | from sklearn.metrics import f1_score 16 | 17 | try: 18 | from torch.utils.tensorboard import SummaryWriter 19 | except ImportError: 20 | from tensorboardX import SummaryWriter 21 | 22 | from transformers import AdamW,get_linear_schedule_with_warmup 23 | 24 | 25 | def set_seed(): 26 | random.seed(3344) 27 | np.random.seed(3344) 28 | torch.manual_seed(3344) 29 | if torch.cuda.is_available(): 30 | torch.cuda.manual_seed(3344) 31 | 32 | 33 | def tokenizer(text): 34 | """ 35 | 定义TEXT的tokenize规则 36 | """ 37 | # regex = re.compile(r'[^\u4e00-\u9fa5A-Za-z0-9]') 38 | # text = regex.sub(' ', text) 39 | seg = pkuseg.pkuseg() 40 | return [word for word in seg.cut(text) if word.strip()] 41 | 42 | 43 | def train(args,model, train_data,dev_data,vocab,dtype='CNN'): 44 | LOG_FILE = args.output_file 45 | with open(LOG_FILE, "a") as fout: 46 | fout.write('\n') 47 | fout.write('=========='*6) 48 | fout.write('start trainning: {}'.format(dtype)) 49 | fout.write('\n') 50 | 51 | time_start = time.time() 52 | if not os.path.exists(os.path.join('./runs',dtype)): 53 | os.makedirs(os.path.join('./runs',dtype)) 54 | tb_writer = SummaryWriter(os.path.join('./runs',dtype)) 55 | 56 | t_total = args.num_epoch * (math.ceil(len(train_data) / args.batch_size)) 57 | optimizer = AdamW(model.parameters(), lr=args.learnning_rate, eps=1e-8) 58 | scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=args.warmup_steps, 59 | num_training_steps=t_total) 60 | criterion = nn.CrossEntropyLoss() 61 | global_step = 0 62 | total_loss = 0. 63 | logg_loss = 0. 64 | val_acces = [] 65 | train_epoch = trange(args.num_epoch,desc='train_epoch') 66 | for epoch in train_epoch: 67 | model.train() 68 | 69 | for src_sents,labels in batch_iter(train_data,args.batch_size,shuffle=True): 70 | src_sents = vocab.vocab.to_input_tensor(src_sents,args.device) 71 | global_step += 1 72 | optimizer.zero_grad() 73 | 74 | logits = model(src_sents) 75 | y_labels = torch.tensor(labels,device=args.device) 76 | 77 | example_losses = criterion(logits,y_labels) 78 | 79 | example_losses.backward() 80 | torch.nn.utils.clip_grad_norm_(model.parameters(),args.GRAD_CLIP) 81 | optimizer.step() 82 | scheduler.step() 83 | 84 | total_loss += example_losses.item() 85 | if global_step % 100 == 0: 86 | loss_scalar = (total_loss - logg_loss) / 100 87 | logg_loss = total_loss 88 | 89 | with open(LOG_FILE, "a") as fout: 90 | fout.write("epoch: {}, iter: {}, loss: {},learn_rate: {}\n".format(epoch, global_step, loss_scalar, 91 | scheduler.get_lr()[0])) 92 | print("epoch: {}, iter: {}, loss: {}, learning_rate: {}".format(epoch, global_step, loss_scalar, 93 | scheduler.get_lr()[0])) 94 | tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step) 95 | tb_writer.add_scalar("loss", loss_scalar, global_step) 96 | 97 | print("Epoch", epoch, "Training loss", total_loss / global_step) 98 | 99 | eval_loss,eval_result = evaluate(args,criterion, model, dev_data,vocab) # 评估模型 100 | with open(LOG_FILE, "a") as fout: 101 | fout.write("EVALUATE: epoch: {}, loss: {},eval_result: {}\n".format(epoch, eval_loss,eval_result)) 102 | # tb_writer.add_scalars('eval_result',eval_result,epoch) 103 | # tb_writer.add_scalar('eval loss',eval_loss,epoch) 104 | eval_acc = eval_result['acc'] 105 | if len(val_acces) == 0 or eval_acc > max(val_acces): 106 | # 如果比之前的acc要da,就保存模型 107 | print("best model on epoch: {}, eval_acc: {}".format(epoch, eval_acc)) 108 | torch.save(model.state_dict(), "classifa-best-{}.th".format(dtype)) 109 | val_acces.append(eval_acc) 110 | 111 | time_end = time.time() 112 | print("run model of {},taking total {} m".format(dtype,(time_end-time_start)/60)) 113 | with open(LOG_FILE, "a") as fout: 114 | fout.write("run model of {},taking total {} m\n".format(dtype,(time_end-time_start)/60)) 115 | 116 | def evaluate(args,criterion,model, dev_data,vocab): 117 | model.eval() 118 | total_loss = 0. 119 | total_step = 0. 120 | preds = None 121 | out_label_ids = None 122 | with torch.no_grad():#不需要更新模型,不需要梯度 123 | for src_sents, labels in batch_iter(dev_data, args.batch_size): 124 | src_sents = vocab.vocab.to_input_tensor(src_sents, args.device) 125 | logits = model(src_sents) 126 | labels = torch.tensor(labels,device=args.device) 127 | example_losses = criterion(logits,labels) 128 | 129 | total_loss += example_losses.item() 130 | total_step += 1 131 | 132 | if preds is None: 133 | preds = logits.detach().cpu().numpy() 134 | out_label_ids = labels.detach().cpu().numpy() 135 | else: 136 | preds = np.append(preds, logits.detach().cpu().numpy(), axis=0) 137 | out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0) 138 | 139 | preds = np.argmax(preds, axis=1) 140 | result = acc_and_f1(preds, out_label_ids) 141 | model.train() 142 | print("Evaluation loss", total_loss/total_step) 143 | print('Evaluation result', result) 144 | return total_loss/total_step,result 145 | 146 | def acc_and_f1(preds,labels): 147 | acc = (preds == labels).mean() 148 | f1 = f1_score(y_true=labels,y_pred=preds,average='weighted') 149 | return { 150 | "acc": acc, 151 | "f1": f1, 152 | "acc_and_f1": (acc + f1) / 2, 153 | } 154 | 155 | def build_vocab(args): 156 | if not os.path.exists(args.vocab_path): 157 | src_sents, labels = read_corpus(args.train_data_dir) 158 | labels = {label: idx for idx, label in enumerate(labels)} 159 | vocab = Vocab.build(src_sents, labels, args.max_vocab_size, args.min_freq) 160 | vocab.save(args.vocab_path) 161 | else: 162 | vocab = Vocab.load(args.vocab_path) 163 | return vocab 164 | 165 | 166 | def main(): 167 | parse = argparse.ArgumentParser() 168 | 169 | parse.add_argument("--train_data_dir", default='./cnews/cnews.train.txt', type=str, required=False) 170 | parse.add_argument("--dev_data_dir", default='./cnews/cnews.val.txt', type=str, required=False) 171 | parse.add_argument("--test_data_dir", default='./cnews/cnews.test.txt', type=str, required=False) 172 | parse.add_argument("--output_file", default='deep_model.log', type=str, required=False) 173 | parse.add_argument("--batch_size", default=8, type=int) 174 | parse.add_argument("--do_train",default=True, action="store_true", help="Whether to run training.") 175 | parse.add_argument("--do_test",default=True, action="store_true", help="Whether to run training.") 176 | parse.add_argument("--learnning_rate", default=5e-4, type=float) 177 | parse.add_argument("--num_epoch", default=10, type=int) 178 | parse.add_argument("--max_vocab_size", default=50000, type=int) 179 | parse.add_argument("--min_freq", default=2, type=int) 180 | parse.add_argument("--embed_size", default=300, type=int) 181 | parse.add_argument("--hidden_size", default=256, type=int) 182 | parse.add_argument("--dropout_rate", default=0.2, type=float) 183 | parse.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 184 | parse.add_argument("--GRAD_CLIP", default=1, type=float) 185 | parse.add_argument("--vocab_path", default='./vocab.json', type=str) 186 | parse.add_argument("--do_cnn",default=True, action="store_true", help="Whether to run training.") 187 | parse.add_argument("--do_rnn", default=True, action="store_true", help="Whether to run training.") 188 | parse.add_argument("--do_avg",default=True, action="store_true", help="Whether to run training.") 189 | 190 | parse.add_argument("--num_filter", default=100, type=int,help="CNN模型一个filter的输出channels") 191 | 192 | args = parse.parse_args() 193 | 194 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 195 | args.device = device 196 | set_seed() 197 | 198 | if os.path.exists('./cnews/cache_train_data'): 199 | train_data = torch.load('./cnews/cache_train_data') 200 | else: 201 | train_data = read_corpus(args.train_data_dir) 202 | train_data = [(text,labs) for text,labs in zip(*train_data)] 203 | torch.save(train_data,'./cnews/cache_train_data') 204 | 205 | if os.path.exists('./cnews/cache_dev_data'): 206 | dev_data = torch.load('./cnews/cache_dev_data') 207 | else: 208 | dev_data = read_corpus(args.dev_data_dir) 209 | dev_data = [(text,labs) for text,labs in zip(*dev_data)] 210 | torch.save(dev_data, './cnews/cache_dev_data') 211 | 212 | vocab = build_vocab(args) 213 | label_map = vocab.labels 214 | print(label_map) 215 | 216 | if args.do_train: 217 | if args.do_cnn: 218 | cnn_model = CNN(len(vocab.vocab),args.embed_size,args.num_filter,[2,3,4],len(label_map),dropout=args.dropout_rate) 219 | cnn_model.to(device) 220 | train(args,cnn_model,train_data,dev_data,vocab,dtype='CNN') 221 | 222 | if args.do_avg: 223 | avg_model = WordAVGModel(len(vocab.vocab),args.embed_size,len(label_map),dropout=args.dropout_rate) 224 | avg_model.to(device) 225 | train(args, avg_model, train_data, dev_data, vocab, dtype='AVG') 226 | 227 | if args.do_rnn: 228 | rnn_model = RNN(len(vocab.vocab),args.embed_size,args.hidden_size, 229 | len(label_map),n_layers=1,bidirectional=True,dropout=args.dropout_rate) 230 | rnn_model.to(device) 231 | train(args, rnn_model, train_data, dev_data, vocab, dtype='RNN') 232 | 233 | if args.do_test: 234 | 235 | if os.path.exists('./cnews/cache_test_data'): 236 | test_data = torch.load('./cnews/cache_test_data') 237 | else: 238 | test_data = read_corpus(args.test_data_dir) 239 | test_data = [(text, labs) for text, labs in zip(*test_data)] 240 | torch.save(test_data, './cnews/cache_test_data') 241 | 242 | cirtion = nn.CrossEntropyLoss() 243 | 244 | cnn_model = CNN(len(vocab.vocab), args.embed_size, args.num_filter, [2, 3, 4], len(label_map), 245 | dropout=args.dropout_rate) 246 | cnn_model.load_state_dict(torch.load('classifa-best-CNN.th')) 247 | cnn_model.to(device) 248 | cnn_test_loss , cnn_result = evaluate(args,cirtion,cnn_model,test_data,vocab) 249 | 250 | avg_model = WordAVGModel(len(vocab.vocab), args.embed_size, len(label_map), dropout=args.dropout_rate) 251 | avg_model.load_state_dict(torch.load('classifa-best-AVG.th')) 252 | avg_model.to(device) 253 | avg_test_loss, avg_result = evaluate(args, cirtion, avg_model, test_data, vocab) 254 | 255 | rnn_model = RNN(len(vocab.vocab), args.embed_size, args.hidden_size, 256 | len(label_map), n_layers=1, bidirectional=True, dropout=args.dropout_rate) 257 | rnn_model.load_state_dict(torch.load('classifa-best-RNN.th')) 258 | rnn_model.to(device) 259 | rnn_test_loss, rnn_result = evaluate(args, cirtion, rnn_model, test_data, vocab) 260 | 261 | with open(args.output_file, "a") as fout: 262 | fout.write('\n') 263 | fout.write('=============== test result ============\n') 264 | fout.write("test model of {}, loss: {},result: {}\n".format('CNN', cnn_test_loss,cnn_result)) 265 | fout.write("test model of {}, loss: {},result: {}\n".format('AVG', avg_test_loss, avg_result)) 266 | fout.write("test model of {}, loss: {},result: {}\n".format('RNN', rnn_test_loss, rnn_result)) 267 | 268 | 269 | if __name__ == '__main__': 270 | main() 271 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class WordAVGModel(nn.Module): 7 | def __init__(self, vocab_size, embedding_dim, output_dim,dropout=0.2, pad_idx=0): 8 | # 初始化参数, 9 | super().__init__() 10 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) 11 | 12 | self.fc = nn.Linear(embedding_dim, output_dim) 13 | self.dropout = nn.Dropout(dropout) 14 | 15 | def forward(self, text): 16 | # embedded.shape = (batch_size,seq,embed_size) 17 | embedded = self.dropout(self.embedding(text)) 18 | 19 | pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) 20 | # [batch size, embedding_dim] 把单词长度的维度压扁为1,并降维 21 | 22 | return self.fc(pooled) 23 | # (batch size,output_dim) 24 | 25 | 26 | class RNN(nn.Module): 27 | def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, 28 | n_layers=2, bidirectional=True, dropout=0.2, pad_idx=0): 29 | super().__init__() 30 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) 31 | self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers,batch_first=True, 32 | bidirectional=bidirectional) 33 | 34 | self.fc = nn.Linear(hidden_dim * 2, output_dim) 35 | # 这里hidden_dim乘以2是因为是双向,需要拼接两个方向,跟n_layers的层数无关。 36 | 37 | self.dropout = nn.Dropout(dropout) 38 | 39 | def forward(self, text): 40 | # text.shape=[seq_len, batch_size] 41 | embedded = self.dropout(self.embedding(text)) 42 | # output: [batch,seq,2*hidden if bidirection else hidden] 43 | # hidden/cell: [bidirec * n_layers, batch, hidden] 44 | output, (hidden, cell) = self.rnn(embedded) 45 | 46 | # concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers 47 | hidden = self.dropout(torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)) 48 | # hidden = [batch size, hid dim * num directions], 49 | 50 | return self.fc(hidden.squeeze(0)) # 在接一个全连接层,最终输出[batch size, output_dim] 51 | 52 | 53 | class CNN(nn.Module): 54 | def __init__(self, vocab_size, embedding_dim, num_filter, 55 | filter_sizes, output_dim, dropout=0.2, pad_idx=0): 56 | super().__init__() 57 | 58 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx) 59 | self.convs = nn.ModuleList([ 60 | nn.Conv2d(in_channels=1, out_channels=num_filter, 61 | kernel_size=(fs, embedding_dim)) 62 | for fs in filter_sizes 63 | ]) 64 | # in_channels:输入的channel,文字都是1 65 | # out_channels:输出的channel维度 66 | # fs:每次滑动窗口计算用到几个单词,相当于n-gram中的n 67 | # for fs in filter_sizes用好几个卷积模型最后concate起来看效果。 68 | 69 | self.fc = nn.Linear(len(filter_sizes) * num_filter, output_dim) 70 | self.dropout = nn.Dropout(dropout) 71 | 72 | def forward(self, text): 73 | embedded = self.dropout(self.embedding(text)) # [batch size, sent len, emb dim] 74 | embedded = embedded.unsqueeze(1) # [batch size, 1, sent len, emb dim] 75 | # 升维是为了和nn.Conv2d的输入维度吻合,把channel列升维。 76 | conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs] 77 | # conved = [batch size, num_filter, sent len - filter_sizes+1] 78 | # 有几个filter_sizes就有几个conved 79 | 80 | pooled = [F.max_pool1d(conv,conv.shape[2]).squeeze(2) for conv in conved] # [batch,num_filter] 81 | 82 | cat = self.dropout(torch.cat(pooled, dim=1)) 83 | # cat = [batch size, num_filter * len(filter_sizes)] 84 | # 把 len(filter_sizes)个卷积模型concate起来传到全连接层。 85 | 86 | return self.fc(cat) 87 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pkuseg 2 | import math 3 | import random 4 | from tqdm import tqdm,trange 5 | import codecs 6 | 7 | label_map = {'财经': 0, '教育': 1, '房产': 2, '娱乐': 3, '游戏': 4, 8 | '体育': 5, '时尚': 6, '科技': 7, '时政': 8, '家居': 9} 9 | def read_corpus(file_path): 10 | """读取语料 11 | :param file_path: 12 | :param type: 13 | :return: 14 | """ 15 | src_data = [] 16 | labels = [] 17 | seg = pkuseg.pkuseg() 18 | with codecs.open(file_path,'r',encoding='utf-8') as fout: 19 | for line in tqdm(fout.readlines(),desc='reading corpus'): 20 | if line is not None: 21 | pair = line.strip().split('\t') 22 | if len(pair) != 2: 23 | print(pair) 24 | continue 25 | src_data.append(seg.cut(pair[1])) 26 | labels.append(pair[0]) 27 | return (src_data, labels) 28 | 29 | def pad_sents(sents,pad_token): 30 | """pad句子""" 31 | sents_padded = [] 32 | lengths = [len(s) for s in sents] 33 | max_len = max(lengths) 34 | for sent in sents: 35 | sent_padded = sent + [pad_token] * (max_len - len(sent)) 36 | sents_padded.append(sent_padded) 37 | return sents_padded 38 | 39 | def batch_iter(data, batch_size, shuffle=False): 40 | """ 41 | batch数据 42 | :param data: list of tuple 43 | :param batch_size: 44 | :param shuffle: 45 | :return: 46 | """ 47 | batch_num = math.ceil(len(data) / batch_size) 48 | index_array = list(range(len(data))) 49 | if shuffle: 50 | random.shuffle(index_array) 51 | 52 | for i in trange(batch_num,desc='get mini_batch data'): 53 | indices = index_array[i*batch_size:(i+1)*batch_size] 54 | examples = [data[idx] for idx in indices] 55 | examples = sorted(examples,key=lambda x: len(x[1]),reverse=True) 56 | src_sents = [e[0] for e in examples] 57 | labels = [label_map[e[1]] for e in examples] 58 | 59 | yield src_sents, labels 60 | -------------------------------------------------------------------------------- /vocab.py: -------------------------------------------------------------------------------- 1 | from utils import read_corpus,pad_sents 2 | 3 | from typing import List 4 | from collections import Counter 5 | from itertools import chain 6 | import json 7 | import torch 8 | 9 | class VocabEntry(object): 10 | def __init__(self,word2id=None): 11 | """ 12 | 初始化vocabEntry 13 | :param word2id: mapping word to indices 14 | """ 15 | if word2id: 16 | self.word2id = word2id 17 | else: 18 | self.word2id = dict() 19 | self.word2id[''] = 0 20 | self.word2id[''] = 1 21 | self.unk_id = self.word2id[''] 22 | self.id2word = {v:k for k,v in self.word2id.items()} 23 | 24 | def __getitem__(self,word): 25 | """获取word的idx""" 26 | return self.word2id.get(word,self.unk_id) 27 | 28 | def __contains__(self,word): 29 | return word in self.word2id 30 | 31 | def __setitem__(self,key,value): 32 | raise ValueError('vocabulary is readonly') 33 | 34 | def __len__(self): 35 | return len(self.word2id) 36 | def __repr__(self): 37 | 38 | return 'Vocabulary[size=%d]' % (len(self.word2id)) 39 | 40 | def add(self,word): 41 | """增加word""" 42 | if word not in self.word2id: 43 | wid = self.word2id[word] = len(self.word2id) 44 | self.id2word[wid] = word 45 | return wid 46 | else: 47 | return self.word2id[word] 48 | 49 | def words2indices(self,sents): 50 | """ 51 | 将sents转为number index 52 | :param sents: list(word) or list(list(wod)) 53 | :return: 54 | """ 55 | if type(sents[0]) == list: 56 | return [[self.word2id.get(w,self.unk_id) for w in s] for s in sents] 57 | else: 58 | return [self.word2id.get(s,self.unk_id) for s in sents] 59 | 60 | def indices2words(self,idxs): 61 | return [self.id2word[id] for id in idxs] 62 | 63 | def to_input_tensor(self,sents: List[List[str]], device: torch.device): 64 | """ 65 | 将原始句子list转为tensor,同时将句子PAD成max_len 66 | :param sents: list of list 67 | :param device: 68 | :return: 69 | """ 70 | sents = self.words2indices(sents) 71 | sents = pad_sents(sents,self.word2id['']) 72 | sents_var = torch.tensor(sents,device=device) 73 | return sents_var 74 | 75 | @staticmethod 76 | def from_corpus(corpus,size,min_feq = 3): 77 | """从给定语料中创建VocabEntry""" 78 | vocab_entry = VocabEntry() 79 | word_freq = Counter(chain(*corpus)) 80 | valid_words = word_freq.most_common(size-2) 81 | valid_words = [word for word, value in valid_words if value >= min_feq] 82 | print('number of word types: {}, number of word types w/ frequency >= {}: {}' 83 | .format(len(word_freq), min_feq, len(valid_words))) 84 | for word in valid_words: 85 | vocab_entry.add(word) 86 | return vocab_entry 87 | 88 | class Vocab(object): 89 | """src、tgt的词汇类""" 90 | def __init__(self, src_vocab: VocabEntry, labels: dict): 91 | self.vocab = src_vocab 92 | self.labels = labels 93 | 94 | @staticmethod 95 | def build(src_sents,labels, vocab_size, min_feq): 96 | 97 | print('initialize source vocabulary ..') 98 | src = VocabEntry.from_corpus(src_sents,vocab_size,min_feq) 99 | 100 | return Vocab(src,labels) 101 | 102 | def save(self,file_path): 103 | with open(file_path,'w') as fint: 104 | json.dump(dict(src_word2id=self.vocab.word2id,labels=self.labels),fint,indent=2) 105 | 106 | @staticmethod 107 | def load(file_path): 108 | with open(file_path,'r') as fout: 109 | entry = json.load(fout) 110 | src_word2id = entry['src_word2id'] 111 | labels = entry['labels'] 112 | 113 | return Vocab(VocabEntry(src_word2id),labels) 114 | def __repr__(self): 115 | """ Representation of Vocab to be used 116 | when printing the object. 117 | """ 118 | return 'Vocab(source %d words)' % (len(self.vocab)) 119 | 120 | 121 | if __name__ == '__main__': 122 | 123 | src_sents,labels = read_corpus('/Users/zhoup/develop/NLPSpace/my-whole-data/cnews/cnews.train.txt') 124 | labels = {label:idx for idx,label in enumerate(labels)} 125 | 126 | vocab = Vocab.build(src_sents,labels, 50000, 3) 127 | print('generated vocabulary, source %d words' % (len(vocab.vocab))) 128 | vocab.save('./vocab.json') 129 | --------------------------------------------------------------------------------