├── .coveragerc ├── .github └── workflows │ ├── pythonpackage.yml │ └── pythonpush.yml ├── .gitignore ├── AUTHORS.rst ├── CHANGES.rst ├── LICENSE.txt ├── README.rst ├── deep_keyphrase ├── __init__.py ├── base_predictor.py ├── base_trainer.py ├── copy_cnn │ ├── __init__.py │ ├── beam_search.py │ ├── model.py │ ├── predict.py │ └── train.py ├── copy_rnn │ ├── __init__.py │ ├── beam_search.py │ ├── model.py │ ├── model_tf.py │ ├── predict.py │ ├── predict_tf.py │ ├── train.py │ └── train_tf.py ├── copy_transformer │ ├── __init__.py │ ├── beam_search.py │ ├── model.py │ ├── predict.py │ ├── train.py │ └── transformer.py ├── data_process │ ├── __init__.py │ └── preprocess.py ├── dataloader.py ├── evaluation.py ├── predict_runner.py └── utils │ ├── __init__.py │ ├── constants.py │ ├── tokenizer.py │ └── vocab_loader.py ├── docs ├── Makefile ├── _static │ └── .gitignore ├── authors.rst ├── changes.rst ├── conf.py ├── index.rst └── license.rst ├── requirements.txt ├── scripts ├── predict_kp20k.sh ├── prepare_kp20k.sh └── train_copyrnn_kp20k.sh ├── setup.cfg ├── setup.py ├── test-requirements.txt └── tests └── test_utils ├── __init__.py └── test_tokenizer.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = deep_keyphrase 5 | # omit = bad_file.py 6 | 7 | [report] 8 | # Regexes for lines to exclude from consideration 9 | exclude_lines = 10 | # Have to re-enable the standard pragma 11 | pragma: no cover 12 | 13 | # Don't complain about missing debug-only code: 14 | def __repr__ 15 | if self\.debug 16 | 17 | # Don't complain if tests don't hit defensive assertion code: 18 | raise AssertionError 19 | raise NotImplementedError 20 | 21 | # Don't complain if non-runnable code isn't run: 22 | if 0: 23 | if __name__ == .__main__.: 24 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: [push] 3 | 4 | jobs: 5 | build: 6 | 7 | runs-on: ubuntu-latest 8 | strategy: 9 | max-parallel: 3 10 | matrix: 11 | python-version: [3.5, 3.6, 3.7] 12 | 13 | steps: 14 | - uses: actions/checkout@v1 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v1 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt 23 | - name: Test with pytest 24 | run: | 25 | pip install -r test-requirements.txt 26 | pip install -e . 27 | pytest -------------------------------------------------------------------------------- /.github/workflows/pythonpush.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@master 13 | - name: Set up Python 14 | uses: actions/setup-python@master 15 | with: 16 | python-version: '3.x' 17 | - name: build package 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install setuptools wheel twine 21 | python setup.py sdist 22 | - name: push to pypi 23 | if: github.event_name == 'push' && startsWith(github.event.ref, 'refs/tags') 24 | uses: pypa/gh-action-pypi-publish@master 25 | with: 26 | password: ${{ secrets.pypi_password }} 27 | - name: create github release 28 | id: create_release 29 | uses: actions/create-release@v1 30 | env: 31 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token 32 | with: 33 | tag_name: ${{ github.ref }} 34 | release_name: Release ${{ github.ref }} 35 | body: | 36 | Changes in this Release 37 | - First Change 38 | - Second Change 39 | draft: true 40 | prerelease: false -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Temporary and binary files 2 | *~ 3 | *.py[cod] 4 | *.so 5 | *.cfg 6 | !setup.cfg 7 | *.orig 8 | *.log 9 | *.pot 10 | __pycache__/* 11 | .cache/* 12 | .*.swp 13 | */.ipynb_checkpoints/* 14 | 15 | # Project files 16 | .ropeproject 17 | .project 18 | .pydevproject 19 | .settings 20 | .idea 21 | 22 | # Package files 23 | *.egg 24 | *.eggs/ 25 | .installed.cfg 26 | *.egg-info 27 | 28 | # Unittest and coverage 29 | htmlcov/* 30 | .coverage 31 | .tox 32 | junit.xml 33 | coverage.xml 34 | 35 | # Build and docs folder/files 36 | build/* 37 | dist/* 38 | sdist/* 39 | docs/api/* 40 | docs/_build/* 41 | cover/* 42 | MANIFEST 43 | /data 44 | /runners 45 | .env 46 | Pipfile 47 | .vscode -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Developers 3 | ========== 4 | 5 | * supercoderhawk 6 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | ========= 2 | Changelog 3 | ========= 4 | Version 0.0.7 5 | ==================== 6 | * fix inference output sentence order error 7 | * continue implementing CopyCNN and CopyTransformer 8 | 9 | Version 0.0.6 10 | ==================== 11 | * update vocab loader 12 | * initialize embedding of copyrnn with uniform distribution 13 | 14 | 15 | Version 0.0.5 16 | ================== 17 | * fix beam search exception in CopyRNN 18 | 19 | Version 0.0.4 20 | ================== 21 | * fix beam search exception in CopyRNN 22 | 23 | Version 0.0.3 24 | ================== 25 | 26 | - fix github workflow push config 27 | 28 | Version 0.0.2 29 | ================== 30 | 31 | - fix requirement bug 32 | 33 | Version 0.0.1 34 | ================== 35 | 36 | - construct basic architecture of project 37 | - implement CopyRNN (total available), CopyTransformer (required to debug) 38 | 39 | 40 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2019 xiayubin 2 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | deep-keyphrase 3 | ============== 4 | 5 | 6 | Implement some keyphrase generation algorithm 7 | 8 | .. image:: https://img.shields.io/github/workflow/status/supercoderhawk/deep-keyphrase/ci.svg 9 | 10 | .. image:: https://img.shields.io/pypi/v/deep-keyphrase.svg 11 | :target: https://pypi.org/project/deep-keyphrase 12 | 13 | .. image:: https://img.shields.io/pypi/dm/deep-keyphrase.svg 14 | :target: https://pypi.org/project/pysenal 15 | 16 | 17 | Description 18 | =========== 19 | Implemented Paper 20 | >>>>>>>>>>>>>>>>>>>>> 21 | 22 | CopyRNN 23 | 24 | `Deep Keyphrase Generation (Meng et al., 2017)`__ 25 | 26 | .. __: https://arxiv.org/abs/1704.06879 27 | 28 | 29 | ToDo List 30 | >>>>>>>>>>>>>>> 31 | 32 | CopyCNN 33 | 34 | CopyTransformer 35 | 36 | 37 | Usage 38 | ============ 39 | 40 | required files (4 files in total) 41 | >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 42 | 43 | 1. vocab_file: word line by line (don't with index!!!!) :: 44 | 45 | this 46 | paper 47 | proposes 48 | 49 | 2. training, valid and test file 50 | 51 | data format for training, valid and test 52 | """""""""""""""""""""""""""""""""""""""""""""""""" 53 | json line format, every line is a dict:: 54 | 55 | {'tokens': ['this', 'paper', 'proposes', 'using', 'virtual', 'reality', 'to', 'enhance', 'the', 'perception', 'of', 'actions', 'by', 'distant', 'users', 'on', 'a', 'shared', 'application', '.', 'here', ',', 'distance', 'may', 'refer', 'either', 'to', 'space', '(', 'e.g.', 'in', 'a', 'remote', 'synchronous', 'collaboration', ')', 'or', 'time', '(', 'e.g.', 'during', 'playback', 'of', 'recorded', 'actions', ')', '.', 'our', 'approach', 'consists', 'in', 'immersing', 'the', 'application', 'in', 'a', 'virtual', 'inhabited', '3d', 'space', 'and', 'mimicking', 'user', 'actions', 'by', 'animating', 'avatars', '.', 'we', 'illustrate', 'this', 'approach', 'with', 'two', 'applications', ',', 'the', 'one', 'for', 'remote', 'collaboration', 'on', 'a', 'shared', 'application', 'and', 'the', 'other', 'to', 'playback', 'recorded', 'sequences', 'of', 'user', 'actions', '.', 'we', 'suggest', 'this', 'could', 'be', 'a', 'low', 'cost', 'enhancement', 'for', 'telepresence', '.'] , 56 | 'keyphrases': [['telepresence'], ['animation'], ['avatars'], ['application', 'sharing'], ['collaborative', 'virtual', 'environments']]} 57 | 58 | 59 | Training 60 | >>>>>>>>>>>>>>> 61 | download the kp20k_ 62 | 63 | .. _kp20k: https://drive.google.com/uc?id=1ZTQEGZSq06kzlPlOv4yGjbUpoDrNxebR&export=download 64 | 65 | :: 66 | 67 | mkdir data 68 | mkdir data/raw 69 | mkdir data/raw/kp20k_new 70 | # !! please unzip kp20k data put the files into above folder manually 71 | python -m nltk.downloader punkt 72 | bash scripts/prepare_kp20k.sh 73 | bash scripts/train_copyrnn_kp20k.sh 74 | 75 | # start tensorboard 76 | # enter the experiment result dir, suffix is time that experiment starts 77 | cd data/kp20k/copyrnn_kp20k_basic-20191212-080000 78 | # start tensorboard services 79 | tenosrboard --bind_all --logdir logs --port 6006 80 | 81 | Notes 82 | ============================= 83 | 1. compared with the original :code:`seq2seq-keyphrase-pytorch` 84 | 1. fix the implementation error: 85 | 1. copy mechanism 86 | 2. train and inference are not correspond (training doesn\'t have input feeding and inference has input feeding) 87 | 2. easy data preparing 88 | 3. tensorboard support 89 | 4. **faster beam search (6x faster used cpu and more than 10x faster used gpu)** -------------------------------------------------------------------------------- /deep_keyphrase/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pkg_resources 3 | 4 | try: 5 | __version__ = pkg_resources.get_distribution(__name__).version 6 | except: 7 | __version__ = 'unknown' 8 | -------------------------------------------------------------------------------- /deep_keyphrase/base_predictor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import json 4 | import torch 5 | from collections import namedtuple, OrderedDict 6 | from pysenal import read_file 7 | 8 | 9 | class BasePredictor(object): 10 | def __init__(self, model_info): 11 | self.config = self.load_config(model_info) 12 | 13 | def load_config(self, model_info): 14 | if 'config' not in model_info: 15 | if isinstance(model_info['model'], str): 16 | config_path = os.path.splitext(model_info['model'])[0] + '.json' 17 | else: 18 | raise ValueError('config path is not assigned') 19 | else: 20 | config_info = model_info['config'] 21 | if isinstance(config_info, str): 22 | config_path = config_info 23 | else: 24 | return config_info 25 | # json to object 26 | config = json.loads(read_file(config_path), 27 | object_hook=lambda d: namedtuple('X', d.keys())(*d.values())) 28 | return config 29 | 30 | def load_model(self, model_info, model): 31 | if isinstance(model_info['model'], torch.nn.Module): 32 | return model_info['model'] 33 | 34 | model_path = model_info['model'] 35 | if not isinstance(model_path, str): 36 | raise TypeError('model path should be str') 37 | # model = load_model_func() 38 | if torch.cuda.is_available(): 39 | checkpoint = torch.load(model_path) 40 | else: 41 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 42 | state_dict = OrderedDict() 43 | # avoid error when load parallel trained model 44 | for k, v in checkpoint.items(): 45 | if k.startswith('module.'): 46 | k = k[7:] 47 | state_dict[k] = v 48 | model.load_state_dict(state_dict) 49 | if torch.cuda.is_available(): 50 | model = model.cuda() 51 | return model 52 | 53 | def predict(self, input_list, batch_size, delimiter=''): 54 | raise NotImplementedError('predict method is not implemented') 55 | -------------------------------------------------------------------------------- /deep_keyphrase/base_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import time 3 | import traceback 4 | import logging 5 | import os 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.tensorboard import SummaryWriter 10 | from pysenal import write_json, get_logger 11 | from deep_keyphrase.utils.vocab_loader import load_vocab 12 | from deep_keyphrase.dataloader import KeyphraseDataLoader 13 | from deep_keyphrase.evaluation import KeyphraseEvaluator 14 | from deep_keyphrase.utils.constants import PAD_WORD 15 | 16 | 17 | class BaseTrainer(object): 18 | def __init__(self, args, model): 19 | torch.manual_seed(0) 20 | torch.autograd.set_detect_anomaly(True) 21 | self.args = args 22 | self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size) 23 | 24 | self.model = model 25 | if torch.cuda.is_available(): 26 | self.model = self.model.cuda() 27 | if args.train_parallel: 28 | self.model = nn.DataParallel(self.model) 29 | self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD]) 30 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.learning_rate) 31 | self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, 32 | self.args.schedule_step, 33 | self.args.schedule_gamma) 34 | self.logger = get_logger('train') 35 | self.train_loader = KeyphraseDataLoader(data_source=self.args.train_filename, 36 | vocab2id=self.vocab2id, 37 | mode='train', 38 | args=args) 39 | if self.args.train_from: 40 | self.dest_dir = os.path.dirname(self.args.train_from) + '/' 41 | else: 42 | timemark = time.strftime('%Y%m%d-%H%M%S', time.localtime(time.time())) 43 | self.dest_dir = os.path.join(self.args.dest_base_dir, self.args.exp_name + '-' + timemark) + '/' 44 | os.mkdir(self.dest_dir) 45 | 46 | fh = logging.FileHandler(os.path.join(self.dest_dir, args.logfile)) 47 | fh.setLevel(logging.INFO) 48 | fh.setFormatter(logging.Formatter('[%(asctime)s] %(message)s')) 49 | self.logger.addHandler(fh) 50 | 51 | if not self.args.tensorboard_dir: 52 | tensorboard_dir = self.dest_dir + 'logs/' 53 | else: 54 | tensorboard_dir = self.args.tensorboard_dir 55 | self.writer = SummaryWriter(tensorboard_dir) 56 | self.eval_topn = (5, 10) 57 | self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro', 58 | args.token_field, args.keyphrase_field) 59 | self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro', 60 | args.token_field, args.keyphrase_field) 61 | self.best_f1 = None 62 | self.best_step = 0 63 | self.not_update_count = 0 64 | 65 | def parse_args(self, args=None): 66 | raise NotImplementedError('build_parser is not implemented') 67 | 68 | def train(self): 69 | try: 70 | self.train_func() 71 | except KeyboardInterrupt: 72 | self.logger.info('you terminate the train logic') 73 | except Exception: 74 | self.logger.error('exception occurred') 75 | err_stack = traceback.format_exc() 76 | self.logger.error(err_stack) 77 | finally: 78 | # terminate the loader processes 79 | del self.train_loader 80 | 81 | def train_func(self): 82 | step = 0 83 | is_stop = False 84 | if self.args.train_from: 85 | step = self.args.step 86 | self.logger.info('train from destination dir:{}'.format(self.dest_dir)) 87 | self.logger.info('train from step {}'.format(step)) 88 | else: 89 | self.logger.info('destination dir:{}'.format(self.dest_dir)) 90 | for epoch in range(1, self.args.epochs + 1): 91 | self.model.train() 92 | for batch_idx, batch in enumerate(self.train_loader): 93 | try: 94 | loss = self.train_batch(batch, step) 95 | except Exception as e: 96 | err_stack = traceback.format_exc() 97 | self.logger.error(err_stack) 98 | loss = 0.0 99 | step += 1 100 | self.writer.add_scalar('loss', loss, step) 101 | del loss 102 | if step and step % self.args.save_model_step == 0: 103 | torch.cuda.empty_cache() 104 | self.evaluate_and_save_model(step, epoch) 105 | torch.cuda.empty_cache() 106 | if self.not_update_count >= self.args.early_stop_tolerance: 107 | is_stop = True 108 | break 109 | if is_stop: 110 | self.logger.info('best step {}'.format(self.best_step)) 111 | break 112 | 113 | def train_batch(self, batch, step): 114 | raise NotImplementedError('train method is not implemented') 115 | 116 | def evaluate_stage(self, step, stage, predict_callback): 117 | if stage == 'valid': 118 | src_filename = self.args.valid_filename 119 | elif stage == 'test': 120 | src_filename = self.args.test_filename 121 | else: 122 | raise ValueError('stage name error, must be in `valid` and `test`') 123 | pred_filename = self.dest_dir + self.get_basename(src_filename) 124 | pred_filename += '.batch_{}.pred.jsonl'.format(step) 125 | torch.cuda.empty_cache() 126 | predict_callback() 127 | torch.cuda.empty_cache() 128 | macro_all_ret = self.macro_evaluator.evaluate(pred_filename) 129 | macro_present_ret = self.macro_evaluator.evaluate(pred_filename, 'present') 130 | macro_absent_ret = self.macro_evaluator.evaluate(pred_filename, 'absent') 131 | 132 | for n, counter in macro_all_ret.items(): 133 | for k, v in counter.items(): 134 | name = '{}/macro_{}@{}'.format(stage, k, n) 135 | self.writer.add_scalar(name, v, step) 136 | for n in self.eval_topn: 137 | name = 'present/{} macro_f1@{}'.format(stage, n) 138 | self.writer.add_scalar(name, macro_present_ret[n]['f1'], step) 139 | for n in self.eval_topn: 140 | absent_f1_name = 'absent/{} macro_f1@{}'.format(stage, n) 141 | self.writer.add_scalar(absent_f1_name, macro_absent_ret[n]['f1'], step) 142 | absent_recall_name = 'absent/{} macro_recall@{}'.format(stage, n) 143 | self.writer.add_scalar(absent_recall_name, macro_absent_ret[n]['recall'], step) 144 | 145 | statistics = {'{}_macro'.format(stage): macro_all_ret, 146 | '{}_macro_present'.format(stage): macro_present_ret, 147 | '{}_macro_absent'.format(stage): macro_absent_ret} 148 | return statistics 149 | 150 | def evaluate_and_save_model(self, step, epoch): 151 | valid_f1 = self.evaluate(step) 152 | if self.best_f1 is None: 153 | self.best_f1 = valid_f1 154 | self.best_step = step 155 | elif valid_f1 >= self.best_f1: 156 | self.best_f1 = valid_f1 157 | self.not_update_count = 0 158 | self.best_step = step 159 | else: 160 | self.not_update_count += 1 161 | exp_name = self.args.exp_name 162 | model_basename = self.dest_dir + '{}_epoch_{}_batch_{}'.format(exp_name, epoch, step) 163 | torch.save(self.model.state_dict(), model_basename + '.model') 164 | write_json(model_basename + '.json', vars(self.args)) 165 | score_msg_tmpl = 'best score: step {} macro f1@{} {:.4f}' 166 | self.logger.info(score_msg_tmpl.format(self.best_step, self.eval_topn[-1], self.best_f1)) 167 | self.logger.info('epoch {} step {}, model saved'.format(epoch, step)) 168 | 169 | def evaluate(self, step): 170 | raise NotImplementedError('evaluate method is not implemented') 171 | 172 | def get_basename(self, filename): 173 | return os.path.splitext(os.path.basename(filename))[0] 174 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_cnn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /deep_keyphrase/copy_cnn/beam_search.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | 4 | class CopyCnnBeamSearch(object): 5 | def __init__(self): 6 | pass 7 | 8 | def beam_search(self): 9 | pass 10 | 11 | def greedy_search(self): 12 | pass 13 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_cnn/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from deep_keyphrase.dataloader import (TOKENS, TOKENS_LENS, TARGET) 6 | 7 | 8 | class Attention(nn.Module): 9 | """ 10 | 11 | """ 12 | 13 | def __init__(self, dim_size): 14 | super().__init__() 15 | self.in_proj = nn.Linear(dim_size, dim_size) 16 | 17 | def forward(self, x, target_embedding, encoder_input, encoder_output, encoder_mask): 18 | pass 19 | 20 | 21 | class CopyCnn(nn.Module): 22 | def __init__(self, args, vocab2id): 23 | super().__init__() 24 | self.args = args 25 | self.vocab2id = vocab2id 26 | self.embedding = nn.Embedding(len(vocab2id), args.dim_size) 27 | self.encoder = CopyCnnEncoder(vocab2id=vocab2id, embedding=self.embedding, args=args) 28 | self.decoder = CopyCnnDecoder(vocab2id=vocab2id, embedding=self.embedding, args=args) 29 | 30 | def forward(self, src_dict, encoder_output): 31 | if encoder_output is None: 32 | encoder_output = self.encoder(src_dict) 33 | 34 | 35 | class CopyCnnEncoder(nn.Module): 36 | def __init__(self, vocab2id, embedding, args): 37 | super().__init__() 38 | self.vocab2id = vocab2id 39 | self.embedding = embedding 40 | self.args = args 41 | self.dim_size = args.dim_size 42 | self.kernel_size = (args.kernal_width, self.dim_size) 43 | self.dropout = args.dropout 44 | self.convolution_layers = [] 45 | for i in range(args.encoder_layer_num): 46 | layer = nn.Conv2d(in_channels=1, out_channels=2 * self.dim_size, 47 | kernel_size=self.kernel_size, bias=True) 48 | self.convolution_layers.append(layer) 49 | 50 | def forward(self, src_dict): 51 | tokens = src_dict[TOKENS] 52 | x = self.embedding(tokens).unsqueeze(1) 53 | # x = tokens.unsqueeze(1) 54 | layer_output = [x] 55 | for layer in self.convolution_layers: 56 | x = F.dropout(x, p=self.dropout, training=self.training) 57 | x = layer(x) 58 | x = F.glu(x, dim=1) + layer_output[-1] 59 | layer_output.append(x) 60 | return x 61 | 62 | 63 | class CopyCnnDecoder(nn.Module): 64 | def __init__(self, vocab2id, embedding, args): 65 | super().__init__() 66 | self.vocab2id = vocab2id 67 | self.embedding = embedding 68 | self.args = args 69 | self.vocab_size = self.args.vocab_size 70 | self.max_oov_count = self.args.max_oov_count 71 | self.total_vocab_size = self.vocab_size + self.max_oov_count 72 | self.dim_size = args.dim_size 73 | self.kernel_size = (args.kernal_width, self.dim_size) 74 | self.dropout = args.dropout 75 | self.convolution_layers = [] 76 | self.attn_linear_layers = [] 77 | self.decoder_layer_num = args.decoder_layer_num 78 | for i in range(self.decoder_layer_num): 79 | conv_layer = nn.Conv2d(in_channels=1, out_channels=2 * self.dim_size, 80 | kernel_size=self.kernel_size, bias=True) 81 | self.convolution_layers.append(conv_layer) 82 | attn_linear_layer = nn.Linear(self.dim_size, self.dim_size, bias=True) 83 | self.attn_linear_layers.append(attn_linear_layer) 84 | self.generate_proj = nn.Linear(self.dim_size, self.vocab_size) 85 | self.copy_proj = nn.Linear(self.dim_size, self.total_vocab_size) 86 | 87 | def forward(self, src_dict, prev_tokens, encoder_output): 88 | """ 89 | 90 | :param src_dict: 91 | :param prev_tokens: 92 | :param encoder_output: 93 | :return: 94 | """ 95 | src_tokens = src_dict[TOKENS] 96 | tokens = src_dict[TARGET][:, :-1] 97 | x = self.embedding(tokens).unsqueeze(1) 98 | prev_x = self.embedding(prev_tokens) 99 | src_x = self.embedding(src_tokens) 100 | layer_output = [x] 101 | for conv_layer, linear_layer in zip(self.convolution_layers, self.attn_linear_layers): 102 | x = F.dropout(x, p=self.dropout, training=self.training) 103 | x = conv_layer(x) 104 | x = F.glu(x, dim=1) + layer_output[-1] 105 | # attention 106 | d = linear_layer(x) + prev_x 107 | attn_weights = torch.softmax(torch.bmm(encoder_output, d.unsqueeze(2)), dim=1) 108 | c = attn_weights * (encoder_output + src_x) 109 | # residual connection 110 | final_output = x + c.squeeze(2) 111 | layer_output.append(final_output) 112 | generate_logits = self.generate_proj(layer_output[-1]) 113 | 114 | def forward_one_pass(self): 115 | pass 116 | 117 | def forward_auto_regressive(self): 118 | pass 119 | 120 | def get_attn_read(self, encoder_output, src_tokens_with_oov, decoder_output, encoder_output_mask): 121 | pass 122 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_cnn/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from deep_keyphrase.base_predictor import BasePredictor 3 | 4 | 5 | class CopyCnnPredictor(BasePredictor): 6 | def __init__(self, model_info): 7 | super().__init__(model_info) 8 | 9 | def predict(self, input_list, batch_size, delimiter=''): 10 | pass 11 | 12 | def eval_predict(self, src_filename, dest_filename, args, 13 | model=None, remove_existed=False): 14 | pass 15 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_cnn/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import argparse 4 | import torch 5 | from collections import OrderedDict 6 | from munch import Munch 7 | from pysenal import read_json 8 | from deep_keyphrase.base_trainer import BaseTrainer 9 | from deep_keyphrase.utils.vocab_loader import load_vocab 10 | from deep_keyphrase.copy_cnn.model import CopyCnn 11 | 12 | 13 | class CopyCnnTrainer(BaseTrainer): 14 | def __init__(self): 15 | self.args = self.parse_args() 16 | self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size) 17 | model = self.load_model() 18 | super().__init__(self.args, model) 19 | 20 | def load_model(self): 21 | if not self.args.train_from: 22 | model = CopyCnn(self.args, self.vocab2id) 23 | else: 24 | model_path = self.args.train_from 25 | config_path = os.path.join(os.path.dirname(model_path), 26 | self.get_basename(model_path) + '.json') 27 | 28 | old_config = read_json(config_path) 29 | old_config['train_from'] = model_path 30 | old_config['step'] = int(model_path.rsplit('_', 1)[-1].split('.')[0]) 31 | self.args = Munch(old_config) 32 | self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size) 33 | 34 | model = CopyCnn(self.args, self.vocab2id) 35 | 36 | if torch.cuda.is_available(): 37 | checkpoint = torch.load(model_path) 38 | else: 39 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 40 | state_dict = OrderedDict() 41 | # avoid error when load parallel trained model 42 | for k, v in checkpoint.items(): 43 | if k.startswith('module.'): 44 | k = k[7:] 45 | state_dict[k] = v 46 | model.load_state_dict(state_dict) 47 | 48 | return model 49 | 50 | def train_batch(self, batch, step): 51 | self.model.train() 52 | loss = 0 53 | self.optimizer.zero_grad() 54 | 55 | def evaluate(self, step): 56 | pass 57 | 58 | def parse_args(self, args=None): 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument("-exp_name", required=True, type=str, help='') 61 | parser.add_argument("-train_filename", required=True, type=str, help='') 62 | parser.add_argument("-valid_filename", required=True, type=str, help='') 63 | parser.add_argument("-test_filename", required=True, type=str, help='') 64 | parser.add_argument("-dest_base_dir", required=True, type=str, help='') 65 | parser.add_argument("-vocab_path", required=True, type=str, help='') 66 | parser.add_argument("-vocab_size", type=int, default=500000, help='') 67 | parser.add_argument("-train_from", default='', type=str, help='') 68 | parser.add_argument("-token_field", default='tokens', type=str, help='') 69 | parser.add_argument("-keyphrase_field", default='keyphrases', type=str, help='') 70 | # parser.add_argument("-auto_regressive", action='store_true', help='') 71 | parser.add_argument("-epochs", type=int, default=10, help='') 72 | parser.add_argument("-batch_size", type=int, default=64, help='') 73 | parser.add_argument("-learning_rate", type=float, default=1e-4, help='') 74 | parser.add_argument("-eval_batch_size", type=int, default=50, help='') 75 | parser.add_argument("-dropout", type=float, default=0.0, help='') 76 | parser.add_argument("-grad_norm", type=float, default=0.0, help='') 77 | parser.add_argument("-max_grad", type=float, default=5.0, help='') 78 | parser.add_argument("-shuffle", action='store_true', help='') 79 | # parser.add_argument("-teacher_forcing", action='store_true', help='') 80 | parser.add_argument("-beam_size", type=float, default=50, help='') 81 | parser.add_argument('-tensorboard_dir', type=str, default='', help='') 82 | parser.add_argument('-logfile', type=str, default='train_log.log', help='') 83 | parser.add_argument('-save_model_step', type=int, default=5000, help='') 84 | parser.add_argument('-early_stop_tolerance', type=int, default=100, help='') 85 | parser.add_argument('-train_parallel', action='store_true', help='') 86 | # parser.add_argument('-schedule_lr', action='store_true', help='') 87 | # parser.add_argument('-schedule_step', type=int, default=100000, help='') 88 | # parser.add_argument('-schedule_gamma', type=float, default=0.5, help='') 89 | # parser.add_argument('-processed', action='store_true', help='') 90 | parser.add_argument('-prefetch', action='store_true', help='') 91 | 92 | parser.add_argument('-dim_size', type=int, default=100, help='') 93 | parser.add_argument('-kernel_width', type=int, default=5, help='') 94 | parser.add_argument('-encoder_layer_num', type=int, default=6, help='') 95 | parser.add_argument('-decoder_layer_num', type=int, default=6, help='') 96 | 97 | args = parser.parse_args(args) 98 | return args 99 | 100 | 101 | if __name__ == '__main__': 102 | CopyCnnTrainer().train() 103 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/beam_search.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | from deep_keyphrase.dataloader import (TOKENS, TOKENS_LENS, TOKENS_OOV, 4 | OOV_COUNT, OOV_LIST, EOS_WORD, UNK_WORD) 5 | 6 | 7 | class BeamSearch(object): 8 | def __init__(self, model, beam_size, max_target_len, id2vocab, bos_idx, unk_idx, args): 9 | self.model = model 10 | self.beam_size = beam_size 11 | self.id2vocab = id2vocab 12 | self.vocab_size = len(self.id2vocab) 13 | self.max_target_len = max_target_len 14 | self.bos_idx = bos_idx 15 | self.unk_idx = unk_idx 16 | self.target_hidden_size = args.target_hidden_size 17 | 18 | def beam_search(self, src_dict, delimiter=None): 19 | """ 20 | generate beam search result 21 | main idea: inference input Batch x beam size, select 22 | :param src_dict: 23 | :param delimiter: 24 | :return: 25 | """ 26 | oov_list = src_dict[OOV_LIST] 27 | batch_size = len(src_dict[TOKENS]) 28 | encoder_output_dict = None 29 | hidden_state = None 30 | beam_batch_size = self.beam_size * batch_size 31 | prev_output_tokens = torch.tensor([[self.bos_idx]] * batch_size, dtype=torch.int64) 32 | decoder_state = torch.zeros(batch_size, self.target_hidden_size) 33 | 34 | if torch.cuda.is_available(): 35 | prev_output_tokens = prev_output_tokens.cuda() 36 | decoder_state = decoder_state.cuda() 37 | 38 | # get BOS output and repeat to beam size 39 | model_output = self.model(src_dict=src_dict, 40 | prev_output_tokens=prev_output_tokens, 41 | encoder_output_dict=encoder_output_dict, 42 | prev_decoder_state=decoder_state, 43 | prev_hidden_state=hidden_state) 44 | decoder_prob, encoder_output_dict, decoder_state, hidden_state = model_output 45 | prev_best_probs, prev_best_index = torch.topk(decoder_prob, self.beam_size, 1) 46 | # map oov token to unk 47 | oov_token_mask = prev_best_index >= self.vocab_size 48 | prev_best_index.masked_fill_(oov_token_mask, self.unk_idx) 49 | # B*b x TH 50 | prev_decoder_state = decoder_state.unsqueeze(1).repeat(1, self.beam_size, 1) 51 | prev_decoder_state = prev_decoder_state.view(beam_batch_size, -1) 52 | hidden_state[0] = hidden_state[0].unsqueeze(2).repeat(1, 1, self.beam_size, 1) 53 | hidden_state[0] = hidden_state[0].view(-1, beam_batch_size, hidden_state[0].size(-1)) 54 | hidden_state[1] = hidden_state[1].unsqueeze(2).repeat(1, 1, self.beam_size, 1) 55 | hidden_state[1] = hidden_state[1].view(-1, beam_batch_size, hidden_state[1].size(-1)) 56 | prev_hidden_state = hidden_state 57 | 58 | result_sequences = prev_best_index.unsqueeze(2) 59 | encoder_output_dict = self.expand_encoder_output(encoder_output_dict, batch_size) 60 | # beam search best probability, its the opposite number of real log probability 61 | # B x b 62 | beam_search_best_probs = torch.abs(prev_best_probs) 63 | 64 | for k in [TOKENS, TOKENS_OOV]: 65 | src_dict[k] = src_dict[k].unsqueeze(1).repeat(1, self.beam_size, 1).reshape(beam_batch_size, -1) 66 | 67 | for k in [TOKENS_LENS, OOV_COUNT]: 68 | src_dict[k] = src_dict[k].unsqueeze(1).repeat(1, self.beam_size).flatten() 69 | 70 | for target_idx in range(1, self.max_target_len): 71 | model_output = self.model(src_dict=src_dict, 72 | prev_output_tokens=prev_best_index.view(-1, 1), 73 | encoder_output_dict=encoder_output_dict, 74 | prev_decoder_state=prev_decoder_state, 75 | prev_hidden_state=prev_hidden_state) 76 | decoder_prob, encoder_output_dict, decoder_state, prev_hidden_state = model_output 77 | # accumulated_probs: B x b*V 78 | accumulated_probs = beam_search_best_probs.view(beam_batch_size, -1) 79 | accumulated_probs = accumulated_probs.repeat(1, decoder_prob.size(1)) 80 | accumulated_probs += torch.abs(decoder_prob) 81 | accumulated_probs = accumulated_probs.view(batch_size, -1) 82 | top_token_probs, top_token_index = torch.topk(-accumulated_probs, self.beam_size, 1) 83 | beam_search_best_probs = -top_token_probs 84 | 85 | select_idx_factor = torch.tensor(range(batch_size)) * self.beam_size 86 | select_idx_factor = select_idx_factor.unsqueeze(1).repeat(1, self.beam_size) 87 | if torch.cuda.is_available(): 88 | select_idx_factor = select_idx_factor.cuda() 89 | state_select_idx = top_token_index.flatten() // decoder_prob.size(1) 90 | state_select_idx += select_idx_factor.flatten() 91 | 92 | prev_decoder_state = decoder_state.index_select(0, state_select_idx) 93 | prev_best_index = top_token_index % decoder_prob.size(1) 94 | prev_hidden_state[0] = prev_hidden_state[0].index_select(1, state_select_idx) 95 | prev_hidden_state[1] = prev_hidden_state[1].index_select(1, state_select_idx) 96 | # map oov token to unk 97 | oov_token_mask = prev_best_index >= self.vocab_size 98 | prev_best_index.masked_fill_(oov_token_mask, self.unk_idx) 99 | 100 | result_sequences = result_sequences.view(batch_size * self.beam_size, -1) 101 | result_sequences = result_sequences.index_select(0, state_select_idx) 102 | result_sequences = result_sequences.view(batch_size, self.beam_size, -1) 103 | 104 | result_sequences = torch.cat([result_sequences, prev_best_index.unsqueeze(2)], dim=2) 105 | prev_best_index = prev_best_index.view(beam_batch_size, -1) 106 | 107 | if torch.cuda.is_available(): 108 | result_sequences = result_sequences.cpu().numpy().tolist() 109 | else: 110 | result_sequences = result_sequences.numpy().tolist() 111 | for k in [TOKENS, TOKENS_OOV, TOKENS_LENS, OOV_COUNT]: 112 | src_dict[k] = src_dict[k].narrow(0, 0, batch_size) 113 | 114 | results = self.__idx2result_beam(delimiter, oov_list, result_sequences) 115 | return results 116 | 117 | def __idx2result_beam(self, delimiter, oov_list, result_sequences): 118 | results = [] 119 | for batch_idx, batch in enumerate(result_sequences): 120 | beam_list = [] 121 | item_oov_list = oov_list[batch_idx] 122 | for beam in batch: 123 | phrase = [] 124 | for idx in beam: 125 | if self.id2vocab.get(idx) == EOS_WORD: 126 | break 127 | if idx in self.id2vocab: 128 | phrase.append(self.id2vocab[idx]) 129 | else: 130 | oov_idx = idx - len(self.id2vocab) 131 | if oov_idx < len(item_oov_list): 132 | phrase.append(item_oov_list[oov_idx]) 133 | else: 134 | phrase.append(UNK_WORD) 135 | 136 | if delimiter is not None: 137 | phrase = delimiter.join(phrase) 138 | if phrase not in beam_list: 139 | beam_list.append(phrase) 140 | results.append(beam_list) 141 | return results 142 | 143 | def expand_encoder_output(self, encoder_output_dict, batch_size): 144 | beam_batch_size = batch_size * self.beam_size 145 | encoder_output = encoder_output_dict['encoder_output'] 146 | encoder_mask = encoder_output_dict['encoder_padding_mask'] 147 | encoder_hidden_state = encoder_output_dict['encoder_hidden'] 148 | max_len = encoder_output.size(-2) 149 | hidden_size = encoder_hidden_state[0].size(-1) 150 | encoder_output = encoder_output.unsqueeze(1).repeat(1, self.beam_size, 1, 1) 151 | encoder_output = encoder_output.reshape(beam_batch_size, max_len, -1) 152 | encoder_mask = encoder_mask.unsqueeze(1).repeat(1, self.beam_size, 1) 153 | encoder_mask = encoder_mask.reshape(beam_batch_size, -1) 154 | encoder_hidden_state0 = encoder_hidden_state[0].unsqueeze(2).repeat(1, 1, self.beam_size, 1) 155 | encoder_hidden_state0 = encoder_hidden_state0.reshape(-1, beam_batch_size, hidden_size) 156 | encoder_hidden_state1 = encoder_hidden_state[1].unsqueeze(2).repeat(1, 1, self.beam_size, 1) 157 | encoder_hidden_state1 = encoder_hidden_state1.reshape(-1, beam_batch_size, hidden_size) 158 | encoder_output_dict['encoder_output'] = encoder_output 159 | encoder_output_dict['encoder_padding_mask'] = encoder_mask 160 | encoder_output_dict['encoder_hidden'] = [encoder_hidden_state0, encoder_hidden_state1] 161 | return encoder_output_dict 162 | 163 | def greedy_search(self, src_dict, delimiter=None): 164 | """ 165 | 166 | :param src_dict: 167 | :param delimiter: 168 | :return: 169 | """ 170 | oov_list = src_dict[OOV_LIST] 171 | batch_size = len(src_dict[TOKENS]) 172 | encoder_output_dict = None 173 | hidden_state = None 174 | prev_output_tokens = [[self.bos_idx]] * batch_size 175 | decoder_state = torch.zeros(batch_size, self.model.decoder.target_hidden_size) 176 | result_seqs = None 177 | 178 | for target_idx in range(self.max_target_len): 179 | model_output = self.model(src_dict=src_dict, 180 | prev_output_tokens=prev_output_tokens, 181 | encoder_output_dict=encoder_output_dict, 182 | prev_decoder_state=decoder_state, 183 | prev_hidden_state=hidden_state) 184 | decoder_prob, encoder_output_dict, decoder_state, hidden_state = model_output 185 | best_probs, best_indices = torch.topk(decoder_prob, 1, dim=1) 186 | if result_seqs is None: 187 | result_seqs = best_indices 188 | else: 189 | result_seqs = torch.cat([result_seqs, best_indices], dim=1) 190 | prev_output_tokens = result_seqs[:, -1].unsqueeze(1) 191 | result = self.__idx2result_greedy(delimiter, oov_list, result_seqs) 192 | 193 | return result 194 | 195 | def __idx2result_greedy(self, delimiter, oov_list, result_seqs): 196 | result = [] 197 | for batch_idx, batch in enumerate(result_seqs.numpy().tolist()): 198 | item_oov_list = oov_list[batch_idx] 199 | phrase = [] 200 | for idx in batch: 201 | if self.id2vocab.get(idx) == EOS_WORD: 202 | break 203 | if idx in self.id2vocab: 204 | phrase.append(self.id2vocab[idx]) 205 | else: 206 | oov_idx = idx - len(self.id2vocab) 207 | if oov_idx < len(item_oov_list): 208 | phrase.append(item_oov_list[oov_idx]) 209 | else: 210 | phrase.append(UNK_WORD) 211 | if delimiter is not None: 212 | phrase = delimiter.join(phrase) 213 | result.append(phrase) 214 | return result 215 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from deep_keyphrase.utils.constants import * 6 | from deep_keyphrase.dataloader import TOKENS, TOKENS_OOV, TOKENS_LENS, OOV_COUNT, TARGET 7 | 8 | 9 | class Attention(nn.Module): 10 | """ 11 | implement attention mechanism 12 | """ 13 | 14 | def __init__(self, input_dim, output_dim, score_mode='general'): 15 | super().__init__() 16 | self.input_dim = input_dim 17 | self.output_dim = output_dim 18 | self.score_mode = score_mode 19 | if self.score_mode == 'general': 20 | self.attn = nn.Linear(self.output_dim, self.input_dim, bias=False) 21 | elif self.score_mode == 'concat': 22 | self.query_proj = nn.Linear(self.output_dim, self.output_dim, bias=False) 23 | self.key_proj = nn.Linear(self.input_dim, self.output_dim, bias=False) 24 | self.concat_proj = nn.Linear(self.output_dim, 1) 25 | elif self.score_mode == 'dot': 26 | if self.input_dim != self.output_dim: 27 | raise ValueError('input and output dim must be equal when attention score mode is dot') 28 | else: 29 | raise ValueError('attention score mode error') 30 | self.output_proj = nn.Linear(self.input_dim + self.output_dim, self.output_dim) 31 | 32 | def score(self, query, key, encoder_padding_mask): 33 | """ 34 | 35 | :param query: 36 | :param key: 37 | :param encoder_padding_mask: 38 | :return: 39 | """ 40 | tgt_len = query.size(1) 41 | src_len = key.size(1) 42 | if self.score_mode == 'general': 43 | attn_weights = torch.bmm(self.attn(query), key.permute(0, 2, 1)) 44 | elif self.score_mode == 'concat': 45 | query_w = self.query_proj(query.unsqueeze(2).repeat(1, 1, src_len, 1)) 46 | key_w = self.key_proj(key.unsqueeze(1).repeat(1, tgt_len, 1, 1)) 47 | score = torch.tanh(query_w + key_w) 48 | attn_weights = self.concat_proj(score) 49 | attn_weights = torch.squeeze(attn_weights, 3) 50 | elif self.score_mode == 'dot': 51 | attn_weights = torch.bmm(query, key.permute(0, 2, 1)) 52 | 53 | # mask input padding to -Inf, they will be zero after softmax. 54 | if encoder_padding_mask is not None: 55 | encoder_padding_mask = encoder_padding_mask.unsqueeze(1).repeat(1, tgt_len, 1) 56 | attn_weights.masked_fill_(encoder_padding_mask, float('-inf')) 57 | attn_weights = torch.softmax(attn_weights, 2) 58 | return attn_weights 59 | 60 | def forward(self, decoder_output, encoder_outputs, encoder_padding_mask): 61 | """ 62 | 63 | :param decoder_output: B x tgt_dim 64 | :param encoder_outputs: B x L x src_dim 65 | :param encoder_padding_mask: 66 | :return: 67 | """ 68 | attn_weights = self.score(decoder_output, encoder_outputs, encoder_padding_mask) 69 | context_embed = torch.bmm(attn_weights, encoder_outputs) 70 | attn_outputs = torch.tanh(self.output_proj(torch.cat([context_embed, decoder_output], dim=2))) 71 | return attn_outputs, attn_weights 72 | 73 | 74 | class CopyRNN(nn.Module): 75 | """ 76 | Abbreviation Noting: 77 | B: batch size 78 | L: source max len 79 | SH: source hidden size 80 | TH: target hidden size 81 | GV: generative vocab size 82 | V: total vocab size (generative vocab size and copy vocab size) 83 | """ 84 | 85 | def __init__(self, args, vocab2id): 86 | super().__init__() 87 | src_hidden_size = args.src_hidden_size 88 | target_hidden_size = args.target_hidden_size 89 | embed_size = args.embed_size 90 | embedding = nn.Embedding(len(vocab2id), embed_size, padding_idx=vocab2id[PAD_WORD]) 91 | nn.init.uniform_(embedding.weight, -0.1, 0.1) 92 | self.encoder = CopyRnnEncoder(vocab2id=vocab2id, 93 | embedding=embedding, 94 | hidden_size=src_hidden_size, 95 | bidirectional=args.bidirectional, 96 | dropout=args.dropout) 97 | if args.bidirectional: 98 | decoder_src_hidden_size = 2 * src_hidden_size 99 | else: 100 | decoder_src_hidden_size = src_hidden_size 101 | self.decoder = CopyRnnDecoder(vocab2id=vocab2id, embedding=embedding, args=args) 102 | if decoder_src_hidden_size != target_hidden_size: 103 | self.encoder2decoder_state = nn.Linear(decoder_src_hidden_size, target_hidden_size) 104 | self.encoder2decoder_cell = nn.Linear(decoder_src_hidden_size, target_hidden_size) 105 | 106 | def forward(self, src_dict, prev_output_tokens, encoder_output_dict, 107 | prev_decoder_state, prev_hidden_state): 108 | """ 109 | 110 | :param src_dict: 111 | :param prev_output_tokens: 112 | :param encoder_output_dict: 113 | :param prev_decoder_state: 114 | :param prev_hidden_state: 115 | :return: 116 | """ 117 | if torch.cuda.is_available(): 118 | src_dict[TOKENS] = src_dict[TOKENS].cuda() 119 | src_dict[TOKENS_LENS] = src_dict[TOKENS_LENS].cuda() 120 | src_dict[TOKENS_OOV] = src_dict[TOKENS_OOV].cuda() 121 | src_dict[OOV_COUNT] = src_dict[OOV_COUNT].cuda() 122 | if prev_output_tokens is not None: 123 | prev_output_tokens = prev_output_tokens.cuda() 124 | prev_decoder_state = prev_decoder_state.cuda() 125 | if encoder_output_dict is None: 126 | encoder_output_dict = self.encoder(src_dict) 127 | prev_hidden_state = encoder_output_dict['encoder_hidden'] 128 | prev_hidden_state[0] = self.encoder2decoder_state(prev_hidden_state[0]) 129 | prev_hidden_state[1] = self.encoder2decoder_cell(prev_hidden_state[1]) 130 | 131 | decoder_prob, prev_decoder_state, prev_hidden_state = self.decoder( 132 | src_dict=src_dict, 133 | prev_output_tokens=prev_output_tokens, 134 | encoder_output_dict=encoder_output_dict, 135 | prev_context_state=prev_decoder_state, 136 | prev_rnn_state=prev_hidden_state) 137 | return decoder_prob, encoder_output_dict, prev_decoder_state, prev_hidden_state 138 | 139 | 140 | class CopyRnnEncoder(nn.Module): 141 | def __init__(self, vocab2id, embedding, hidden_size, 142 | bidirectional, dropout): 143 | super().__init__() 144 | embed_dim = embedding.embedding_dim 145 | self.embed_dim = embed_dim 146 | self.embedding = embedding 147 | self.hidden_size = hidden_size 148 | self.bidirectional = bidirectional 149 | self.num_layers = 1 150 | self.pad_idx = vocab2id[PAD_WORD] 151 | self.dropout = dropout 152 | self.lstm = nn.LSTM( 153 | input_size=embed_dim, 154 | hidden_size=hidden_size, 155 | num_layers=self.num_layers, 156 | bidirectional=bidirectional, 157 | batch_first=True) 158 | 159 | def forward(self, src_dict): 160 | """ 161 | 162 | :param src_dict: 163 | :return: 164 | """ 165 | src_tokens = src_dict[TOKENS] 166 | src_lengths = src_dict[TOKENS_LENS] 167 | batch_size = len(src_tokens) 168 | src_embed = self.embedding(src_tokens) 169 | src_embed = F.dropout(src_embed, p=self.dropout, training=self.training) 170 | 171 | total_length = src_embed.size(1) 172 | packed_src_embed = nn.utils.rnn.pack_padded_sequence(src_embed, 173 | src_lengths, 174 | batch_first=True, 175 | enforce_sorted=False) 176 | state_size = [self.num_layers, batch_size, self.hidden_size] 177 | if self.bidirectional: 178 | state_size[0] *= 2 179 | h0 = src_embed.new_zeros(state_size) 180 | c0 = src_embed.new_zeros(state_size) 181 | hidden_states, (final_hiddens, final_cells) = self.lstm(packed_src_embed, (h0, c0)) 182 | hidden_states, _ = nn.utils.rnn.pad_packed_sequence(hidden_states, 183 | padding_value=self.pad_idx, 184 | batch_first=True, 185 | total_length=total_length) 186 | encoder_padding_mask = src_tokens.eq(self.pad_idx) 187 | if self.bidirectional: 188 | final_hiddens = torch.cat((final_hiddens[0], final_hiddens[1]), dim=1).unsqueeze(0) 189 | final_cells = torch.cat((final_cells[0], final_cells[1]), dim=1).unsqueeze(0) 190 | output = {'encoder_output': hidden_states, 191 | 'encoder_padding_mask': encoder_padding_mask, 192 | 'encoder_hidden': [final_hiddens, final_cells]} 193 | return output 194 | 195 | 196 | class CopyRnnDecoder(nn.Module): 197 | def __init__(self, vocab2id, embedding, args): 198 | super().__init__() 199 | self.vocab2id = vocab2id 200 | vocab_size = embedding.num_embeddings 201 | embed_dim = embedding.embedding_dim 202 | self.vocab_size = vocab_size 203 | self.embed_size = embed_dim 204 | self.embedding = embedding 205 | self.target_hidden_size = args.target_hidden_size 206 | if args.bidirectional: 207 | self.src_hidden_size = args.src_hidden_size * 2 208 | else: 209 | self.src_hidden_size = args.src_hidden_size 210 | self.max_src_len = args.max_src_len 211 | self.max_oov_count = args.max_oov_count 212 | self.dropout = args.dropout 213 | self.pad_idx = vocab2id[PAD_WORD] 214 | self.is_copy = args.copy_net 215 | self.input_feeding = args.input_feeding 216 | self.auto_regressive = args.auto_regressive 217 | 218 | if not self.auto_regressive and self.input_feeding: 219 | raise ValueError('auto regressive must be used when input_feeding is on') 220 | 221 | decoder_input_size = embed_dim 222 | if args.input_feeding: 223 | decoder_input_size += self.src_hidden_size 224 | 225 | self.lstm = nn.LSTM( 226 | input_size=decoder_input_size, 227 | hidden_size=self.target_hidden_size, 228 | num_layers=1, 229 | batch_first=True 230 | ) 231 | self.attn_layer = Attention(self.src_hidden_size, self.target_hidden_size, args.attention_mode) 232 | self.copy_proj = nn.Linear(self.src_hidden_size, self.target_hidden_size, bias=False) 233 | self.input_copy_proj = nn.Linear(self.src_hidden_size, self.target_hidden_size, bias=False) 234 | self.generate_proj = nn.Linear(self.target_hidden_size, self.vocab_size, bias=False) 235 | 236 | def forward(self, prev_output_tokens, encoder_output_dict, prev_context_state, 237 | prev_rnn_state, src_dict): 238 | """ 239 | 240 | :param prev_output_tokens: B x 1 241 | :param encoder_output_dict: 242 | :param prev_context_state: B x TH 243 | :param prev_rnn_state: 244 | :param src_dict: 245 | :return: 246 | """ 247 | if self.is_copy: 248 | if self.auto_regressive or not self.training: 249 | output = self.forward_copyrnn_auto_regressive(encoder_output_dict=encoder_output_dict, 250 | prev_context_state=prev_context_state, 251 | prev_output_tokens=prev_output_tokens, 252 | prev_rnn_state=prev_rnn_state, 253 | src_dict=src_dict) 254 | else: 255 | output = self.forward_copyrnn_one_pass(encoder_output_dict=encoder_output_dict, 256 | src_dict=src_dict, 257 | encoder_hidden_state=prev_rnn_state) 258 | else: 259 | if self.auto_regressive or not self.training: 260 | output = self.forward_rnn_auto_regressive(encoder_output_dict=encoder_output_dict, 261 | prev_output_tokens=prev_output_tokens, 262 | prev_rnn_state=prev_rnn_state, 263 | prev_context_state=prev_context_state) 264 | else: 265 | output = self.forward_rnn_one_pass(encoder_output_dict=encoder_output_dict, 266 | src_dict=src_dict, 267 | encoder_hidden_state=prev_rnn_state) 268 | return output 269 | 270 | def forward_copyrnn_one_pass(self, encoder_output_dict, encoder_hidden_state, src_dict): 271 | """ 272 | 273 | :param encoder_output_dict: 274 | :param encoder_hidden_state: 275 | :param src_dict: 276 | :return: 277 | """ 278 | dec_len = src_dict[TARGET].size(1) - 1 279 | src_tokens_with_oov = src_dict[TOKENS_OOV] 280 | batch_size = len(src_tokens_with_oov) 281 | encoder_output = encoder_output_dict['encoder_output'] 282 | encoder_output_mask = encoder_output_dict['encoder_padding_mask'] 283 | 284 | decoder_input = self.embedding(src_dict[TARGET][:, :-1]) 285 | 286 | rnn_output, rnn_state = self.lstm(decoder_input, encoder_hidden_state) 287 | attn_output, attn_weights = self.attn_layer(rnn_output, encoder_output, encoder_output_mask) 288 | 289 | generate_logits = torch.exp(self.generate_proj(attn_output)) 290 | # add 1e-10 to avoid -inf in torch.log 291 | generate_oov_logits = torch.zeros(batch_size, dec_len, self.max_oov_count) + 1e-10 292 | if torch.cuda.is_available(): 293 | generate_oov_logits = generate_oov_logits.cuda() 294 | generate_logits = torch.cat([generate_logits, generate_oov_logits], dim=2) 295 | copy_logits = self.get_copy_score(encoder_output, 296 | src_tokens_with_oov, 297 | attn_output, 298 | encoder_output_mask) 299 | # log softmax 300 | # !! important !! 301 | # must add the generative and copy logits after exp func , so tf.log_softmax can't be called 302 | # because it will add the generative and copy logits before exp func, then it's equal to multiply 303 | # the exp(generative) and exp(copy) result, not the sum of them. 304 | total_logit = generate_logits + copy_logits 305 | total_prob = total_logit / torch.sum(total_logit, 2).unsqueeze(2) 306 | total_prob = torch.log(total_prob) 307 | return total_prob, attn_output, rnn_state 308 | 309 | def forward_copyrnn_auto_regressive(self, 310 | encoder_output_dict, 311 | prev_context_state, 312 | prev_output_tokens, 313 | prev_rnn_state, 314 | src_dict): 315 | """ 316 | 317 | :param encoder_output_dict: 318 | :param prev_context_state: 319 | :param prev_output_tokens: 320 | :param prev_rnn_state: 321 | :param src_dict: 322 | :return: 323 | """ 324 | src_tokens = src_dict[TOKENS] 325 | src_tokens_with_oov = src_dict[TOKENS_OOV] 326 | batch_size = len(src_tokens) 327 | prev_output_tokens = torch.as_tensor(prev_output_tokens, dtype=torch.int64) 328 | if torch.cuda.is_available(): 329 | prev_output_tokens = prev_output_tokens.cuda() 330 | 331 | encoder_output = encoder_output_dict['encoder_output'] 332 | encoder_output_mask = encoder_output_dict['encoder_padding_mask'] 333 | # B x 1 x L 334 | copy_state = self.get_attn_read_input(encoder_output, 335 | prev_context_state, 336 | prev_output_tokens, 337 | src_tokens_with_oov) 338 | 339 | # map copied oov tokens to OOV idx to avoid embedding lookup error 340 | prev_output_tokens[prev_output_tokens >= self.vocab_size] = self.vocab2id[UNK_WORD] 341 | src_embed = self.embedding(prev_output_tokens) 342 | 343 | if self.input_feeding: 344 | decoder_input = torch.cat([src_embed, copy_state], dim=2) 345 | else: 346 | decoder_input = src_embed 347 | decoder_input = F.dropout(decoder_input, p=self.dropout, training=self.training) 348 | rnn_output, rnn_state = self.lstm(decoder_input, prev_rnn_state) 349 | rnn_state = list(rnn_state) 350 | # attn_output is the final hidden state of decoder layer 351 | # attn_output B x 1 x TH 352 | attn_output, attn_weights = self.attn_layer(rnn_output, encoder_output, encoder_output_mask) 353 | generate_logits = torch.exp(self.generate_proj(attn_output).squeeze(1)) 354 | # add 1e-10 to avoid -inf in torch.log 355 | generate_oov_logits = torch.zeros(batch_size, self.max_oov_count) + 1e-10 356 | if torch.cuda.is_available(): 357 | generate_oov_logits = generate_oov_logits.cuda() 358 | generate_logits = torch.cat([generate_logits, generate_oov_logits], dim=1) 359 | copy_logits = self.get_copy_score(encoder_output, 360 | src_tokens_with_oov, 361 | attn_output, 362 | encoder_output_mask) 363 | # log softmax 364 | # !! important !! 365 | # must add the generative and copy logits after exp func , so tf.log_softmax can't be called 366 | # because it will add the generative and copy logits before exp func, then it's equal to multiply 367 | # the exp(generative) and exp(copy) result, not the sum of them. 368 | total_logit = generate_logits + copy_logits.squeeze(1) 369 | total_prob = total_logit / torch.sum(total_logit, 1).unsqueeze(1) 370 | total_prob = torch.log(total_prob) 371 | return total_prob, attn_output.squeeze(1), rnn_state 372 | 373 | def forward_rnn_one_pass(self, encoder_output_dict, encoder_hidden_state, src_dict): 374 | encoder_output = encoder_output_dict['encoder_output'] 375 | encoder_output_mask = encoder_output_dict['encoder_padding_mask'] 376 | 377 | decoder_input = self.embedding(src_dict[TARGET][:, :-1]) 378 | 379 | rnn_output, rnn_state = self.lstm(decoder_input, encoder_hidden_state) 380 | attn_output, attn_weights = self.attn_layer(rnn_output, encoder_output, encoder_output_mask) 381 | probs = torch.log_softmax(self.generate_proj(attn_output), dim=-1) 382 | return probs, attn_output, rnn_state 383 | 384 | def forward_rnn_auto_regressive(self, encoder_output_dict, prev_output_tokens, 385 | prev_rnn_state, prev_context_state): 386 | """ 387 | 388 | :param encoder_output_dict: 389 | :param prev_output_tokens: 390 | :param prev_rnn_state: 391 | :param prev_context_state: 392 | :return: 393 | """ 394 | encoder_output = encoder_output_dict['encoder_output'] 395 | encoder_output_mask = encoder_output_dict['encoder_padding_mask'] 396 | src_embed = self.embedding(prev_output_tokens) 397 | if self.input_feeding: 398 | prev_context_state = prev_context_state.unsqueeze(1) 399 | decoder_input = torch.cat([src_embed, prev_context_state], dim=2) 400 | else: 401 | decoder_input = src_embed 402 | rnn_output, rnn_state = self.lstm(decoder_input, prev_rnn_state) 403 | rnn_state = list(rnn_state) 404 | attn_output, attn_weights = self.attn_layer(rnn_output, encoder_output, encoder_output_mask) 405 | probs = torch.log_softmax(self.generate_proj(attn_output).squeeze(1), 1) 406 | return probs, attn_output.squeeze(1), rnn_state 407 | 408 | def get_attn_read_input(self, encoder_output, prev_context_state, 409 | prev_output_tokens, src_tokens_with_oov): 410 | """ 411 | build CopyNet decoder input of "attentive read" part. 412 | :param encoder_output: 413 | :param prev_context_state: 414 | :param prev_output_tokens: 415 | :param src_tokens_with_oov: 416 | :return: 417 | """ 418 | # mask : B x L x 1 419 | mask_bool = torch.eq(prev_output_tokens.repeat(1, self.max_src_len), 420 | src_tokens_with_oov).unsqueeze(2) 421 | mask = mask_bool.type_as(encoder_output) 422 | # B x L x SH 423 | aggregate_weight = torch.tanh(self.input_copy_proj(torch.mul(mask, encoder_output))) 424 | # when all prev_tokens are not in src_tokens, don't execute mask -inf to avoid nan result in softmax 425 | no_zero_mask = ((mask != 0).sum(dim=1) != 0).repeat(1, self.max_src_len).unsqueeze(2) 426 | input_copy_logit_mask = no_zero_mask * mask_bool 427 | input_copy_logit = torch.bmm(aggregate_weight, prev_context_state.unsqueeze(2)) 428 | input_copy_logit.masked_fill_(input_copy_logit_mask, float('-inf')) 429 | input_copy_weight = torch.softmax(input_copy_logit.squeeze(2), 1) 430 | # B x 1 x SH 431 | copy_state = torch.bmm(input_copy_weight.unsqueeze(1), encoder_output) 432 | return copy_state 433 | 434 | def get_copy_score(self, encoder_out, src_tokens_with_oov, decoder_output, encoder_output_mask): 435 | """ 436 | 437 | :param encoder_out: B x L x SH 438 | :param src_tokens_with_oov: B x L 439 | :param decoder_output: B x dec_len x TH 440 | :param encoder_output_mask: B x L 441 | :return: B x dec_len x V 442 | """ 443 | 444 | dec_len = decoder_output.size(1) 445 | batch_size = len(encoder_out) 446 | # copy_score: B x L x dec_len 447 | copy_score_in_seq = torch.bmm(torch.tanh(self.copy_proj(encoder_out)), 448 | decoder_output.permute(0, 2, 1)) 449 | copy_score_mask = encoder_output_mask.unsqueeze(2).repeat(1, 1, dec_len) 450 | copy_score_in_seq.masked_fill_(copy_score_mask, float('-inf')) 451 | copy_score_in_seq = torch.exp(copy_score_in_seq) 452 | total_vocab_size = self.vocab_size + self.max_oov_count 453 | copy_score_in_vocab = torch.zeros(batch_size, total_vocab_size, dec_len) 454 | if torch.cuda.is_available(): 455 | copy_score_in_vocab = copy_score_in_vocab.cuda() 456 | token_ids = src_tokens_with_oov.unsqueeze(2).repeat(1, 1, dec_len) 457 | copy_score_in_vocab.scatter_add_(1, token_ids, copy_score_in_seq) 458 | copy_score_in_vocab = copy_score_in_vocab.permute(0, 2, 1) 459 | return copy_score_in_vocab 460 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/model_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import argparse 3 | import tensorflow as tf 4 | from ..dataloader import UNK_WORD, BOS_WORD 5 | 6 | 7 | def mask_fill(t, mask, num): 8 | """ 9 | 10 | :param t: input tensor 11 | :param mask: mask value True for keeping value and False for mask 12 | :param num: num to be fill in masked index 13 | :return: 14 | """ 15 | t_dtype = t.dtype 16 | mask = tf.cast(mask, dtype=t_dtype) 17 | neg_mask = 1 - mask 18 | filled_t = t * mask + neg_mask * num 19 | return filled_t 20 | 21 | 22 | class Attention(tf.keras.layers.Layer): 23 | def __init__(self, encoder_dim, decoder_dim, score_mode='general'): 24 | super().__init__() 25 | self.encoder_dim = encoder_dim 26 | self.decoder_dim = decoder_dim 27 | self.score_mode = score_mode 28 | self.permuate_1_2 = tf.keras.layers.Permute((2, 1)) 29 | if self.score_mode == 'general': 30 | self.attn = tf.keras.layers.Dense(self.encoder_dim, use_bias=False) 31 | 32 | self.output_layer = tf.keras.layers.Dense(self.decoder_dim) 33 | 34 | @tf.function 35 | def score(self, query, key, mask, dec_len): 36 | if self.score_mode == 'general': 37 | attn_weights = tf.matmul(self.attn(query), self.permuate_1_2(key)) 38 | elif self.score_mode == 'concat': 39 | pass 40 | elif self.score_mode == 'dot': 41 | attn_weights = tf.matmul(query, self.permuate_1_2(key)) 42 | 43 | mask = tf.repeat(tf.expand_dims(mask, 1), repeats=dec_len, axis=1) 44 | attn_weights = mask_fill(attn_weights, mask, -1e20) 45 | attn_weights = tf.nn.softmax(attn_weights, axis=2) 46 | return attn_weights 47 | 48 | @tf.function 49 | def call(self, decoder_output, encoder_output, enc_mask, dec_len): 50 | attn_weights = self.score(decoder_output, encoder_output, enc_mask, dec_len) 51 | context_embed = tf.matmul(attn_weights, encoder_output) 52 | attn_output = tf.tanh(self.output_layer(tf.concat([context_embed, decoder_output], axis=-1))) 53 | return attn_output 54 | 55 | 56 | class CopyRnnTF(tf.keras.Model): 57 | def __init__(self, args: argparse.Namespace, vocab2id): 58 | super().__init__() 59 | self.args = args 60 | self.vocab_size = len(vocab2id) 61 | initializer = tf.random_uniform_initializer(minval=-0.1, maxval=0.1) 62 | self.embedding = tf.keras.layers.Embedding(self.vocab_size, args.embed_dim, 63 | embeddings_initializer=initializer, 64 | dtype=tf.float32) 65 | self.encoder = Encoder(args, self.embedding) 66 | self.decoder = Decoder(args, self.embedding) 67 | self.max_target_len = self.args.max_target_len 68 | self.total_vocab_size = self.vocab_size + args.max_oov_count 69 | self.encoder2decoder_state = tf.keras.layers.Dense(args.decoder_hidden_size) 70 | self.encoder2decoder_cell = tf.keras.layers.Dense(args.decoder_hidden_size) 71 | self.beam_size = args.beam_size 72 | self.beam_size_t = tf.constant(args.beam_size, dtype=tf.int64) 73 | self.unk_idx = vocab2id[UNK_WORD] 74 | self.bos_idx = vocab2id[BOS_WORD] 75 | 76 | def call(self, x, x_with_oov, x_len, enc_output, dec_x, prev_h, prev_c, 77 | batch_size, dec_len): 78 | if enc_output._rank() <= 1: 79 | enc_output, prev_h, prev_c = self.encoder(x, x_len, batch_size) 80 | prev_h = self.encoder2decoder_state(prev_h) 81 | prev_c = self.encoder2decoder_state(prev_c) 82 | 83 | probs, prev_h, prev_c = self.decoder(dec_x, x_with_oov, x_len, enc_output, 84 | prev_h, prev_c, batch_size, dec_len) 85 | 86 | return probs, enc_output, prev_h, prev_c 87 | 88 | @tf.function 89 | def beam_search(self, x, x_with_oov, x_len, batch_size): 90 | """ 91 | 92 | :param x: 93 | :param x_with_oov: 94 | :param x_len: 95 | :param batch_size_t: 1-D tensor, because SavedModel not support scalar input parameter 96 | :return: 97 | """ 98 | batch_size = tf.reduce_sum(batch_size) 99 | beam_batch_size = self.beam_size * batch_size 100 | prev_output_tokens = tf.ones([batch_size, 1], dtype=tf.int64) * self.bos_idx 101 | # assign encoder_output to tf.constant(0) is just as placeholder to avoid exception 102 | probs, enc_output, prev_h, prev_c = self.call(x, x_with_oov, x_len, tf.constant(0), 103 | prev_output_tokens, tf.zeros([1, 100]), 104 | tf.zeros([1, 100]), batch_size, 105 | tf.ones([], dtype=tf.int64)) 106 | probs = tf.squeeze(probs, axis=1) 107 | prev_best_probs, prev_best_index = tf.math.top_k(probs, k=self.beam_size) 108 | prev_best_index = tf.cast(prev_best_index, dtype=tf.int64) 109 | 110 | prev_h = tf.repeat(prev_h, self.beam_size, axis=0) 111 | prev_c = tf.repeat(prev_c, self.beam_size, axis=0) 112 | enc_output = tf.repeat(enc_output, self.beam_size, axis=0) 113 | result_sequences = prev_best_index 114 | 115 | prev_best_index = mask_fill(prev_best_index, prev_best_index < self.vocab_size, self.unk_idx) 116 | prev_best_index = tf.reshape(prev_best_index, [beam_batch_size, -1]) 117 | x = tf.repeat(x, repeats=self.beam_size, axis=0) 118 | x_with_oov = tf.repeat(x_with_oov, repeats=self.beam_size, axis=0) 119 | x_len = tf.repeat(x_len, repeats=self.beam_size, axis=0) 120 | 121 | for target_idx in range(1, self.max_target_len): 122 | probs, enc_output, prev_h, prev_c = self.call(x, x_with_oov, x_len, enc_output, 123 | prev_best_index, 124 | prev_h, prev_c, beam_batch_size, 125 | tf.ones([], dtype=tf.int64)) 126 | probs = tf.squeeze(probs, axis=1) 127 | # B x b*V 128 | accumulated_probs = tf.reshape(prev_best_probs, [beam_batch_size, 1]) 129 | accumulated_probs = tf.repeat(accumulated_probs, repeats=self.total_vocab_size, axis=1) 130 | accumulated_probs += probs 131 | accumulated_probs = tf.reshape(accumulated_probs, 132 | [batch_size, self.beam_size * self.total_vocab_size]) 133 | prev_best_probs, top_token_index = tf.math.top_k(accumulated_probs, k=self.beam_size) 134 | top_token_index = tf.cast(top_token_index, dtype=tf.int64) 135 | 136 | select_idx_factor = tf.range(0, batch_size, dtype=tf.int64) * self.beam_size 137 | select_idx_factor = tf.repeat(tf.expand_dims(select_idx_factor, axis=1), 138 | self.beam_size, axis=1) 139 | state_select_idx = tf.reshape(top_token_index, [beam_batch_size]) // probs.shape[1] 140 | state_select_idx += tf.reshape(select_idx_factor, [beam_batch_size]) 141 | 142 | prev_best_index = top_token_index % probs.shape[1] 143 | prev_h = tf.gather(prev_h, state_select_idx, axis=0) 144 | prev_c = tf.gather(prev_c, state_select_idx, axis=0) 145 | 146 | result_sequences = tf.reshape(result_sequences, [beam_batch_size, -1]) 147 | result_sequences = tf.gather(result_sequences, state_select_idx, axis=0) 148 | result_sequences = tf.reshape(result_sequences, [batch_size, self.beam_size, -1]) 149 | result_sequences = tf.concat([result_sequences, tf.expand_dims(prev_best_index, axis=2)], 150 | axis=2) 151 | 152 | prev_best_index = tf.reshape(prev_best_index, [beam_batch_size, 1]) 153 | prev_best_index = mask_fill(prev_best_index, prev_best_index < self.vocab_size, self.unk_idx) 154 | 155 | return result_sequences 156 | 157 | 158 | class Encoder(tf.keras.layers.Layer): 159 | def __init__(self, args, embedding): 160 | super().__init__() 161 | self.args = args 162 | self.embedding = embedding 163 | self.lstm = tf.keras.layers.LSTM(self.args.encoder_hidden_size, 164 | return_state=True, return_sequences=True) 165 | if args.bidirectional: 166 | self.lstm = tf.keras.layers.Bidirectional(self.lstm) 167 | 168 | self.max_dec = self.args.max_src_len 169 | 170 | @tf.function 171 | def call(self, x, x_len, batch_size): 172 | embed_x = self.embedding(x) 173 | mask = tf.sequence_mask(x_len, maxlen=self.max_dec) 174 | if self.args.bidirectional: 175 | lstm_output, state_fw_h, state_fw_c, state_bw_h, state_bw_c = self.lstm(embed_x, mask=mask) 176 | state_h = tf.concat([state_fw_h, state_bw_h], axis=1) 177 | state_c = tf.concat([state_fw_c, state_bw_c], axis=1) 178 | else: 179 | lstm_output, state_h, state_c = self.lstm(embed_x) 180 | return lstm_output, state_h, state_c 181 | 182 | 183 | class Decoder(tf.keras.layers.Layer): 184 | def __init__(self, args, embedding): 185 | super().__init__() 186 | self.args = args 187 | self.embedding = embedding 188 | self.vocab_size = self.embedding.input_dim 189 | self.max_oov_count = self.args.max_oov_count 190 | self.max_src_len = self.args.max_src_len 191 | self.max_enc = self.args.max_src_len 192 | self.lstm = tf.keras.layers.LSTM(self.args.decoder_hidden_size, 193 | return_state=True, return_sequences=True) 194 | if self.args.bidirectional: 195 | enc_hidden_size = self.args.encoder_hidden_size * 2 196 | else: 197 | enc_hidden_size = self.args.encoder_hidden_size 198 | self.attention = Attention(enc_hidden_size, 199 | self.args.decoder_hidden_size) 200 | self.generate_layer = tf.keras.layers.Dense(self.vocab_size, use_bias=False) 201 | self.concat_layer = tf.keras.layers.Concatenate() 202 | self.copy_layer = tf.keras.layers.Dense(self.args.decoder_hidden_size) 203 | self.permuate_1_2 = tf.keras.layers.Permute((2, 1)) 204 | 205 | @tf.function 206 | def call(self, dec_x, enc_x_with_oov, enc_len, enc_output, 207 | enc_h, enc_c, batch_size, dec_len): 208 | """ 209 | 210 | :return: 211 | """ 212 | embed_dec_x = self.embedding(dec_x) 213 | mask = tf.sequence_mask(enc_len, maxlen=self.max_src_len) 214 | hidden_states, state_h, state_c = self.lstm(embed_dec_x, initial_state=(enc_h, enc_c)) 215 | attn_output = self.attention(hidden_states, enc_output, mask, dec_len) 216 | generation_logits = tf.exp(self.generate_layer(attn_output)) 217 | 218 | generation_logits = tf.pad(generation_logits, [[0, 0], [0, 0], [0, self.max_oov_count]], 219 | constant_values=1e-10) 220 | copy_logits = self.get_copy_score(enc_output, enc_x_with_oov, attn_output, mask, batch_size, dec_len) 221 | total_logits = generation_logits + copy_logits 222 | total_prob = total_logits / tf.reduce_sum(total_logits, axis=2, keepdims=True) 223 | total_prob = tf.math.log(total_prob) 224 | return total_prob, state_h, state_c 225 | 226 | @tf.function 227 | def get_copy_score(self, src_output, x_with_oov, tgt_output, mask, batch_size, dec_len): 228 | total_vocab_size = self.vocab_size + self.max_oov_count 229 | tgt_output = self.permuate_1_2(tgt_output) 230 | 231 | copy_score_in_seq = tf.matmul(tf.tanh(self.copy_layer(src_output)), tgt_output) 232 | copy_score_in_seq = self.permuate_1_2(copy_score_in_seq) 233 | mask = tf.repeat(tf.expand_dims(mask, axis=1), repeats=dec_len, axis=1) 234 | copy_score_in_seq = mask_fill(copy_score_in_seq, mask, -1e20) 235 | copy_score_in_seq = tf.exp(copy_score_in_seq) 236 | 237 | batch_idx = tf.transpose(tf.broadcast_to(tf.range(batch_size, dtype=tf.int64), 238 | [self.max_src_len, dec_len, batch_size])) 239 | src_idx = tf.broadcast_to(tf.range(dec_len, dtype=tf.int64), [batch_size, dec_len]) 240 | src_idx = tf.repeat(tf.expand_dims(src_idx, axis=2), repeats=self.max_src_len, axis=2) 241 | x_with_oov = tf.repeat(tf.expand_dims(x_with_oov, axis=1), repeats=dec_len, axis=1) 242 | 243 | score_idx = tf.stack([batch_idx, src_idx, x_with_oov], axis=-1) 244 | 245 | to_shape = [batch_size, dec_len, total_vocab_size] 246 | 247 | copy_score_in_vocab = tf.scatter_nd(score_idx, copy_score_in_seq, to_shape) 248 | return copy_score_in_vocab 249 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import torch 4 | from munch import Munch 5 | from pysenal import read_file, append_jsonlines 6 | from deep_keyphrase.base_predictor import BasePredictor 7 | from deep_keyphrase.copy_rnn.model import CopyRNN 8 | from deep_keyphrase.copy_rnn.beam_search import BeamSearch 9 | from deep_keyphrase.dataloader import KeyphraseDataLoader, RAW_BATCH, TOKENS, INFERENCE_MODE, EVAL_MODE 10 | from deep_keyphrase.utils.vocab_loader import load_vocab 11 | from deep_keyphrase.utils.constants import BOS_WORD, UNK_WORD 12 | from deep_keyphrase.utils.tokenizer import token_char_tokenize 13 | 14 | 15 | class CopyRnnPredictor(BasePredictor): 16 | def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_length): 17 | """ 18 | 19 | :param model_info: input the model information. 20 | str type: model path 21 | dict type: must have `model` and `config` field, 22 | indicate the model object and config object 23 | 24 | :param vocab_info: input the vocab information. 25 | str type: vocab path 26 | dict type: vocab2id dict which map word to id 27 | :param beam_size: beam size 28 | :param max_target_len: max keyphrase token length 29 | :param max_src_length: max source text length 30 | """ 31 | super().__init__(model_info) 32 | if isinstance(vocab_info, str): 33 | self.vocab2id = load_vocab(vocab_info) 34 | elif isinstance(vocab_info, dict): 35 | self.vocab2id = vocab_info 36 | else: 37 | raise ValueError('vocab info type error') 38 | self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys())) 39 | self.config = self.load_config(model_info) 40 | self.model = self.load_model(model_info, CopyRNN(self.config, self.vocab2id)) 41 | self.model.eval() 42 | self.beam_size = beam_size 43 | self.max_target_len = max_target_len 44 | self.max_src_len = max_src_length 45 | self.beam_searcher = BeamSearch(model=self.model, 46 | beam_size=self.beam_size, 47 | max_target_len=self.max_target_len, 48 | id2vocab=self.id2vocab, 49 | bos_idx=self.vocab2id[BOS_WORD], 50 | unk_idx=self.vocab2id[UNK_WORD], 51 | args=self.config) 52 | self.pred_base_config = {'max_oov_count': self.config.max_oov_count, 53 | 'max_src_len': self.max_src_len, 54 | 'max_target_len': self.max_target_len, 55 | 'prefetch': False, 56 | 'shuffle_in_batch': False, 57 | 'token_field': TOKENS, 58 | 'keyphrase_field': 'keyphrases'} 59 | 60 | def predict(self, text_list, batch_size=10, delimiter=None, tokenized=False): 61 | """ 62 | 63 | :param text_list: 64 | :param batch_size: 65 | :param delimiter: 66 | :param tokenized: 67 | :return: 68 | """ 69 | # eval mode closes dropout, triggers auto regression in decoding stage 70 | self.model.eval() 71 | if len(text_list) < batch_size: 72 | batch_size = len(text_list) 73 | 74 | if tokenized: 75 | text_list = [{TOKENS: i} for i in text_list] 76 | else: 77 | text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list] 78 | args = Munch({'batch_size': batch_size, **self.config._asdict(), **self.pred_base_config}) 79 | loader = KeyphraseDataLoader(data_source=text_list, 80 | vocab2id=self.vocab2id, 81 | mode=INFERENCE_MODE, 82 | args=args) 83 | result = [] 84 | for batch in loader: 85 | with torch.no_grad(): 86 | result.extend(self.beam_searcher.beam_search(batch, delimiter=delimiter)) 87 | return result 88 | 89 | def eval_predict(self, src_filename, dest_filename, args, 90 | model=None, remove_existed=False): 91 | args_dict = vars(args) 92 | args_dict['batch_size'] = args_dict['eval_batch_size'] 93 | args = Munch(args_dict) 94 | loader = KeyphraseDataLoader(data_source=src_filename, 95 | vocab2id=self.vocab2id, 96 | mode=EVAL_MODE, 97 | args=args) 98 | 99 | if os.path.exists(dest_filename): 100 | print('destination filename {} existed'.format(dest_filename)) 101 | if remove_existed: 102 | os.remove(dest_filename) 103 | if model is not None: 104 | model.eval() 105 | self.beam_searcher = BeamSearch(model=model, 106 | beam_size=self.beam_size, 107 | max_target_len=self.max_target_len, 108 | id2vocab=self.id2vocab, 109 | bos_idx=self.vocab2id[BOS_WORD], 110 | unk_idx=self.vocab2id[UNK_WORD], 111 | args=self.config) 112 | 113 | for batch in loader: 114 | with torch.no_grad(): 115 | batch_result = self.beam_searcher.beam_search(batch, delimiter=None) 116 | final_result = [] 117 | assert len(batch_result) == len(batch[RAW_BATCH]) 118 | for item_input, item_output in zip(batch[RAW_BATCH], batch_result): 119 | item_input['pred_keyphrases'] = item_output 120 | final_result.append(item_input) 121 | append_jsonlines(dest_filename, final_result) 122 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/predict_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | from .model_tf import CopyRnnTF 5 | from deep_keyphrase.dataloader import UNK_WORD, PAD_WORD, TOKENS, TOKENS_OOV, TOKENS_LENS, OOV_LIST, EOS_WORD 6 | from ..utils.tokenizer import token_char_tokenize 7 | 8 | 9 | class PredictorTF(object): 10 | def __init__(self, model: CopyRnnTF, vocab2id, args): 11 | self.model = model 12 | self.max_src_len = args.max_src_len 13 | self.vocab2id = vocab2id 14 | self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys())) 15 | self.vocab_size = len(self.vocab2id) 16 | self.max_oov_count = args.max_oov_count 17 | self.pad_idx = vocab2id[PAD_WORD] 18 | 19 | def eval_predict(self, batch, model=None, delimiter=None): 20 | if model is None: 21 | model = self.model 22 | batch_size = len(batch[TOKENS]) 23 | result_tensor = model.beam_search(batch[TOKENS], batch[TOKENS_OOV], batch[TOKENS_LENS], 24 | np.array([batch_size], dtype=np.int64)) 25 | result_np = result_tensor.numpy() 26 | oov_list = batch[OOV_LIST] 27 | return self.__idx2result_beam(delimiter, oov_list, result_np) 28 | 29 | def predict(self, text_list): 30 | x_batch, x_oov_batch, sent_len_batch, oov_list_batch = self.generate_input_batch(text_list) 31 | batch_size = len(x_batch) 32 | batch_size_np = np.array([batch_size], dtype=np.long) 33 | result_tensor = self.model.beam_search(x_batch, x_oov_batch, sent_len_batch, batch_size_np) 34 | result_np = result_tensor.numpy() 35 | return self.__idx2result_beam('', oov_list_batch, result_np) 36 | 37 | def __idx2result_beam(self, delimiter, oov_list, result_sequences): 38 | results = [] 39 | for batch_idx, batch in enumerate(result_sequences): 40 | beam_list = [] 41 | item_oov_list = oov_list[batch_idx] 42 | for beam in batch: 43 | phrase = [] 44 | for idx in beam: 45 | if self.id2vocab.get(idx) == EOS_WORD: 46 | break 47 | if idx in self.id2vocab: 48 | phrase.append(self.id2vocab[idx]) 49 | else: 50 | oov_idx = idx - len(self.id2vocab) 51 | if oov_idx < len(item_oov_list): 52 | phrase.append(item_oov_list[oov_idx]) 53 | else: 54 | phrase.append(UNK_WORD) 55 | 56 | if delimiter is not None: 57 | phrase = delimiter.join(phrase) 58 | if phrase not in beam_list: 59 | beam_list.append(phrase) 60 | results.append(beam_list) 61 | return results 62 | 63 | def generate_input_batch(self, text_list): 64 | x_batch = [] 65 | x_oov_batch = [] 66 | sent_len_batch = [] 67 | oov_list_batch = [] 68 | for text in text_list: 69 | tokens = token_char_tokenize(text) 70 | x, x_oov, oov_list, sent_len = self.generate_input(tokens) 71 | x_batch.append(x) 72 | x_oov_batch.append(x_oov) 73 | sent_len_batch.append(sent_len) 74 | oov_list_batch.append(oov_list) 75 | x_batch = tf.convert_to_tensor(x_batch) 76 | x_oov_batch = tf.convert_to_tensor(x_oov_batch) 77 | sent_len_batch = tf.convert_to_tensor(sent_len_batch) 78 | return x_batch, x_oov_batch, sent_len_batch, oov_list_batch 79 | 80 | def generate_input(self, tokens): 81 | if len(tokens) > self.max_src_len: 82 | tokens = tokens[:self.max_src_len] 83 | token_ids_with_oov = [] 84 | token_ids = [] 85 | oov_list = [] 86 | 87 | for token in tokens: 88 | idx = self.vocab2id.get(token, self.vocab_size) 89 | if idx == self.vocab_size: 90 | token_ids.append(self.vocab2id[UNK_WORD]) 91 | if token not in oov_list: 92 | if len(oov_list) >= self.max_oov_count: 93 | token_ids_with_oov.append(self.vocab_size + self.max_oov_count - 1) 94 | else: 95 | token_ids_with_oov.append(self.vocab_size + len(oov_list)) 96 | oov_list.append(token) 97 | else: 98 | token_ids_with_oov.append(self.vocab_size + oov_list.index(token)) 99 | else: 100 | token_ids.append(idx) 101 | token_ids_with_oov.append(idx) 102 | sent_len = len(token_ids) 103 | 104 | if len(token_ids) < self.max_src_len: 105 | pad_tokens = [self.vocab2id[PAD_WORD]] * (self.max_src_len - len(token_ids)) 106 | token_ids.extend(pad_tokens) 107 | token_ids_with_oov.extend(pad_tokens) 108 | elif len(token_ids) > self.max_src_len: 109 | token_ids = token_ids[:self.max_src_len] 110 | token_ids_with_oov = token_ids_with_oov[:self.max_src_len] 111 | sent_len = self.max_src_len 112 | return token_ids, token_ids_with_oov, oov_list, sent_len 113 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import argparse 4 | from collections import OrderedDict 5 | from munch import Munch 6 | import torch 7 | from pysenal import write_json, read_json 8 | from deep_keyphrase.utils.vocab_loader import load_vocab 9 | from deep_keyphrase.copy_rnn.model import CopyRNN 10 | from deep_keyphrase.base_trainer import BaseTrainer 11 | from deep_keyphrase.dataloader import TOKENS, TARGET 12 | from deep_keyphrase.copy_rnn.predict import CopyRnnPredictor 13 | 14 | 15 | class CopyRnnTrainer(BaseTrainer): 16 | def __init__(self): 17 | self.args = self.parse_args() 18 | self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size) 19 | model = self.load_model() 20 | super().__init__(self.args, model) 21 | 22 | def load_model(self): 23 | if not self.args.train_from: 24 | model = CopyRNN(self.args, self.vocab2id) 25 | else: 26 | model_path = self.args.train_from 27 | config_path = os.path.join(os.path.dirname(model_path), 28 | self.get_basename(model_path) + '.json') 29 | 30 | old_config = read_json(config_path) 31 | old_config['train_from'] = model_path 32 | old_config['step'] = int(model_path.rsplit('_', 1)[-1].split('.')[0]) 33 | self.args = Munch(old_config) 34 | self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size) 35 | 36 | model = CopyRNN(self.args, self.vocab2id) 37 | 38 | if torch.cuda.is_available(): 39 | checkpoint = torch.load(model_path) 40 | else: 41 | checkpoint = torch.load(model_path, map_location=torch.device('cpu')) 42 | state_dict = OrderedDict() 43 | # avoid error when load parallel trained model 44 | for k, v in checkpoint.items(): 45 | if k.startswith('module.'): 46 | k = k[7:] 47 | state_dict[k] = v 48 | model.load_state_dict(state_dict) 49 | 50 | return model 51 | 52 | def train_batch(self, batch, step): 53 | self.model.train() 54 | loss = 0 55 | self.optimizer.zero_grad() 56 | if torch.cuda.is_available(): 57 | batch[TARGET] = batch[TARGET].cuda() 58 | targets = batch[TARGET] 59 | if self.args.auto_regressive: 60 | loss = self.get_auto_regressive_loss(batch, loss, targets) 61 | else: 62 | loss = self.get_one_pass_loss(batch, targets) 63 | 64 | loss.backward() 65 | 66 | # clip norm, this is very import for avoiding nan gradient and misconvergence 67 | if self.args.max_grad: 68 | torch.nn.utils.clip_grad_value_(self.model.parameters(), self.args.max_grad) 69 | if self.args.grad_norm: 70 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_norm) 71 | 72 | self.optimizer.step() 73 | if self.args.schedule_lr and step <= self.args.schedule_step: 74 | self.scheduler.step() 75 | return loss 76 | 77 | def get_one_pass_loss(self, batch, targets): 78 | batch_size = len(batch) 79 | encoder_output = None 80 | decoder_state = torch.zeros(batch_size, self.args.target_hidden_size) 81 | hidden_state = None 82 | prev_output_tokens = None 83 | output = self.model(src_dict=batch, 84 | prev_output_tokens=prev_output_tokens, 85 | encoder_output_dict=encoder_output, 86 | prev_decoder_state=decoder_state, 87 | prev_hidden_state=hidden_state) 88 | decoder_prob, encoder_output, decoder_state, hidden_state = output 89 | vocab_size = decoder_prob.size(-1) 90 | decoder_prob = decoder_prob.view(-1, vocab_size) 91 | loss = self.loss_func(decoder_prob, targets[:, 1:].flatten()) 92 | return loss 93 | 94 | def get_auto_regressive_loss(self, batch, loss, targets): 95 | batch_size = len(batch[TOKENS]) 96 | encoder_output = None 97 | decoder_state = torch.zeros(batch_size, self.args.target_hidden_size) 98 | hidden_state = None 99 | for target_index in range(self.args.max_target_len): 100 | if target_index == 0: 101 | # bos indices 102 | prev_output_tokens = targets[:, target_index].unsqueeze(1) 103 | else: 104 | if self.args.teacher_forcing: 105 | prev_output_tokens = targets[:, target_index].unsqueeze(1) 106 | else: 107 | best_probs, prev_output_tokens = torch.topk(decoder_prob, 1, 1) 108 | prev_output_tokens = prev_output_tokens.clone() 109 | output = self.model(src_dict=batch, 110 | prev_output_tokens=prev_output_tokens, 111 | encoder_output_dict=encoder_output, 112 | prev_decoder_state=decoder_state, 113 | prev_hidden_state=hidden_state) 114 | decoder_prob, encoder_output, decoder_state, hidden_state = output 115 | true_indices = targets[:, target_index + 1].clone() 116 | loss += self.loss_func(decoder_prob, true_indices) 117 | loss /= self.args.max_target_len 118 | return loss 119 | 120 | def evaluate(self, step): 121 | predictor = CopyRnnPredictor(model_info={'model': self.model, 'config': self.args}, 122 | vocab_info=self.vocab2id, 123 | beam_size=self.args.beam_size, 124 | max_target_len=self.args.max_target_len, 125 | max_src_length=self.args.max_src_len) 126 | 127 | def pred_callback(stage): 128 | if stage == 'valid': 129 | src_filename = self.args.valid_filename 130 | dest_filename = self.dest_dir + self.get_basename(self.args.valid_filename) 131 | elif stage == 'test': 132 | src_filename = self.args.test_filename 133 | dest_filename = self.dest_dir + self.get_basename(self.args.test_filename) 134 | else: 135 | raise ValueError('stage name error, must be in `valid` and `test`') 136 | dest_filename += '.batch_{}.pred.jsonl'.format(step) 137 | def predict_func(): 138 | predictor.eval_predict(src_filename=src_filename, 139 | dest_filename=dest_filename, 140 | args=self.args, 141 | model=self.model, 142 | remove_existed=True) 143 | 144 | return predict_func 145 | 146 | valid_statistics = self.evaluate_stage(step, 'valid', pred_callback('valid')) 147 | test_statistics = self.evaluate_stage(step, 'test', pred_callback('test')) 148 | total_statistics = {**valid_statistics, **test_statistics} 149 | 150 | eval_filename = self.dest_dir + self.args.exp_name + '.batch_{}.eval.json'.format(step) 151 | write_json(eval_filename, total_statistics) 152 | return valid_statistics['valid_macro'][self.eval_topn[-1]]['f1'] 153 | 154 | def parse_args(self, args=None): 155 | parser = argparse.ArgumentParser() 156 | # train and evaluation parameter 157 | parser.add_argument("-exp_name", required=True, type=str, help='') 158 | parser.add_argument("-train_filename", required=True, type=str, help='') 159 | parser.add_argument("-valid_filename", required=True, type=str, help='') 160 | parser.add_argument("-test_filename", required=True, type=str, help='') 161 | parser.add_argument("-dest_base_dir", required=True, type=str, help='') 162 | parser.add_argument("-vocab_path", required=True, type=str, help='') 163 | parser.add_argument("-vocab_size", type=int, default=500000, help='') 164 | parser.add_argument("-train_from", default='', type=str, help='') 165 | parser.add_argument("-token_field", default='tokens', type=str, help='') 166 | parser.add_argument("-keyphrase_field", default='keyphrases', type=str, help='') 167 | parser.add_argument("-auto_regressive", action='store_true', help='') 168 | parser.add_argument("-epochs", type=int, default=10, help='') 169 | parser.add_argument("-batch_size", type=int, default=64, help='') 170 | parser.add_argument("-learning_rate", type=float, default=1e-4, help='') 171 | parser.add_argument("-eval_batch_size", type=int, default=50, help='') 172 | parser.add_argument("-dropout", type=float, default=0.0, help='') 173 | parser.add_argument("-grad_norm", type=float, default=0.0, help='') 174 | parser.add_argument("-max_grad", type=float, default=5.0, help='') 175 | parser.add_argument("-shuffle", action='store_true', help='') 176 | parser.add_argument("-teacher_forcing", action='store_true', help='') 177 | parser.add_argument("-beam_size", type=float, default=50, help='') 178 | parser.add_argument('-tensorboard_dir', type=str, default='', help='') 179 | parser.add_argument('-logfile', type=str, default='train_log.log', help='') 180 | parser.add_argument('-save_model_step', type=int, default=5000, help='') 181 | parser.add_argument('-early_stop_tolerance', type=int, default=100, help='') 182 | parser.add_argument('-train_parallel', action='store_true', help='') 183 | parser.add_argument('-schedule_lr', action='store_true', help='') 184 | parser.add_argument('-schedule_step', type=int, default=10000, help='') 185 | parser.add_argument('-schedule_gamma', type=float, default=0.1, help='') 186 | parser.add_argument('-processed', action='store_true', help='') 187 | parser.add_argument('-prefetch', action='store_true', help='') 188 | parser.add_argument('-lazy_loading', action='store_true', help='') 189 | parser.add_argument('-fix_batch_size', action='store_true', help='') 190 | parser.add_argument('-backend', type=str, default='tf', help='') 191 | 192 | # model specific parameter 193 | parser.add_argument("-embed_size", type=int, default=200, help='') 194 | parser.add_argument("-max_oov_count", type=int, default=100, help='') 195 | parser.add_argument("-max_src_len", type=int, default=1500, help='') 196 | parser.add_argument("-max_target_len", type=int, default=8, help='') 197 | parser.add_argument("-src_hidden_size", type=int, default=100, help='') 198 | parser.add_argument("-target_hidden_size", type=int, default=100, help='') 199 | parser.add_argument('-src_num_layers', type=int, default=1, help='') 200 | parser.add_argument('-target_num_layers', type=int, default=1, help='') 201 | parser.add_argument("-attention_mode", type=str, default='general', 202 | choices=['general', 'dot', 'concat'], help='') 203 | parser.add_argument("-bidirectional", action='store_true', help='') 204 | parser.add_argument("-copy_net", action='store_true', help='') 205 | parser.add_argument("-input_feeding", action='store_true', help='') 206 | 207 | args = parser.parse_args(args) 208 | return args 209 | 210 | 211 | def accuracy(probs, true_indices, pad_idx): 212 | pred_indices = torch.argmax(probs, dim=1) 213 | mask = torch.eq(true_indices, torch.ones(*true_indices.size(), dtype=torch.int64) * pad_idx) 214 | tp_result = torch.eq(pred_indices, true_indices).type(torch.int) * (~mask).type(torch.int) 215 | return torch.sum(tp_result).numpy() / true_indices.numel() 216 | 217 | 218 | if __name__ == '__main__': 219 | CopyRnnTrainer().train() 220 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_rnn/train_tf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | 4 | gpu_devices = tf.config.experimental.list_physical_devices('GPU') 5 | for device in gpu_devices: 6 | tf.config.experimental.set_memory_growth(device, True) 7 | 8 | import argparse 9 | import tensorflow as tf 10 | from pysenal import write_jsonline, append_jsonlines, get_logger, write_json 11 | from munch import Munch 12 | from deep_keyphrase.dataloader import * 13 | from deep_keyphrase.utils.vocab_loader import load_vocab 14 | from deep_keyphrase.copy_rnn.model_tf import CopyRnnTF 15 | from deep_keyphrase.dataloader import PAD_WORD 16 | from deep_keyphrase.copy_rnn.predict_tf import PredictorTF 17 | from deep_keyphrase.evaluation import KeyphraseEvaluator 18 | 19 | 20 | class CopyRnnTrainerTF(object): 21 | def __init__(self): 22 | self.args = self.parse_args() 23 | self.vocab2id = load_vocab(self.args.vocab_path) 24 | self.dest_base_dir = self.args.dest_base_dir 25 | self.writer = tf.summary.create_file_writer(self.dest_base_dir + '/logs') 26 | self.exp_name = self.args.exp_name 27 | self.pad_idx = self.vocab2id[PAD_WORD] 28 | self.eval_topn = (5, 10) 29 | self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro', 30 | self.args.token_field, self.args.keyphrase_field) 31 | self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro', 32 | self.args.token_field, self.args.keyphrase_field) 33 | self.best_f1 = None 34 | self.best_step = 0 35 | self.not_update_count = 0 36 | self.logger = get_logger(__name__) 37 | self.total_vocab_size = len(self.vocab2id) + self.args.max_oov_count 38 | 39 | def train(self): 40 | # avoid tensorflow hold all gpus 41 | with tf.device('/device:GPU:0'): 42 | self.train_func() 43 | 44 | def train_func(self): 45 | model = CopyRnnTF(self.args, self.vocab2id) 46 | dataloader = KeyphraseDataLoader(data_source=self.args.train_filename, 47 | vocab2id=self.vocab2id, 48 | mode='train', 49 | args=self.args) 50 | optimizer = tf.keras.optimizers.Adam(learning_rate=self.args.learning_rate) 51 | 52 | @tf.function 53 | def train_step(x, x_with_oov, x_len, target): 54 | batch_size = x.shape[0] 55 | dec_len = self.args.max_target_len 56 | with tf.GradientTape() as tape: 57 | loss = 0 58 | probs, enc_output, prev_h, prev_c = model(x, x_with_oov, x_len, tf.constant(0), 59 | target[:, :-1], None, None, 60 | tf.convert_to_tensor(batch_size), dec_len) 61 | for batch_idx in range(batch_size): 62 | dec_target = target[batch_idx, 1:] 63 | target_idx = tf.one_hot(dec_target, self.total_vocab_size) 64 | dec_step_loss = -tf.reduce_sum(probs[batch_idx, :] * target_idx, axis=1) 65 | mask = tf.cast(dec_target != self.pad_idx, dtype=tf.float32) 66 | 67 | dec_step_loss *= mask 68 | loss += tf.reduce_sum(dec_step_loss) / tf.reduce_sum(mask) 69 | 70 | loss /= batch_size 71 | grads = tape.gradient(loss, model.trainable_variables) 72 | grads = [(tf.clip_by_value(grad, -0.1, 0.1)) for grad in grads] 73 | optimizer.apply_gradients(zip(grads, model.trainable_weights)) 74 | return loss 75 | 76 | step_idx = 0 77 | for epoch in range(self.args.epochs): 78 | for batch in dataloader: 79 | loss = train_step(batch[TOKENS], batch[TOKENS_OOV], batch[TOKENS_LENS], batch[TARGET]) 80 | with self.writer.as_default(): 81 | tf.summary.scalar('loss', loss, step=step_idx) 82 | step_idx += 1 83 | if not step_idx % self.args.save_model_step: 84 | model_basename = self.dest_base_dir + '/{}_step{}'.format(self.exp_name, step_idx) 85 | # write_json(model_basename + '.json', vars(self.args)) 86 | # beam_search_graph = model.beam_search.get_concrete_function( 87 | # x=tf.TensorSpec(shape=[None, self.args.max_src_len], dtype=tf.int64), 88 | # x_with_oov=tf.TensorSpec(shape=[None, self.args.max_src_len], dtype=tf.int64), 89 | # x_len=tf.TensorSpec(shape=[None], dtype=tf.int64), 90 | # batch_size=tf.TensorSpec(shape=[None], dtype=tf.int64) 91 | # ) 92 | # tf.saved_model.save(model, model_basename, signatures=beam_search_graph) 93 | model.save_weights(model_basename + '.ckpt', save_format='tf') 94 | write_json(model_basename + '.json', vars(self.args)) 95 | f1 = self.evaluate(model, step_idx) 96 | self.logger.info('step {}, f1 {}'.format(step_idx, f1)) 97 | 98 | def evaluate(self, model: CopyRnnTF, step): 99 | test_basename = '/{}_step_{}.pred.jsonl'.format(self.args.exp_name, step) 100 | pred_test_filename = self.dest_base_dir + test_basename 101 | predictor = PredictorTF(model, self.vocab2id, self.args) 102 | args_dict = vars(self.args) 103 | args_dict['batch_size'] = args_dict['eval_batch_size'] 104 | args = Munch(args_dict) 105 | loader = KeyphraseDataLoader(data_source=self.args.test_filename, 106 | vocab2id=self.vocab2id, 107 | mode=EVAL_MODE, 108 | args=args) 109 | 110 | for batch in loader: 111 | kp_result = predictor.eval_predict(batch) 112 | result = [] 113 | for item, pred_keyphrases in zip(batch[RAW_BATCH], kp_result): 114 | result_item = {'patent_id': item['patent_id'], 'pred_keyphrases': pred_keyphrases, 115 | self.args.token_field: item[self.args.token_field], 116 | self.args.keyphrase_field: item[self.args.keyphrase_field]} 117 | result.append(result_item) 118 | append_jsonlines(pred_test_filename, result) 119 | 120 | macro_all_ret = self.macro_evaluator.evaluate(pred_test_filename) 121 | macro_present_ret = self.macro_evaluator.evaluate(pred_test_filename, 'present') 122 | macro_absent_ret = self.macro_evaluator.evaluate(pred_test_filename, 'absent') 123 | stage = 'test' 124 | 125 | for n, counter in macro_all_ret.items(): 126 | for k, v in counter.items(): 127 | name = '{}_macro_{}_{}'.format('test', k, n) 128 | tf.summary.scalar(name, v, step=step) 129 | for n in self.eval_topn: 130 | name = 'present_{}_macro_f1_{}'.format(stage, n) 131 | tf.summary.scalar(name, macro_present_ret[n]['f1'], step=step) 132 | for n in self.eval_topn: 133 | absent_f1_name = 'absent_{}_macro_f1_{}'.format(stage, n) 134 | tf.summary.scalar(absent_f1_name, macro_absent_ret[n]['f1'], step=step) 135 | absent_recall_name = 'absent_{}_macro_recall_{}'.format(stage, n) 136 | tf.summary.scalar(absent_recall_name, macro_absent_ret[n]['recall'], step=step) 137 | return macro_all_ret[self.eval_topn[-1]]['f1'] 138 | 139 | def parse_args(self, args=None): 140 | parser = argparse.ArgumentParser() 141 | # train and evaluation parameter 142 | parser.add_argument("-exp_name", required=True, type=str, help='') 143 | parser.add_argument("-train_filename", required=True, type=str, help='') 144 | parser.add_argument("-valid_filename", required=True, type=str, help='') 145 | parser.add_argument("-test_filename", required=True, type=str, help='') 146 | parser.add_argument("-dest_base_dir", required=True, type=str, help='') 147 | parser.add_argument("-vocab_path", required=True, type=str, help='') 148 | parser.add_argument("-vocab_size", type=int, default=500000, help='') 149 | parser.add_argument("-train_from", default='', type=str, help='') 150 | parser.add_argument("-token_field", default='tokens', type=str, help='') 151 | parser.add_argument("-keyphrase_field", default='keyphrases', type=str, help='') 152 | parser.add_argument("-auto_regressive", action='store_true', help='') 153 | parser.add_argument("-epochs", type=int, default=10, help='') 154 | parser.add_argument("-batch_size", type=int, default=128, help='') 155 | parser.add_argument("-learning_rate", type=float, default=1e-3, help='') 156 | parser.add_argument("-eval_batch_size", type=int, default=20, help='') 157 | parser.add_argument("-dropout", type=float, default=0.0, help='') 158 | parser.add_argument("-grad_norm", type=float, default=0.0, help='') 159 | parser.add_argument("-max_grad", type=float, default=5.0, help='') 160 | parser.add_argument("-shuffle", action='store_true', help='') 161 | parser.add_argument("-teacher_forcing", action='store_true', help='') 162 | parser.add_argument("-beam_size", type=float, default=50, help='') 163 | parser.add_argument('-tensorboard_dir', type=str, default='', help='') 164 | parser.add_argument('-logfile', type=str, default='train_log.log', help='') 165 | parser.add_argument('-save_model_step', type=int, default=5000, help='') 166 | parser.add_argument('-early_stop_tolerance', type=int, default=100, help='') 167 | parser.add_argument('-train_parallel', action='store_true', help='') 168 | parser.add_argument('-schedule_lr', action='store_true', help='') 169 | parser.add_argument('-schedule_step', type=int, default=10000, help='') 170 | parser.add_argument('-schedule_gamma', type=float, default=0.1, help='') 171 | parser.add_argument('-processed', action='store_true', help='') 172 | parser.add_argument('-prefetch', action='store_true', help='') 173 | parser.add_argument('-backend', type=str, default='tf', help='') 174 | parser.add_argument('-lazy_loading', action='store_true', help='') 175 | parser.add_argument('-fix_batch_size', action='store_true', help='') 176 | 177 | # model specific parameter 178 | parser.add_argument("-embed_dim", type=int, default=200, help='') 179 | parser.add_argument("-max_oov_count", type=int, default=100, help='') 180 | parser.add_argument("-max_src_len", type=int, default=1500, help='') 181 | parser.add_argument("-max_target_len", type=int, default=8, help='') 182 | parser.add_argument("-encoder_hidden_size", type=int, default=100, help='') 183 | parser.add_argument("-decoder_hidden_size", type=int, default=100, help='') 184 | parser.add_argument('-src_num_layers', type=int, default=1, help='') 185 | parser.add_argument('-target_num_layers', type=int, default=1, help='') 186 | parser.add_argument("-attention_mode", type=str, default='general', 187 | choices=['general', 'dot', 'concat'], help='') 188 | parser.add_argument("-bidirectional", action='store_true', help='') 189 | parser.add_argument("-copy_net", action='store_true', help='') 190 | parser.add_argument("-input_feeding", action='store_true', help='') 191 | 192 | args = parser.parse_args(args) 193 | return args 194 | 195 | 196 | if __name__ == '__main__': 197 | CopyRnnTrainerTF().train() 198 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /deep_keyphrase/copy_transformer/beam_search.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | from deep_keyphrase.dataloader import (OOV_LIST, TOKENS, EOS_WORD, TOKENS_OOV) 4 | 5 | 6 | class TransformerBeamSearch(object): 7 | def __init__(self, model, beam_size, max_target_len, id2vocab, bos_idx, args): 8 | self.model = model 9 | self.beam_size = beam_size 10 | self.id2vocab = id2vocab 11 | self.max_target_len = max_target_len 12 | self.bos_idx = bos_idx 13 | self.target_hidden_size = args.target_hidden_size 14 | self.bidirectional = args.bidirectional 15 | self.input_dim = args.input_dim 16 | 17 | def beam_search(self, src_dict, delimiter=None): 18 | batch_size = len(src_dict[TOKENS]) 19 | beam_batch_size = batch_size * self.beam_size 20 | encoder_output = encoder_mask = None 21 | prev_copy_state = None 22 | prev_decoder_state = torch.zeros(batch_size, ) 23 | prev_output_tokens = torch.tensor([[self.bos_idx]] * batch_size) 24 | 25 | output = self.model(src_dict=src_dict, 26 | prev_output_tokens=prev_output_tokens, 27 | encoder_output=encoder_output, 28 | encoder_mask=encoder_mask, 29 | prev_decoder_state=prev_decoder_state, 30 | position=0, 31 | prev_copy_state=prev_copy_state) 32 | decoder_prob, prev_decoder_state, prev_copy_state, encoder_output, encoder_mask = output 33 | prev_decoder_state = self.beam_repeat(prev_decoder_state) 34 | prev_copy_state = self.beam_repeat(prev_copy_state) 35 | encoder_output = self.beam_repeat(encoder_output) 36 | encoder_mask = self.beam_repeat(encoder_mask) 37 | src_dict[TOKENS] = self.beam_repeat(src_dict[TOKENS]) 38 | src_dict[TOKENS_OOV] = self.beam_repeat(src_dict[TOKENS_OOV]) 39 | prev_best_probs, prev_best_index = torch.topk(decoder_prob, self.beam_size, 1) 40 | beam_search_best_probs = torch.abs(prev_best_probs) 41 | result_sequences = prev_best_index.unsqueeze(2) 42 | 43 | for target_idx in range(1, self.max_target_len): 44 | output = self.model(src_dict=src_dict, 45 | prev_output_tokens=prev_output_tokens, 46 | encoder_output=encoder_output, 47 | encoder_mask=encoder_mask, 48 | prev_decoder_state=prev_decoder_state, 49 | position=target_idx, 50 | prev_copy_state=prev_copy_state) 51 | decoder_prob, decoder_state, copy_state, encoder_output, encoder_mask = output 52 | accumulated_probs = beam_search_best_probs.view(beam_batch_size, -1) 53 | accumulated_probs = accumulated_probs.repeat(1, decoder_prob.size(1)) 54 | accumulated_probs += torch.abs(decoder_prob) 55 | accumulated_probs = accumulated_probs.view(batch_size, -1) 56 | top_token_probs, top_token_index = torch.topk(-accumulated_probs, self.beam_size, 1) 57 | beam_search_best_probs = -top_token_probs 58 | 59 | select_idx_factor = torch.tensor(range(batch_size)) * self.beam_size 60 | select_idx_factor = select_idx_factor.unsqueeze(1).repeat(1, self.beam_size) 61 | if torch.cuda.is_available(): 62 | select_idx_factor = select_idx_factor.cuda() 63 | state_select_idx = top_token_index.flatten() // decoder_prob.size(1) 64 | state_select_idx += select_idx_factor.flatten() 65 | 66 | prev_decoder_state = prev_decoder_state.index_select(0, state_select_idx) 67 | prev_copy_state = prev_copy_state.index_select(0, state_select_idx) 68 | prev_output_tokens = prev_output_tokens.index_select(0, state_select_idx) 69 | 70 | prev_best_index = top_token_index % decoder_prob.size(1) 71 | result_sequences = result_sequences.view(beam_batch_size, -1) 72 | result_sequences = result_sequences.index_select(0, state_select_idx) 73 | result_sequences = result_sequences.view(batch_size, self.beam_size, -1) 74 | 75 | result_sequences = torch.cat([result_sequences, prev_best_index.unsqueeze(2)], dim=2) 76 | prev_best_index = prev_best_index.view(beam_batch_size, -1) 77 | prev_output_tokens = torch.cat([prev_output_tokens, prev_best_index.unsqueeze(2)], dim=2) 78 | result = self.__idx2result_beam(delimiter, src_dict[OOV_LIST], result_sequences.tolist()) 79 | return result 80 | 81 | def beam_repeat(self, t): 82 | size = list(t.size()) 83 | size[0] *= self.beam_size 84 | repeat_size = [1] * len(size) 85 | repeat_size[1] *= self.beam_size 86 | t = t.unsqueese(1).repeat(repeat_size) 87 | t = t.reshape(size) 88 | return t 89 | 90 | def greedy_search(self, src_dict, delimiter=None): 91 | batch_size = len(src_dict[TOKENS]) 92 | encoder_output = encoder_mask = None 93 | prev_copy_state = None 94 | prev_decoder_state = torch.zeros(batch_size, self.input_dim) 95 | prev_output_tokens = torch.tensor([[self.bos_idx]] * batch_size) 96 | result_matrix = None 97 | for target_idx in range(self.max_target_len): 98 | output = self.model(src_dict=src_dict, 99 | prev_output_tokens=prev_output_tokens, 100 | encoder_output=encoder_output, 101 | encoder_mask=encoder_mask, 102 | prev_decoder_state=prev_decoder_state, 103 | position=target_idx, 104 | prev_copy_state=prev_copy_state) 105 | total_prob, prev_decoder_state, prev_copy_state, encoder_output, encoder_mask = output 106 | prev_output_tokens = total_prob.topk(k=1, dim=1).clone() 107 | if result_matrix is None: 108 | result_matrix = prev_output_tokens 109 | else: 110 | result_matrix = torch.cat([result_matrix, prev_output_tokens], dim=1) 111 | result = self.__idx2result_beam(delimiter, src_dict[OOV_LIST], result_matrix.tolist()) 112 | return result 113 | 114 | def __idx2result_beam(self, delimiter, oov_list, result_sequences): 115 | results = [] 116 | for batch in result_sequences: 117 | beam_list = [] 118 | for beam in batch: 119 | phrase = [] 120 | for idx in beam: 121 | if self.id2vocab[idx] == EOS_WORD: 122 | break 123 | if idx in self.id2vocab: 124 | phrase.append(self.id2vocab[idx]) 125 | else: 126 | phrase.append(oov_list[idx - len(self.id2vocab)]) 127 | 128 | if delimiter is not None: 129 | phrase = delimiter.join(phrase) 130 | if phrase not in beam_list: 131 | beam_list.append(phrase) 132 | results.append(beam_list) 133 | return results 134 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_transformer/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn.modules.transformer import (TransformerEncoder, TransformerDecoder, 6 | TransformerEncoderLayer, TransformerDecoderLayer) 7 | from deep_keyphrase.dataloader import (TOKENS, TOKENS_LENS, TOKENS_OOV, 8 | UNK_WORD, PAD_WORD, OOV_COUNT, TARGET) 9 | 10 | 11 | def get_position_encoding(input_tensor): 12 | batch_size, position, dim_size = input_tensor.size() 13 | assert dim_size % 2 == 0 14 | num_timescales = dim_size // 2 15 | time_scales = torch.arange(0, position + 1, dtype=torch.float).unsqueeze(1) 16 | dim_scales = torch.arange(0, num_timescales, dtype=torch.float).unsqueeze(0) 17 | dim_val = torch.pow(1.0e4, 2 * dim_scales / dim_size) 18 | matrix = torch.matmul(time_scales, 1.0 / dim_val) 19 | position_embed = torch.cat([torch.sin(matrix), torch.cos(matrix)], dim=1).repeat(batch_size, 1, 1) 20 | 21 | if torch.cuda.is_available(): 22 | position_embed = position_embed.cuda() 23 | 24 | return position_embed 25 | 26 | 27 | class CopyTransformer(nn.Module): 28 | def __init__(self, args, vocab2id): 29 | super().__init__() 30 | embedding = nn.Embedding(len(vocab2id), args.input_dim, vocab2id[PAD_WORD]) 31 | self.encoder = CopyTransformerEncoder(embedding=embedding, args=args) 32 | self.decoder = CopyTransformerDecoder(embedding=embedding, vocab2id=vocab2id, args=args) 33 | 34 | def forward(self, src_dict, prev_output_tokens, encoder_output, encoder_mask, 35 | prev_decoder_state, position, prev_copy_state): 36 | if torch.cuda.is_available(): 37 | src_dict[TOKENS] = src_dict[TOKENS].cuda() 38 | src_dict[TOKENS_LENS] = src_dict[TOKENS_LENS].cuda() 39 | src_dict[TOKENS_OOV] = src_dict[TOKENS_OOV].cuda() 40 | src_dict[OOV_COUNT] = src_dict[OOV_COUNT].cuda() 41 | if prev_output_tokens is not None: 42 | prev_output_tokens = prev_output_tokens.cuda() 43 | prev_decoder_state = prev_decoder_state.cuda() 44 | if encoder_output is None: 45 | encoder_output, encoder_mask = self.encoder(src_dict=src_dict) 46 | output = self.decoder(prev_output_tokens=prev_output_tokens, 47 | prev_decoder_state=prev_decoder_state, 48 | position=position, 49 | encoder_output=encoder_output, 50 | encoder_mask=encoder_mask, 51 | src_dict=src_dict, 52 | prev_copy_state=prev_copy_state) 53 | return output 54 | 55 | 56 | class CopyTransformerEncoder(nn.Module): 57 | def __init__(self, embedding, args): 58 | super().__init__() 59 | self.embedding = embedding 60 | self.input_dim = args.input_dim 61 | self.head_size = args.head_size, 62 | self.feed_forward_dim = args.feed_forward_dim 63 | self.num_layzers = args.num_layers 64 | self.dropout = args.dropout 65 | layer = TransformerEncoderLayer(d_model=self.input_dim, 66 | nhead=self.head_size, 67 | dim_feedforward=self.feed_forward_dim, 68 | dropout=self.dropout) 69 | self.encoder = TransformerEncoder(encoder_layer=layer, num_layers=self.num_layers) 70 | 71 | def forward(self, src_dict): 72 | batch_size, max_len = src_dict[TOKENS].size() 73 | mask_range = torch.arange(max_len).unsqueeze(0).repeat(batch_size, 1) 74 | 75 | if torch.cuda.is_available(): 76 | mask_range = mask_range.cuda() 77 | mask = mask_range >= src_dict[TOKENS_LENS] 78 | # mask = (mask_range > src_dict[TOKENS_LENS].unsqueeze(1)).expand(batch_size, max_len, max_len) 79 | src_embed = self.embedding(src_dict[TOKENS]).transpose(1, 0) 80 | pos_embed = get_position_encoding(src_embed) 81 | src_embed = src_embed + pos_embed 82 | src_embed = F.dropout(src_embed, p=self.dropout, training=self.training) 83 | output = self.encoder(src_embed, src_key_padding_mask=mask).transpose(1, 0) 84 | return output, mask 85 | 86 | 87 | class CopyTransformerDecoder(nn.Module): 88 | def __init__(self, embedding, vocab2id, args): 89 | super().__init__() 90 | self.embedding = embedding 91 | self.vocab2id = vocab2id 92 | self.args = args 93 | self.input_dim = args.input_dim 94 | self.head_size = args.target_head_size 95 | self.feed_forward_dim = args.feed_forward_dim 96 | self.dropout = args.target_dropout 97 | self.num_layers = args.target_layers 98 | self.target_max_len = args.max_target_len 99 | self.max_oov_count = args.max_oov_count 100 | self.vocab_size = embedding.num_embeddings 101 | 102 | layer = TransformerDecoderLayer(d_model=self.input_dim, 103 | nhead=self.head_size, 104 | dim_feedforward=self.feed_forward_dim, 105 | dropout=self.dropout) 106 | self.decoder = TransformerDecoder(decoder_layer=layer, num_layers=self.num_layers) 107 | self.input_copy_proj = nn.Linear(self.input_dim, self.input_dim, bias=False) 108 | self.copy_proj = nn.Linear(self.input_dim, self.input_dim, bias=False) 109 | self.embed_proj = nn.Linear(2 * self.input_dim, self.input_dim, bias=False) 110 | self.generate_proj = nn.Linear(self.input_dim, self.vocab_size, bias=False) 111 | 112 | def forward(self, prev_output_tokens, prev_decoder_state, position, 113 | encoder_output, encoder_mask, src_dict, prev_copy_state): 114 | if self.args.auto_regressive and not self.training: 115 | output = self.forward_auto_regressive(prev_output_tokens, prev_decoder_state, position, 116 | encoder_output, encoder_mask, src_dict, prev_copy_state) 117 | else: 118 | output = self.forward_one_pass(encoder_output, encoder_mask, src_dict) 119 | return output 120 | 121 | def forward_transformer(self, encoder_output, encoder_mask, src_dict): 122 | token_embed = self.embedding(src_dict[TARGET][:, :-1]) 123 | pos_embed = get_position_encoding(token_embed) 124 | # B x seq_len x H 125 | src_embed = token_embed + pos_embed 126 | decoder_input = F.dropout(src_embed, p=self.dropout, training=self.training) 127 | decoder_input_mask = torch.triu(torch.ones(self.input_dim, self.input_dim), 1) 128 | decoder_output = self.decoder(tgt=decoder_input, 129 | memory=encoder_output.transpose(1, 0), 130 | memory_key_padding_mask=decoder_input_mask) 131 | probs = torch.softmax(decoder_output, dim=-1) 132 | return probs, decoder_output.squeeze(1), None, encoder_output, encoder_mask 133 | 134 | def forward_one_pass(self, encoder_output, encoder_mask, src_dict): 135 | batch_size = len(src_dict[TOKENS]) 136 | token_embed = self.embedding(src_dict[TARGET][:, :-1]) 137 | src_tokens_with_oov = src_dict[TOKENS_OOV] 138 | pos_embed = get_position_encoding(token_embed) 139 | # B x seq_len x H 140 | src_embed = token_embed + pos_embed 141 | decoder_input = F.dropout(src_embed, p=self.dropout, training=self.training) 142 | decoder_input_mask = torch.triu(torch.ones(self.input_dim, self.input_dim), 1) 143 | decoder_output = self.decoder(tgt=decoder_input, 144 | memory=encoder_output.transpose(1, 0), 145 | memory_key_padding_mask=decoder_input_mask) 146 | # B x seq_len x H 147 | decoder_output = decoder_output.transpose(1, 0) 148 | generation_logits = torch.exp(self.generate_proj(decoder_output).squeeze(1)) 149 | generation_oov_logits = torch.zeros(batch_size, self.max_oov_count) 150 | if torch.cuda.is_available(): 151 | generation_oov_logits = generation_oov_logits.cuda() 152 | generation_logits = torch.cat([generation_logits, generation_oov_logits], dim=1) 153 | copy_logits = self.get_copy_score(encoder_output, 154 | src_tokens_with_oov, 155 | decoder_output, 156 | encoder_mask) 157 | total_logit = generation_logits + copy_logits 158 | total_prob = total_logit / torch.sum(total_logit, 1).unsqueeze(1) 159 | total_prob = torch.log(total_prob) 160 | 161 | return total_prob, decoder_output.squeeze(1), None, encoder_output, encoder_mask 162 | 163 | def forward_auto_regressive(self, prev_output_tokens, prev_decoder_state, position, 164 | encoder_output, encoder_mask, src_dict, prev_copy_state): 165 | src_tokens = src_dict[TOKENS] 166 | src_tokens_with_oov = src_dict[TOKENS_OOV] 167 | batch_size, src_max_len = src_tokens.size() 168 | prev_output_tokens = torch.as_tensor(prev_output_tokens, dtype=torch.int64) 169 | if torch.cuda.is_available(): 170 | prev_output_tokens = prev_output_tokens.cuda() 171 | copy_state = self.get_attn_read_input(encoder_output, 172 | prev_decoder_state, 173 | prev_output_tokens[:, -1:], 174 | src_tokens_with_oov, 175 | src_max_len, 176 | prev_copy_state) 177 | # map copied oov tokens to OOV idx to avoid embedding lookup error 178 | prev_output_tokens[prev_output_tokens >= self.vocab_size] = self.vocab2id[UNK_WORD] 179 | token_embed = self.embedding(prev_output_tokens) 180 | 181 | pos_embed = get_position_encoding(token_embed) 182 | # B x seq_len x H 183 | src_embed = token_embed + pos_embed 184 | decoder_input = self.embed_proj(torch.cat([src_embed, copy_state], dim=2)).transpose(1, 0) 185 | decoder_input = F.dropout(decoder_input, p=self.dropout, training=self.training) 186 | decoder_input_mask = torch.triu(torch.ones(self.input_dim, self.input_dim), 1) 187 | # B x seq_len x H 188 | decoder_output = self.decoder(tgt=decoder_input, 189 | memory=encoder_output.transpose(1, 0), 190 | memory_key_padding_mask=decoder_input_mask) 191 | decoder_output = decoder_output.transpose(1, 0) 192 | 193 | # B x 1 x H 194 | decoder_output = decoder_output[:, -1:, :] 195 | generation_logits = self.generate_proj(decoder_output).squeeze(1) 196 | generation_oov_logits = torch.zeros(batch_size, self.max_oov_count) 197 | if torch.cuda.is_available(): 198 | generation_oov_logits = generation_oov_logits.cuda() 199 | generation_logits = torch.cat([generation_logits, generation_oov_logits], dim=1) 200 | copy_logits = self.get_copy_score(encoder_output, 201 | src_tokens_with_oov, 202 | decoder_output, 203 | encoder_mask) 204 | total_logit = torch.exp(generation_logits) + copy_logits 205 | total_prob = total_logit / torch.sum(total_logit, 1).unsqueeze(1) 206 | total_prob = torch.log(total_prob) 207 | return total_prob, decoder_output.squeeze(1), copy_state, encoder_output, encoder_mask 208 | 209 | def get_attn_read_input(self, encoder_output, prev_context_state, 210 | prev_output_tokens, src_tokens_with_oov, 211 | src_max_len, prev_copy_state): 212 | """ 213 | build CopyNet decoder input of "attentive read" part. 214 | :param encoder_output: 215 | :param prev_context_state: 216 | :param prev_output_tokens: 217 | :param src_tokens_with_oov: 218 | :return: 219 | """ 220 | # mask : B x SL x 1 221 | mask_bool = torch.eq(prev_output_tokens.repeat(1, src_max_len), src_tokens_with_oov).unsqueeze(2) 222 | mask = mask_bool.type_as(encoder_output) 223 | # B x SL x H 224 | aggregate_weight = torch.tanh(self.input_copy_proj(torch.mul(mask, encoder_output))) 225 | # when all prev_tokens are not in src_tokens, don't execute mask -inf to avoid nan result in softmax 226 | no_zero_mask = ((mask != 0).sum(dim=1) != 0).repeat(1, src_max_len).unsqueeze(2) 227 | input_copy_logit_mask = no_zero_mask * mask_bool 228 | input_copy_logit = torch.bmm(aggregate_weight, prev_context_state.unsqueeze(2)) 229 | input_copy_logit.masked_fill_(input_copy_logit_mask, float('-inf')) 230 | input_copy_weight = torch.softmax(input_copy_logit.squeeze(2), 1) 231 | # B x 1 x H 232 | copy_state = torch.bmm(input_copy_weight.unsqueeze(1), encoder_output) 233 | if prev_copy_state is not None: 234 | copy_state = torch.cat([prev_copy_state, copy_state], dim=1) 235 | return copy_state 236 | 237 | def get_copy_score(self, encoder_out, src_tokens_with_oov, decoder_output, encoder_output_mask): 238 | """ 239 | 240 | :param encoder_out: B x L x SH 241 | :param src_tokens_with_oov: B x L 242 | :param decoder_output: B x dec_len x TH 243 | :param encoder_output_mask: B x L 244 | :return: B x dec_len x V 245 | """ 246 | # copy_score: B x L 247 | dec_len = decoder_output.size(1) 248 | batch_size = len(encoder_out) 249 | # copy_score: B x L x dec_len 250 | copy_score_in_seq = torch.bmm(torch.tanh(self.copy_proj(encoder_out)), 251 | decoder_output.permute(0, 2, 1)) 252 | copy_score_mask = encoder_output_mask.unsqueeze(2).repeat(1, 1, dec_len) 253 | copy_score_in_seq.masked_fill_(copy_score_mask, float('-inf')) 254 | copy_score_in_seq = torch.exp(copy_score_in_seq) 255 | total_vocab_size = self.vocab_size + self.max_oov_count 256 | copy_score_in_vocab = torch.zeros(batch_size, total_vocab_size, dec_len) 257 | if torch.cuda.is_available(): 258 | copy_score_in_vocab = copy_score_in_vocab.cuda() 259 | token_ids = src_tokens_with_oov.unsqueeze(2).repeat(1, 1, dec_len) 260 | copy_score_in_vocab.scatter_add_(1, token_ids, copy_score_in_seq) 261 | copy_score_in_vocab = copy_score_in_vocab.permute(0, 2, 1) 262 | 263 | return copy_score_in_vocab 264 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_transformer/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import torch 4 | from pysenal import append_jsonlines 5 | from munch import Munch 6 | from deep_keyphrase.base_predictor import BasePredictor 7 | from deep_keyphrase.dataloader import KeyphraseDataLoader, TOKENS, RAW_BATCH 8 | from deep_keyphrase.utils.constants import BOS_WORD 9 | from deep_keyphrase.utils.vocab_loader import load_vocab 10 | from deep_keyphrase.utils.tokenizer import token_char_tokenize 11 | from .model import CopyTransformer 12 | from .beam_search import TransformerBeamSearch 13 | 14 | 15 | class CopyTransformerPredictor(BasePredictor): 16 | def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_length): 17 | super().__init__(model_info) 18 | if isinstance(vocab_info, str): 19 | self.vocab2id = load_vocab(vocab_info) 20 | elif isinstance(vocab_info, dict): 21 | self.vocab2id = vocab_info 22 | else: 23 | raise ValueError('vocab info type error') 24 | self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys())) 25 | self.config = self.load_config(model_info) 26 | self.model = self.load_model(model_info, CopyTransformer(self.config, self.vocab2id)) 27 | self.model.eval() 28 | self.beam_size = beam_size 29 | self.max_target_len = max_target_len 30 | self.max_src_len = max_src_length 31 | self.beam_searcher = TransformerBeamSearch(model=self.model, 32 | beam_size=self.beam_size, 33 | max_target_len=self.max_target_len, 34 | id2vocab=self.id2vocab, 35 | bos_idx=self.vocab2id[BOS_WORD], 36 | args=self.config) 37 | self.pred_base_config = {'max_oov_count': self.config.max_oov_count, 38 | 'max_src_len': self.max_src_len, 39 | 'max_target_len': self.max_target_len} 40 | 41 | def predict(self, text_list, batch_size, delimiter=None): 42 | self.model.eval() 43 | if len(text_list) < batch_size: 44 | batch_size = len(text_list) 45 | args = Munch({'batch_size': batch_size, **self.pred_base_config}) 46 | text_list = [{TOKENS: token_char_tokenize(i)} for i in text_list] 47 | loader = KeyphraseDataLoader(data_source=text_list, 48 | vocab2id=self.vocab2id, 49 | args=args, 50 | mode='inference') 51 | result = [] 52 | for batch in loader: 53 | with torch.no_grad(): 54 | result.extend(self.beam_searcher.beam_search(batch, delimiter=delimiter)) 55 | return result 56 | 57 | def eval_predict(self, src_filename, dest_filename, args, 58 | model=None, remove_existed=False): 59 | args_dict = vars(args) 60 | args_dict['batch_size'] = args_dict['eval_batch_size'] 61 | args = Munch(args_dict) 62 | loader = KeyphraseDataLoader(data_source=src_filename, 63 | vocab2id=self.vocab2id, 64 | mode='inference', 65 | args=args) 66 | if os.path.exists(dest_filename): 67 | print('destination filename {} existed'.format(dest_filename)) 68 | if remove_existed: 69 | os.remove(dest_filename) 70 | if model is not None: 71 | model.eval() 72 | self.beam_searcher = TransformerBeamSearch(model=model, 73 | beam_size=self.beam_size, 74 | max_target_len=self.max_target_len, 75 | id2vocab=self.id2vocab, 76 | bos_idx=self.vocab2id[BOS_WORD], 77 | args=self.config) 78 | 79 | for batch in loader: 80 | with torch.no_grad(): 81 | batch_result = self.beam_searcher.beam_search(batch, delimiter=None) 82 | final_result = [] 83 | assert len(batch_result) == len(batch[RAW_BATCH]) 84 | for item_input, item_output in zip(batch[RAW_BATCH], batch_result): 85 | item_input['pred_keyphrases'] = item_output 86 | final_result.append(item_input) 87 | append_jsonlines(dest_filename, final_result) 88 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_transformer/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import argparse 3 | import torch 4 | from pysenal import write_json 5 | from deep_keyphrase.base_trainer import BaseTrainer 6 | from deep_keyphrase.copy_transformer.model import CopyTransformer 7 | from deep_keyphrase.copy_transformer.predict import CopyTransformerPredictor 8 | from deep_keyphrase.dataloader import (TARGET, TOKENS) 9 | from deep_keyphrase.utils.vocab_loader import load_vocab 10 | 11 | 12 | class CopyTransformerTrainer(BaseTrainer): 13 | def __init__(self): 14 | args = self.parse_args() 15 | vocab2id = load_vocab(args.vocab_path, vocab_size=args.vocab_size) 16 | model = CopyTransformer(args, vocab2id) 17 | super().__init__(args, model) 18 | 19 | def train_batch(self, batch, step): 20 | torch.autograd.set_detect_anomaly(True) 21 | loss = 0 22 | self.optimizer.zero_grad() 23 | targets = batch[TARGET] 24 | if torch.cuda.is_available(): 25 | targets = targets.cuda() 26 | if self.args.auto_regressive: 27 | loss = self.get_auto_regressive_loss(batch, loss, targets) 28 | else: 29 | loss = self.get_one_pass_loss(batch, targets) 30 | loss.backward() 31 | self.optimizer.step() 32 | # torch.cuda.empty_cache() 33 | return loss 34 | 35 | def get_one_pass_loss(self, batch, targets): 36 | batch_size = len(batch[TOKENS]) 37 | encoder_output = encoder_mask = None 38 | prev_copy_state = None 39 | prev_decoder_state = torch.zeros(batch_size, self.args.input_dim) 40 | output = self.model(src_dict=batch, 41 | prev_output_tokens=None, 42 | encoder_output=encoder_output, 43 | encoder_mask=encoder_mask, 44 | prev_decoder_state=prev_decoder_state, 45 | position=0, 46 | prev_copy_state=prev_copy_state) 47 | decoder_prob, prev_decoder_state, prev_copy_state, encoder_output, encoder_mask = output 48 | vocab_size = decoder_prob.size(-1) 49 | decoder_prob = decoder_prob.view(-1, vocab_size) 50 | loss = self.loss_func(decoder_prob, targets[:, 1:].flatten()) 51 | return loss 52 | 53 | def get_auto_regressive_loss(self, batch, loss, targets): 54 | batch_size = len(batch[TOKENS]) 55 | encoder_output = encoder_mask = None 56 | prev_copy_state = None 57 | prev_decoder_state = torch.zeros(batch_size, self.args.input_dim) 58 | for target_index in range(self.args.max_target_len): 59 | prev_output_tokens = targets[:, :target_index + 1].clone() 60 | true_indices = targets[:, target_index + 1].clone() 61 | output = self.model(src_dict=batch, 62 | prev_output_tokens=prev_output_tokens, 63 | encoder_output=encoder_output, 64 | encoder_mask=encoder_mask, 65 | prev_decoder_state=prev_decoder_state, 66 | position=target_index, 67 | prev_copy_state=prev_copy_state) 68 | probs, prev_decoder_state, prev_copy_state, encoder_output, encoder_mask = output 69 | loss += self.loss_func(probs, true_indices) 70 | loss /= self.args.max_target_len 71 | return loss 72 | 73 | def evaluate(self, step): 74 | predictor = CopyTransformerPredictor(model_info={'model': self.model, 'config': self.args}, 75 | vocab_info=self.vocab2id, 76 | beam_size=self.args.beam_size, 77 | max_target_len=self.args.max_target_len, 78 | max_src_length=self.args.max_src_len) 79 | 80 | def pred_callback(stage): 81 | if stage == 'valid': 82 | src_filename = self.args.valid_filename 83 | dest_filename = self.dest_dir + self.get_basename(self.args.valid_filename) 84 | elif stage == 'test': 85 | src_filename = self.args.test_filename 86 | dest_filename = self.dest_dir + self.get_basename(self.args.test_filename) 87 | else: 88 | raise ValueError('stage name error, must be in `valid` and `test`') 89 | 90 | def predict_func(): 91 | predictor.eval_predict(src_filename=src_filename, 92 | dest_filename=dest_filename, 93 | args=self.args, 94 | model=self.model, 95 | remove_existed=True) 96 | 97 | return predict_func 98 | 99 | valid_statistics = self.evaluate_stage(step, 'valid', pred_callback('valid')) 100 | test_statistics = self.evaluate_stage(step, 'test', pred_callback('test')) 101 | total_statistics = {**valid_statistics, **test_statistics} 102 | 103 | eval_filename = self.dest_dir + self.args.exp_name + '.batch_{}.eval.json'.format(step) 104 | write_json(eval_filename, total_statistics) 105 | return valid_statistics['valid_macro'][self.eval_topn[-1]]['f1'] 106 | 107 | def parse_args(self, args=None): 108 | parser = argparse.ArgumentParser() 109 | # train and evaluation parameter 110 | parser.add_argument("-exp_name", required=True, type=str, help='') 111 | parser.add_argument("-train_filename", required=True, type=str, help='') 112 | parser.add_argument("-valid_filename", required=True, type=str, help='') 113 | parser.add_argument("-test_filename", required=True, type=str, help='') 114 | parser.add_argument("-dest_base_dir", required=True, type=str, help='') 115 | parser.add_argument("-vocab_path", required=True, type=str, help='') 116 | parser.add_argument("-vocab_size", type=int, default=500000, help='') 117 | parser.add_argument("-epochs", type=int, default=10, help='') 118 | parser.add_argument("-batch_size", type=int, default=12, help='') 119 | parser.add_argument("-learning_rate", type=float, default=1e-4, help='') 120 | parser.add_argument("-eval_batch_size", type=int, default=1, help='') 121 | parser.add_argument("-dropout", type=float, default=0.0, help='') 122 | parser.add_argument("-grad_norm", type=float, default=0.0, help='') 123 | parser.add_argument("-max_grad", type=float, default=5.0, help='') 124 | parser.add_argument("-shuffle_in_batch", action='store_true', help='') 125 | parser.add_argument("-teacher_forcing", action='store_true', help='') 126 | parser.add_argument("-beam_size", type=float, default=50, help='') 127 | parser.add_argument('-tensorboard_dir', type=str, default='', help='') 128 | parser.add_argument('-logfile', type=str, default='train_log.log', help='') 129 | parser.add_argument('-save_model_step', type=int, default=5000, help='') 130 | parser.add_argument('-early_stop_tolerance', type=int, default=50, help='') 131 | parser.add_argument('-train_parallel', action='store_true', help='') 132 | parser.add_argument('-auto_regressive', action='store_true', help='') 133 | 134 | # model specific parameter 135 | parser.add_argument("-input_dim", type=int, default=256, help='') 136 | parser.add_argument("-src_head_size", type=int, default=4, help='') 137 | parser.add_argument("-target_head_size", type=int, default=4, help='') 138 | parser.add_argument("-feed_forward_dim", type=int, default=1024, help='') 139 | parser.add_argument("-src_dropout", type=int, default=0.1, help='') 140 | parser.add_argument("-target_dropout", type=int, default=0.1, help='') 141 | parser.add_argument("-src_layers", type=int, default=6, help='') 142 | parser.add_argument("-target_layers", type=int, default=6, help='') 143 | parser.add_argument("-max_src_len", type=int, default=1000, help='') 144 | parser.add_argument("-max_target_len", type=int, default=8, help='') 145 | parser.add_argument("-max_oov_count", type=int, default=100, help='') 146 | parser.add_argument("-copy_net", action='store_true', help='') 147 | parser.add_argument("-input_feedding", action='store_true', help='') 148 | 149 | args = parser.parse_args() 150 | return args 151 | 152 | 153 | if __name__ == '__main__': 154 | CopyTransformerTrainer().train() 155 | -------------------------------------------------------------------------------- /deep_keyphrase/copy_transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class Transformer(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self): 11 | torch.max() 12 | pass 13 | 14 | 15 | class TransformerEncoder(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self): 20 | 21 | pass 22 | 23 | 24 | class TransformerDecoder(nn.Module): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | def forward(self): 29 | pass 30 | 31 | 32 | class TransformerEncoderLayer(nn.Module): 33 | def __init__(self): 34 | super().__init__() 35 | 36 | def forward(self): 37 | pass 38 | 39 | 40 | class TransformerDecoderLayer(nn.Module): 41 | def __init__(self): 42 | super().__init__() 43 | 44 | def forward(self): 45 | pass 46 | -------------------------------------------------------------------------------- /deep_keyphrase/data_process/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /deep_keyphrase/data_process/preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import os 3 | import re 4 | import argparse 5 | import string 6 | from collections import Counter 7 | from itertools import chain 8 | from multiprocessing import Pool 9 | from nltk.tokenize import word_tokenize 10 | from nltk.stem import PorterStemmer 11 | from pysenal import get_chunk, read_jsonline_lazy, append_jsonlines, write_lines 12 | from deep_keyphrase.utils.constants import (PAD_WORD, UNK_WORD, DIGIT_WORD, 13 | BOS_WORD, EOS_WORD, SEP_WORD) 14 | 15 | 16 | class Kp20kPreprocessor(object): 17 | """ 18 | kp20k data preprocessor, build the data and vocab for training. 19 | 20 | """ 21 | num_and_punc_regex = re.compile(r'[_\-—<>{,(?\\.\'%]|\d+([.]\d+)?', re.IGNORECASE) 22 | num_regex = re.compile(r'\d+([.]\d+)?') 23 | 24 | def __init__(self, args): 25 | self.src_filename = args.src_filename 26 | self.dest_filename = args.dest_filename 27 | self.dest_vocab_path = args.dest_vocab_path 28 | self.vocab_size = args.vocab_size 29 | self.parallel_count = args.parallel_count 30 | self.is_src_lower = args.src_lower 31 | self.is_src_stem = args.src_stem 32 | self.is_target_lower = args.target_lower 33 | self.is_target_stem = args.target_stem 34 | self.stemmer = PorterStemmer() 35 | if os.path.exists(self.dest_filename): 36 | print('destination file existed, will be deleted!!!') 37 | os.remove(self.dest_filename) 38 | 39 | def build_vocab(self, tokens): 40 | vocab = [PAD_WORD, UNK_WORD, BOS_WORD, EOS_WORD, DIGIT_WORD, SEP_WORD] 41 | vocab.extend(list(string.digits)) 42 | 43 | token_counter = Counter(tokens).most_common(self.vocab_size) 44 | for token, count in token_counter: 45 | vocab.append(token) 46 | if len(vocab) >= self.vocab_size: 47 | break 48 | return vocab 49 | 50 | def process(self): 51 | pool = Pool(self.parallel_count) 52 | tokens = [] 53 | chunk_size = 100 54 | for item_chunk in get_chunk(read_jsonline_lazy(self.src_filename), chunk_size): 55 | processed_records = pool.map(self.tokenize_record, item_chunk) 56 | if self.dest_vocab_path: 57 | for record in processed_records: 58 | tokens.extend(record['title_and_abstract_tokens'] + record['flatten_keyword_tokens']) 59 | for record in processed_records: 60 | record.pop('flatten_keyword_tokens') 61 | append_jsonlines(self.dest_filename, processed_records) 62 | if self.dest_vocab_path: 63 | vocab = self.build_vocab(tokens) 64 | write_lines(self.dest_vocab_path, vocab) 65 | 66 | def tokenize_record(self, record): 67 | abstract_tokens = self.tokenize(record['abstract'], self.is_src_lower, self.is_src_stem) 68 | title_tokens = self.tokenize(record['title'], self.is_src_lower, self.is_src_stem) 69 | keyword_token_list = [] 70 | for keyword in record['keyword'].split(';'): 71 | keyword_token_list.append(self.tokenize(keyword, self.is_target_lower, self.is_target_stem)) 72 | result = { 73 | # 'title_tokens': title_tokens, 'abstract_tokens': abstract_tokens, 74 | 'title_and_abstract_tokens': title_tokens + abstract_tokens, 75 | 'keyword_tokens': keyword_token_list, 76 | 'flatten_keyword_tokens': list(chain(*keyword_token_list)) 77 | } 78 | return result 79 | 80 | def tokenize(self, text, is_lower, is_stem): 81 | text = self.num_and_punc_regex.sub(r' \g<0> ', text) 82 | tokens = word_tokenize(text) 83 | if is_lower: 84 | tokens = [token.lower() for token in tokens] 85 | if is_stem: 86 | tokens = [self.stemmer.stem(token) for token in tokens] 87 | for idx, token in enumerate(tokens): 88 | token = tokens[idx] 89 | if self.num_regex.fullmatch(token): 90 | tokens[idx] = DIGIT_WORD 91 | return tokens 92 | 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('-src_filename', type=str, required=True, 97 | help='input source kp20k file path') 98 | parser.add_argument('-dest_filename', type=str, required=True, 99 | help='destination of processed file path') 100 | parser.add_argument('-dest_vocab_path', type=str, 101 | help='') 102 | parser.add_argument('-vocab_size', type=int, default=50000, 103 | help='') 104 | parser.add_argument('-parallel_count', type=int, default=10) 105 | parser.add_argument('-src_lower', action='store_true') 106 | parser.add_argument('-src_stem', action='store_true') 107 | parser.add_argument('-target_lower', action='store_true') 108 | parser.add_argument('-target_stem', action='store_true') 109 | 110 | args = parser.parse_args() 111 | processor = Kp20kPreprocessor(args) 112 | processor.process() 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /deep_keyphrase/dataloader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import random 3 | import traceback 4 | import sys 5 | import numpy as np 6 | import multiprocessing 7 | from pysenal import read_jsonline_lazy, get_chunk, read_jsonline 8 | from deep_keyphrase.utils.constants import * 9 | 10 | TOKENS = 'tokens' 11 | TOKENS_LENS = 'tokens_len' 12 | TOKENS_OOV = 'tokens_with_oov' 13 | OOV_COUNT = 'oov_count' 14 | OOV_LIST = 'oov_list' 15 | TARGET_LIST = 'targets' 16 | TARGET = 'target' 17 | RAW_BATCH = 'raw' 18 | 19 | TRAIN_MODE = 'train' 20 | EVAL_MODE = 'eval' 21 | INFERENCE_MODE = 'inference' 22 | 23 | 24 | class ExceptionWrapper(object): 25 | """ 26 | Wraps an exception plus traceback to communicate across threads 27 | """ 28 | 29 | def __init__(self, exc_info): 30 | self.exc_type = exc_info[0] 31 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 32 | 33 | 34 | class KeyphraseDataLoader(object): 35 | def __init__(self, data_source, vocab2id, mode, args): 36 | self.data_source = data_source 37 | self.vocab2id = vocab2id 38 | self.vocab_size = len(self.vocab2id) 39 | self.args = args 40 | self.batch_size = args.batch_size 41 | self.max_src_len = args.max_src_len 42 | self.max_oov_count = args.max_oov_count 43 | self.max_target_len = args.max_target_len 44 | self.mode = mode 45 | self.fix_batch_size = args.fix_batch_size 46 | self.prefetch = args.prefetch 47 | self.lazy_loading = args.lazy_loading 48 | self.shuffle = args.shuffle 49 | self.token_field = args.token_field 50 | self.keyphrases_field = args.keyphrase_field 51 | 52 | self.is_inference = mode != TRAIN_MODE 53 | 54 | def __iter__(self): 55 | return iter(KeyphraseDataIterator(self)) 56 | 57 | def collate_fn(self, item): 58 | tokens = item[self.token_field] 59 | if len(tokens) > self.max_src_len: 60 | tokens = tokens[:self.max_src_len] 61 | token_ids_with_oov = [] 62 | token_ids = [] 63 | oov_list = [] 64 | 65 | for token in tokens: 66 | idx = self.vocab2id.get(token, self.vocab_size) 67 | if idx == self.vocab_size: 68 | token_ids.append(self.vocab2id[UNK_WORD]) 69 | if token not in oov_list: 70 | if len(oov_list) >= self.max_oov_count: 71 | token_ids_with_oov.append(self.vocab_size + self.max_oov_count - 1) 72 | else: 73 | token_ids_with_oov.append(self.vocab_size + len(oov_list)) 74 | oov_list.append(token) 75 | else: 76 | token_ids_with_oov.append(self.vocab_size + oov_list.index(token)) 77 | else: 78 | token_ids.append(idx) 79 | token_ids_with_oov.append(idx) 80 | 81 | final_item = {TOKENS: token_ids, 82 | TOKENS_OOV: token_ids_with_oov, 83 | OOV_COUNT: len(oov_list), 84 | OOV_LIST: oov_list} 85 | 86 | if self.is_inference: 87 | final_item[RAW_BATCH] = item 88 | else: 89 | keyphrase = item['phrase'] 90 | target_ids = [self.vocab2id[BOS_WORD]] 91 | for token in keyphrase: 92 | target_ids.append(self.vocab2id.get(token, self.vocab2id[UNK_WORD])) 93 | target_ids.append(self.vocab2id[EOS_WORD]) 94 | final_item[TARGET] = target_ids[:self.max_target_len] 95 | return final_item 96 | 97 | 98 | class KeyphraseDataIterator(object): 99 | def __init__(self, loader): 100 | self.loader = loader 101 | self.data_source = loader.data_source 102 | self.batch_size = loader.batch_size 103 | self.token_field = loader.token_field 104 | self.keyphrases_field = loader.keyphrases_field 105 | self.lazy_loading = loader.lazy_loading 106 | self.backend = loader.args.backend 107 | self.fix_batch_size = loader.fix_batch_size 108 | self.num_workers = multiprocessing.cpu_count() // 2 or 1 109 | 110 | if self.loader.mode == TRAIN_MODE: 111 | self.chunk_size = self.batch_size * 5 112 | else: 113 | self.chunk_size = self.batch_size 114 | self._data = self.load_data(self.chunk_size) 115 | self._batch_count_in_output_queue = 0 116 | self._redundant_batch = [] 117 | self.workers = [] 118 | self.worker_shutdown = False 119 | 120 | if self.loader.mode in {TRAIN_MODE, EVAL_MODE}: 121 | self.input_queue = multiprocessing.Queue(-1) 122 | self.output_queue = multiprocessing.Queue(-1) 123 | self.__prefetch() 124 | for _ in range(self.num_workers): 125 | worker = multiprocessing.Process(target=self._data_worker_loop) 126 | self.workers.append(worker) 127 | for worker in self.workers: 128 | worker.daemon = True 129 | worker.start() 130 | 131 | def __iter__(self): 132 | if self.loader.mode == TRAIN_MODE: 133 | yield from self.iter_train() 134 | elif self.loader.mode == EVAL_MODE: 135 | yield from self.iter_inference_parallel() 136 | else: 137 | yield from self.iter_inference() 138 | 139 | def load_data(self, chunk_size): 140 | if isinstance(self.data_source, str): 141 | if self.lazy_loading: 142 | data = read_jsonline_lazy(self.data_source) 143 | else: 144 | data = read_jsonline(self.data_source) 145 | if self.loader.shuffle: 146 | random.shuffle(data) 147 | random.shuffle(data) 148 | random.shuffle(data) 149 | elif isinstance(self.data_source, list): 150 | data = iter(self.data_source) 151 | else: 152 | raise TypeError('input filename type is error') 153 | return get_chunk(data, chunk_size) 154 | 155 | def _data_worker_loop(self): 156 | while True: 157 | raw_batch = self.input_queue.get() 158 | 159 | # exit signal 160 | if raw_batch is None: 161 | break 162 | try: 163 | batch = self.padding_batch(raw_batch) 164 | 165 | self.output_queue.put(batch) 166 | 167 | except Exception as e: 168 | self.output_queue.put(ExceptionWrapper(sys.exc_info())) 169 | 170 | def padding_batch(self, raw_batch): 171 | max_src_len = self.loader.max_src_len 172 | max_target_len = self.loader.max_target_len 173 | pad_id = self.loader.vocab2id[PAD_WORD] 174 | token_ids_list = [] 175 | token_len_list = [] 176 | token_oov_ids_list = [] 177 | oov_len_list = [] 178 | oov_list = [] 179 | raw_item_list = [] 180 | target_ids_list = [] 181 | 182 | for raw_item in raw_batch: 183 | if not self.loader.args.processed: 184 | item = self.loader.collate_fn(raw_item) 185 | else: 186 | item = raw_item 187 | token_len = len(item[TOKENS]) 188 | token_len_list.append(token_len) 189 | token_ids = item[TOKENS] + [pad_id] * (max_src_len - token_len) 190 | token_ids_list.append(token_ids) 191 | token_oov_ids = item[TOKENS_OOV] + [pad_id] * (max_src_len - token_len) 192 | token_oov_ids_list.append(token_oov_ids) 193 | oov_len_list.append(item[OOV_COUNT]) 194 | oov_list.append(item[OOV_LIST]) 195 | if self.loader.is_inference: 196 | raw_item_list.append(raw_item) 197 | else: 198 | target_ids = item[TARGET] + [pad_id] * (max_target_len + 1 - len(item[TARGET])) 199 | target_ids_list.append(target_ids) 200 | token_ids_np = np.array(token_ids_list, dtype=np.long) 201 | token_len_np = np.array(token_len_list, dtype=np.long) 202 | token_oov_np = np.array(token_oov_ids_list, dtype=np.long) 203 | oov_len_np = np.array(oov_len_list, dtype=np.long) 204 | batch = {TOKENS: token_ids_np, 205 | TOKENS_OOV: token_oov_np, 206 | TOKENS_LENS: token_len_np, 207 | OOV_COUNT: oov_len_np, 208 | OOV_LIST: oov_list} 209 | if not self.loader.is_inference: 210 | batch[TARGET] = np.array(target_ids_list, dtype=np.long) 211 | else: 212 | batch[RAW_BATCH] = raw_item_list 213 | 214 | return batch 215 | 216 | def __prefetch(self): 217 | if self.loader.mode == TRAIN_MODE: 218 | item_chunk = next(self._data) 219 | if self.loader.shuffle: 220 | random.shuffle(item_chunk) 221 | batches, redundant_batch = self.get_batches(item_chunk, []) 222 | self._redundant_batch = redundant_batch 223 | for batch in batches: 224 | self.input_queue.put(batch) 225 | self._batch_count_in_output_queue += 1 226 | else: 227 | for _ in range(self.num_workers): 228 | try: 229 | item_chunk = next(self._data) 230 | except StopIteration: 231 | break 232 | self.input_queue.put(item_chunk) 233 | self._batch_count_in_output_queue += 1 234 | 235 | def iter_train(self): 236 | redundant_batch = self._redundant_batch 237 | for item_chunk in self._data: 238 | if self.loader.shuffle: 239 | random.shuffle(item_chunk) 240 | batches, redundant_batch = self.get_batches(item_chunk, redundant_batch) 241 | batch_idx = 0 242 | for idx in range(self._batch_count_in_output_queue): 243 | if batch_idx < len(batches): 244 | self.input_queue.put(batches[batch_idx]) 245 | batch_idx += 1 246 | yield self.batch2tensor(self.output_queue.get()) 247 | 248 | if batch_idx < len(batches): 249 | for batch in batches[batch_idx:]: 250 | self.input_queue.put(batch) 251 | 252 | self._batch_count_in_output_queue = len(batches) 253 | if redundant_batch: 254 | self.input_queue.put(redundant_batch) 255 | self._batch_count_in_output_queue += 1 256 | 257 | if self._batch_count_in_output_queue: 258 | for idx in range(self._batch_count_in_output_queue): 259 | yield self.batch2tensor(self.output_queue.get()) 260 | 261 | def get_batches(self, item_chunk, batch): 262 | if self.loader.args.processed: 263 | return self.get_batches_processed(item_chunk, batch) 264 | else: 265 | return self.get_batches_raw(item_chunk, batch) 266 | 267 | def get_batches_processed(self, item_chunk, batch): 268 | batches = [] 269 | for new_batch in get_chunk(batch + item_chunk, self.batch_size): 270 | batches.append(new_batch) 271 | return batches, [] 272 | 273 | def get_batches_raw(self, item_chunk, batch): 274 | batches = [] 275 | 276 | for item in item_chunk: 277 | if self.fix_batch_size: 278 | if batch and len(batch) > self.batch_size: 279 | tail_count = len(batch) % self.batch_size 280 | if tail_count: 281 | batch_chunk = batch[:-tail_count] 282 | batch = batch[-tail_count:] 283 | else: 284 | batch_chunk = batch 285 | batch = [] 286 | for sliced_batch in get_chunk(batch_chunk, self.batch_size): 287 | batches.append(sliced_batch) 288 | 289 | flatten_items = self.flatten_raw_item(item) 290 | batch.extend(flatten_items) 291 | else: 292 | if batch and len(batch) > self.batch_size: 293 | for sliced_batch in get_chunk(batch, self.batch_size): 294 | batches.append(sliced_batch) 295 | batch = [] 296 | flatten_items = self.flatten_raw_item(item) 297 | if batch and len(batch) + len(flatten_items) > self.batch_size: 298 | batches.append(batch) 299 | batch = flatten_items 300 | else: 301 | batch.extend(flatten_items) 302 | # batches = self.reorder_batch_list(batches) 303 | return batches, batch 304 | 305 | def reorder_batch(self, batch): 306 | seq_idx_and_len = [(idx, len(item[TOKENS])) for idx, item in enumerate(batch)] 307 | seq_idx_and_len = sorted(seq_idx_and_len, key=lambda i: i[1], reverse=True) 308 | batch = [batch[idx] for idx, _ in seq_idx_and_len] 309 | return batch 310 | 311 | def reorder_batch_list(self, batches): 312 | """ 313 | for onnx format compatibility 314 | :param batches: 315 | :return: 316 | """ 317 | new_batches = [] 318 | for batch in batches: 319 | new_batches.append(self.reorder_batch(batch)) 320 | return new_batches 321 | 322 | def flatten_raw_item(self, item): 323 | flatten_items = [] 324 | for phrase in item[self.loader.keyphrases_field]: 325 | flatten_items.append({self.token_field: item[self.token_field], 'phrase': phrase}) 326 | return flatten_items 327 | 328 | def iter_inference_parallel(self): 329 | assert not self._redundant_batch 330 | assert self.workers 331 | for item_chunk in self._data: 332 | if self._batch_count_in_output_queue > 0: 333 | yield self.batch2tensor(self.output_queue.get()) 334 | self.input_queue.put(item_chunk) 335 | else: 336 | self.input_queue.put(item_chunk) 337 | yield self.batch2tensor(self.output_queue.get()) 338 | if self._batch_count_in_output_queue: 339 | for _ in range(self._batch_count_in_output_queue): 340 | yield self.batch2tensor(self.output_queue.get()) 341 | 342 | def iter_inference(self): 343 | assert not self._batch_count_in_output_queue 344 | assert not self._redundant_batch 345 | for item_chunk in self._data: 346 | yield self.batch2tensor(self.padding_batch(item_chunk)) 347 | 348 | def batch2tensor(self, batch): 349 | new_batch = {} 350 | for key, val in batch.items(): 351 | if isinstance(val, np.ndarray): 352 | if self.backend == 'torch': 353 | import torch 354 | new_batch[key] = torch.as_tensor(val) 355 | elif self.backend == 'tf': 356 | import tensorflow as tf 357 | new_batch[key] = tf.constant(val) 358 | else: 359 | new_batch[key] = val 360 | return new_batch 361 | 362 | def _shutdown_workers(self): 363 | if not self.workers: 364 | return 365 | 366 | self.input_queue.close() 367 | self.output_queue.close() 368 | 369 | for worker in self.workers: 370 | worker.terminate() 371 | self.workers = [] 372 | 373 | def __del__(self): 374 | if self.workers: 375 | self._shutdown_workers() 376 | -------------------------------------------------------------------------------- /deep_keyphrase/evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import copy 3 | from operator import itemgetter 4 | from collections import OrderedDict 5 | from pysenal import read_jsonline, read_json 6 | 7 | 8 | class KeyphraseEvaluator(object): 9 | def __init__(self, top_n, 10 | metrics_mode, 11 | token_field='tokens', 12 | true_keyphrase_field='keyphrases', 13 | pred_keyphrase_field='pred_keyphrases'): 14 | self.top_n = top_n 15 | self.metrics_mode = metrics_mode 16 | self.token_field = token_field 17 | self.true_keyphrase_field = true_keyphrase_field 18 | self.pred_keyphrase_field = pred_keyphrase_field 19 | 20 | def __load_data(self, input_data): 21 | if isinstance(input_data, str): 22 | if input_data.endswith('.json'): 23 | data_source = read_json(input_data) 24 | elif input_data.endswith('.jsonl'): 25 | data_source = read_jsonline(input_data) 26 | else: 27 | raise ValueError('input file type is not supported, only support .json and .jsonl') 28 | elif isinstance(input_data, list): 29 | data_source = copy.deepcopy(input_data) 30 | else: 31 | raise TypeError('input data type error. only accept str (path) and list.') 32 | return data_source 33 | 34 | def evaluate(self, input_data, eval_mode='all', ): 35 | data_source = self.__load_data(input_data) 36 | if self.metrics_mode == 'micro': 37 | result = self.evaluate_micro_average(data_source, eval_mode) 38 | elif self.metrics_mode == 'macro': 39 | result = self.evaluate_macro_average(data_source, eval_mode) 40 | else: 41 | raise ValueError('evaluation mode is error.') 42 | return result 43 | 44 | def evaluate_micro_average(self, data_source, eval_mode): 45 | self.__check_eval_mode(eval_mode) 46 | top_counter = OrderedDict() 47 | for n in self.top_n: 48 | true_positive_count = 0 49 | pred_count = 0 50 | true_count = 0 51 | for record in data_source: 52 | tokens = record[self.token_field] 53 | true_positive_phrase_list = [] 54 | pred_phrase_list_topn = self.filter_phrase(record[self.pred_keyphrase_field], 55 | eval_mode, 56 | tokens, 57 | n) 58 | true_phrase_list = self.filter_phrase(record[self.true_keyphrase_field], eval_mode, tokens) 59 | for predict_phrase in pred_phrase_list_topn: 60 | if predict_phrase in true_phrase_list: 61 | true_positive_phrase_list.append(predict_phrase) 62 | true_positive_count += len(true_positive_phrase_list) 63 | pred_count += len(pred_phrase_list_topn) 64 | true_count += len(true_phrase_list) 65 | if not pred_count: 66 | prec = 0 67 | else: 68 | prec = true_positive_count / pred_count 69 | if not true_count: 70 | recall = 0 71 | else: 72 | recall = true_positive_count / true_count 73 | if prec + recall == 0: 74 | f1 = 0 75 | else: 76 | f1 = 2 * prec * recall / (prec + recall) 77 | top_counter[n] = {'precision': prec, 78 | 'recall': recall, 79 | 'f1': f1} 80 | return top_counter 81 | 82 | def evaluate_macro_average(self, data_source, eval_mode): 83 | self.__check_eval_mode(eval_mode) 84 | top_counter = OrderedDict() 85 | for n in self.top_n: 86 | counter = [] 87 | for record in data_source: 88 | tokens = record[self.token_field] 89 | true_positive_topn_phrase_list = [] 90 | pred_phrase_list_topn = self.filter_phrase(record[self.pred_keyphrase_field], 91 | eval_mode, 92 | record[self.token_field], 93 | n) 94 | true_phrase_list = self.filter_phrase(record[self.true_keyphrase_field], eval_mode, tokens) 95 | for predict_phrase in pred_phrase_list_topn: 96 | if predict_phrase in true_phrase_list: 97 | true_positive_topn_phrase_list.append(predict_phrase) 98 | if not pred_phrase_list_topn: 99 | p = 0 100 | else: 101 | p = len(true_positive_topn_phrase_list) / len(pred_phrase_list_topn) 102 | 103 | if not true_phrase_list: 104 | continue 105 | else: 106 | r = len(true_positive_topn_phrase_list) / len(true_phrase_list) 107 | 108 | if p + r == 0: 109 | f1 = 0 110 | else: 111 | f1 = 2 * p * r / (p + r) 112 | counter.append({'true_positive': len(true_positive_topn_phrase_list), 113 | 'pred_count': len(pred_phrase_list_topn), 114 | 'true_count': len(true_phrase_list), 115 | 'precision': p, 116 | 'recall': r, 117 | 'f1': f1}) 118 | top_counter[n] = {'precision': sum(map(itemgetter('precision'), counter)) / len(counter), 119 | 'recall': sum(map(itemgetter('recall'), counter)) / len(counter), 120 | 'f1': sum(map(itemgetter('f1'), counter)) / len(counter)} 121 | return top_counter 122 | 123 | def filter_phrase(self, phrase_list, mode, input_tokens, top_n=None): 124 | input_text = ' '.join(input_tokens) 125 | filtered_phrase_list = [] 126 | 127 | for phrase in phrase_list: 128 | phrase_text = ' '.join(phrase) 129 | if mode == 'all': 130 | filtered_phrase_list.append(phrase) 131 | elif mode == 'present' and phrase_text in input_text: 132 | filtered_phrase_list.append(phrase) 133 | elif mode == 'absent' and phrase_text not in input_text: 134 | filtered_phrase_list.append(phrase) 135 | 136 | if top_n is not None: 137 | filtered_phrase_list = filtered_phrase_list[:top_n] 138 | return filtered_phrase_list 139 | 140 | def __check_eval_mode(self, eval_mode): 141 | if eval_mode not in {'all', 'present', 'absent'}: 142 | raise ValueError('evaluation mode must be in `all`, `present` and `absent`') 143 | -------------------------------------------------------------------------------- /deep_keyphrase/predict_runner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import argparse 3 | from munch import Munch 4 | from pysenal import get_chunk, read_jsonline_lazy 5 | from deep_keyphrase.copy_rnn.predict import CopyRnnPredictor 6 | from deep_keyphrase.dataloader import KeyphraseDataLoader 7 | from deep_keyphrase.utils.vocab_loader import load_vocab 8 | 9 | 10 | class PredictRunner(object): 11 | def __init__(self): 12 | self.args = self.parse_args() 13 | self.predictor = CopyRnnPredictor(model_info=self.args.model_path, 14 | vocab_info=self.args.vocab_path, 15 | beam_size=self.args.beam_size, 16 | max_src_length=self.args.max_src_len, 17 | max_target_len=self.args.max_target_len) 18 | 19 | self.config = {**self.predictor.config, 'batch_size': self.args.batch_size} 20 | self.config = Munch(self.config) 21 | 22 | def predict(self): 23 | # vocab2id = load_vocab(self.args.vocab_size, vocab_size=self.config.vocab_size) 24 | # loader = KeyphraseDataLoader(self.args.src_filename, 25 | # vocab2id=vocab2id, 26 | # mode='inference', args=self.config) 27 | # for batch in loader: 28 | # for batch in loader: 29 | self.predictor.eval_predict(self.args.src_filename, self.args.dest_filename, 30 | args=self.config) 31 | # chunk_size = 32 | # for item_chunk in get_chunk(read_jsonline_lazy()) 33 | # pass 34 | 35 | def parse_args(self): 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('-src_filename', type=str, help='') 38 | parser.add_argument('-mode_path', type=str, help='') 39 | parser.add_argument('-vocab_path', type=str, help='') 40 | parser.add_argument('-batch_size', type=int, default=10, help='') 41 | parser.add_argument('-beam_size', type=int, default=200, help='') 42 | parser.add_argument('-max_src_len', type=int, default=1500, help='') 43 | parser.add_argument('-max_target_len', type=int, default=8, help='') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | if __name__ == '__main__': 49 | PredictRunner().predict() 50 | -------------------------------------------------------------------------------- /deep_keyphrase/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /deep_keyphrase/utils/constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | # !! please don't change the value of following variable 3 | PAD_WORD = '' 4 | UNK_WORD = '' 5 | BOS_WORD = '' 6 | EOS_WORD = '' 7 | DIGIT_WORD = '' 8 | SEP_WORD = '' 9 | -------------------------------------------------------------------------------- /deep_keyphrase/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import re 3 | from .constants import DIGIT_WORD 4 | 5 | num_regex = re.compile(r'\d+([.]\d+)?') 6 | 7 | # '1.' => ['', '.'], '1.1' => [''] 8 | char_regex = re.compile(r'[_\-—<>{,(?\\.\'%]|\d+([.]\d+)?', re.IGNORECASE) 9 | 10 | 11 | def token_char_tokenize(text): 12 | text = char_regex.sub(r' \g<0> ', text) 13 | tokens = num_regex.sub(DIGIT_WORD, text).split() 14 | chars = [] 15 | for token in tokens: 16 | if token == DIGIT_WORD: 17 | chars.append(token) 18 | else: 19 | chars.extend(list(token)) 20 | return chars 21 | -------------------------------------------------------------------------------- /deep_keyphrase/utils/vocab_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from pysenal import read_lines_lazy 3 | from .constants import * 4 | 5 | 6 | def load_vocab(src_filename, vocab_size=None): 7 | vocab2id = {} 8 | for word in read_lines_lazy(src_filename): 9 | if word not in vocab2id: 10 | vocab2id[word] = len(vocab2id) 11 | 12 | if vocab_size and len(vocab2id) >= vocab_size: 13 | break 14 | if PAD_WORD not in vocab2id: 15 | raise ValueError('padding char is not in vocab') 16 | if UNK_WORD not in vocab2id: 17 | raise ValueError('unk char is not in vocab') 18 | if BOS_WORD not in vocab2id: 19 | raise ValueError('begin of sentence char is not in vocab') 20 | if EOS_WORD not in vocab2id: 21 | raise ValueError('end of sentence char is not in vocab') 22 | # if DIGIT_WORD not in vocab2id: 23 | # raise ValueError('digit char is not in vocab') 24 | # if SEP_WORD not in vocab2id: 25 | # raise ValueError('separator char is not in vocab') 26 | return vocab2id 27 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/deep-keyphrase.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/deep-keyphrase.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $HOME/.local/share/devhelp/deep-keyphrase" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $HOME/.local/share/devhelp/deep-keyphrase" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /docs/_static/.gitignore: -------------------------------------------------------------------------------- 1 | # Empty directory 2 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. _authors: 2 | .. include:: ../AUTHORS.rst 3 | -------------------------------------------------------------------------------- /docs/changes.rst: -------------------------------------------------------------------------------- 1 | .. _changes: 2 | .. include:: ../CHANGES.rst 3 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # This file is execfile()d with the current directory set to its containing dir. 4 | # 5 | # Note that not all possible configuration values are present in this 6 | # autogenerated file. 7 | # 8 | # All configuration values have a default; values that are commented out 9 | # serve to show the default. 10 | 11 | import sys 12 | 13 | # If extensions (or modules to document with autodoc) are in another directory, 14 | # add these directories to sys.path here. If the directory is relative to the 15 | # documentation root, use os.path.abspath to make it absolute, like shown here. 16 | # sys.path.insert(0, os.path.abspath('.')) 17 | 18 | # -- Hack for ReadTheDocs ------------------------------------------------------ 19 | # This hack is necessary since RTD does not issue `sphinx-apidoc` before running 20 | # `sphinx-build -b html . _build/html`. See Issue: 21 | # https://github.com/rtfd/readthedocs.org/issues/1139 22 | # DON'T FORGET: Check the box "Install your project inside a virtualenv using 23 | # setup.py install" in the RTD Advanced Settings. 24 | import os 25 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 26 | if on_rtd: 27 | import inspect 28 | from sphinx import apidoc 29 | 30 | __location__ = os.path.join(os.getcwd(), os.path.dirname( 31 | inspect.getfile(inspect.currentframe()))) 32 | 33 | output_dir = os.path.join(__location__, "../docs/api") 34 | module_dir = os.path.join(__location__, "../deep_keyphrase") 35 | cmd_line_template = "sphinx-apidoc -f -o {outputdir} {moduledir}" 36 | cmd_line = cmd_line_template.format(outputdir=output_dir, moduledir=module_dir) 37 | apidoc.main(cmd_line.split(" ")) 38 | 39 | # -- General configuration ----------------------------------------------------- 40 | 41 | # If your documentation needs a minimal Sphinx version, state it here. 42 | # needs_sphinx = '1.0' 43 | 44 | # Add any Sphinx extension module names here, as strings. They can be extensions 45 | # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. 46 | extensions = ['sphinx.ext.autodoc', 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 47 | 'sphinx.ext.autosummary', 'sphinx.ext.viewcode', 'sphinx.ext.coverage', 48 | 'sphinx.ext.doctest', 'sphinx.ext.ifconfig', 'sphinx.ext.mathjax', 49 | 'sphinx.ext.napoleon'] 50 | 51 | # Add any paths that contain templates here, relative to this directory. 52 | templates_path = ['_templates'] 53 | 54 | # The suffix of source filenames. 55 | source_suffix = '.rst' 56 | 57 | # The encoding of source files. 58 | # source_encoding = 'utf-8-sig' 59 | 60 | # The master toctree document. 61 | master_doc = 'index' 62 | 63 | # General information about the project. 64 | project = u'deep-keyphrase' 65 | copyright = u'2019, xiayubin' 66 | 67 | # The version info for the project you're documenting, acts as replacement for 68 | # |version| and |release|, also used in various other places throughout the 69 | # built documents. 70 | # 71 | # The short X.Y version. 72 | version = '' # Is set by calling `setup.py docs` 73 | # The full version, including alpha/beta/rc tags. 74 | release = '' # Is set by calling `setup.py docs` 75 | 76 | # The language for content autogenerated by Sphinx. Refer to documentation 77 | # for a list of supported languages. 78 | # language = None 79 | 80 | # There are two options for replacing |today|: either, you set today to some 81 | # non-false value, then it is used: 82 | # today = '' 83 | # Else, today_fmt is used as the format for a strftime call. 84 | # today_fmt = '%B %d, %Y' 85 | 86 | # List of patterns, relative to source directory, that match files and 87 | # directories to ignore when looking for source files. 88 | exclude_patterns = ['_build'] 89 | 90 | # The reST default role (used for this markup: `text`) to use for all documents. 91 | # default_role = None 92 | 93 | # If true, '()' will be appended to :func: etc. cross-reference text. 94 | # add_function_parentheses = True 95 | 96 | # If true, the current module name will be prepended to all description 97 | # unit titles (such as .. function::). 98 | # add_module_names = True 99 | 100 | # If true, sectionauthor and moduleauthor directives will be shown in the 101 | # output. They are ignored by default. 102 | # show_authors = False 103 | 104 | # The name of the Pygments (syntax highlighting) style to use. 105 | pygments_style = 'sphinx' 106 | 107 | # A list of ignored prefixes for module index sorting. 108 | # modindex_common_prefix = [] 109 | 110 | # If true, keep warnings as "system message" paragraphs in the built documents. 111 | # keep_warnings = False 112 | 113 | 114 | # -- Options for HTML output --------------------------------------------------- 115 | 116 | # The theme to use for HTML and HTML Help pages. See the documentation for 117 | # a list of builtin themes. 118 | html_theme = 'alabaster' 119 | 120 | # Theme options are theme-specific and customize the look and feel of a theme 121 | # further. For a list of options available for each theme, see the 122 | # documentation. 123 | # html_theme_options = {} 124 | 125 | # Add any paths that contain custom themes here, relative to this directory. 126 | # html_theme_path = [] 127 | 128 | # The name for this set of Sphinx documents. If None, it defaults to 129 | # " v documentation". 130 | try: 131 | from deep_keyphrase import __version__ as version 132 | except ImportError: 133 | pass 134 | else: 135 | release = version 136 | 137 | # A shorter title for the navigation bar. Default is the same as html_title. 138 | # html_short_title = None 139 | 140 | # The name of an image file (relative to this directory) to place at the top 141 | # of the sidebar. 142 | # html_logo = "" 143 | 144 | # The name of an image file (within the static path) to use as favicon of the 145 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 146 | # pixels large. 147 | # html_favicon = None 148 | 149 | # Add any paths that contain custom static files (such as style sheets) here, 150 | # relative to this directory. They are copied after the builtin static files, 151 | # so a file named "default.css" will overwrite the builtin "default.css". 152 | html_static_path = ['_static'] 153 | 154 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 155 | # using the given strftime format. 156 | # html_last_updated_fmt = '%b %d, %Y' 157 | 158 | # If true, SmartyPants will be used to convert quotes and dashes to 159 | # typographically correct entities. 160 | # html_use_smartypants = True 161 | 162 | # Custom sidebar templates, maps document names to template names. 163 | # html_sidebars = {} 164 | 165 | # Additional templates that should be rendered to pages, maps page names to 166 | # template names. 167 | # html_additional_pages = {} 168 | 169 | # If false, no module index is generated. 170 | # html_domain_indices = True 171 | 172 | # If false, no index is generated. 173 | # html_use_index = True 174 | 175 | # If true, the index is split into individual pages for each letter. 176 | # html_split_index = False 177 | 178 | # If true, links to the reST sources are added to the pages. 179 | # html_show_sourcelink = True 180 | 181 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 182 | # html_show_sphinx = True 183 | 184 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 185 | # html_show_copyright = True 186 | 187 | # If true, an OpenSearch description file will be output, and all pages will 188 | # contain a tag referring to it. The value of this option must be the 189 | # base URL from which the finished HTML is served. 190 | # html_use_opensearch = '' 191 | 192 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 193 | # html_file_suffix = None 194 | 195 | # Output file base name for HTML help builder. 196 | htmlhelp_basename = 'deep_keyphrase-doc' 197 | 198 | 199 | # -- Options for LaTeX output -------------------------------------------------- 200 | 201 | latex_elements = { 202 | # The paper size ('letterpaper' or 'a4paper'). 203 | # 'papersize': 'letterpaper', 204 | 205 | # The font size ('10pt', '11pt' or '12pt'). 206 | # 'pointsize': '10pt', 207 | 208 | # Additional stuff for the LaTeX preamble. 209 | # 'preamble': '', 210 | } 211 | 212 | # Grouping the document tree into LaTeX files. List of tuples 213 | # (source start file, target name, title, author, documentclass [howto/manual]). 214 | latex_documents = [ 215 | ('index', 'user_guide.tex', u'deep-keyphrase Documentation', 216 | u'xiayubin', 'manual'), 217 | ] 218 | 219 | # The name of an image file (relative to this directory) to place at the top of 220 | # the title page. 221 | # latex_logo = "" 222 | 223 | # For "manual" documents, if this is true, then toplevel headings are parts, 224 | # not chapters. 225 | # latex_use_parts = False 226 | 227 | # If true, show page references after internal links. 228 | # latex_show_pagerefs = False 229 | 230 | # If true, show URL addresses after external links. 231 | # latex_show_urls = False 232 | 233 | # Documents to append as an appendix to all manuals. 234 | # latex_appendices = [] 235 | 236 | # If false, no module index is generated. 237 | # latex_domain_indices = True 238 | 239 | # -- External mapping ------------------------------------------------------------ 240 | python_version = '.'.join(map(str, sys.version_info[0:2])) 241 | intersphinx_mapping = { 242 | 'sphinx': ('http://sphinx.pocoo.org', None), 243 | 'python': ('http://docs.python.org/' + python_version, None), 244 | 'matplotlib': ('http://matplotlib.sourceforge.net', None), 245 | 'numpy': ('http://docs.scipy.org/doc/numpy', None), 246 | 'sklearn': ('http://scikit-learn.org/stable', None), 247 | 'pandas': ('http://pandas.pydata.org/pandas-docs/stable', None), 248 | 'scipy': ('http://docs.scipy.org/doc/scipy/reference/', None), 249 | } 250 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | ============== 2 | deep-keyphrase 3 | ============== 4 | 5 | This is the documentation of **deep-keyphrase**. 6 | 7 | .. note:: 8 | 9 | This is the main page of your project's `Sphinx `_ 10 | documentation. It is formatted in `reStructuredText 11 | `__. Add additional pages by creating 12 | rst-files in ``docs`` and adding them to the `toctree 13 | `_ below. Use then 14 | `references `__ in order to link 15 | them from this page, e.g. :ref:`authors ` and :ref:`changes`. 16 | 17 | It is also possible to refer to the documentation of other Python packages 18 | with the `Python domain syntax 19 | `__. By default you 20 | can reference the documentation of `Sphinx `__, 21 | `Python `__, `NumPy 22 | `__, `SciPy 23 | `__, `matplotlib 24 | `__, `Pandas 25 | `__, `Scikit-Learn 26 | `__. You can add more by 27 | extending the ``intersphinx_mapping`` in your Sphinx's ``conf.py``. 28 | 29 | The pretty useful extension `autodoc 30 | `__ is activated by 31 | default and lets you include documentation from docstrings. Docstrings can 32 | be written in `Google 33 | `__ 34 | (recommended!), `NumPy 35 | `__ 36 | and `classical 37 | `__ 38 | style. 39 | 40 | 41 | Contents 42 | ======== 43 | 44 | .. toctree:: 45 | :maxdepth: 2 46 | 47 | License 48 | Authors 49 | Changelog 50 | Module Reference 51 | 52 | 53 | Indices and tables 54 | ================== 55 | 56 | * :ref:`genindex` 57 | * :ref:`modindex` 58 | * :ref:`beam_search` 59 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. _license: 2 | 3 | ======= 4 | License 5 | ======= 6 | 7 | .. literalinclude:: ../LICENSE.txt 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.2.0 2 | pysenal>=0.0.8 3 | tensorboard>=2.0.1 4 | munch==2.5.0 5 | nltk 6 | future -------------------------------------------------------------------------------- /scripts/predict_kp20k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 deep_keyphrase/predict_runner.py -------------------------------------------------------------------------------- /scripts/prepare_kp20k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_PATH=${PWD} 3 | DATA_DIR=$ROOT_PATH/data 4 | SRC_TRAIN=$DATA_DIR/raw/kp20k_new/kp20k_training.json 5 | SRC_VALID=$DATA_DIR/raw/kp20k_new/kp20k_validation.json 6 | SRC_TEST=$DATA_DIR/raw/kp20k_new/kp20k_testing.json 7 | DEST_TRAIN=$DATA_DIR/kp20k.train.jsonl 8 | DEST_VALID=$DATA_DIR/kp20k.valid.jsonl 9 | DEST_TEST=$DATA_DIR/kp20k.test.jsonl 10 | DEST_VOCAB=$DATA_DIR/vocab_kp20k.txt 11 | 12 | python3 deep_keyphrase/data_process/preprocess.py -src_filename $SRC_TRAIN \ 13 | -dest_filename $DEST_TRAIN -dest_vocab_path $DEST_VOCAB -src_lower -target_lower 14 | python3 deep_keyphrase/data_process/preprocess.py -src_filename $SRC_VALID \ 15 | -dest_filename $DEST_VALID -src_lower -target_lower 16 | python3 deep_keyphrase/data_process/preprocess.py -src_filename $SRC_TEST \ 17 | -dest_filename $DEST_TEST -src_lower -target_lower 18 | -------------------------------------------------------------------------------- /scripts/train_copyrnn_kp20k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ROOT_PATH=${PWD} 3 | DATA_DIR=$ROOT_PATH/data 4 | TRAIN_FILENAME=$DATA_DIR/kp20k.train.jsonl 5 | VALID_FILENAME=$DATA_DIR/kp20k.valid.jsonl 6 | TEST_FILENAME=$DATA_DIR/kp20k.test.jsonl 7 | VOCAB_PATH=$DATA_DIR/vocab_kp20k.txt 8 | DEST_DIR=$DATA_DIR/kp20k/ 9 | EXP_NAME=copyrnn_kp20k_basic 10 | 11 | # export CUDA_VISIBLE_DEVICES=1 12 | 13 | python3 deep_keyphrase/copy_rnn/train.py -exp_name $EXP_NAME \ 14 | -train_filename $TRAIN_FILENAME \ 15 | -valid_filename $VALID_FILENAME -test_filename $TEST_FILENAME \ 16 | -batch_size 128 -max_src_len 1500 -learning_rate 1e-3 \ 17 | -token_field title_and_abstract_tokens -keyphrase_field keyword_tokens \ 18 | -vocab_path $VOCAB_PATH -dest_base_dir $DEST_DIR \ 19 | -bidirectional -teacher_forcing -copy_net -shuffle -prefetch \ 20 | -schedule_lr -schedule_step 10000 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = deep-keyphrase 3 | summary = Add a short description here! 4 | author = xiayubin 5 | author-email = supercoderhawk@gmail.com 6 | license = MIT 7 | home-page = https://github.com/supercoderhawk/deep-keyphrase 8 | description-file = README.rst 9 | # Add here all kinds of additional classifiers as defined under 10 | # https://pypi.python.org/pypi?%3Aaction=list_classifiers 11 | classifier = 12 | Development Status :: 4 - Beta 13 | Programming Language :: Python 14 | 15 | [entry_points] 16 | # Add here console scripts like: 17 | # console_scripts = 18 | # script_name = deep_keyphrase.module:function 19 | # For example: 20 | # console_scripts = 21 | # fibonacci = deep_keyphrase.skeleton:run 22 | # as well as other entry_points. 23 | 24 | 25 | [files] 26 | # Add here 'data_files', 'packages' or 'namespace_packages'. 27 | # Additional data files are defined as key value pairs of target directory 28 | # and source location from the root of the repository: 29 | packages = 30 | deep_keyphrase 31 | # data_files = 32 | # share/deep_keyphrase_docs = docs/* 33 | 34 | [extras] 35 | # Add here additional requirements for extra features, like: 36 | # PDF = 37 | # ReportLab>=1.2 38 | # RXP 39 | 40 | [test] 41 | # py.test options when running `python setup.py test` 42 | addopts = tests 43 | 44 | [tool:pytest] 45 | # Options for py.test: 46 | # Specify command line options as you would do when invoking py.test directly. 47 | # e.g. --cov-report html (or xml) for html/xml output or --junitxml junit.xml 48 | # in order to write a coverage file that can be read by Jenkins. 49 | addopts = 50 | --cov deep_keyphrase --cov-report term-missing 51 | --verbose 52 | 53 | [aliases] 54 | docs = build_sphinx 55 | 56 | [bdist_wheel] 57 | # Use this option if your package is pure-python 58 | universal = 1 59 | 60 | [build_sphinx] 61 | source_dir = docs 62 | build_dir = docs/_build 63 | 64 | [pbr] 65 | # Let pbr run sphinx-apidoc 66 | autodoc_tree_index_modules = True 67 | # autodoc_tree_excludes = ... 68 | # Let pbr itself beam_search the apidoc 69 | # autodoc_index_modules = True 70 | # autodoc_exclude_modules = ... 71 | # Convert warnings to errors 72 | # warnerrors = True 73 | 74 | [devpi:upload] 75 | # Options for the devpi: PyPI server and packaging tool 76 | # VCS export must be deactivated since we are using setuptools-scm 77 | no-vcs = 1 78 | formats = bdist_wheel 79 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Setup file for deep_keyphrase. 5 | 6 | This file was generated with PyScaffold 2.5.11, a tool that easily 7 | puts up a scaffold for your new Python project. Learn more under: 8 | http://pyscaffold.readthedocs.org/ 9 | """ 10 | 11 | import sys 12 | from setuptools import setup 13 | 14 | 15 | def setup_package(): 16 | needs_sphinx = {'build_sphinx', 'upload_docs'}.intersection(sys.argv) 17 | sphinx = ['sphinx'] if needs_sphinx else [] 18 | setup(setup_requires=['six', 'pyscaffold>=2.5,<=2.6'] + sphinx, 19 | use_pyscaffold=True) 20 | 21 | 22 | if __name__ == "__main__": 23 | setup_package() 24 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | # Add requirements only needed for your unittests and during development here. 2 | # They will be installed automatically when running `python setup.py test`. 3 | # ATTENTION: Don't remove pytest-cov and pytest as they are needed. 4 | pytest-cov 5 | pytest 6 | -------------------------------------------------------------------------------- /tests/test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- -------------------------------------------------------------------------------- /tests/test_utils/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from deep_keyphrase.utils.tokenizer import token_char_tokenize 3 | 4 | 5 | def test_token_char_tokenize(): 6 | tokens = token_char_tokenize('1.1在10~11个之间。') 7 | assert tokens == ['', '在', '', '~', '', '个', '之', '间', '。'] 8 | 9 | tokens = token_char_tokenize('1.发明内容11.11-11.12') 10 | assert tokens == ['', '.', '发', '明', '内', '容', '', '-', ''] 11 | --------------------------------------------------------------------------------