├── tests ├── fancy_nlp │ ├── callbacks │ │ ├── test_ensemble.py │ │ └── test_metrics.py │ ├── layers │ │ ├── test_attention.py │ │ └── test_matching.py │ ├── utils │ │ ├── test_other.py │ │ ├── test_embedding.py │ │ ├── test_generator.py │ │ ├── test_save_load_model.py │ │ └── test_data_loader.py │ ├── preprocessors │ │ ├── test_ner_preprocessor.py │ │ ├── test_preprocessor.py │ │ └── test_spm_preprocessor.py │ ├── applications │ │ ├── test_ner.py │ │ └── test_spm.py │ └── trainer │ │ ├── test_spm_trainer.py │ │ └── test_ner_trainer.py └── test_sample.py ├── setup.cfg ├── fancy_nlp ├── losses │ ├── __init__.py │ └── crf_losses.py ├── models │ ├── spm │ │ └── __init__.py │ ├── text_classification │ │ ├── __init__.py │ │ ├── base_text_classification_model.py │ │ └── text_classification_models.py │ ├── __init__.py │ ├── ner │ │ ├── __init__.py │ │ └── base_ner_model.py │ └── base_model.py ├── metrics │ ├── __init__.py │ └── crf_accuracies.py ├── callbacks │ ├── __init__.py │ ├── ensemble.py │ └── metrics.py ├── layers │ ├── __init__.py │ ├── other.py │ ├── attention.py │ └── matching.py ├── trainers │ ├── __init__.py │ ├── text_classification_trainer.py │ └── spm_trainer.py ├── applications │ └── __init__.py ├── predictors │ ├── __init__.py │ ├── spm_predictor.py │ └── text_classification_predictor.py ├── __init__.py ├── preprocessors │ ├── __init__.py │ ├── preprocessor.py │ └── text_classification_preprocessor.py ├── config.py └── utils │ ├── __init__.py │ ├── save_load_model.py │ ├── other.py │ ├── data_generator.py │ ├── data_loader.py │ └── embedding.py ├── img ├── text_matching.png ├── entity_extract.png ├── entity_linking.png └── fancy-nlp_300-180_white.jpg ├── data ├── embeddings │ └── bert_sample_model │ │ ├── bert_model.ckpt.index │ │ ├── bert_model.ckpt.meta │ │ ├── vocab.txt │ │ ├── bert_model.ckpt.data-00000-of-00001 │ │ └── bert_config.json ├── spm │ └── webank │ │ └── example.txt └── ner │ └── msra │ └── example.txt ├── pytest.ini ├── .travis.yml ├── README.en.md ├── requirements.txt ├── setup.py ├── examples ├── ner_example.py ├── spm_example.py ├── bert_single.py ├── bert_fine_tuning.py ├── bert_combination.py └── text_classification_example.py └── .gitignore /tests/fancy_nlp/callbacks/test_ensemble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /tests/fancy_nlp/callbacks/test_metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [easy_install] 2 | index_url = https://mirrors.cloud.tencent.com/pypi/simple/ 3 | -------------------------------------------------------------------------------- /fancy_nlp/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .crf_losses import crf_loss 4 | -------------------------------------------------------------------------------- /fancy_nlp/models/spm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .spm_models import * 4 | -------------------------------------------------------------------------------- /img/text_matching.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/img/text_matching.png -------------------------------------------------------------------------------- /img/entity_extract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/img/entity_extract.png -------------------------------------------------------------------------------- /img/entity_linking.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/img/entity_linking.png -------------------------------------------------------------------------------- /fancy_nlp/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .crf_accuracies import crf_accuracy 4 | -------------------------------------------------------------------------------- /img/fancy-nlp_300-180_white.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/img/fancy-nlp_300-180_white.jpg -------------------------------------------------------------------------------- /fancy_nlp/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .ensemble import * 4 | from .metrics import * 5 | -------------------------------------------------------------------------------- /fancy_nlp/models/text_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .text_classification_models import * 4 | -------------------------------------------------------------------------------- /fancy_nlp/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from . import ner 4 | from . import text_classification 5 | from . import spm 6 | -------------------------------------------------------------------------------- /data/embeddings/bert_sample_model/bert_model.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/data/embeddings/bert_sample_model/bert_model.ckpt.index -------------------------------------------------------------------------------- /data/embeddings/bert_sample_model/bert_model.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/data/embeddings/bert_sample_model/bert_model.ckpt.meta -------------------------------------------------------------------------------- /fancy_nlp/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .attention import * 4 | from .matching import * 5 | from .other import * 6 | from .crf import CRF 7 | -------------------------------------------------------------------------------- /tests/test_sample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | def func(x): 5 | return x + 1 6 | 7 | 8 | def test_answer(): 9 | assert func(3) == 4 10 | -------------------------------------------------------------------------------- /data/embeddings/bert_sample_model/vocab.txt: -------------------------------------------------------------------------------- 1 | [PAD] 2 | [UNK] 3 | [CLS] 4 | [SEP] 5 | [MASK] 6 | all 7 | work 8 | and 9 | no 10 | play 11 | makes 12 | jack 13 | a 14 | dull 15 | boy -------------------------------------------------------------------------------- /data/embeddings/bert_sample_model/bert_model.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boat-group/fancy-nlp/HEAD/data/embeddings/bert_sample_model/bert_model.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /fancy_nlp/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .ner_trainer import NERTrainer 4 | from .text_classification_trainer import TextClassificationTrainer 5 | from .spm_trainer import SPMTrainer 6 | -------------------------------------------------------------------------------- /fancy_nlp/applications/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | 5 | from .ner import NER 6 | from .text_classification import TextClassification 7 | from .spm import SPM 8 | -------------------------------------------------------------------------------- /fancy_nlp/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .ner_predictor import NERPredictor 4 | from .text_classification_predictor import TextClassificationPredictor 5 | from .spm_predictor import SPMPredictor 6 | -------------------------------------------------------------------------------- /fancy_nlp/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | import os 5 | os.environ['TF_KERAS'] = '1' 6 | 7 | from . import utils 8 | from . import applications 9 | from . import preprocessors 10 | 11 | __version__ = '1.0.0' 12 | -------------------------------------------------------------------------------- /fancy_nlp/models/ner/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .ner_models import * 4 | 5 | ner_model_dict = { 6 | 'bilstm': BiLSTMNER, 7 | 'bilstm_cnn': BiGRUCNNNER, 8 | 'bigru': BiGRUNER, 9 | 'bigru_cnn': BiGRUCNNNER, 10 | 'bert': BertNER 11 | } 12 | -------------------------------------------------------------------------------- /fancy_nlp/preprocessors/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | 5 | from .ner_preprocessor import NERPreprocessor 6 | from .text_classification_preprocessor import TextClassificationPreprocessor 7 | from .spm_preprocessor import SPMPreprocessor 8 | -------------------------------------------------------------------------------- /data/spm/webank/example.txt: -------------------------------------------------------------------------------- 1 | 为何苹果手机显示微粒贷暂未开放? 我想开通微粒贷,不知我应该做写什么准备材料呢 0 2 | 为什么综合评分不能过? 综合评估为何过不了 1 3 | 我账户有足够存款怎么不会扣钱? 为什么卡的有钱,不自动扣钱呢。 1 4 | 什么时候才会全面开放名额 微信6.2的版本没有微粒贷吗? 0 5 | 借款成功后的钱只能用于日常什么消费 借款的钱能用来干什么? 1 6 | 请帮我查看我刚才借贷500是否成功 如果今天扣款不成功,会不会造成以后借款的问题 0 7 | QQ钱包在那打开 QQ钱包里的微粒贷和微信里的有区别吗? 0 8 | 5点怎么没有扣款, 我一点钟转钱进还款卡里了,现在过5点了,怎么还没自动扣款呢? 1 -------------------------------------------------------------------------------- /fancy_nlp/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | CACHE_DIR = '~/.fancy_cache' 6 | if not os.path.exists(os.path.expanduser(CACHE_DIR)): 7 | os.makedirs(os.path.expanduser(CACHE_DIR)) 8 | 9 | MODEL_STORAGE_PREFIX = \ 10 | 'https://fancy-nlp-1253403094.cos.ap-shanghai.myqcloud.com/pretrained_models/' 11 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore:the imp module is deprecated in favour of importlib:DeprecationWarning 4 | ignore:`scipy.sparse.sparsetools` is deprecated!:DeprecationWarning 5 | ignore:inspect.getargspec\(\) is deprecated:DeprecationWarning 6 | ignore:builtin type EagerTensor has no __module__ attribute:DeprecationWarning 7 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | cache: pip 3 | python: 4 | - "3.6" 5 | install: 6 | - pip install -r requirements.txt 7 | script: 8 | - pip install google-compute-engine 9 | - pip install pytest==5.0.1 10 | - pip install pytest-cov==2.7.1 11 | - pip install coveralls 12 | - pip install -e . 13 | - pytest --cov=fancy_nlp tests/ 14 | after_success: 15 | - coveralls 16 | -------------------------------------------------------------------------------- /fancy_nlp/models/base_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Base model with tf.Keras 4 | """ 5 | 6 | 7 | class BaseModel(object): 8 | def build_input(self): 9 | """We build input and embedding layer for tf.keras model here""" 10 | raise NotImplementedError 11 | 12 | def build_model(self): 13 | """We build tf.keras model here""" 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /tests/fancy_nlp/layers/test_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | from fancy_nlp.layers import MultiHeadAttention 6 | 7 | 8 | class TestAttention: 9 | def test_multihead_attention(self): 10 | input_embed = tf.keras.layers.Input(shape=(3, 300)) 11 | input_encode = MultiHeadAttention()(input_embed) 12 | model = tf.keras.models.Model(input_embed, input_encode) 13 | -------------------------------------------------------------------------------- /README.en.md: -------------------------------------------------------------------------------- 1 | # fancy-nlp 2 | [![Build Status](https://travis-ci.org/boat-group/fancy-nlp.svg?branch=master)](https://travis-ci.org/boat-group/fancy-nlp) 3 | [![Coverage Status](https://coveralls.io/repos/github/boat-group/fancy-nlp/badge.svg?branch=master)](https://coveralls.io/github/boat-group/fancy-nlp?branch=master) 4 | [![Commitizen friendly](https://img.shields.io/badge/commitizen-friendly-brightgreen.svg)](http://commitizen.github.io/cz-cli/) 5 | 6 | A fast and esay-to-use natural language processing (NLP) toolkit, satisfying your imagination about NLP. 7 | -------------------------------------------------------------------------------- /fancy_nlp/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import absolute_import 4 | from .data_loader import load_ner_data_and_labels, load_text_classification_data_and_labels, \ 5 | load_spm_data_and_labels 6 | from .embedding import load_pre_trained, train_w2v, train_fasttext 7 | from .save_load_model import load_keras_model, save_keras_model 8 | from .other import pad_sequences_2d, get_len_from_corpus, get_custom_objects, ChineseBertTokenizer 9 | from .data_generator import NERGenerator, TextClassificationGenerator, SPMGenerator 10 | -------------------------------------------------------------------------------- /data/embeddings/bert_sample_model/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 4, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 16, 9 | "max_position_embeddings": 16, 10 | "num_attention_heads": 4, 11 | "num_hidden_layers": 2, 12 | "pooler_fc_size": 4, 13 | "pooler_num_attention_heads": 4, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 16, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 15 19 | } 20 | -------------------------------------------------------------------------------- /fancy_nlp/metrics/crf_accuracies.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # crf metrics that works with tf2.x 3 | 4 | 5 | import tensorflow as tf 6 | 7 | from ..layers.crf import CRF 8 | 9 | 10 | def crf_accuracy(y_true, y_pred): 11 | """ 12 | Args 13 | y_true: true targets tensor. 14 | y_pred: predictions tensor. 15 | Returns: 16 | scalar. 17 | """ 18 | crf_layer = y_pred._keras_history[0] 19 | 20 | # check if last layer is CRF 21 | if not isinstance(crf_layer, CRF): 22 | raise ValueError( 23 | "Last layer must be CRF for use {}.".format("crf_accuracy")) 24 | 25 | accuracy = crf_layer.get_accuracy(y_true, y_pred) 26 | 27 | return accuracy 28 | -------------------------------------------------------------------------------- /fancy_nlp/layers/other.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | 6 | class NonMaskingLayer(tf.keras.layers.Layer): 7 | """ 8 | Fix convolutional 1D can't receive masked input. 9 | See: https://github.com/keras-team/keras/issues/4978 10 | Thanks for https://github.com/jacoxu 11 | """ 12 | 13 | def __init__(self, **kwargs): 14 | self.supports_masking = True 15 | super(NonMaskingLayer, self).__init__(**kwargs) 16 | 17 | def build(self, input_shape): 18 | pass 19 | 20 | def compute_mask(self, inputs, input_mask=None): 21 | # do not pass the mask to the next layers 22 | return None 23 | 24 | def call(self, x, mask=None): 25 | return x 26 | -------------------------------------------------------------------------------- /fancy_nlp/losses/crf_losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # crf losses that works with tf2.x,originally forked from 3 | # https://github.com/howl-anderson/addons/blob/feature/crf_layers/tensorflow_addons/losses/crf_losses.py 4 | 5 | import tensorflow as tf 6 | 7 | from ..layers.crf import CRF 8 | 9 | 10 | def crf_loss(y_true, y_pred): 11 | """ 12 | Args 13 | y_true: true targets tensor. 14 | y_pred: predictions tensor. 15 | Returns: 16 | scalar. 17 | """ 18 | crf_layer = y_pred._keras_history[0] 19 | 20 | # check if last layer is CRF 21 | if not isinstance(crf_layer, CRF): 22 | raise ValueError( 23 | "Last layer must be CRF for use {}.".format("crf_loss")) 24 | 25 | loss_vector = crf_layer.get_loss(y_true, y_pred) 26 | 27 | return tf.keras.backend.mean(loss_vector) 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astor==0.8.1 3 | boto==2.49.0 4 | boto3==1.11.3 5 | botocore==1.14.3 6 | cachetools==4.0.0 7 | certifi==2019.11.28 8 | chardet==3.0.4 9 | docutils==0.15.2 10 | gast==0.2.2 11 | gensim==3.8.1 12 | google-auth==1.10.1 13 | google-auth-oauthlib==0.4.1 14 | google-pasta==0.1.8 15 | grpcio==1.26.0 16 | h5py==2.10.0 17 | idna==2.8 18 | jieba==0.42 19 | jmespath==0.9.4 20 | joblib==0.14.1 21 | Keras-Applications==1.0.8 22 | Keras-Preprocessing==1.1.0 23 | Markdown==3.1.1 24 | numpy==1.18.1 25 | oauthlib==3.1.0 26 | opt-einsum==3.1.0 27 | protobuf==3.11.2 28 | pyasn1==0.4.8 29 | pyasn1-modules==0.2.8 30 | python-dateutil==2.8.1 31 | requests==2.22.0 32 | requests-oauthlib==1.3.0 33 | rsa==4.0 34 | s3transfer==0.3.0 35 | scikit-learn==0.22.1 36 | scipy==1.4.1 37 | six==1.14.0 38 | smart-open==1.9.0 39 | tensorboard==2.1.0 40 | tensorflow==2.1.0 41 | tensorflow-estimator==2.1.0 42 | termcolor==1.1.0 43 | urllib3==1.25.7 44 | Werkzeug==0.16.0 45 | wrapt==1.11.2 46 | bert4keras==0.4.9 47 | tensorflow-addons==0.8.1 48 | seqeval==0.0.12 49 | -------------------------------------------------------------------------------- /tests/fancy_nlp/layers/test_matching.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from keras.layers import * 4 | from keras.models import Model 5 | 6 | from fancy_nlp.layers import FullMatching, MaxPoolingMatching, AttentiveMatching, \ 7 | MaxAttentiveMatching 8 | 9 | 10 | class TestMatching: 11 | def test_full_matching(self): 12 | input_embed_a = Input(shape=(3, 300)) 13 | input_embed_b = Input(shape=(300,)) 14 | input_encode = FullMatching()([input_embed_a, input_embed_b]) 15 | model = Model([input_embed_a, input_embed_b], input_encode) 16 | 17 | def test_max_pooling_matching(self): 18 | input_embed_a = Input(shape=(3, 300)) 19 | input_embed_b = Input(shape=(3, 300)) 20 | input_encode = MaxPoolingMatching()([input_embed_a, input_embed_b]) 21 | model = Model([input_embed_a, input_embed_b], input_encode) 22 | 23 | def test_attentive_matching(self): 24 | input_embed_a = Input(shape=(3, 300)) 25 | input_embed_b = Input(shape=(3, 300)) 26 | input_encode = AttentiveMatching()([input_embed_a, input_embed_b]) 27 | model = Model([input_embed_a, input_embed_b], input_encode) 28 | 29 | def test_max_attentive_matching(self): 30 | input_embed_a = Input(shape=(3, 300)) 31 | input_embed_b = Input(shape=(3, 300)) 32 | input_encode = MaxAttentiveMatching()([input_embed_a, input_embed_b]) 33 | model = Model([input_embed_a, input_embed_b], input_encode) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import codecs 3 | 4 | from setuptools import setup 5 | from setuptools import find_packages 6 | 7 | long_description = ''' 8 | fancy-nlp is a fast and easy-to-use natural language processing (NLP) toolkit, 9 | satisfying your imagination about NLP. 10 | 11 | fancy-nlp is compatible with Python 3.6 12 | and is distributed under the GPLv3 license. 13 | ''' 14 | 15 | with codecs.open('requirements.txt', 'r', 'utf8') as reader: 16 | install_requires = list(map(lambda x: x.strip(), reader.readlines())) 17 | 18 | setup(name='fancy-nlp', 19 | version='1.0.0', 20 | author='boat-group', 21 | author_email='e.shijia@foxmail.com', 22 | description='NLP for humans', 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | url="https://github.com/boat-group/fancy-nlp", 26 | install_requires=install_requires, 27 | classifiers=[ 28 | 'Programming Language :: Python :: 3.6', 29 | 'Intended Audience :: Developers', 30 | 'Intended Audience :: Science/Research', 31 | 'Intended Audience :: Education', 32 | 'Intended Audience :: Information Technology', 33 | 'License :: OSI Approved :: GNU General Public License v3 (GPLv3)', 34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 35 | 'Topic :: Software Development :: Libraries', 36 | 'Topic :: Software Development :: Libraries :: Python Modules' 37 | ], 38 | packages=find_packages()) 39 | 40 | -------------------------------------------------------------------------------- /fancy_nlp/utils/save_load_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Any 4 | 5 | from absl import logging 6 | import tensorflow as tf 7 | 8 | 9 | def save_keras_model(model: tf.keras.models.Model, json_file: str, weights_file: str) -> None: 10 | """Save keras model's architecture and weights to disk 11 | 12 | Args: 13 | model: Instance of tf.keras model. 14 | json_file: str, Path to save model's architecture info. 15 | weights_file: str. Path to save model's weights. 16 | 17 | Returns: 18 | 19 | """ 20 | model_json = model.to_json() 21 | with open(json_file, 'w') as writer: 22 | writer.write(model_json) 23 | model.save_weights(weights_file) 24 | logging.info('Saved model to disk') 25 | 26 | 27 | def load_keras_model(json_file: str, 28 | weights_file: str, 29 | custom_objects: Dict[str, Any] = None) -> tf.keras.models.Model: 30 | """Load keras model from disk 31 | 32 | Args: 33 | json_file: str. File path to model's architecture info. 34 | weights_file: str. File path to model's weights. 35 | custom_objects: Optional dictionary mapping names (strings) to custom classes or 36 | functions to be considered during deserialization. Must provided when 37 | using custom layer. 38 | """ 39 | with open(json_file, 'r') as reader: 40 | model = tf.keras.models.model_from_json(reader.read(), custom_objects=custom_objects) 41 | model.load_weights(weights_file) 42 | return model 43 | -------------------------------------------------------------------------------- /tests/fancy_nlp/utils/test_other.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from fancy_nlp.utils.other import pad_sequences_2d 5 | 6 | 7 | class TestOther: 8 | def test_pad_sequence_2d(self): 9 | test_case = [[[1, 2, 4], [1, 2], [3]], 10 | [[2, 4], [1, 0, 2]], 11 | [[1, 2, 3, 4, 5]]] 12 | expected = [[[1, 2, 4], [0, 1, 2]], 13 | [[0, 2, 4], [1, 0, 2]], 14 | [[0, 0, 0], [1, 2, 3]]] 15 | result = pad_sequences_2d(test_case, max_len_1=2, max_len_2=3, padding='pre', 16 | truncating='post') 17 | np.testing.assert_equal(result, expected) 18 | 19 | test_case = [[[1, 2, 4], [1, 2], [3]], 20 | [[2, 4], [1, 0, 2]], 21 | [[1, 2, 3, 4, 5]]] 22 | expected = [[[1, 2, 0, 0, 0], [3, 0, 0, 0, 0]], 23 | [[2, 4, 0, 0, 0], [1, 0, 2, 0, 0]], 24 | [[1, 2, 3, 4, 5], [0, 0, 0, 0, 0]]] 25 | result = pad_sequences_2d(test_case, max_len_1=2, max_len_2=None, padding='post', 26 | truncating='pre') 27 | np.testing.assert_equal(result, expected) 28 | 29 | test_case = [[[1, 2, 3, 4], [1, 2], [1], [1, 2, 3]], 30 | [[1, 2, 3, 4, 5], [1, 2], [1, 2, 3, 4]]] 31 | expected = [[[1, 2, 3, 4, 0], [1, 2, 0, 0, 0], [1, 0, 0, 0, 0], [1, 2, 3, 0, 0]], 32 | [[1, 2, 3, 4, 5], [1, 2, 0, 0, 0], [1, 2, 3, 4, 0], [0, 0, 0, 0, 0]]] 33 | result = pad_sequences_2d(test_case) 34 | np.testing.assert_equal(result, expected) 35 | -------------------------------------------------------------------------------- /tests/fancy_nlp/utils/test_embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import numpy as np 6 | 7 | from fancy_nlp.utils import load_ner_data_and_labels 8 | from fancy_nlp.utils import train_w2v, train_fasttext, load_pre_trained 9 | 10 | 11 | class TestEmbedding: 12 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/ner/msra/example.txt') 13 | embedding_file = os.path.join(os.path.dirname(__file__), 14 | '../../../data/embeddings/Tencent_ChineseEmbedding_example.txt') 15 | 16 | def setup_class(self): 17 | self.test_corpus, _ = load_ner_data_and_labels(self.test_file) 18 | self.test_vocab = {'': 0, '': 1} 19 | for token in set(self.test_corpus[0]): 20 | self.test_vocab[token] = len(self.test_vocab) 21 | 22 | def test_train_w2v(self): 23 | emb = train_w2v(self.test_corpus, self.test_vocab, embedding_dim=10) 24 | assert emb.shape[0] == len(self.test_vocab) and emb.shape[1] == 10 25 | assert not np.any(emb[0]) 26 | 27 | def test_train_fasttext(self): 28 | emb = train_fasttext(self.test_corpus, self.test_vocab, embedding_dim=10) 29 | assert emb.shape[0] == len(self.test_vocab) and emb.shape[1] == 10 30 | assert not np.any(emb[0]) 31 | 32 | def test_load_pre_trained(self): 33 | emb = load_pre_trained(load_filename=self.embedding_file, 34 | embedding_dim=200, 35 | vocabulary=self.test_vocab) 36 | assert emb.shape[0] == len(self.test_vocab) and emb.shape[1] == 200 37 | assert not np.any(emb[0]) 38 | -------------------------------------------------------------------------------- /examples/ner_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_ner_data_and_labels 6 | from fancy_nlp.applications import NER 7 | 8 | msra_train_file = 'datasets/ner/msra/train_data' 9 | msra_dev_file = 'datasets/ner/msra/test_data' 10 | 11 | checkpoint_dir = 'pretrained_models' 12 | model_name = 'msra_ner_bilstm_cnn_crf' 13 | 14 | if not os.path.exists(checkpoint_dir): 15 | os.makedirs(checkpoint_dir) 16 | 17 | train_data, train_labels = load_ner_data_and_labels(msra_train_file) 18 | dev_data, dev_labels = load_ner_data_and_labels(msra_dev_file) 19 | 20 | ner = NER(use_pretrained=False) 21 | 22 | ner.fit(train_data, train_labels, dev_data, dev_labels, 23 | ner_model_type='bilstm_cnn', 24 | char_embed_trainable=True, 25 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 26 | checkpoint_dir=checkpoint_dir, 27 | model_name=model_name, 28 | load_swa_model=True) 29 | 30 | ner.save(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 31 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json')) 32 | 33 | ner.load(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 34 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json'), 35 | weights_file=os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 36 | 37 | print(ner.score(dev_data, dev_labels)) 38 | print(ner.analyze(train_data[2])) 39 | print(ner.analyze_batch(train_data[:3])) 40 | print(ner.restrict_analyze(train_data[2])) 41 | print(ner.restrict_analyze_batch(train_data[:3])) 42 | print(ner.analyze('同济大学位于上海市杨浦区,校长为陈杰')) 43 | -------------------------------------------------------------------------------- /examples/spm_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_spm_data_and_labels 6 | from fancy_nlp.applications import SPM 7 | 8 | train_file = 'datasets/spm/webank/BQ_train.txt' 9 | valid_file = 'datasets/spm/webank/BQ_dev.txt' 10 | test_file = 'datasets/spm/webank/BQ_test.txt' 11 | 12 | model_name = 'webank_spm_siamese_cnn_word' 13 | checkpoint_dir = 'pretrained_models' 14 | 15 | if not os.path.exists(checkpoint_dir): 16 | os.makedirs(checkpoint_dir) 17 | 18 | train_data, train_labels = load_spm_data_and_labels(train_file) 19 | valid_data, valid_labels = load_spm_data_and_labels(valid_file) 20 | test_data, test_labels = load_spm_data_and_labels(test_file) 21 | 22 | spm_app = SPM(use_pretrained=False) 23 | 24 | spm_app.fit(train_data, train_labels, valid_data, valid_labels, 25 | spm_model_type='siamese_cnn', 26 | word_embed_trainable=True, 27 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 28 | checkpoint_dir=checkpoint_dir, 29 | model_name=model_name, 30 | max_len=60, 31 | load_swa_model=True) 32 | 33 | spm_app.save( 34 | preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 35 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json')) 36 | 37 | spm_app.load( 38 | preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 39 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json'), 40 | weights_file=os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 41 | 42 | print(spm_app.score(test_data, test_labels)) 43 | print(spm_app.predict(('未满足微众银行审批是什么意思', '为什么我未满足微众银行审批'))) 44 | print(spm_app.analyze(('未满足微众银行审批是什么意思', '为什么我未满足微众银行审批'))) 45 | -------------------------------------------------------------------------------- /tests/fancy_nlp/utils/test_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import math 6 | 7 | from fancy_nlp.utils import load_ner_data_and_labels, load_spm_data_and_labels 8 | from fancy_nlp.preprocessors import NERPreprocessor, SPMPreprocessor 9 | from fancy_nlp.utils import NERGenerator, SPMGenerator 10 | 11 | 12 | class TestGenerator: 13 | def test_ner_generator(self): 14 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/ner/msra/example.txt') 15 | x_train, y_train = load_ner_data_and_labels(test_file) 16 | 17 | preprocessor = NERPreprocessor(x_train, y_train) 18 | generator = NERGenerator(preprocessor, x_train, batch_size=64) 19 | assert len(generator) == math.ceil(len(x_train) / 64) 20 | for i, (features, y) in enumerate(generator): 21 | if i < len(generator) - 1: 22 | assert features.shape[0] == 64 23 | assert y is None 24 | else: 25 | assert features.shape[0] == len(x_train) - 64 * (len(generator) - 1) 26 | assert y is None 27 | 28 | def test_spm_generator(self): 29 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/spm/webank/example.txt') 30 | x_train, y_train = load_spm_data_and_labels(test_file) 31 | 32 | preprocessor = SPMPreprocessor(x_train, y_train) 33 | generator = SPMGenerator(preprocessor, x_train, batch_size=64) 34 | assert len(generator) == math.ceil(len(x_train[0]) / 64) 35 | for i, (features, y) in enumerate(generator): 36 | if i < len(generator) - 1: 37 | assert features[0].shape[0] == features[1].shape[0] == 64 38 | assert y is None 39 | else: 40 | assert features[0].shape[0] == features[1].shape[0] == \ 41 | len(x_train[0]) - 64 * (len(generator) - 1) 42 | assert y is None 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # MacOS files 107 | .DS_Store 108 | 109 | # Pycharm files 110 | .idea/ 111 | 112 | # Model checkpoints 113 | pretrained_models/ 114 | 115 | # Datasets 116 | datasets/ 117 | 118 | # Pretrained embeddings 119 | pretrained_embeddings/ 120 | -------------------------------------------------------------------------------- /examples/bert_single.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_ner_data_and_labels 6 | from fancy_nlp.applications import NER 7 | 8 | msra_train_file = 'datasets/ner/msra/train_data' 9 | msra_dev_file = 'datasets/ner/msra/test_data' 10 | 11 | checkpoint_dir = 'pretrained_models' 12 | model_name = 'msra_ner_bilstm_cnn_bert_crf' 13 | 14 | if not os.path.exists(checkpoint_dir): 15 | os.makedirs(checkpoint_dir) 16 | 17 | train_data, train_labels = load_ner_data_and_labels(msra_train_file) 18 | valid_data, valid_labels = load_ner_data_and_labels(msra_dev_file) 19 | 20 | ner = NER(use_pretrained=False) 21 | ner.fit(train_data, train_labels, valid_data, valid_labels, 22 | ner_model_type='bilstm_cnn', 23 | use_char=False, 24 | use_word=False, 25 | use_bert=True, 26 | # 传入bert模型各文件的路径 27 | bert_vocab_file='pretrained_embeddings/chinese_L-12_H-768_A-12/vocab.txt', 28 | bert_config_file='pretrained_embeddings/chinese_L-12_H-768_A-12/bert_config.json', 29 | bert_checkpoint_file='pretrained_embeddings/chinese_L-12_H-768_A-12/bert_model.ckpt', 30 | # 设置bert不可训练 31 | bert_trainable=False, 32 | optimizer='adam', 33 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 34 | checkpoint_dir=checkpoint_dir, 35 | model_name=model_name, 36 | load_swa_model=True) 37 | 38 | ner.save(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 39 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json')) 40 | 41 | ner.load(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 42 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json'), 43 | weights_file=os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 44 | 45 | print(ner.score(valid_data, valid_labels)) 46 | print(ner.analyze(train_data[2])) 47 | print(ner.analyze_batch(train_data[:3])) 48 | print(ner.restrict_analyze(train_data[2])) 49 | print(ner.restrict_analyze_batch(train_data[:3])) 50 | print(ner.analyze('同济大学位于上海市杨浦区,校长为陈杰')) 51 | -------------------------------------------------------------------------------- /examples/bert_fine_tuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import tensorflow as tf 5 | 6 | from fancy_nlp.utils import load_ner_data_and_labels 7 | from fancy_nlp.applications import NER 8 | 9 | msra_train_file = 'datasets/ner/msra/train_data' 10 | msra_dev_file = 'datasets/ner/msra/test_data' 11 | 12 | checkpoint_dir = 'pretrained_models' 13 | model_name = 'msra_ner_bert_crf' 14 | 15 | if not os.path.exists(checkpoint_dir): 16 | os.makedirs(checkpoint_dir) 17 | 18 | train_data, train_labels = load_ner_data_and_labels(msra_train_file) 19 | valid_data, valid_labels = load_ner_data_and_labels(msra_dev_file) 20 | 21 | ner = NER(use_pretrained=False) 22 | ner.fit(train_data, train_labels, valid_data, valid_labels, 23 | ner_model_type='bert', 24 | use_char=False, 25 | use_word=False, 26 | use_bert=True, 27 | # 传入bert模型各文件的路径 28 | bert_vocab_file='pretrained_embeddings/chinese_L-12_H-768_A-12/vocab.txt', 29 | bert_config_file='pretrained_embeddings/chinese_L-12_H-768_A-12/bert_config.json', 30 | bert_checkpoint_file='pretrained_embeddings/chinese_L-12_H-768_A-12/bert_model.ckpt', 31 | bert_trainable=True, 32 | optimizer=tf.keras.optimizers.Adam(1e-5), 33 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 34 | checkpoint_dir=checkpoint_dir, 35 | model_name=model_name, 36 | load_swa_model=True) 37 | 38 | ner.save(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 39 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json')) 40 | 41 | ner.load(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 42 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json'), 43 | weights_file=os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 44 | 45 | print(ner.score(valid_data, valid_labels)) 46 | print(ner.analyze(train_data[2])) 47 | print(ner.analyze_batch(train_data[:3])) 48 | print(ner.restrict_analyze(train_data[2])) 49 | print(ner.restrict_analyze_batch(train_data[:3])) 50 | print(ner.analyze('同济大学位于上海市杨浦区,校长为陈杰')) 51 | -------------------------------------------------------------------------------- /tests/fancy_nlp/utils/test_save_load_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | import tensorflow.keras.backend as K 7 | 8 | from fancy_nlp.utils import save_keras_model, load_keras_model 9 | 10 | 11 | class MyLayer(tf.keras.layers.Layer): 12 | def __init__(self, output_dim, **kwargs): 13 | self.output_dim = output_dim 14 | super(MyLayer, self).__init__(**kwargs) 15 | 16 | def build(self, input_shape): 17 | # Create a trainable weight variable for this layer. 18 | self.kernel = self.add_weight(name='kernel', 19 | shape=(input_shape[1], self.output_dim), 20 | initializer='uniform', 21 | trainable=True) 22 | super(MyLayer, self).build(input_shape) # Be sure to call this somewhere! 23 | 24 | def call(self, inputs, **kwargs): 25 | return K.dot(inputs, self.kernel) 26 | 27 | def compute_output_shape(self, input_shape): 28 | return input_shape[0], self.output_dim 29 | 30 | def get_config(self): 31 | config = {'output_dim': self.output_dim} 32 | 33 | base_config = super(MyLayer, self).get_config() 34 | return dict(list(base_config.items()) + list(config.items())) 35 | 36 | 37 | class TestSaveLoad: 38 | test_json_file = 'my_model_architecture.json' 39 | test_weights_file = 'my_model_weights.h5' 40 | 41 | def setup_class(self): 42 | self.model = tf.keras.models.Sequential() 43 | self.model.add(tf.keras.layers.Dense(32, input_shape=(784, ))) 44 | self.model.add(MyLayer(100, name='my_layer1')) 45 | self.model.add(MyLayer(100, name='my_layer2')) 46 | 47 | def test_save(self): 48 | save_keras_model(self.model, self.test_json_file, self.test_weights_file) 49 | 50 | def test_load(self): 51 | self.model = load_keras_model(self.test_json_file, self.test_weights_file, 52 | custom_objects={'MyLayer': MyLayer}) 53 | assert len(self.model.layers) == 3 54 | 55 | def teardown_class(self): 56 | os.remove(self.test_json_file) 57 | os.remove(self.test_weights_file) 58 | -------------------------------------------------------------------------------- /examples/bert_combination.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import tensorflow as tf 5 | 6 | from fancy_nlp.utils import load_ner_data_and_labels 7 | from fancy_nlp.applications import NER 8 | 9 | msra_train_file = 'datasets/ner/msra/train_data' 10 | msra_dev_file = 'datasets/ner/msra/test_data' 11 | 12 | checkpoint_dir = 'pretrained_models' 13 | model_name = 'msra_ner_bilstm_cnn_char_bert_crf' 14 | 15 | if not os.path.exists(checkpoint_dir): 16 | os.makedirs(checkpoint_dir) 17 | 18 | train_data, train_labels = load_ner_data_and_labels(msra_train_file) 19 | valid_data, valid_labels = load_ner_data_and_labels(msra_dev_file) 20 | 21 | ner = NER(use_pretrained=False) 22 | ner.fit(train_data, train_labels, valid_data, valid_labels, 23 | ner_model_type='bilstm_cnn', 24 | use_char=True, 25 | use_word=False, 26 | use_bert=True, 27 | # 传入bert模型各文件的路径 28 | bert_vocab_file='pretrained_embeddings/chinese_L-12_H-768_A-12/vocab.txt', 29 | bert_config_file='pretrained_embeddings/chinese_L-12_H-768_A-12/bert_config.json', 30 | bert_checkpoint_file='pretrained_embeddings/chinese_L-12_H-768_A-12/bert_model.ckpt', 31 | # 设置bert可训练 32 | bert_trainable=True, 33 | # 使用小一点学习率的优化器 34 | optimizer=tf.keras.optimizers.Adam(1e-5), 35 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 36 | checkpoint_dir='pretrained_models', 37 | model_name=model_name, 38 | load_swa_model=True) 39 | 40 | ner.save(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 41 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json')) 42 | 43 | ner.load(preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 44 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json'), 45 | weights_file=os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 46 | 47 | print(ner.score(valid_data, valid_labels)) 48 | print(ner.analyze(train_data[2])) 49 | print(ner.analyze_batch(train_data[:3])) 50 | print(ner.restrict_analyze(train_data[2])) 51 | print(ner.restrict_analyze_batch(train_data[:3])) 52 | print(ner.analyze('同济大学位于上海市杨浦区,校长为陈杰')) 53 | -------------------------------------------------------------------------------- /examples/text_classification_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_text_classification_data_and_labels 6 | from fancy_nlp.applications import TextClassification 7 | 8 | data_file = 'datasets/text_classification/toutiao/toutiao_cat_data.txt' 9 | dict_file = 'datasets/text_classification/toutiao/toutiao_label_dict.txt' 10 | model_name = 'toutiao_text_classification_cnn' 11 | checkpoint_dir = 'pretrained_models' 12 | 13 | if not os.path.exists(checkpoint_dir): 14 | os.makedirs(checkpoint_dir) 15 | 16 | train_data, train_labels, valid_data, valid_labels, test_data, test_labels = \ 17 | load_text_classification_data_and_labels(data_file, 18 | label_index=1, 19 | text_index=3, 20 | delimiter='_!_', 21 | split_mode=2, 22 | split_size=0.3) 23 | 24 | text_classification_app = TextClassification(use_pretrained=False) 25 | 26 | text_classification_app.fit(train_data, train_labels, valid_data, valid_labels, 27 | text_classification_model_type='cnn', 28 | char_embed_trainable=True, 29 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 30 | checkpoint_dir=checkpoint_dir, 31 | model_name=model_name, 32 | label_dict_file=dict_file, 33 | max_len=60, 34 | load_swa_model=True) 35 | 36 | text_classification_app.save( 37 | preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 38 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json')) 39 | 40 | text_classification_app.load( 41 | preprocessor_file=os.path.join(checkpoint_dir, f'{model_name}_preprocessor.pkl'), 42 | json_file=os.path.join(checkpoint_dir, f'{model_name}.json'), 43 | weights_file=os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 44 | 45 | print(text_classification_app.score(test_data, test_labels)) 46 | print(text_classification_app.predict('小米公司成立十周年')) 47 | print(text_classification_app.analyze('小米公司成立十周年')) 48 | -------------------------------------------------------------------------------- /tests/fancy_nlp/utils/test_data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_ner_data_and_labels, load_spm_data_and_labels 6 | 7 | 8 | class TestNerDataLoader: 9 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/ner/msra/example.txt') 10 | 11 | def test_load_ner(self): 12 | x_train, y_train = load_ner_data_and_labels(self.test_file) 13 | assert len(x_train) == len(y_train) 14 | assert len(x_train) > 0 15 | assert len(x_train[0]) == len(y_train[0]) 16 | assert len(x_train[0]) > 0 17 | assert x_train[:5] != y_train[:5] 18 | 19 | def test_load_ner_split(self): 20 | x_train, y_train, x_test, y_test = load_ner_data_and_labels(self.test_file, split=True) 21 | assert len(x_train) == len(y_train) and len(x_test) == len(y_test) 22 | assert len(x_train) > 0 and len(x_train) > 0 23 | assert len(x_train[0]) == len(y_train[0]) and len(x_test[0]) == len(y_test[0]) 24 | assert len(x_train[0]) > 0 and len(x_test[0]) > 0 25 | assert x_train[:5] != y_train[:5] and x_test[:5] != y_test[:5] 26 | assert x_train[:5] != x_test[:5] and y_train[:5] != y_test[:5] 27 | 28 | 29 | class TestSpmDataLoader: 30 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/spm/webank/example.txt') 31 | 32 | def test_load_spm(self): 33 | x_train, y_train = load_spm_data_and_labels(self.test_file) 34 | assert len(x_train) == 2 35 | assert len(x_train[0]) == len(x_train[1]) 36 | assert len(x_train[0]) == len(y_train) 37 | assert len(x_train[0]) > 0 38 | 39 | def test_load_spm_split1(self): 40 | x_train, y_train, x_test, y_test = load_spm_data_and_labels(self.test_file, split_mode=1) 41 | assert len(x_train[0]) == len(x_train[1]) 42 | assert len(x_test[0]) == len(x_test[1]) 43 | assert len(x_train[0]) == len(y_train) and len(x_test[0]) == len(y_test) 44 | assert len(x_train[0]) > 0 and len(x_test[0]) > 0 45 | 46 | def test_load_spm_split2(self): 47 | x_train, y_train, x_valid, y_valid, x_test, y_test = \ 48 | load_spm_data_and_labels(self.test_file, split_mode=2, split_size=0.4) 49 | assert len(x_train[0]) == len(x_train[1]) == len(y_train) 50 | assert len(x_valid[0]) == len(x_valid[1]) == len(y_valid) 51 | assert len(x_test[0]) == len(x_test[1]) == len(y_test) 52 | assert len(x_train[0]) > 0 and len(x_test[0]) > 0 -------------------------------------------------------------------------------- /fancy_nlp/callbacks/ensemble.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from absl import logging 6 | import tensorflow as tf 7 | 8 | ''' 9 | Applying ensemble during a single training process 10 | ''' 11 | 12 | 13 | class SWA(tf.keras.callbacks.Callback): 14 | """This callback implements a stochastic weight averaging (SWA) method with constant lr as 15 | presented in the paper: "Izmailov et al. Averaging Weights Leads to Wider Optima and Better 16 | Generalization" (https://arxiv.org/abs/1803.05407) 17 | 18 | Author's implementation: https://github.com/timgaripov/swa 19 | """ 20 | def __init__(self, 21 | swa_model: tf.keras.models.Model, 22 | checkpoint_dir: str, 23 | model_name: str, 24 | swa_start: int = 1) -> None: 25 | """ 26 | 27 | Args: 28 | swa_model: Instance of `tf.keras.models.Model`. The model that used to store the 29 | average of the weights once SWA begins 30 | checkpoint_dir: str. tThe directory where the model will be saved in 31 | model_name: str. The name of model we're training. 32 | We use checkpoint_dir and model_name to save swa model's weights after training 33 | done. For example, if checkpoint_dir is 'ckpt' and model_name is 'model', 34 | the weights of swa model will be save in 'ckpt/model_swa.hdf5'. 35 | swa_start: str. The epoch when averaging begins. We generally pre-train the network 36 | for a certain amount of epochs to start (swa_start > 1), as opposed to starting to 37 | track the average from the very beginning. 38 | """ 39 | super(SWA, self).__init__() 40 | self.checkpoint_dir = checkpoint_dir 41 | self.model_name = model_name 42 | self.swa_start = swa_start 43 | self.swa_model = swa_model 44 | 45 | def on_train_begin(self, logs=None): 46 | self.epoch = 0 47 | self.swa_n = 0 48 | # Note: I found deep copy of a model with customized layer would give errors 49 | # self.swa_model = copy.deepcopy(self.model) # make a copy of the model we're training 50 | 51 | # Note: Something wired still happen even though i use keras.models.clone_model method, 52 | # so I build swa_model outside this callback and pass it as an argument. 53 | # It's not fancy, but the best I can do :) 54 | # see: https://github.com/keras-team/keras/issues/1765 55 | # self.swa_model = keras.models.clone_model(self.model) 56 | self.swa_model.set_weights(self.model.get_weights()) 57 | 58 | def on_epoch_end(self, epoch, logs=None): 59 | if (self.epoch + 1) >= self.swa_start: 60 | self.update_average_model() 61 | self.swa_n += 1 62 | 63 | self.epoch += 1 64 | 65 | def update_average_model(self): 66 | # update running average of parameters 67 | alpha = 1. / (self.swa_n + 1) 68 | for layer, swa_layer in zip(self.model.layers, self.swa_model.layers): 69 | weights = [] 70 | for w1, w2 in zip(swa_layer.get_weights(), layer.get_weights()): 71 | weights.append((1 - alpha) * w1 + alpha * w2) 72 | swa_layer.set_weights(weights) 73 | 74 | def on_train_end(self, logs=None): 75 | logging.info('Logging Info - Saving SWA model checkpoint: %s_swa.hdf5' % self.model_name) 76 | weights_file = os.path.join(self.checkpoint_dir, '{}_swa.hdf5'.format(self.model_name)) 77 | self.swa_model.save_weights(weights_file) 78 | logging.info('Logging Info - SWA model Saved') 79 | -------------------------------------------------------------------------------- /fancy_nlp/predictors/spm_predictor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Tuple, List 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from fancy_nlp.preprocessors import SPMPreprocessor 9 | 10 | 11 | class SPMPredictor(object): 12 | """SPM predictor for evaluating spm model, output predictive probabilities 13 | for input sentence 14 | """ 15 | def __init__(self, 16 | model: tf.keras.models.Model, 17 | preprocessor: SPMPreprocessor) -> None: 18 | """ 19 | 20 | Args: 21 | model: instance of keras model 22 | preprocessor: `SPMPreprocessor` instance to prepare feature input for spm model 23 | """ 24 | self.model = model 25 | self.preprocessor = preprocessor 26 | 27 | def predict_prob(self, text: Tuple[str, str]) -> np.ndarray: 28 | """Return probabilities for a pair of sentence 29 | 30 | Args: 31 | text: a pair of untokenized text(str) 32 | 33 | Returns: np.array, shaped [num_classes] 34 | 35 | """ 36 | assert isinstance(text, tuple) and len(text) == 2, "input must be a tuple of two texts" 37 | features, _ = self.preprocessor.prepare_input(([text[0]], [text[1]])) 38 | pred_probs = self.model.predict(features) 39 | return pred_probs[0] 40 | 41 | def predict_prob_batch(self, texts: Tuple[List[str], List[str]]) -> np.ndarray: 42 | """Return probabilities for a batch sentence pairs 43 | 44 | Args: 45 | texts: a list of text pairs, each text must be untokenized 46 | 47 | Returns: np.array, shaped [num_texts, num_classes] 48 | """ 49 | assert isinstance(texts, (list, tuple)) and len(texts) == 2, "input must be text pairs" 50 | features, _ = self.preprocessor.prepare_input(texts) 51 | pred_probs = self.model.predict(features) 52 | return pred_probs 53 | 54 | def matching(self, text: Tuple[str, str]) -> str: 55 | """Return label string for a pair of text 56 | 57 | Args: 58 | text: can be untokenized text pair 59 | 60 | Returns: str 61 | 62 | """ 63 | pred_prob = self.predict_prob(text) 64 | labels = self.preprocessor.label_decode(np.expand_dims(pred_prob, 0)) 65 | return labels[0] 66 | 67 | def matching_batch(self, texts: Tuple[List[str], List[str]]) -> List[str]: 68 | """Return label string for a batch text pairs 69 | 70 | Args: 71 | texts: a list of text pairs, each text must be untokenized 72 | 73 | Returns: list of str 74 | 75 | """ 76 | pred_probs = self.predict_prob_batch(texts) 77 | labels = self.preprocessor.label_decode(pred_probs) 78 | return labels 79 | 80 | def matching_with_prob(self, text: Tuple[str, str]) -> Tuple[str, np.ndarray]: 81 | """Return the classification result for one sentence with probability 82 | 83 | Args: 84 | text: can be untokenized text pair 85 | 86 | Returns: tuple 87 | 88 | """ 89 | pred_result = self.predict_prob(text) 90 | tags = self.preprocessor.label_decode(np.expand_dims(pred_result, 0)) 91 | label_name = tags[0] 92 | return label_name, pred_result 93 | 94 | def matching_with_prob_batch(self, text: Tuple[List[str], List[str]]) -> \ 95 | List[Tuple[str, np.ndarray]]: 96 | """Return the matching results for a batch sentence pairs with probabilities 97 | 98 | Args: 99 | text: a list of text pairs, each text must be untokenized 100 | 101 | Returns: list of tuple 102 | 103 | """ 104 | pred_results = self.predict_prob_batch(text) 105 | tags = self.preprocessor.label_decode(pred_results) 106 | results = list(zip(tags, pred_results)) 107 | return results 108 | -------------------------------------------------------------------------------- /tests/fancy_nlp/preprocessors/test_ner_preprocessor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import jieba 6 | import numpy as np 7 | 8 | from fancy_nlp.utils.data_loader import load_ner_data_and_labels 9 | from fancy_nlp.preprocessors.ner_preprocessor import NERPreprocessor 10 | 11 | 12 | class TestNERPreprocessor: 13 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/ner/msra/example.txt') 14 | bert_vocab_file = os.path.join(os.path.dirname(__file__), 15 | '../../../data/embeddings/bert_sample_model/vocab.txt') 16 | 17 | def setup_class(self): 18 | x_train, y_train = load_ner_data_and_labels(self.test_file) 19 | self.preprocessor = NERPreprocessor(x_train, y_train, use_char=True, use_bert=True, 20 | use_word=True, bert_vocab_file=self.bert_vocab_file, 21 | external_word_dict=['比特币'], 22 | char_embed_type='word2vec', max_len=16) 23 | 24 | def test_init(self): 25 | assert len(self.preprocessor.char_vocab_count) + 4 == len(self.preprocessor.char_vocab) \ 26 | == len(self.preprocessor.id2char) 27 | assert list(self.preprocessor.id2char.keys())[0] == 0 28 | for cnt in self.preprocessor.char_vocab_count.values(): 29 | assert cnt >= 2 30 | 31 | assert self.preprocessor.char_embeddings.shape[0] == len(self.preprocessor.char_vocab) 32 | assert self.preprocessor.char_embeddings.shape[1] == 300 33 | assert not np.any(self.preprocessor.char_embeddings[0]) 34 | 35 | assert len(self.preprocessor.word_vocab_count) + 4 == len(self.preprocessor.word_vocab) \ 36 | == len(self.preprocessor.id2word) 37 | assert list(self.preprocessor.id2word.keys())[0] == 0 38 | for cnt in self.preprocessor.word_vocab_count.values(): 39 | assert cnt >= 2 40 | assert self.preprocessor.word_embeddings is None 41 | 42 | assert len(self.preprocessor.label_vocab) == len(self.preprocessor.id2label) 43 | assert list(self.preprocessor.id2label.keys())[0] == 0 44 | 45 | def test_prepare_input(self): 46 | features, y = self.preprocessor.prepare_input(self.preprocessor.train_data, 47 | self.preprocessor.train_labels) 48 | assert len(features) == 4 49 | assert features[0].shape == features[1].shape == features[2].shape == features[3].shape == \ 50 | (len(self.preprocessor.train_data), self.preprocessor.max_len) 51 | assert self.preprocessor.id2char[features[0][0][0]] == self.preprocessor.cls_token 52 | assert self.preprocessor.id2word[features[0][0][0]] == self.preprocessor.cls_token 53 | assert y.shape == (len(self.preprocessor.train_data), self.preprocessor.max_len, 54 | self.preprocessor.num_class) 55 | 56 | def test_get_word_ids(self): 57 | example_text = ''.join(self.preprocessor.train_data[0]) 58 | word_cut = jieba.lcut(example_text) 59 | word_ids = self.preprocessor.get_word_ids(word_cut) 60 | assert len(word_ids) == len(example_text) + 2 61 | 62 | start = 1 63 | for word in word_cut: 64 | if start > len(word_ids): 65 | break 66 | assert len(set(word_ids[start:start+len(word)])) == 1 67 | start += len(word) 68 | 69 | def test_label_decode(self): 70 | rand_pred_probs = np.random.rand(2, 10, self.preprocessor.num_class) 71 | lengths = [8, 9] 72 | pred_labels = self.preprocessor.label_decode(rand_pred_probs, lengths) 73 | assert len(pred_labels) == len(lengths) 74 | for i, length in enumerate(lengths): 75 | assert len(pred_labels[i]) == length 76 | 77 | def test_save_load(self): 78 | pkl_file = 'test_preprocessor.pkl' 79 | self.preprocessor.save(pkl_file) 80 | assert os.path.exists(pkl_file) 81 | new_preprocessor = NERPreprocessor.load(pkl_file) 82 | assert new_preprocessor.num_class == self.preprocessor.num_class 83 | os.remove(pkl_file) 84 | -------------------------------------------------------------------------------- /fancy_nlp/predictors/text_classification_predictor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | from absl import logging 5 | 6 | 7 | class TextClassificationPredictor(object): 8 | """TextClassification predictor for evaluating text classification model, output predictive 9 | probabilities and labels for input sentence""" 10 | def __init__(self, model, preprocessor): 11 | """ 12 | 13 | Args: 14 | model: instance of keras model 15 | preprocessor: `TextClassificationPreprocessor` instance 16 | """ 17 | self.model = model 18 | self.preprocessor = preprocessor 19 | 20 | def predict_prob(self, text): 21 | """Return probabilities for one sentence 22 | 23 | Args: 24 | text: can be untokenized (str) or tokenized in char level (list) 25 | 26 | Returns: np.array, shaped [num_classes,] 27 | 28 | """ 29 | if isinstance(text, list): 30 | logging.warning('Text is passed in a list. Make sure it is tokenized in char level!') 31 | features, _ = self.preprocessor.prepare_input([text]) 32 | else: 33 | assert isinstance(text, str) 34 | features, _ = self.preprocessor.prepare_input([list(text)]) 35 | pred_probs = self.model.predict(features) 36 | return pred_probs[0] 37 | 38 | def predict_prob_batch(self, texts): 39 | """Return probabilities for a batch sentences 40 | 41 | Args: 42 | texts: a list of texts, each text can be untokenized (str) or 43 | tokenized in char level (list) 44 | 45 | Returns: np.array, shaped [num_texts, num_classes] 46 | """ 47 | assert isinstance(texts, list) 48 | if isinstance(texts[0], list): 49 | logging.warning('Text is passed in a list. Make sure it is tokenized in char level!') 50 | features, _ = self.preprocessor.prepare_input(texts) 51 | else: 52 | assert isinstance(texts[0], str) 53 | char_cut_texts = [list(text) for text in texts] 54 | features, _ = self.preprocessor.prepare_input(char_cut_texts) 55 | pred_probs = self.model.predict(features) 56 | return pred_probs 57 | 58 | def classify(self, text): 59 | """Return the classification result for one sentence 60 | 61 | Args: 62 | text: can be untokenized (str) or tokenized in char level (list) 63 | 64 | Returns: str 65 | 66 | """ 67 | pred_prob = self.predict_prob(text) 68 | tags = self.preprocessor.label_decode(np.expand_dims(pred_prob, 0), 69 | self.preprocessor.label_dict) 70 | return tags[0] 71 | 72 | def classify_batch(self, texts): 73 | """Return classification result for a batch sentences 74 | 75 | Args: 76 | texts: a list of text, each text can be untokenized (str) or 77 | tokenized in char level (list) 78 | 79 | Returns: list of str 80 | 81 | """ 82 | pred_probs = self.predict_prob_batch(texts) 83 | tags = self.preprocessor.label_decode(pred_probs, self.preprocessor.label_dict) 84 | return tags 85 | 86 | def classification_with_prob(self, text): 87 | """Return the classification result for one sentence with probability 88 | 89 | Args: 90 | text: can be untokenized (str) or tokenized in char level (list) 91 | 92 | Returns: tuple 93 | 94 | """ 95 | pred_result = self.predict_prob(text) 96 | tags = self.preprocessor.label_decode(np.expand_dims(pred_result, 0), 97 | self.preprocessor.label_dict) 98 | label_name = tags[0] 99 | label_prob = np.max(pred_result) 100 | return label_name, label_prob 101 | 102 | def classification_with_prob_batch(self, text): 103 | """Return the classification results for a batch sentences with probabilities 104 | 105 | Args: 106 | text: can be untokenized (str) or tokenized in char level (list) 107 | 108 | Returns: list of tuple 109 | 110 | """ 111 | pred_results = self.predict_prob_batch(text) 112 | tags = self.preprocessor.label_decode(pred_results, self.preprocessor.label_dict) 113 | results = list(zip(tags, pred_results)) 114 | return results 115 | -------------------------------------------------------------------------------- /tests/fancy_nlp/preprocessors/test_preprocessor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import jieba 6 | import numpy as np 7 | 8 | from fancy_nlp.preprocessors.preprocessor import Preprocessor 9 | 10 | 11 | class TestPreprocessor: 12 | sample_texts = ['文献是朝阳区区长', 13 | '拳王阿里是个传奇', 14 | '杭州阿里吸引了很多人才', 15 | '习近平常与特朗普通电话', 16 | '南京市长江大桥'] 17 | embedding_file = os.path.join(os.path.dirname(__file__), 18 | '../../../data/embeddings/Tencent_ChineseEmbedding_example.txt') 19 | 20 | def setup_class(self): 21 | self.preprocessor = Preprocessor(max_len=50) 22 | 23 | def test_build_corpus(self): 24 | char_corpus = self.preprocessor.build_corpus(self.sample_texts, 25 | cut_func=lambda x: list(x)) 26 | assert len(char_corpus) == len(self.sample_texts) 27 | assert ''.join(char_corpus[0]) == self.sample_texts[0] 28 | 29 | word_corpus = self.preprocessor.build_corpus(self.sample_texts, 30 | cut_func=lambda x: jieba.lcut(x)) 31 | assert len(word_corpus) == len(self.sample_texts) 32 | assert ''.join(word_corpus[0]) == self.sample_texts[0] 33 | 34 | def test_build_vocab(self): 35 | char_corpus = self.preprocessor.build_corpus(self.sample_texts, 36 | cut_func=lambda x: list(x)) 37 | char_vocab_count, char_vocab, id2char = self.preprocessor.build_vocab( 38 | char_corpus, min_count=1) 39 | assert len(char_vocab_count) + 2 == len(char_vocab) == len(id2char) 40 | assert list(id2char.keys())[0] == 0 41 | 42 | def test_build_embedding(self): 43 | char_corpus = self.preprocessor.build_corpus(self.sample_texts, 44 | cut_func=lambda x: list(x)) 45 | _, char_vocab, _ = self.preprocessor.build_vocab(char_corpus, min_count=1) 46 | emb = self.preprocessor.build_embedding(embed_type=None, vocab=char_vocab) 47 | assert emb is None 48 | 49 | emb = self.preprocessor.build_embedding(embed_type='word2vec', vocab=char_vocab, 50 | corpus=char_corpus) 51 | assert emb.shape[0] == len(char_vocab) and emb.shape[1] == 300 52 | assert not np.any(emb[0]) 53 | 54 | emb = self.preprocessor.build_embedding(embed_type='fasttext', vocab=char_vocab, 55 | corpus=char_corpus, embedding_dim=20) 56 | assert emb.shape[0] == len(char_vocab) and emb.shape[1] == 20 57 | assert not np.any(emb[0]) 58 | 59 | emb = self.preprocessor.build_embedding(embed_type=self.embedding_file, 60 | embedding_dim=200, 61 | vocab=char_vocab, 62 | corpus=char_corpus) 63 | assert emb.shape[0] == len(char_vocab) and emb.shape[1] == 200 64 | assert not np.any(emb[0]) 65 | 66 | def test_build_id_sequence(self): 67 | char_corpus = self.preprocessor.build_corpus(self.sample_texts, 68 | cut_func=lambda x: list(x)) 69 | _, char_vocab, _ = self.preprocessor.build_vocab(char_corpus, min_count=1) 70 | sample_text = list('文献是朝阳区区长吗?') 71 | id_sequence = self.preprocessor.build_id_sequence(sample_text, char_vocab) 72 | assert len(id_sequence) == len(sample_text) 73 | assert id_sequence[-1] == 1 74 | 75 | def test_build_id_matrix(self): 76 | sample_texts = [list('文献是朝阳区区长吗?'), list('拳王阿里是个传奇啊!')] 77 | char_corpus = self.preprocessor.build_corpus(self.sample_texts, 78 | cut_func=lambda x: list(x)) 79 | _, char_vocab, _ = self.preprocessor.build_vocab(char_corpus, min_count=1) 80 | id_matrix = self.preprocessor.build_id_matrix(sample_texts, char_vocab) 81 | assert len(id_matrix) == len(sample_texts) 82 | assert len(id_matrix[-1]) == len(sample_texts[-1]) 83 | assert id_matrix[-1][-1] == 1 84 | 85 | def test_pad_sequence(self): 86 | x = [[1, 3, 5]] 87 | x_padded = self.preprocessor.pad_sequence(x) 88 | 89 | assert x_padded.shape == (1, 50) 90 | assert (np.array(x_padded) == 91 | np.array(x[0] + [0] * (self.preprocessor.max_len - len(x[0]))).reshape(1, -1)).any() 92 | -------------------------------------------------------------------------------- /fancy_nlp/utils/other.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | from typing import List, Union, Dict, Any 5 | 6 | import numpy as np 7 | from bert4keras.tokenizer import Tokenizer 8 | 9 | from fancy_nlp.layers import NonMaskingLayer, FullMatching, MaxPoolingMatching, \ 10 | AttentiveMatching, MaxAttentiveMatching, CRF 11 | 12 | 13 | def pad_sequences_2d(sequences: List[List[List[str]]], 14 | max_len_1: int = None, 15 | max_len_2: int = None, 16 | dtype: str = 'int32', 17 | padding: str = 'post', 18 | truncating: str = 'post', 19 | value: Union[int, float] = 0.) -> np.ndarray: 20 | """Pad sequence for [[[a, b, c, ...], ...], ...] to the same length, similar as 21 | `tf.keras.preprocessing.sequence.pad_sequences` does. 22 | 23 | Returns: 24 | np.ndarray, shaped [num_samples, max_len_1, max_len_2} 25 | 26 | """ 27 | lengths_1, lengths_2 = [], [] 28 | for s in sequences: 29 | lengths_1.append(len(s)) 30 | for t in s: 31 | lengths_2.append(len(t)) 32 | if max_len_1 is None: 33 | max_len_1 = np.max(lengths_1) 34 | if max_len_2 is None: 35 | max_len_2 = np.max(lengths_2) 36 | 37 | num_samples = len(sequences) 38 | x = (np.ones((num_samples, max_len_1, max_len_2)) * value).astype(dtype) 39 | for i, s in enumerate(sequences): 40 | if not len(s): 41 | continue # empty list was found 42 | 43 | if truncating == 'pre': 44 | s = s[-max_len_1:] 45 | elif truncating == 'post': 46 | s = s[:max_len_1] 47 | else: 48 | raise ValueError('Truncating type "%s" not understood' % truncating) 49 | 50 | y = (np.ones((len(s), max_len_2)) * value).astype(dtype) 51 | for j, t in enumerate(s): 52 | if not len(t): 53 | continue 54 | 55 | if truncating == 'pre': 56 | trunc = t[-max_len_2:] 57 | elif truncating == 'post': 58 | trunc = t[:max_len_2] 59 | else: 60 | raise ValueError('Truncating type "%s" not understood' % truncating) 61 | 62 | trunc = np.asarray(trunc, dtype=dtype) 63 | 64 | if padding == 'post': 65 | y[j, :len(trunc)] = trunc 66 | elif padding == 'pre': 67 | y[j, -len(trunc):] = trunc 68 | else: 69 | raise ValueError('Padding type "%s" not understood' % padding) 70 | 71 | if padding == 'post': 72 | x[i, :y.shape[0], :] = y 73 | elif padding == 'pre': 74 | x[i, -y.shape[0]:, :] = y 75 | else: 76 | raise ValueError('Padding type "%s" not understood' % padding) 77 | 78 | return x 79 | 80 | 81 | def get_len_from_corpus(corpus: List[List[str]], 82 | mode: str = 'most') -> int: 83 | """Get sequence len from corpus 84 | 85 | Args: 86 | corpus: List of List of str. 87 | mode: str. One of {'avg', 'median', 'max'm, 'most'} 88 | 89 | """ 90 | lengths = [len(seq) for seq in corpus] 91 | if mode == 'avg': 92 | return math.ceil(np.mean(lengths)) 93 | elif mode == 'median': 94 | return math.ceil(np.median(lengths)) 95 | elif mode == 'max': 96 | return np.max(lengths) 97 | elif mode == 'most': 98 | return sorted(lengths)[int(0.95 * len(corpus))] 99 | else: 100 | raise ValueError(f'`mode` not understood: {mode}') 101 | 102 | 103 | def get_custom_objects() -> Dict[str, Any]: 104 | """Get all custom objects for loading saved tf.keras models in Fancy-NLP.""" 105 | custom_objects = {'CRF': CRF, 106 | 'NonMaskingLayer': NonMaskingLayer, 107 | 'FullMatching': FullMatching, 108 | 'MaxPoolingMatching': MaxPoolingMatching, 109 | 'AttentiveMatching': AttentiveMatching, 110 | 'MaxAttentiveMatching': MaxAttentiveMatching} 111 | return custom_objects 112 | 113 | 114 | class ChineseBertTokenizer(Tokenizer): 115 | """bert tokenizer for chinese""" 116 | def _tokenize(self, text): 117 | result = [] 118 | for ch in text: 119 | if ch in self._token_dict: 120 | result.append(ch) 121 | elif self._is_space(ch): 122 | result.append('[unused1]') 123 | else: 124 | result.append('[UNK]') 125 | return result 126 | -------------------------------------------------------------------------------- /fancy_nlp/layers/attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | import tensorflow.keras.backend as K 5 | 6 | 7 | class MultiHeadAttention(tf.keras.layers.Layer): 8 | """ 9 | Multi-head Attention introduced in Transformer, support masking 10 | """ 11 | def __init__(self, num_units=100, num_heads=3, residual=True, normalize=True, 12 | initializer='orthogonal', regularizer=None, constraint=None, **kwargs): 13 | self.num_units = num_units 14 | self.num_heads = num_heads 15 | self.model_units = self.num_units * self.num_heads 16 | self.residual = residual 17 | self.normalize = normalize 18 | self.initializer = tf.keras.initializers.get(initializer) 19 | self.regularizer = tf.keras.regularizers.get(regularizer) 20 | self.constraint = tf.keras.constraints.get(constraint) 21 | self.supports_masking = True 22 | super(MultiHeadAttention, self).__init__(**kwargs) 23 | 24 | def build(self, input_shape): 25 | if len(input_shape) != 3: 26 | raise ValueError('Input into MultiHeadAttention should be a 3D input tensor') 27 | 28 | self.w_q = self.add_weight(name='w_q', shape=(input_shape[-1], self.model_units), 29 | initializer=self.initializer, regularizer=self.regularizer, 30 | constraint=self.constraint) 31 | self.w_k = self.add_weight(name='w_k', shape=(input_shape[-1], self.model_units), 32 | initializer=self.initializer, regularizer=self.regularizer, 33 | constraint=self.constraint) 34 | self.w_v = self.add_weight(name='w_v', shape=(input_shape[-1], self.model_units), 35 | initializer=self.initializer, regularizer=self.regularizer, 36 | constraint=self.constraint) 37 | self.w_final = self.add_weight(name='w_v', shape=(self.model_units, self.model_units), 38 | initializer=self.initializer, regularizer=self.regularizer, 39 | constraint=self.constraint) 40 | if self.normalize: 41 | self.gamma = self.add_weight(name='gamma', shape=(self.model_units,), initializer='one', 42 | regularizer=self.regularizer, constraint=self.constraint) 43 | self.beta = self.add_weight(name='beta', shape=(self.model_units,), initializer='zero', 44 | regularizer=self.regularizer, constraint=self.constraint) 45 | super(MultiHeadAttention, self).build(input_shape) 46 | 47 | def call(self, inputs, mask=None): 48 | """ 49 | convert to query, key, value vectors, shaped [batch_size*num_head, time_step, embed_dim] 50 | """ 51 | multihead_query = K.concatenate(tf.split(K.dot(inputs, self.w_q), 52 | self.num_heads, axis=2), axis=0) 53 | multihead_key = K.concatenate(tf.split(K.dot(inputs, self.w_k), 54 | self.num_heads, axis=2), axis=0) 55 | multihead_value = K.concatenate(tf.split(K.dot(inputs, self.w_v), 56 | self.num_heads, axis=2), axis=0) 57 | 58 | """scaled dot product""" 59 | scaled = K.int_shape(inputs)[-1] ** -0.5 60 | attend = K.batch_dot(multihead_query, multihead_key, axes=2) * scaled 61 | # apply mask before normalization (softmax) 62 | if mask is not None: 63 | multihead_mask = K.tile(mask, [self.num_heads, 1]) 64 | attend *= K.expand_dims(K.cast(multihead_mask, K.floatx()), 2) 65 | attend *= K.expand_dims(K.cast(multihead_mask, K.floatx()), 1) 66 | # normalization 67 | attend = attend / K.cast(K.sum(attend, axis=-1, keepdims=True) + K.epsilon(), K.floatx()) 68 | # apply attention 69 | attend = K.batch_dot(attend, multihead_value, axes=(2, 1)) 70 | attend = tf.concat(tf.split(attend, self.num_heads, axis=0), axis=2) 71 | attend = K.dot(attend, self.w_final) 72 | 73 | if self.residual: 74 | attend = attend + inputs 75 | if self.normalize: 76 | mean = K.mean(attend, axis=-1, keepdims=True) 77 | std = K.mean(attend, axis=-1, keepdims=True) 78 | attend = self.gamma * (attend - mean) / (std + K.epsilon()) + self.beta 79 | 80 | return attend 81 | 82 | def compute_output_shape(self, input_shape): 83 | return input_shape[0], input_shape[1], self.num_units*self.num_heads 84 | 85 | def get_config(self): 86 | config = {'num_units': self.num_units, 87 | 'num_heads': self.num_heads, 88 | 'residual': self.residual, 89 | 'normalize': self.normalize, 90 | 'initializer': tf.keras.initializers.serialize(self.initializer), 91 | 'regularizer': tf.keras.regularizers.serialize(self.regularizer), 92 | 'constraint': tf.keras.constraints.serialize(self.constraint)} 93 | base_config = super(MultiHeadAttention, self).get_config() 94 | return dict(list(base_config.items()) + list(config.items())) 95 | -------------------------------------------------------------------------------- /fancy_nlp/models/text_classification/base_text_classification_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Base text classification model 4 | """ 5 | import tensorflow as tf 6 | from bert4keras.bert import build_bert_model 7 | 8 | from fancy_nlp.layers import NonMaskingLayer 9 | from fancy_nlp.models.base_model import BaseModel 10 | 11 | 12 | class BaseTextClassificationModel(BaseModel): 13 | def __init__(self, 14 | use_char=True, 15 | char_embeddings=None, 16 | char_vocab_size=-1, 17 | char_embed_dim=-1, 18 | char_embed_trainable=False, 19 | use_bert=False, 20 | bert_config_file=None, 21 | bert_checkpoint_file=None, 22 | bert_trainable=False, 23 | use_word=False, 24 | word_embeddings=None, 25 | word_vocab_size=-1, 26 | word_embed_dim=-1, 27 | word_embed_trainable=False, 28 | max_len=None, 29 | dropout=0.2): 30 | 31 | self.use_char = use_char 32 | self.char_embeddings = char_embeddings 33 | self.char_vocab_size = char_vocab_size 34 | self.char_embed_dim = char_embed_dim 35 | self.char_embed_trainable = char_embed_trainable 36 | self.use_bert = use_bert 37 | self.bert_config_file = bert_config_file 38 | self.bert_checkpoint_file = bert_checkpoint_file 39 | self.bert_trainable = bert_trainable 40 | self.use_word = use_word 41 | self.word_embeddings = word_embeddings 42 | self.word_vocab_size = word_vocab_size 43 | self.word_embed_dim = word_embed_dim 44 | self.word_embed_trainable = word_embed_trainable 45 | self.max_len = max_len 46 | self.dropout = dropout 47 | 48 | assert self.use_char or self.use_bert, "must use char or bert embedding as main input" 49 | assert not (self.use_bert and self.max_len is None), \ 50 | "max_len must be provided when using bert embedding as input" 51 | 52 | def build_input(self): 53 | model_inputs = [] 54 | input_embed = [] 55 | 56 | # TODO: consider masking 57 | if self.use_char: 58 | if self.char_embeddings is not None: 59 | char_embedding_layer = tf.keras.layers.Embedding(input_dim=self.char_vocab_size, 60 | output_dim=self.char_embed_dim, 61 | weights=[self.char_embeddings], 62 | trainable=self.char_embed_trainable 63 | ) 64 | else: 65 | char_embedding_layer = tf.keras.layers.Embedding(input_dim=self.char_vocab_size, 66 | output_dim=self.char_embed_dim) 67 | input_char = tf.keras.layers.Input(shape=(self.max_len,)) 68 | model_inputs.append(input_char) 69 | char_embed = char_embedding_layer(input_char) 70 | input_embed.append(tf.keras.layers.SpatialDropout1D(self.dropout)(char_embed)) 71 | 72 | if self.use_bert: 73 | bert_model = build_bert_model(config_path=self.bert_config_file, 74 | checkpoint_path=self.bert_checkpoint_file) 75 | if not self.bert_trainable: 76 | # manually set every layer in bert model to be non-trainable 77 | for layer in bert_model.layers: 78 | layer.trainable = False 79 | model_inputs.extend(bert_model.inputs) 80 | bert_embed = NonMaskingLayer()(bert_model.output) 81 | input_embed.append(tf.keras.layers.SpatialDropout1D(0.2)(bert_embed)) 82 | 83 | if self.use_word: 84 | if self.word_embeddings is not None: 85 | word_embedding_layer = tf.keras.layers.Embedding(input_dim=self.word_vocab_size, 86 | output_dim=self.word_embed_dim, 87 | weights=[self.word_embeddings], 88 | trainable=self.word_embed_trainable 89 | ) 90 | else: 91 | word_embedding_layer = tf.keras.layers.Embedding(input_dim=self.word_vocab_size, 92 | output_dim=self.word_embed_dim) 93 | input_word = tf.keras.layers.Input(shape=(self.max_len,)) 94 | model_inputs.append(input_word) 95 | word_embed = word_embedding_layer(input_word) 96 | input_embed.append(tf.keras.layers.SpatialDropout1D(self.dropout)(word_embed)) 97 | 98 | if len(input_embed) > 1: 99 | input_embed = tf.keras.layers.concatenate(input_embed) 100 | else: 101 | input_embed = input_embed[0] 102 | return model_inputs, input_embed 103 | 104 | def build_model(self): 105 | raise NotImplementedError 106 | -------------------------------------------------------------------------------- /fancy_nlp/callbacks/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple 4 | 5 | import tensorflow as tf 6 | from seqeval import metrics 7 | from sklearn.metrics import f1_score, precision_score, recall_score, classification_report 8 | import numpy as np 9 | 10 | from fancy_nlp.preprocessors import NERPreprocessor, SPMPreprocessor 11 | 12 | 13 | class NERMetric(tf.keras.callbacks.Callback): 14 | """Callback for evaluating ner model during training. 15 | """ 16 | def __init__(self, 17 | preprocessor: NERPreprocessor, 18 | valid_data: List[List[str]], 19 | valid_labels: List[List[str]]) -> None: 20 | """ 21 | Args: 22 | preprocessor: Instance of `NERPreprocessor`, which helps to prepare feature input for 23 | ner model. 24 | valid_data: List of List of str, can be None. List of tokenized (in char 25 | level) texts for evaluation, like ``[['我', '在', '上', '海', '上', '学'], ...]``. 26 | valid_labels: List of List of str, can be None. The labels of valid_data, usually in 27 | BIO or BIOES format, like ``[['O', 'O', 'B-LOC', 'I-LOC', 'O', 'O'], ...]``. 28 | """ 29 | self.preprocessor = preprocessor 30 | self.valid_data = valid_data 31 | self.valid_labels = valid_labels 32 | self.valid_features, self.valid_y = self.preprocessor.prepare_input(valid_data, 33 | valid_labels) 34 | super(NERMetric, self).__init__() 35 | 36 | def get_lengths(self, pred_probs): 37 | return [min(len(valid_label), pred_prob.shape[0]) 38 | for valid_label, pred_prob in zip(self.valid_labels, pred_probs)] 39 | 40 | def on_epoch_end(self, epoch, logs=None): 41 | pred_probs = self.model.predict(self.valid_features) 42 | if self.preprocessor.use_bert: 43 | pred_probs = pred_probs[:, 1:-1, :] # remove and 44 | y_pred = self.preprocessor.label_decode(pred_probs, self.get_lengths(pred_probs)) 45 | 46 | r = metrics.recall_score(self.valid_labels, y_pred) 47 | p = metrics.precision_score(self.valid_labels, y_pred) 48 | f1 = metrics.f1_score(self.valid_labels, y_pred) 49 | 50 | logs['val_r'] = r 51 | logs['val_p'] = p 52 | logs['val_f1'] = f1 53 | print('Epoch {}: val_r: {}, val_p: {}, val_f1: {}'.format(epoch+1, r, p, f1)) 54 | print(metrics.classification_report(self.valid_labels, y_pred)) 55 | 56 | 57 | class TextClassificationMetric(tf.keras.callbacks.Callback): 58 | """ 59 | callback for evaluating text classification model 60 | """ 61 | def __init__(self, preprocessor, valid_data, valid_labels): 62 | """ 63 | Args: 64 | preprocessor: `TextClassificationPreprocessor` instance to help prepare input for 65 | text classification model 66 | valid_data: list of tokenized texts (, like ``[['我', '是', '中', '国', '人']]`` 67 | valid_labels: list of str, the corresponding label strings 68 | """ 69 | self.preprocessor = preprocessor 70 | self.valid_data = valid_data 71 | self.valid_labels = valid_labels 72 | self.valid_features, self.valid_y = self.preprocessor.prepare_input(valid_data, 73 | valid_labels) 74 | super(TextClassificationMetric, self).__init__() 75 | 76 | def on_epoch_end(self, epoch, logs=None): 77 | pred_probs = self.model.predict(self.valid_features) 78 | y_pred = self.preprocessor.label_decode(pred_probs) 79 | 80 | r = recall_score(self.valid_labels, y_pred, average='macro') 81 | p = precision_score(self.valid_labels, y_pred, average='macro') 82 | f1 = f1_score(self.valid_labels, y_pred, average='macro') 83 | 84 | logs['val_r'] = r 85 | logs['val_p'] = p 86 | logs['val_f1'] = f1 87 | print('Epoch {}: val_r: {}, val_p: {}, val_f1: {}'.format(epoch + 1, r, p, f1)) 88 | print(classification_report(self.valid_labels, y_pred)) 89 | 90 | 91 | class SPMMetric(tf.keras.callbacks.Callback): 92 | """ 93 | callback for evaluating spm model 94 | """ 95 | def __init__(self, 96 | preprocessor: SPMPreprocessor, 97 | valid_data: Tuple[List[str], List[str]], 98 | valid_labels: List[str]) -> None: 99 | """ 100 | Args: 101 | preprocessor: `SPMPreprocessor` instance to help prepare input for spm model 102 | valid_data: list of text pairs (, like ``[['我是中国人', ...], ['我爱中国', ...]]`` 103 | valid_labels: list of str, the corresponding label strings 104 | """ 105 | self.preprocessor = preprocessor 106 | self.valid_data = valid_data 107 | self.valid_labels = valid_labels 108 | self.valid_features, self.valid_y = self.preprocessor.prepare_input(valid_data, 109 | valid_labels) 110 | super(SPMMetric, self).__init__() 111 | 112 | def on_epoch_end(self, epoch, logs=None): 113 | pred_probs = self.model.predict(self.valid_features) 114 | y_pred = np.argmax(pred_probs, axis=-1) 115 | y_true = np.argmax(self.valid_y, axis=-1) 116 | 117 | r = recall_score(y_true, y_pred, average='macro') 118 | p = precision_score(y_true, y_pred, average='macro') 119 | f1 = f1_score(y_true, y_pred, average='macro') 120 | 121 | logs['val_r'] = r 122 | logs['val_p'] = p 123 | logs['val_f1'] = f1 124 | print('Epoch {}: val_r: {}, val_p: {}, val_f1: {}'.format(epoch, r, p, f1)) 125 | print(classification_report(y_true, y_pred)) 126 | -------------------------------------------------------------------------------- /fancy_nlp/utils/data_generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | from typing import List, Optional, Tuple 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from fancy_nlp.preprocessors import NERPreprocessor, SPMPreprocessor 10 | 11 | 12 | class NERGenerator(tf.keras.utils.Sequence): 13 | """Data Generator for NER 14 | """ 15 | def __init__(self, 16 | preprocessor: NERPreprocessor, 17 | data: List[List[str]], 18 | labels: Optional[List[List[str]]] = None, 19 | batch_size: int = 32, 20 | shuffle: bool = True) -> None: 21 | """ 22 | Args: 23 | preprocessor: Instance of NERPreprocessor, which helps to prepare feature input for 24 | ner model. 25 | data: List of List of str. List of tokenized texts for training, 26 | like ``[['我', '在', '上', '海', '上', '学'], ...]``. 27 | labels: List of List of str. The labels of train_data, usually in BIO or BIOES 28 | format, like ``[['O', 'O', 'B-LOC', 'I-LOC', 'O', 'O'], ...]``. 29 | batch_size: int. How many samples to train on in one iteration 30 | shuffle: Boolean. wWhether to shuffle data after each epoch of training. 31 | """ 32 | self.preprocessor = preprocessor 33 | self.data = data 34 | self.labels = labels 35 | self.data_size = len(self.data) 36 | self.batch_size = batch_size 37 | self.indices = np.arange(self.data_size) 38 | self.steps = int(math.ceil(self.data_size / self.batch_size)) 39 | self.shuffle = shuffle 40 | 41 | def __len__(self): 42 | return self.steps 43 | 44 | def on_epoch_end(self): 45 | if self.shuffle: 46 | np.random.shuffle(self.indices) 47 | 48 | def __getitem__(self, index): 49 | batch_index = self.indices[index * self.batch_size: (index + 1) * self.batch_size] 50 | if self.labels is not None: 51 | batch_data, batch_labels = zip(*[(self.data[i], self.labels[i]) for i in batch_index]) 52 | else: 53 | batch_data = [self.data[i] for i in batch_index] 54 | batch_labels = None 55 | return self.preprocessor.prepare_input(batch_data, batch_labels) 56 | 57 | 58 | class TextClassificationGenerator(tf.keras.utils.Sequence): 59 | """Data Generator for text classification 60 | """ 61 | def __init__(self, preprocessor, data, labels=None, batch_size=32, shuffle=True): 62 | """ 63 | Args: 64 | preprocessor: `TextClassificationPreprocessor` instance to help prepare input for ner model 65 | data: list of tokenized texts (, like ``[['我', '是', '中', '国', '人']]`` 66 | labels: list of str, the corresponding label strings 67 | batch_size: how many samples to train on in one iteration 68 | shuffle: whether to shuffle data after each epoch of training 69 | """ 70 | self.preprocessor = preprocessor 71 | self.data = data 72 | self.labels = labels 73 | self.data_size = len(self.data) 74 | self.batch_size = batch_size 75 | self.indices = np.arange(self.data_size) 76 | self.steps = int(math.ceil(self.data_size / self.batch_size)) 77 | self.shuffle = shuffle 78 | 79 | def __len__(self): 80 | return self.steps 81 | 82 | def on_epoch_end(self): 83 | if self.shuffle: 84 | np.random.shuffle(self.indices) 85 | 86 | def __getitem__(self, index): 87 | batch_index = self.indices[index * self.batch_size: (index + 1) * self.batch_size] 88 | if self.labels is not None: 89 | batch_data, batch_labels = zip(*[(self.data[i], self.labels[i]) for i in batch_index]) 90 | else: 91 | batch_data = [self.data[i] for i in batch_index] 92 | batch_labels = None 93 | return self.preprocessor.prepare_input(batch_data, batch_labels) 94 | 95 | 96 | class SPMGenerator(tf.keras.utils.Sequence): 97 | """Data Generator for SPM 98 | """ 99 | def __init__(self, preprocessor: SPMPreprocessor, 100 | data: Tuple[List[str], List[str]], 101 | labels: Optional[List[str]] = None, 102 | batch_size: int = 32, 103 | shuffle: bool = True) -> None: 104 | """ 105 | Args: 106 | preprocessor: `SPMPreprocessor` instance to help prepare input for spm model 107 | data: list of text pairs (, like ``[['我是中国人', ...], ['我爱中国', ...]]`` 108 | labels: list of str, the corresponding label strings 109 | batch_size: how many samples to train on in one iteration 110 | shuffle: whether to shuffle data after each epoch of training 111 | """ 112 | self.preprocessor = preprocessor 113 | self.data = data 114 | self.labels = labels 115 | self.data_size = len(self.data[0]) 116 | self.batch_size = batch_size 117 | self.indices = np.arange(self.data_size) 118 | self.steps = int(math.ceil(self.data_size / self.batch_size)) 119 | self.shuffle = shuffle 120 | 121 | def __len__(self): 122 | return self.steps 123 | 124 | def on_epoch_end(self): 125 | if self.shuffle: 126 | np.random.shuffle(self.indices) 127 | 128 | def __getitem__(self, index): 129 | batch_index = self.indices[index * self.batch_size: (index + 1) * self.batch_size] 130 | if self.labels is not None: 131 | batch_data_a, batch_data_b, batch_labels = \ 132 | zip(*[(self.data[0][i], self.data[1][i], self.labels[i]) for i in batch_index]) 133 | else: 134 | batch_data_a, batch_data_b = zip(*[(self.data[0][i], self.data[1][i]) 135 | for i in batch_index]) 136 | batch_labels = None 137 | return self.preprocessor.prepare_input((batch_data_a, batch_data_b), batch_labels) 138 | -------------------------------------------------------------------------------- /tests/fancy_nlp/applications/test_ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_ner_data_and_labels 6 | from fancy_nlp.applications import NER 7 | from fancy_nlp.config import CACHE_DIR 8 | 9 | 10 | class TestNER: 11 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/ner/msra/example.txt') 12 | bert_vocab_file = os.path.join(os.path.dirname(__file__), 13 | '../../../data/embeddings/bert_sample_model/vocab.txt') 14 | bert_config_file = os.path.join(os.path.dirname(__file__), 15 | '../../../data/embeddings/bert_sample_model/bert_config.json') 16 | bert_model_file = os.path.join(os.path.dirname(__file__), 17 | '../../../data/embeddings/bert_sample_model/bert_model.ckpt') 18 | 19 | def setup_class(self): 20 | self.train_data, self.train_labels, self.valid_data, self.valid_labels = \ 21 | load_ner_data_and_labels(self.test_file, split=True) 22 | 23 | self.checkpoint_dir = os.path.dirname(__file__) 24 | self.model_name = 'bilstm_cnn_ner' 25 | self.json_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.json') 26 | self.weights_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.hdf5') 27 | self.swa_weights_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner_swa.hdf5') 28 | self.preprocessor_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_preprocessor.pkl') 29 | 30 | def test_ner(self): 31 | ner = NER() 32 | 33 | # cache_dir = os.path.expanduser(CACHE_DIR) 34 | # cache_subdir = 'pretrained_models' 35 | # preprocessor_file = os.path.join(cache_dir, cache_subdir, 36 | # 'msra_ner_bilstm_cnn_crf_preprocessor.pkl') 37 | # json_file = os.path.join(cache_dir, cache_subdir, 'msra_ner_bilstm_cnn_crf.json') 38 | # 39 | # weights_file = os.path.join(cache_dir, cache_subdir, 'msra_ner_bilstm_cnn_crf.hdf5') 40 | # assert os.path.exists(preprocessor_file) 41 | # assert os.path.exists(json_file) 42 | # assert os.path.exists(weights_file) 43 | # 44 | # ner.analyze('同济大学位于上海市杨浦区,成立于1907年') 45 | # ner.restrict_analyze('同济大学位于上海市杨浦区,成立于1907年') 46 | 47 | # test train 48 | ner.fit(train_data=self.train_data, 49 | train_labels=self.train_labels, 50 | valid_data=self.valid_data, 51 | valid_labels=self.valid_labels, 52 | ner_model_type='bilstm_cnn', 53 | use_char=True, 54 | use_bert=True, 55 | bert_vocab_file=self.bert_vocab_file, 56 | bert_config_file=self.bert_config_file, 57 | bert_checkpoint_file=self.bert_model_file, 58 | use_word=True, 59 | max_len=16, 60 | batch_size=2, 61 | epochs=7, 62 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 63 | checkpoint_dir=self.checkpoint_dir, 64 | model_name=self.model_name, 65 | load_swa_model=True) 66 | 67 | assert not os.path.exists(self.json_file) 68 | assert os.path.exists(self.weights_file) 69 | assert os.path.exists(self.swa_weights_file) 70 | os.remove(self.weights_file) 71 | os.remove(self.swa_weights_file) 72 | assert not os.path.exists(self.weights_file) 73 | assert not os.path.exists(self.swa_weights_file) 74 | 75 | # test score 76 | score = ner.score(self.valid_data, self.valid_labels) 77 | assert isinstance(score, (float, int)) 78 | 79 | # test predict 80 | valid_tag = ner.predict(self.valid_data[0]) 81 | assert isinstance(valid_tag, list) and isinstance(valid_tag[0], str) 82 | assert len(valid_tag) == len(self.valid_data[0]) or \ 83 | len(valid_tag) == ner.preprocessor.max_len - 2 84 | 85 | # test predict_batch 86 | valid_tags = ner.predict_batch(self.valid_data) 87 | assert isinstance(valid_tags, list) and isinstance(valid_tags[-1], list) 88 | assert isinstance(valid_tags[-1][0], str) 89 | assert len(valid_tags) == len(self.valid_data) 90 | assert len(valid_tags[-1]) == len(self.valid_data[-1]) or \ 91 | len(valid_tags[-1]) == ner.preprocessor.max_len - 2 92 | 93 | # test analyze 94 | result = ner.analyze(self.valid_data[0]) 95 | assert isinstance(result, dict) and 'text' in result and 'entities' in result 96 | 97 | # test analyze_batch 98 | results = ner.analyze_batch(self.valid_data) 99 | assert isinstance(results, list) and len(results) == len(self.valid_data) 100 | assert isinstance(results[-1], dict) 101 | assert 'text' in results[-1] and 'entities' in results[-1] 102 | 103 | # test restrict analyze 104 | result = ner.restrict_analyze(self.valid_data[0], threshold=0.85) 105 | entity_types = [entity['type'] for entity in result['entities']] 106 | assert len(set(entity_types)) == len(entity_types) 107 | for score in [entity['score'] for entity in result['entities']]: 108 | assert score >= 0.85 109 | 110 | # test restrict analyze batch 111 | results = ner.restrict_analyze_batch(self.valid_data, threshold=0.85) 112 | entity_types = [entity['type'] for entity in results[-1]['entities']] 113 | assert len(set(entity_types)) == len(entity_types) 114 | for score in [entity['score'] for entity in results[-1]['entities']]: 115 | assert score >= 0.85 116 | 117 | # test save 118 | ner.save(self.preprocessor_file, self.json_file, self.weights_file) 119 | assert os.path.exists(self.json_file) 120 | assert os.path.exists(self.weights_file) 121 | assert os.path.exists(self.preprocessor_file) 122 | 123 | # test load 124 | ner.load(self.preprocessor_file, self.json_file, self.weights_file) 125 | os.remove(self.json_file) 126 | os.remove(self.weights_file) 127 | os.remove(self.preprocessor_file) 128 | 129 | 130 | test_ner = TestNER() 131 | test_ner.setup_class() 132 | test_ner.test_ner() 133 | -------------------------------------------------------------------------------- /tests/fancy_nlp/applications/test_spm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_spm_data_and_labels 6 | from fancy_nlp.applications import SPM 7 | 8 | 9 | class TestSPM: 10 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/spm/webank/example.txt') 11 | bert_vocab_file = os.path.join(os.path.dirname(__file__), 12 | '../../../data/embeddings/bert_sample_model/vocab.txt') 13 | bert_config_file = os.path.join(os.path.dirname(__file__), 14 | '../../../data/embeddings/bert_sample_model/bert_config.json') 15 | bert_model_file = os.path.join(os.path.dirname(__file__), 16 | '../../../data/embeddings/bert_sample_model/bert_model.ckpt') 17 | 18 | def setup_class(self): 19 | self.train_data, self.train_labels, self.valid_data, self.valid_labels = \ 20 | load_spm_data_and_labels(self.test_file, split_mode=1, split_size=0.3) 21 | 22 | self.checkpoint_dir = os.path.dirname(__file__) 23 | self.model_name = 'siamese_cnn_spm' 24 | self.json_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm.json') 25 | self.weights_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm.hdf5') 26 | self.swa_weights_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm_swa.hdf5') 27 | self.preprocessor_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_preprocessor.pkl') 28 | 29 | def test_spm(self): 30 | spm = SPM() 31 | # spm.predict(['未满足微众银行审批是什么意思', '为什么我未满足微众银行审批']) 32 | # spm.analyze(['未满足微众银行审批是什么意思', '为什么我未满足微众银行审批']) 33 | 34 | # test train word and char 35 | spm.fit(train_data=self.train_data, 36 | train_labels=self.train_labels, 37 | valid_data=self.valid_data, 38 | valid_labels=self.valid_labels, 39 | spm_model_type='siamese_cnn', 40 | use_word=True, 41 | word_embed_type='fasttext', 42 | word_embed_dim=20, 43 | use_char=True, 44 | char_embed_dim=20, 45 | use_bert=False, 46 | bert_vocab_file=self.bert_vocab_file, 47 | bert_config_file=self.bert_config_file, 48 | bert_checkpoint_file=self.bert_model_file, 49 | batch_size=6, 50 | epochs=2, 51 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 52 | checkpoint_dir=self.checkpoint_dir, 53 | model_name=self.model_name, 54 | load_swa_model=True) 55 | 56 | # test train char and bert 57 | spm.fit(train_data=self.train_data, 58 | train_labels=self.train_labels, 59 | valid_data=self.valid_data, 60 | valid_labels=self.valid_labels, 61 | spm_model_type='siamese_cnn', 62 | use_word=False, 63 | use_char=True, 64 | use_bert=True, 65 | bert_vocab_file=self.bert_vocab_file, 66 | bert_config_file=self.bert_config_file, 67 | bert_checkpoint_file=self.bert_model_file, 68 | max_len=10, 69 | batch_size=6, 70 | epochs=2, 71 | callback_list=['modelcheckpoint', 'earlystopping'], 72 | checkpoint_dir=self.checkpoint_dir, 73 | model_name=self.model_name, 74 | load_swa_model=True) 75 | 76 | # test train bert model 77 | spm.fit(train_data=self.train_data, 78 | train_labels=self.train_labels, 79 | valid_data=self.valid_data, 80 | valid_labels=self.valid_labels, 81 | spm_model_type='bert', 82 | use_word=False, 83 | use_char=False, 84 | use_bert=True, 85 | bert_vocab_file=self.bert_vocab_file, 86 | bert_config_file=self.bert_config_file, 87 | bert_checkpoint_file=self.bert_model_file, 88 | bert_output_layer_num=2, 89 | max_len=10, 90 | batch_size=6, 91 | epochs=2, 92 | callback_list=['modelcheckpoint', 'earlystopping', 'swa'], 93 | checkpoint_dir=self.checkpoint_dir, 94 | model_name=self.model_name, 95 | load_swa_model=True) 96 | 97 | assert not os.path.exists(self.json_file) 98 | assert os.path.exists(self.weights_file) 99 | assert os.path.exists(self.swa_weights_file) 100 | os.remove(self.weights_file) 101 | os.remove(self.swa_weights_file) 102 | assert not os.path.exists(self.weights_file) 103 | assert not os.path.exists(self.swa_weights_file) 104 | 105 | # test score 106 | score = spm.score(self.valid_data, self.valid_labels) 107 | assert isinstance(score, (float, int)) 108 | 109 | # test predict 110 | valid_label = spm.predict((self.valid_data[0][0], self.valid_data[1][0])) 111 | assert isinstance(valid_label, str) 112 | 113 | # test predict_batch 114 | valid_labels = spm.predict_batch(self.valid_data) 115 | assert isinstance(valid_labels, list) and isinstance(valid_labels[-1], str) 116 | assert len(valid_labels) == len(self.valid_data[0]) 117 | assert valid_label == valid_labels[0] 118 | 119 | # test analyze 120 | valid_label = spm.analyze((self.valid_data[0][0], self.valid_data[1][0])) 121 | assert isinstance(valid_label, tuple) 122 | assert len(valid_label) == spm.preprocessor.num_class 123 | 124 | # test analyze_batch 125 | valid_labels = spm.analyze_batch(self.valid_data) 126 | assert isinstance(valid_labels, list) and isinstance(valid_labels[-1], tuple) 127 | assert len(valid_labels) == len(self.valid_data[0]) 128 | 129 | # test save 130 | spm.save(self.preprocessor_file, self.json_file, self.weights_file) 131 | assert os.path.exists(self.json_file) 132 | assert os.path.exists(self.weights_file) 133 | assert os.path.exists(self.preprocessor_file) 134 | 135 | # test load 136 | spm.load(self.preprocessor_file, self.json_file, self.weights_file) 137 | os.remove(self.json_file) 138 | os.remove(self.weights_file) 139 | os.remove(self.preprocessor_file) 140 | -------------------------------------------------------------------------------- /fancy_nlp/layers/matching.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | import tensorflow.keras.backend as K 5 | 6 | 7 | class FullMatching(tf.keras.layers.Layer): 8 | """ 9 | Full Matching strategy, each contextual embedding is compared with the average 10 | representation of the other sentence. 11 | """ 12 | def __init__(self, perspective_num=10, **kwargs): 13 | self.perspective_num = perspective_num 14 | self.kernel = None 15 | super(FullMatching, self).__init__(**kwargs) 16 | 17 | def build(self, input_shape): 18 | self.dim = input_shape[0][-1] 19 | self.max_len = input_shape[0][1] 20 | self.kernel = self.add_weight(name='kernel', shape=(self.perspective_num, self.dim), 21 | initializer='glorot_uniform') 22 | super(FullMatching, self).build(input_shape) 23 | 24 | def call(self, inputs, **kwargs): 25 | sent1 = inputs[0] 26 | sent2 = inputs[1] 27 | 28 | v1 = K.expand_dims(sent1, -2) * self.kernel 29 | v2 = self.kernel * K.expand_dims(sent2, 1) 30 | v2 = K.expand_dims(v2, 1) 31 | v1 = K.l2_normalize(v1, axis=-1) 32 | v2 = K.l2_normalize(v2, axis=-1) 33 | matching = K.sum(v1 * v2, axis=-1) 34 | return matching 35 | 36 | def compute_output_shape(self, input_shape): 37 | return input_shape[0][0], input_shape[0][1], self.perspective_num 38 | 39 | def get_config(self): 40 | config = {'perspective_num': self.perspective_num} 41 | base_config = super(FullMatching, self).get_config() 42 | return dict(list(base_config.items()) + list(config.items())) 43 | 44 | 45 | class MaxPoolingMatching(tf.keras.layers.Layer): 46 | """ 47 | MaxPooling Matching strategy, each contextual embedding is compared with every 48 | contextual embeddings of the other sentence, and only the maximum value of each 49 | dimension is retained. 50 | """ 51 | def __init__(self, perspective_num=10, **kwargs): 52 | self.perspective_num = perspective_num 53 | self.kernel = None 54 | super(MaxPoolingMatching, self).__init__(**kwargs) 55 | 56 | def build(self, input_shape): 57 | self.dim = input_shape[0][-1] 58 | self.max_len = input_shape[0][1] 59 | self.kernel = self.add_weight(name='kernel', shape=(self.perspective_num, self.dim), 60 | initializer='glorot_uniform') 61 | super(MaxPoolingMatching, self).build(input_shape) 62 | 63 | def call(self, inputs, **kwargs): 64 | sent1 = inputs[0] 65 | sent2 = inputs[1] 66 | 67 | v1 = K.expand_dims(sent1, -2) * self.kernel 68 | v2 = K.expand_dims(sent2, -2) * self.kernel 69 | v1 = K.l2_normalize(v1, axis=-1) 70 | v2 = K.l2_normalize(v2, axis=-1) 71 | matching = K.max(K.sum(K.expand_dims(v1, 2) * K.expand_dims(v2, 1), axis=-1), axis=-2) 72 | return matching 73 | 74 | def compute_output_shape(self, input_shape): 75 | return input_shape[0][0], input_shape[0][1], self.perspective_num 76 | 77 | def get_config(self): 78 | config = {'perspective_num': self.perspective_num} 79 | base_config = super(MaxPoolingMatching, self).get_config() 80 | return dict(list(base_config.items()) + list(config.items())) 81 | 82 | 83 | class AttentiveMatching(tf.keras.layers.Layer): 84 | """ 85 | Attentive Matching strategy, each contextual embedding is compared with its attentive 86 | weighted representation of the other sentence. 87 | """ 88 | def __init__(self, perspective_num=10, **kwargs): 89 | self.perspective_num = perspective_num 90 | self.kernel = None 91 | super(AttentiveMatching, self).__init__(**kwargs) 92 | 93 | def build(self, input_shape): 94 | self.dim = input_shape[0][-1] 95 | self.max_len = input_shape[0][1] 96 | self.kernel = self.add_weight(name='kernel', shape=(self.perspective_num, self.dim), 97 | initializer='glorot_uniform') 98 | super(AttentiveMatching, self).build(input_shape) 99 | 100 | def call(self, inputs, **kwargs): 101 | sent1 = inputs[0] 102 | sent2 = inputs[1] 103 | 104 | v1 = K.expand_dims(sent1, -2) * self.kernel 105 | v2 = K.expand_dims(sent2, -2) * self.kernel 106 | v1 = K.l2_normalize(v1, axis=-1) 107 | v2 = K.l2_normalize(v2, axis=-1) 108 | matching = K.sum(v1 * v2, axis=-1) 109 | return matching 110 | 111 | def compute_output_shape(self, input_shape): 112 | return input_shape[0][0], input_shape[0][1], self.perspective_num 113 | 114 | def get_config(self): 115 | config = {'perspective_num': self.perspective_num} 116 | base_config = super(AttentiveMatching, self).get_config() 117 | return dict(list(base_config.items()) + list(config.items())) 118 | 119 | 120 | class MaxAttentiveMatching(tf.keras.layers.Layer): 121 | """ 122 | MaxAttentive Matching strategy, each contextual embedding picks the contextual 123 | embedding of the other sentence with the highest cosine similarity as the 124 | attentive vector. 125 | """ 126 | def __init__(self, perspective_num=10, **kwargs): 127 | self.perspective_num = perspective_num 128 | self.kernel = None 129 | super(MaxAttentiveMatching, self).__init__(**kwargs) 130 | 131 | def build(self, input_shape): 132 | self.dim = input_shape[0][-1] 133 | self.max_len = input_shape[0][1] 134 | self.kernel = self.add_weight(name='kernel', shape=(self.perspective_num, self.dim), 135 | initializer='glorot_uniform') 136 | super(MaxAttentiveMatching, self).build(input_shape) 137 | 138 | def call(self, inputs, **kwargs): 139 | sent1 = inputs[0] 140 | sent2 = inputs[1] 141 | 142 | v1 = K.expand_dims(sent1, -2) * self.kernel 143 | v2 = K.expand_dims(sent2, -2) * self.kernel 144 | v1 = K.l2_normalize(v1, axis=-1) 145 | v2 = K.l2_normalize(v2, axis=-1) 146 | matching = K.sum(v1 * v2, axis=-1) 147 | return matching 148 | 149 | def compute_output_shape(self, input_shape): 150 | return input_shape[0][0], input_shape[0][1], self.perspective_num 151 | 152 | def get_config(self): 153 | config = {'perspective_num': self.perspective_num} 154 | base_config = super(MaxAttentiveMatching, self).get_config() 155 | return dict(list(base_config.items()) + list(config.items())) 156 | -------------------------------------------------------------------------------- /tests/fancy_nlp/preprocessors/test_spm_preprocessor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | import jieba 6 | import numpy as np 7 | 8 | from fancy_nlp.utils.data_loader import load_spm_data_and_labels 9 | from fancy_nlp.preprocessors.spm_preprocessor import SPMPreprocessor 10 | 11 | 12 | class TestSPMPreprocessor: 13 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/spm/webank/example.txt') 14 | bert_vocab_file = os.path.join(os.path.dirname(__file__), 15 | '../../../data/embeddings/bert_sample_model/vocab.txt') 16 | 17 | def setup_class(self): 18 | self.x_train, self.y_train = load_spm_data_and_labels(self.test_file) 19 | self.preprocessor = SPMPreprocessor(self.x_train, self.y_train, use_word=True, 20 | use_char=True, use_bert=False, 21 | bert_vocab_file=self.bert_vocab_file, 22 | external_word_dict=['微众'], 23 | word_embed_type='word2vec', 24 | max_len=16, max_word_len=3) 25 | 26 | def test_no_word(self): 27 | preprocessor = SPMPreprocessor(self.x_train, self.y_train, use_word=False, 28 | use_char=True, use_bert=True, 29 | bert_vocab_file=self.bert_vocab_file, 30 | external_word_dict=['微众'], 31 | char_embed_type='word2vec', max_len=16) 32 | 33 | assert len(preprocessor.char_vocab_count) + 4 == len(preprocessor.char_vocab) \ 34 | == len(preprocessor.id2char) 35 | assert list(preprocessor.id2char.keys())[0] == 0 36 | for cnt in preprocessor.char_vocab_count.values(): 37 | assert cnt >= 2 38 | assert preprocessor.char_embeddings.shape[0] == len(preprocessor.char_vocab) 39 | assert preprocessor.char_embeddings.shape[1] == 300 40 | assert not np.any(preprocessor.char_embeddings[0]) 41 | assert preprocessor.word_embeddings is None 42 | 43 | assert len(preprocessor.label_vocab) == len(preprocessor.id2label) 44 | assert list(preprocessor.id2label.keys())[0] == 0 45 | 46 | features, y = preprocessor.prepare_input(preprocessor.train_data, 47 | preprocessor.train_labels) 48 | assert len(features) == 6 49 | assert features[0].shape == features[1].shape == features[2].shape == features[3].shape == \ 50 | features[4].shape == features[5].shape == \ 51 | (len(self.x_train[0]), preprocessor.max_len) 52 | assert preprocessor.id2char[features[0][0][0]] == preprocessor.cls_token 53 | assert y.shape == (len(self.x_train[0]), preprocessor.num_class) 54 | 55 | def test_bert_model(self): 56 | preprocessor = SPMPreprocessor(self.x_train, self.y_train, use_word=False, 57 | use_char=False, use_bert=True, use_bert_model=True, 58 | bert_vocab_file=self.bert_vocab_file, 59 | max_len=16) 60 | 61 | assert preprocessor.word_embeddings is None 62 | assert preprocessor.char_embeddings is None 63 | 64 | assert len(preprocessor.label_vocab) == len(preprocessor.id2label) 65 | assert list(preprocessor.id2label.keys())[0] == 0 66 | 67 | features, y = preprocessor.prepare_input(self.x_train, self.y_train) 68 | assert len(features) == 2 69 | assert features[0].shape == features[1].shape == \ 70 | (len(self.x_train[0]), preprocessor.max_len) 71 | assert y.shape == (len(self.x_train[0]), preprocessor.num_class) 72 | 73 | def test_no_bert(self): 74 | preprocessor = SPMPreprocessor(self.x_train, self.y_train, use_word=True, 75 | use_char=True, use_bert=False, 76 | bert_vocab_file=self.bert_vocab_file, 77 | external_word_dict=['微众'], 78 | word_embed_type='word2vec', 79 | max_len=16, max_word_len=3) 80 | 81 | assert len(preprocessor.word_vocab_count) + 2 == len(preprocessor.word_vocab) \ 82 | == len(preprocessor.id2word) 83 | assert list(preprocessor.id2word.keys())[0] == 0 84 | for cnt in preprocessor.word_vocab_count.values(): 85 | assert cnt >= 2 86 | assert preprocessor.word_embeddings.shape[0] == len(preprocessor.word_vocab) 87 | assert preprocessor.word_embeddings.shape[1] == 300 88 | assert not np.any(preprocessor.word_embeddings[0]) 89 | 90 | assert len(preprocessor.char_vocab_count) + 2 == len(preprocessor.char_vocab) \ 91 | == len(preprocessor.id2char) 92 | assert list(preprocessor.id2char.keys())[0] == 0 93 | for cnt in preprocessor.char_vocab_count.values(): 94 | assert cnt >= 2 95 | assert preprocessor.char_embeddings is None 96 | 97 | assert len(preprocessor.label_vocab) == len(preprocessor.id2label) 98 | assert list(preprocessor.id2label.keys())[0] == 0 99 | 100 | features, y = preprocessor.prepare_input(self.x_train, self.y_train) 101 | assert len(features) == 4 102 | assert features[0].shape == features[2].shape == \ 103 | (len(self.x_train[0]), preprocessor.max_len) and \ 104 | features[1].shape == features[3].shape == \ 105 | (len(self.x_train[0]), preprocessor.max_len, preprocessor.max_word_len) 106 | assert y.shape == (len(self.x_train[0]), preprocessor.num_class) 107 | 108 | def test_get_word_ids(self): 109 | example_text = ''.join(self.x_train[0][0]) 110 | word_cut = jieba.lcut(example_text) 111 | word_ids = self.preprocessor.get_word_ids(word_cut) 112 | assert len(word_ids) == len(word_cut) 113 | 114 | def test_label_decode(self): 115 | rand_pred_probs = np.random.rand(2, self.preprocessor.num_class) 116 | pred_labels = self.preprocessor.label_decode(rand_pred_probs) 117 | assert isinstance(pred_labels[0], str) 118 | assert len(pred_labels) == 2 119 | 120 | def test_save_load(self): 121 | pkl_file = 'test_preprocessor.pkl' 122 | self.preprocessor.save(pkl_file) 123 | assert os.path.exists(pkl_file) 124 | new_preprocessor = SPMPreprocessor.load(pkl_file) 125 | assert new_preprocessor.num_class == self.preprocessor.num_class 126 | os.remove(pkl_file) 127 | -------------------------------------------------------------------------------- /fancy_nlp/models/ner/base_ner_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Base NER model 4 | """ 5 | 6 | from typing import Optional 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | from bert4keras.bert import build_bert_model 11 | 12 | from fancy_nlp.layers import NonMaskingLayer 13 | from fancy_nlp.models.base_model import BaseModel 14 | 15 | 16 | class BaseNERModel(BaseModel): 17 | """The basic class for ner models. All the ner models will inherit from it. 18 | 19 | """ 20 | def __init__(self, 21 | use_char: bool = True, 22 | char_embeddings: Optional[np.ndarray] = None, 23 | char_vocab_size: int = -1, 24 | char_embed_dim: int = -1, 25 | char_embed_trainable: bool = False, 26 | use_bert: bool = False, 27 | bert_config_file: Optional[str] = None, 28 | bert_checkpoint_file: Optional[str] = None, 29 | bert_trainable: bool = False, 30 | use_word: bool = False, 31 | word_embeddings: Optional[np.ndarray] = None, 32 | word_vocab_size: int = -1, 33 | word_embed_dim: int = -1, 34 | word_embed_trainable: bool = False, 35 | max_len: Optional[int] = None, 36 | dropout: float = 0.2) -> None: 37 | """ 38 | 39 | Args: 40 | use_char: Boolean. Whether to use character embedding as input. 41 | char_embeddings: Optional np.ndarray. Char embedding matrix, shaped 42 | [char_vocab_size, char_embed_dim]. There are 2 cases when char_embeddings is None: 43 | 1) use_char is False, do not use char embedding as input; 2) user did not 44 | provide valid pre-trained embedding file or any embedding training method. In 45 | this case, use randomly initialized embedding instead. 46 | char_vocab_size: int. The size of char vocabulary. 47 | char_embed_dim: int. Dimensionality of char embedding. 48 | char_embed_trainable: Boolean. Whether to update char embedding during training. 49 | use_bert: Boolean. Whether to use bert embedding as input. 50 | bert_config_file: Optional str, can be None. Path to bert's configuration file. 51 | bert_checkpoint_file: Optional str, can be None. Path to bert's checkpoint file. 52 | bert_trainable: Boolean. Whether to update bert during training. 53 | use_word: Boolean. Whether to use word as additional input. 54 | word_embeddings: Optional np.ndarray. Similar as char_embeddings. 55 | word_vocab_size: int. Similar as char_vocab_size. 56 | word_embed_dim: int. Similar as char_embed_dim. 57 | word_embed_trainable: Boolean. Similar as char_embed_trainable. 58 | max_len: Optional int, can be None. Max length of one sequence. 59 | dropout: float. The dropout rate applied to embedding layer. 60 | """ 61 | 62 | self.use_char = use_char 63 | self.char_embeddings = char_embeddings 64 | self.char_vocab_size = char_vocab_size 65 | self.char_embed_dim = char_embed_dim 66 | self.char_embed_trainable = char_embed_trainable 67 | self.use_bert = use_bert 68 | self.bert_config_file = bert_config_file 69 | self.bert_checkpoint_file = bert_checkpoint_file 70 | self.bert_trainable = bert_trainable 71 | self.use_word = use_word 72 | self.word_embeddings = word_embeddings 73 | self.word_vocab_size = word_vocab_size 74 | self.word_embed_dim = word_embed_dim 75 | self.word_embed_trainable = word_embed_trainable 76 | self.max_len = max_len 77 | self.dropout = dropout 78 | 79 | assert self.use_char or self.use_bert, "must use char or bert embedding as main input" 80 | assert not (self.use_bert and self.max_len is None), \ 81 | "max_len must be provided when using bert embedding as input" 82 | 83 | def build_input(self): 84 | """Build input placeholder and prepare embedding for ner model. 85 | 86 | Returns: Tuples of 2 tensor: 87 | 1). Input tensor(s), depending whether using multiple inputs; 88 | 2). Embedding tensor, which will be passed to following layers of ner models. 89 | 90 | """ 91 | model_inputs = [] 92 | input_embed = [] 93 | 94 | # TODO: consider masking 95 | if self.use_char: 96 | if self.char_embeddings is not None: 97 | char_embedding_layer = tf.keras.layers.Embedding( 98 | input_dim=self.char_vocab_size, 99 | output_dim=self.char_embed_dim, 100 | weights=[self.char_embeddings], 101 | trainable=self.char_embed_trainable) 102 | else: 103 | char_embedding_layer = tf.keras.layers.Embedding(input_dim=self.char_vocab_size, 104 | output_dim=self.char_embed_dim) 105 | input_char = tf.keras.layers.Input(shape=(self.max_len,)) 106 | model_inputs.append(input_char) 107 | 108 | char_embed = char_embedding_layer(input_char) 109 | input_embed.append(tf.keras.layers.SpatialDropout1D(self.dropout)(char_embed)) 110 | 111 | if self.use_bert: 112 | bert_model = build_bert_model(config_path=self.bert_config_file, 113 | checkpoint_path=self.bert_checkpoint_file) 114 | if not self.bert_trainable: 115 | # manually set every layer in bert model to be non-trainable 116 | for layer in bert_model.layers: 117 | layer.trainable = False 118 | 119 | model_inputs.extend(bert_model.inputs) 120 | bert_embed = NonMaskingLayer()(bert_model.output) 121 | input_embed.append(tf.keras.layers.SpatialDropout1D(0.2)(bert_embed)) 122 | 123 | if self.use_word: 124 | if self.word_embeddings is not None: 125 | word_embedding_layer = tf.keras.layers.Embedding( 126 | input_dim=self.word_vocab_size, 127 | output_dim=self.word_embed_dim, 128 | weights=[self.word_embeddings], 129 | trainable=self.word_embed_trainable) 130 | else: 131 | word_embedding_layer = tf.keras.layers.Embedding(input_dim=self.word_vocab_size, 132 | output_dim=self.word_embed_dim) 133 | input_word = tf.keras.layers.Input(shape=(self.max_len,)) 134 | model_inputs.append(input_word) 135 | 136 | word_embed = word_embedding_layer(input_word) 137 | input_embed.append(tf.keras.layers.SpatialDropout1D(self.dropout)(word_embed)) 138 | 139 | if len(input_embed) > 1: 140 | input_embed = tf.keras.layers.concatenate(input_embed) 141 | else: 142 | input_embed = input_embed[0] 143 | return model_inputs, input_embed 144 | 145 | def build_model(self): 146 | """Build ner model's architecture.""" 147 | raise NotImplementedError 148 | -------------------------------------------------------------------------------- /fancy_nlp/models/text_classification/text_classification_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | from fancy_nlp.models.text_classification.base_text_classification_model import \ 6 | BaseTextClassificationModel 7 | 8 | 9 | class CNNTextClassification(BaseTextClassificationModel): 10 | """CNN model for text classification. 11 | """ 12 | def __init__(self, 13 | num_class, 14 | use_char=True, 15 | char_embeddings=None, 16 | char_vocab_size=-1, 17 | char_embed_dim=-1, 18 | char_embed_trainable=False, 19 | use_bert=False, 20 | bert_config_file=None, 21 | bert_checkpoint_file=None, 22 | bert_trainable=False, 23 | use_word=False, 24 | word_embeddings=None, 25 | word_vocab_size=-1, 26 | word_embed_dim=-1, 27 | word_embed_trainable=False, 28 | max_len=None, 29 | dropout=0.2, 30 | rnn_units=150, 31 | fc_dim=100, 32 | activation='tanh', 33 | optimizer='adam'): 34 | self.num_class = num_class 35 | self.rnn_units = rnn_units 36 | self.fc_fim = fc_dim 37 | self.activation = activation 38 | self.optimizer = optimizer 39 | super(CNNTextClassification, self).__init__( 40 | use_char, char_embeddings, char_vocab_size, char_embed_dim, 41 | char_embed_trainable, use_bert, bert_config_file, 42 | bert_checkpoint_file, bert_trainable, use_word, 43 | word_embeddings, word_vocab_size, word_embed_dim, 44 | word_embed_trainable, max_len, dropout) 45 | 46 | def build_model(self): 47 | model_inputs, input_embed = self.build_input() 48 | filter_lengths = [3, 4, 5] 49 | conv_layers = [] 50 | for filter_length in filter_lengths: 51 | conv_layer = tf.keras.layers.Conv1D(filters=128, 52 | kernel_size=filter_length, 53 | padding='valid', 54 | activation='relu', 55 | strides=1)(input_embed) 56 | conv_layers.append(conv_layer) 57 | poolings = [tf.keras.layers.GlobalMaxPooling1D()(conv) for conv in conv_layers] 58 | x = tf.keras.layers.Concatenate()(poolings) 59 | output_layer = tf.keras.layers.Dense(self.num_class, activation='softmax')(x) 60 | text_classification_loss = 'categorical_crossentropy' 61 | text_classification_metrics = 'accuracy' 62 | text_classification_model = tf.keras.models.Model( 63 | model_inputs if len(model_inputs) > 1 else model_inputs[0], output_layer) 64 | text_classification_model.compile( 65 | optimizer=self.optimizer, loss=text_classification_loss, 66 | metrics=[text_classification_metrics]) 67 | return text_classification_model 68 | 69 | 70 | class RCNNTextClassification(BaseTextClassificationModel): 71 | """RCNN model for text classification. 72 | """ 73 | def __init__(self, 74 | num_class, 75 | use_char=True, 76 | char_embeddings=None, 77 | char_vocab_size=-1, 78 | char_embed_dim=-1, 79 | char_embed_trainable=False, 80 | use_bert=False, 81 | bert_config_file=None, 82 | bert_checkpoint_file=None, 83 | bert_trainable=False, 84 | use_word=False, 85 | word_embeddings=None, 86 | word_vocab_size=-1, 87 | word_embed_dim=-1, 88 | word_embed_trainable=False, 89 | max_len=None, 90 | dropout=0.2, 91 | rnn_units=150, 92 | fc_dim=100, 93 | activation='tanh', 94 | optimizer='adam'): 95 | self.num_class = num_class 96 | self.rnn_units = rnn_units 97 | self.fc_fim = fc_dim 98 | self.activation = activation 99 | self.optimizer = optimizer 100 | super(RCNNTextClassification, self).__init__( 101 | use_char, char_embeddings, char_vocab_size, char_embed_dim, 102 | char_embed_trainable, use_bert, bert_config_file, 103 | bert_checkpoint_file, bert_trainable, use_word, 104 | word_embeddings, word_vocab_size, word_embed_dim, 105 | word_embed_trainable, max_len, dropout) 106 | 107 | def build_model(self): 108 | model_inputs, input_embed = self.build_input() 109 | input_encode = tf.keras.layers.Bidirectional( 110 | tf.keras.layers.LSTM(self.rnn_units, return_sequences=True))(input_embed) 111 | x = tf.keras.layers.Concatenate()([input_embed, input_encode]) 112 | convs = [] 113 | for kernel_size in range(1, 5): 114 | conv = tf.keras.layers.Conv1D(128, kernel_size, activation='relu')(x) 115 | convs.append(conv) 116 | poolings = [tf.keras.layers.GlobalAveragePooling1D()(conv) for conv in convs] + \ 117 | [tf.keras.layers.GlobalMaxPooling1D()(conv) for conv in convs] 118 | x = tf.keras.layers.Concatenate()(poolings) 119 | output_layer = tf.keras.layers.Dense(self.num_class, activation='softmax')(x) 120 | 121 | text_classification_loss = 'categorical_crossentropy' 122 | text_classification_metrics = 'accuracy' 123 | text_classification_model = tf.keras.models.Model( 124 | model_inputs if len(model_inputs) > 1 else model_inputs[0], output_layer) 125 | text_classification_model.compile( 126 | optimizer=self.optimizer, loss=text_classification_loss, 127 | metrics=[text_classification_metrics]) 128 | return text_classification_model 129 | 130 | 131 | class BertTextClassification(BaseTextClassificationModel): 132 | """Bert model for text classification. 133 | We suggest you to train bert on machines with GPU cause it will be very slow to be trained with 134 | cpu. You will have to re-install a gpu version of tensorflow to do so. 135 | """ 136 | 137 | def __init__(self, 138 | num_class, 139 | bert_config_file, 140 | bert_checkpoint_file, 141 | bert_trainable, 142 | max_len, 143 | dropout=0.2, 144 | fc_dim=100, 145 | activation='tanh', 146 | optimizer='adam'): 147 | self.num_class = num_class 148 | self.fc_fim = fc_dim 149 | self.activation = activation 150 | self.optimizer = optimizer 151 | super(BertTextClassification, self).__init__( 152 | use_char=False, use_bert=True, 153 | bert_config_file=bert_config_file, 154 | bert_checkpoint_file=bert_checkpoint_file, 155 | bert_trainable=bert_trainable, use_word=False, 156 | max_len=max_len, dropout=dropout) 157 | 158 | def build_model(self): 159 | model_inputs, input_embed = self.build_input() 160 | x = tf.keras.layers.GlobalAveragePooling1D()(input_embed) 161 | 162 | output_layer = tf.keras.layers.Dense(self.num_class, activation='softmax')(x) 163 | text_classification_loss = 'categorical_crossentropy' 164 | text_classification_metrics = 'accuracy' 165 | text_classification_model = tf.keras.models.Model( 166 | model_inputs if len(model_inputs) > 1 else model_inputs[0], output_layer) 167 | text_classification_model.compile( 168 | optimizer=self.optimizer, loss=text_classification_loss, 169 | metrics=[text_classification_metrics]) 170 | return text_classification_model 171 | -------------------------------------------------------------------------------- /fancy_nlp/utils/data_loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Tuple, Union, List 4 | 5 | import codecs 6 | 7 | from sklearn.model_selection import train_test_split 8 | 9 | 10 | def load_ner_data_and_labels(filename: str, 11 | delimiter: str = '\t', 12 | split: bool = False, 13 | split_size: float = 0.1, 14 | seed: int = 42): 15 | """Load ner data and label from a file. 16 | 17 | The file should follow CoNLL format: 18 | Each line is a token and its label separated by 'delimiter', or a blank line indicating the end of 19 | a sentence. Like: 20 | ``` 21 | 我 O 22 | 在 O 23 | 上 B-LOC 24 | 海 I-LOC 25 | 上 O 26 | 学 O 27 | 28 | ... 29 | ``` 30 | 31 | Args: 32 | filename: str. Path to ner file. 33 | delimiter: str. Delimiter to split token and label. 34 | split: Boolean. Whether to split into train and test subsets. 35 | split_size: float. The proportion of test subset, between 0.0 and 1.0 36 | seed: int. Random seed. 37 | 38 | Returns: If split: return a tuple of 4 list: train data and labels as well as test data and 39 | labels. 40 | Otherwise: return a tuple of 2 list: data and labels 41 | 42 | """ 43 | with codecs.open(filename, 'r', encoding='utf8') as reader: 44 | token_seqs, label_seqs = [], [] 45 | tokens, labels = [], [] 46 | for i, line in enumerate(reader): 47 | line = line.rstrip() 48 | if line: 49 | line_split = line.split(delimiter) 50 | if len(line_split) == 2: 51 | token, label = line_split 52 | tokens.append(token) 53 | labels.append(label) 54 | else: 55 | raise Exception(f'Format Error at line {i}!' 56 | f' Input file should follow CoNLL format.') 57 | else: 58 | if tokens: 59 | token_seqs.append(tokens) 60 | label_seqs.append(labels) 61 | tokens, labels = [], [] 62 | 63 | if tokens: # in case there's no blank line at the end of the file 64 | token_seqs.append(tokens) 65 | label_seqs.append(labels) 66 | 67 | if split: 68 | x_train, x_test, y_train, y_test = train_test_split(token_seqs, label_seqs, 69 | test_size=split_size, random_state=seed) 70 | return x_train, y_train, x_test, y_test 71 | else: 72 | return token_seqs, label_seqs 73 | 74 | 75 | def load_text_classification_data_and_labels( 76 | filename, label_index=0, text_index=1, delimiter='\t', split_mode=0, split_size=0.1, 77 | use_header=False, seed=42): 78 | """Load text classification data and label from a file. 79 | 80 | Args: 81 | filename: str, path to ner file 82 | delimiter: str, delimiter to split token and label 83 | label_index: int, refer to which column store the classification label 84 | text_index:int, refer to which column store the text information 85 | split_mode: int, if `split_mode` is 1, it will split the dataset into train and valid; 86 | if `split_mode` is 2, it will split the dataset into train, valid and test 87 | 88 | split_size: float, the proportion of test subset, between 0.0 and 1.0 89 | use_header: bool, whether to use first line as the header 90 | seed: int, random seed 91 | 92 | Returns: If split: tuple(list, list, list, list), train data and labels as well as test data and 93 | labels. 94 | Otherwise: tuple(list, list), data and labels 95 | 96 | """ 97 | with codecs.open(filename, 'r', encoding='utf8') as reader: 98 | token_seqs, label_seqs = [], [] 99 | for i, line in enumerate(reader): 100 | if use_header and i == 0: 101 | continue 102 | line_items = line.strip().split(delimiter) 103 | 104 | text, label = line_items[text_index], line_items[label_index] 105 | token_seqs.append(list(text)) 106 | label_seqs.append(label) 107 | 108 | if split_mode == 1: 109 | x_train, x_valid, y_train, y_valid = train_test_split( 110 | token_seqs, label_seqs, test_size=split_size, stratify=label_seqs, random_state=seed) 111 | return x_train, y_train, x_valid, y_valid 112 | elif split_mode == 2: 113 | x_train, x_holdout, y_train, y_holdout = train_test_split( 114 | token_seqs, label_seqs, test_size=split_size, stratify=label_seqs, random_state=seed) 115 | x_valid, x_test, y_valid, y_test = train_test_split( 116 | x_holdout, y_holdout, test_size=0.5, stratify=y_holdout, random_state=seed) 117 | return x_train, y_train, x_valid, y_valid, x_test, y_test 118 | else: 119 | return token_seqs, label_seqs 120 | 121 | 122 | def load_spm_data_and_labels(filename: str, delimiter: str = '\t', split_mode: int = 0, 123 | split_size: float = 0.2, seed: int = 42) -> \ 124 | Union[Tuple[Tuple[List[str], List[str]], List[str]], 125 | Tuple[Tuple[List[str], List[str]], List[str], Tuple[List[str], List[str]], List[str]], 126 | Tuple[Tuple[List[str], List[str]], List[str], Tuple[List[str], List[str]], List[str], 127 | Tuple[List[str], List[str]], List[str]]]: 128 | """Load spm data and label from a file. 129 | 130 | The file should follow fixed format: 131 | Each line is a pair of sentences and its label separated by tab, like text_a \t text_b \t label. 132 | 133 | Args: 134 | filename: str, path to spm file 135 | delimiter: str, delimiter to split texts and label 136 | split_mode: int, if `split_mode` is 1, it will split the dataset into train and valid; 137 | if `split_mode` is 2, it will split the dataset into train, valid and test 138 | split_size: float, the proportion of test subset, between 0.0 and 1.0 139 | seed: int, random seed 140 | 141 | Returns: If split: tuple((list_a, list_b), list, (list_a, list_b), list), train data pairs and 142 | labels as well as test data pairs and labels. 143 | Otherwise: tuple((list_a, list_b), list), data pairs and labels 144 | 145 | """ 146 | # read data file 147 | with codecs.open(filename, 'r', encoding='utf8') as reader: 148 | text_a, text_b, labels = [], [], [] 149 | for i, line in enumerate(reader): 150 | line = line.rstrip() 151 | line_split = line.split(delimiter) 152 | if line: 153 | if len(line_split) == 3: 154 | a, b, label = line_split 155 | text_a.append(a) 156 | text_b.append(b) 157 | labels.append(label) 158 | else: 159 | raise Exception(f'Format Error at line {i}!' 160 | f'Input file should follow fixed spm format.') 161 | 162 | # randomly split train data and valid data 163 | if split_mode == 1: 164 | x_a_train, x_a_valid, x_b_train, x_b_valid, y_train, y_valid = train_test_split( 165 | text_a, text_b, labels, test_size=split_size, stratify=labels, random_state=seed) 166 | return (x_a_train, x_b_train), y_train, (x_a_valid, x_b_valid), y_valid 167 | # randomly split train data, valid data and test data 168 | elif split_mode == 2: 169 | x_a_train, x_a_holdout, x_b_train, x_b_holdout, y_train, y_holdout = train_test_split( 170 | text_a, text_b, labels, test_size=split_size, stratify=labels, random_state=seed) 171 | x_a_valid, x_a_test, x_b_valid, x_b_test, y_valid, y_test = train_test_split( 172 | x_a_holdout, x_b_holdout, y_holdout, test_size=0.5, stratify=y_holdout, 173 | random_state=seed) 174 | return (x_a_train, x_b_train), y_train, (x_a_valid, x_b_valid), y_valid, \ 175 | (x_a_test, x_b_test), y_test 176 | else: 177 | return (text_a, text_b), labels 178 | -------------------------------------------------------------------------------- /fancy_nlp/trainers/text_classification_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from absl import logging 6 | import tensorflow as tf 7 | from sklearn.metrics import f1_score, precision_score, recall_score, classification_report 8 | 9 | from fancy_nlp.utils import TextClassificationGenerator 10 | from fancy_nlp.callbacks import TextClassificationMetric 11 | from fancy_nlp.callbacks import SWA 12 | 13 | 14 | class TextClassificationTrainer(object): 15 | def __init__(self, model, preprocessor): 16 | """ 17 | 18 | Args: 19 | model: instance of keras Model 20 | preprocessor: instance of TextClassificationPreprocessor 21 | """ 22 | self.model = model 23 | self.preprocessor = preprocessor 24 | 25 | def train(self, train_data, train_labels, valid_data=None, valid_labels=None, batch_size=32, 26 | epochs=5, callback_list=None, checkpoint_dir=None, model_name=None, swa_model=None, 27 | load_swa_model=False): 28 | callbacks = self.prepare_callback(callback_list, valid_data, valid_labels, checkpoint_dir, 29 | model_name, swa_model) 30 | 31 | train_features, train_y = self.preprocessor.prepare_input(train_data, train_labels) 32 | if valid_data is not None and valid_labels is not None: 33 | valid_features, valid_y = self.preprocessor.prepare_input(valid_data, valid_labels) 34 | validation_data = (valid_features, valid_y) 35 | else: 36 | validation_data = None 37 | 38 | logging.info('Training start...') 39 | self.model.fit(x=train_features, y=train_y, batch_size=batch_size, epochs=epochs, 40 | validation_data=validation_data, callbacks=callbacks) 41 | logging.info('Training end...') 42 | 43 | if load_swa_model and callback_list is not None and 'swa' in callback_list: 44 | logging.info('Loading swa model after using SWA callback') 45 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 46 | 47 | elif callback_list is not None and 'modelcheckpoint' in callback_list: 48 | logging.info('Loading best model after using ModelCheckpoint callback...') 49 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}.hdf5')) 50 | 51 | def train_generator(self, train_data, train_labels, valid_data=None, valid_labels=None, 52 | batch_size=32, epochs=50, callback_list=None, checkpoint_dir=None, 53 | model_name=None, swa_model=None, load_swa_model=False): 54 | callbacks = self.prepare_callback(callback_list, valid_data, valid_labels, checkpoint_dir, 55 | model_name, swa_model) 56 | 57 | train_generator = TextClassificationGenerator( 58 | self.preprocessor, train_data, train_labels, batch_size) 59 | 60 | if valid_data and valid_labels: 61 | valid_generator = TextClassificationGenerator( 62 | self.preprocessor, valid_data, valid_labels, batch_size) 63 | else: 64 | valid_generator = None 65 | 66 | logging.info('Training start...') 67 | # Note: Model.fit now supports generators 68 | self.model.fit(x=train_generator, 69 | epochs=epochs, 70 | callbacks=callbacks, 71 | validation_data=valid_generator) 72 | logging.info('Training end...') 73 | 74 | if load_swa_model and callback_list is not None and 'swa' in callback_list: 75 | logging.info('Loading swa model after using SWA callback') 76 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 77 | 78 | elif callback_list is not None and 'modelcheckpoint' in callback_list: 79 | logging.info('Loading best model after using ModelCheckpoint callback...') 80 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}.hdf5')) 81 | 82 | def prepare_callback(self, callback_list, valid_data=None, valid_labels=None, 83 | checkpoint_dir=None, model_name=None, swa_model=None): 84 | """ 85 | 86 | Args: 87 | callback_list: list of str, each item indicate the callback to apply during training. 88 | For example, 'earlystopping' means using 'EarlyStopping' callback. 89 | valid_data: list of tokenized (in char level) texts for evaluation 90 | valid_labels: labels string of valid data 91 | checkpoint_dir: str, directory to save ner model, must be provided when using 92 | `ModelCheckpoint` or `SWA` callback. 93 | model_name: str, prefix of ner model's weights filem must be provided when using 94 | `ModelCheckpoint` or `SWA` callback. 95 | For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the 96 | weights of ner model saved by `ModelCheckpoint` callback will be 97 | 'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5' 98 | 99 | Returns: a list of `keras.callbacks.Callback` instances 100 | 101 | """ 102 | assert not isinstance(callback_list, str) 103 | callback_list = callback_list or [] 104 | callbacks = [] 105 | if valid_data is not None and valid_labels is not None: 106 | callbacks.append(TextClassificationMetric(self.preprocessor, valid_data, valid_labels)) 107 | add_metric = True 108 | else: 109 | add_metric = False 110 | 111 | if 'modelcheckpoint' in callback_list: 112 | if not add_metric: 113 | logging.warning('Using `ModelCheckpoint` with validation data not provided is not ' 114 | 'Recommended! We will use `loss` (of training data) as monitor.') 115 | 116 | assert checkpoint_dir is not None, \ 117 | '`checkpoint_dir` must must be provided when using "ModelCheckpoint" callback' 118 | assert model_name is not None, \ 119 | '`model_name` must must be provided when using "ModelCheckpoint" callback' 120 | callbacks.append(tf.keras.callbacks.ModelCheckpoint( 121 | filepath=os.path.join(checkpoint_dir, f'{model_name}.hdf5'), 122 | monitor='val_f1' if add_metric else 'loss', 123 | save_best_only=True, 124 | save_weights_only=True, 125 | mode='max' if add_metric else 'min', 126 | verbose=1)) 127 | logging.info('ModelCheckpoint Callback added') 128 | 129 | if 'earlystopping' in callback_list: 130 | if not add_metric: 131 | logging.warning( 132 | 'Using `ModelCheckpoint` without validation data provided is not Recommended! ' 133 | 'We will use `loss` (of training data) as monitor.') 134 | callbacks.append(tf.keras.callbacks.EarlyStopping( 135 | monitor='val_f1' if add_metric else 'loss', 136 | mode='max' if add_metric else 'min', 137 | patience=5, 138 | verbose=1)) 139 | logging.info('Earlystopping Callback added') 140 | 141 | if 'swa' in callback_list: 142 | assert checkpoint_dir is not None, \ 143 | '`checkpoint_dir` must must be provided when using "SWA" callback' 144 | assert model_name is not None, \ 145 | '`model_name` must must be provided when using "SWA" callback' 146 | assert swa_model is not None, \ 147 | '`swa_model` must must be provided when using "SWA" callback' 148 | callbacks.append(SWA(swa_model=swa_model, checkpoint_dir=checkpoint_dir, 149 | model_name=model_name, swa_start=5)) 150 | logging.info('SWA Callback added') 151 | 152 | return callbacks 153 | 154 | def load_model_weights(self, weights_file): 155 | self.model.load_weights(weights_file) 156 | 157 | def evaluate(self, data, labels): 158 | """Evaluate the performance of text classification model. 159 | 160 | Args: 161 | data: list of tokenized texts (, like ``[['我', '是', '中', '国', '人']]`` 162 | labels: list of str, the corresponding label strings 163 | 164 | """ 165 | features, y = self.preprocessor.prepare_input(data, labels) 166 | pred_probs = self.model.predict(features) 167 | 168 | y_pred = self.preprocessor.label_decode(pred_probs) 169 | 170 | r = recall_score(labels, y_pred, average='macro') 171 | p = precision_score(labels, y_pred, average='macro') 172 | f1 = f1_score(labels, y_pred, average='macro') 173 | 174 | logging.info('Recall: {}, Precision: {}, F1: {}'.format(r, p, f1)) 175 | logging.info(classification_report(labels, y_pred)) 176 | return f1 177 | -------------------------------------------------------------------------------- /fancy_nlp/utils/embedding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Dict, Union, List, Optional 4 | 5 | from absl import logging 6 | import numpy as np 7 | from gensim.models import Word2Vec 8 | from gensim.models import KeyedVectors 9 | from gensim.models import FastText 10 | 11 | 12 | def load_glove_format(filename: str, embedding_dim: int) -> Dict[str, np.ndarray]: 13 | """Load pre-trained embedding from a file in glove-embedding-like format: 14 | Each line is a token and its embedding separated by blank space. 15 | 16 | Args: 17 | filename: str. File path to pre-trained embedding. 18 | embedding_dim: int. Dimensionality of embedding, used to validate the embedding file. 19 | Returns: 20 | word_vector: Dict(str, np.ndarray), a mapping of words to embeddings. 21 | 22 | """ 23 | word_vectors = {} 24 | with open(filename, 'r') as reader: 25 | for i, line in enumerate(reader): 26 | line = line.strip().split() 27 | word = line[0] 28 | word_vector = np.array([float(v) for v in line[1:]]) 29 | 30 | if word_vector.shape[0] != embedding_dim: 31 | raise ValueError(f'Format error at line {i}! The size of word embedding dose not ' 32 | f'equal {embedding_dim}') 33 | 34 | word_vectors[word] = word_vector 35 | 36 | return word_vectors 37 | 38 | 39 | def filter_embeddings(trained_embedding: Dict[str, np.ndarray], 40 | embedding_dim: int, 41 | vocabulary: Dict[str, int], 42 | zero_init_indices: Union[int, List[int]] = 0, 43 | rand_init_indices: Union[int, List[int]] = 1) -> np.ndarray: 44 | """Build word embeddings matrix from pre-trained-embeddings and word vocabulary. 45 | 46 | Args: 47 | trained_embedding: Dict(str, np.ndarray). A mapping of words to pre-trained embeddings 48 | embedding_dim: int. Dimensionality of embedding. 49 | vocabulary: Dict[str, int]. A mapping of words to indices 50 | zero_init_indices: int or a List of int. The indices which use zero-initialization. These 51 | indices usually represent padding token. 52 | rand_init_indices: int or a List of int. The indices which use randomly-initialization.These 53 | indices usually represent other special tokens, such as "unk" token. 54 | 55 | Returns: np.ndarray, a word embedding matrix, shaped [vocab_size, embedding_dim]. 56 | 57 | """ 58 | emb = np.zeros(shape=(len(vocabulary), embedding_dim), dtype='float32') 59 | nb_unk = 0 60 | for w, i in vocabulary.items(): 61 | if w not in trained_embedding: 62 | nb_unk += 1 63 | emb[i, :] = np.random.normal(0, 0.05, embedding_dim) 64 | else: 65 | emb[i, :] = trained_embedding[w] 66 | 67 | if isinstance(zero_init_indices, int): 68 | zero_init_indices = [zero_init_indices] 69 | if isinstance(rand_init_indices, int): 70 | rand_init_indices = [rand_init_indices] 71 | for idx in zero_init_indices: 72 | emb[idx] = np.zeros(embedding_dim) 73 | for idx in rand_init_indices: 74 | emb[idx] = np.random.normal(0, 0.05, embedding_dim) 75 | 76 | logging.info('Embedding matrix created, shaped: {}, not found tokens: {}'.format(emb.shape, 77 | nb_unk)) 78 | return emb 79 | 80 | 81 | def load_pre_trained(load_filename: str, 82 | embedding_dim: Optional[int] = None, 83 | vocabulary: Optional[Dict[str, int]] = None, 84 | zero_init_indices: Union[int, List[int]] = 0, 85 | rand_init_indices: Union[int, List[int]] = 1) \ 86 | -> Union[Dict[str, np.ndarray], np.ndarray]: 87 | """Load pre-trained embedding and fit into vocabulary if provided 88 | 89 | Args: 90 | load_filename: str. Pre-trained embedding file, in word2vec-like format or glove-like format 91 | embedding_dim: int. Dimensionality of embeddings in the embedding file, must be provided 92 | when the file is in glove-like format. 93 | vocabulary: Dict[str, int]. A mapping of words to indices. 94 | zero_init_indices: int or a List of int. The indices which use zero-initialization. These 95 | indices usually represent padding token. 96 | rand_init_indices: int or a List of int. The indices which use randomly-initialization.These 97 | indices usually represent other special tokens, such as "unk" token. 98 | 99 | Returns: If vocabulary is None: Dict(str, np.ndarray), a mapping of words to embeddings. 100 | Otherwise: np.ndarray, a word embedding matrix. 101 | 102 | """ 103 | word_vectors = {} 104 | try: 105 | # load word2vec-like format pre-trained embedding file 106 | model = KeyedVectors.load_word2vec_format(load_filename) 107 | weights = model.wv.vectors 108 | word_embed_dim = weights.shape[1] 109 | for k, v in model.wv.vocab.items(): 110 | word_vectors[k] = weights[v.index, :] 111 | except ValueError: 112 | # load glove-like format pre-trained embedding file 113 | if embedding_dim is None: 114 | raise ValueError('`embedding_dim` must be provided when pre-trained embedding file is' 115 | 'in glove-like format!') 116 | word_vectors = load_glove_format(load_filename, embedding_dim) 117 | word_embed_dim = embedding_dim 118 | 119 | if vocabulary is not None: 120 | emb = filter_embeddings(word_vectors, word_embed_dim, vocabulary, zero_init_indices, 121 | rand_init_indices) 122 | return emb 123 | else: 124 | logging.info('Loading Embedding from: {}, shaped: {}'.format(load_filename, 125 | (len(word_vectors), 126 | word_embed_dim))) 127 | return word_vectors 128 | 129 | 130 | def train_w2v(corpus: List[List[str]], 131 | vocabulary: Dict[str, int], 132 | zero_init_indices: Union[int, List[int]] = 0, 133 | rand_init_indices: Union[int, List[int]] = 1, 134 | embedding_dim: int = 300) -> np.ndarray: 135 | """Use word2vec to train on corpus to obtain embedding. 136 | 137 | Args: 138 | corpus: List of List of str. List of tokenized texts, the corpus to train on, like ``[['我', 139 | '是', '中', '国', '人'], ...]``. 140 | vocabulary: Dict[str, int']. A mapping of words to indices 141 | zero_init_indices: int or a List of int. The indices which use zero-initialization. These 142 | indices usually represent padding token. 143 | rand_init_indices: int or a List of int. The indices which use randomly-initialization.These 144 | indices usually represent other special tokens, such as "unk" token. 145 | embedding_dim: int. Dimensionality of embedding 146 | 147 | Returns: np.ndarray, a word embedding matrix, shaped [vocab_size, embedding_dim]. 148 | 149 | """ 150 | model = Word2Vec(corpus, size=embedding_dim, min_count=1, window=5, sg=1, iter=10) 151 | weights = model.wv.vectors 152 | d = dict([(k, v.index) for k, v in model.wv.vocab.items()]) 153 | word_vectors = dict((w, weights[d[w], :]) for w in d) 154 | emb = filter_embeddings(word_vectors, embedding_dim, vocabulary, zero_init_indices, 155 | rand_init_indices) 156 | return emb 157 | 158 | 159 | def train_fasttext(corpus: List[List[str]], 160 | vocabulary: Dict[str, int], 161 | zero_init_indices: Union[int, List[int]] = 0, 162 | rand_init_indices: Union[int, List[int]] = 1, 163 | embedding_dim: int = 300) -> np.ndarray: 164 | """Use fasttext to train on corpus to obtain embedding 165 | 166 | Args: 167 | corpus: List of List of str. List of tokenized texts, the corpus to train on, like ``[['我', 168 | '是', '中', '国', '人'], ...]``. 169 | vocabulary: Dict[str, int']. A mapping of words to indices 170 | zero_init_indices: int or a List of int. The indices which use zero-initialization. These 171 | indices usually represent padding token. 172 | rand_init_indices: int or a List of int. The indices which use randomly-initialization.These 173 | indices usually represent other special tokens, such as "unk" token. 174 | embedding_dim: int. Dimensionality of embedding 175 | 176 | Returns: np.ndarray, a word embedding matrix, shaped [vocab_size, embedding_dim]. 177 | 178 | """ 179 | model = FastText(size=embedding_dim, min_count=1, window=5, sg=1, word_ngrams=1) 180 | model.build_vocab(sentences=corpus) 181 | model.train(sentences=corpus, total_examples=len(corpus), epochs=10) 182 | 183 | emb = np.zeros(shape=(len(vocabulary), embedding_dim), dtype='float32') 184 | 185 | for w, i in vocabulary.items(): 186 | emb[i, :] = model.wv[w] # note that oov words can still have word vectors 187 | 188 | if isinstance(zero_init_indices, int): 189 | zero_init_indices = [zero_init_indices] 190 | if isinstance(rand_init_indices, int): 191 | rand_init_indices = [rand_init_indices] 192 | for idx in zero_init_indices: 193 | emb[idx] = np.zeros(embedding_dim) 194 | for idx in rand_init_indices: 195 | emb[idx] = np.random.normal(0, 0.05, embedding_dim) 196 | 197 | return emb 198 | -------------------------------------------------------------------------------- /fancy_nlp/trainers/spm_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | from typing import Tuple, List, Union, Optional 5 | 6 | from absl import logging 7 | import numpy as np 8 | import tensorflow as tf 9 | from sklearn import metrics 10 | 11 | from fancy_nlp.utils import SPMGenerator 12 | from fancy_nlp.callbacks import SPMMetric 13 | from fancy_nlp.callbacks import SWA 14 | from fancy_nlp.preprocessors import SPMPreprocessor 15 | 16 | 17 | class SPMTrainer(object): 18 | def __init__(self, 19 | model: tf.keras.models.Model, 20 | preprocessor: SPMPreprocessor) -> None: 21 | """ 22 | 23 | Args: 24 | model: instance of keras Model 25 | preprocessor: instance of SPMPreporcessor 26 | """ 27 | self.model = model 28 | self.preprocessor = preprocessor 29 | 30 | def train(self, 31 | train_data: Tuple[List[str], List[str]], 32 | train_labels: List[str], 33 | valid_data: Optional[Tuple[List[str], List[str]]] = None, 34 | valid_labels: Optional[List[str]] = None, 35 | batch_size: int = 32, 36 | epochs: int = 50, 37 | callback_list: Optional[List[str]] = None, 38 | checkpoint_dir: Optional[str] = None, 39 | model_name: Optional[str] = None, 40 | swa_model: Optional[tf.keras.models.Model] = None, 41 | load_swa_model: bool = False) -> None: 42 | callbacks = self.prepare_callback(callback_list, valid_data, valid_labels, checkpoint_dir, 43 | model_name, swa_model) 44 | 45 | train_features, train_y = self.preprocessor.prepare_input(train_data, train_labels) 46 | if valid_data is not None and valid_labels is not None: 47 | valid_features, valid_y = self.preprocessor.prepare_input(valid_data, valid_labels) 48 | validation_data = (valid_features, valid_y) 49 | else: 50 | validation_data = None 51 | 52 | logging.info('Training start...') 53 | self.model.fit(x=train_features, y=train_y, batch_size=batch_size, epochs=epochs, 54 | validation_data=validation_data, callbacks=callbacks) 55 | logging.info('Training end...') 56 | 57 | if load_swa_model and callback_list is not None and 'swa' in callback_list: 58 | logging.info('Loading swa model after using SWA callback') 59 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 60 | 61 | elif callback_list is not None and 'modelcheckpoint' in callback_list: 62 | logging.info('Loading best model after using ModelCheckpoint callback...') 63 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}.hdf5')) 64 | 65 | def train_generator(self, 66 | train_data: Tuple[List[str], List[str]], 67 | train_labels: List[str], 68 | valid_data: Optional[Tuple[List[str], List[str]]] = None, 69 | valid_labels: Optional[List[str]] = None, 70 | batch_size: int = 32, 71 | epochs: int = 50, 72 | callback_list: Optional[List[str]] = None, 73 | checkpoint_dir: Optional[str] = None, 74 | model_name: Optional[str] = None, 75 | swa_model: Optional[tf.keras.models.Model] = None, 76 | load_swa_model: bool = False) -> None: 77 | callbacks = self.prepare_callback(callback_list, valid_data, valid_labels, checkpoint_dir, 78 | model_name, swa_model) 79 | 80 | train_generator = SPMGenerator(self.preprocessor, train_data, train_labels, batch_size) 81 | 82 | if valid_data and valid_labels: 83 | valid_generator = SPMGenerator(self.preprocessor, valid_data, valid_labels, 84 | batch_size) 85 | else: 86 | valid_generator = None 87 | 88 | print('Training start...') 89 | # Note: Model.fit now supports generators 90 | self.model.fit(x=train_generator, 91 | epochs=epochs, 92 | callbacks=callbacks, 93 | validation_data=valid_generator) 94 | print('Training end...') 95 | 96 | if load_swa_model and callback_list is not None and 'swa' in callback_list: 97 | logging.info('Loading swa model after using SWA callback') 98 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}_swa.hdf5')) 99 | 100 | elif callback_list is not None and 'modelcheckpoint' in callback_list: 101 | logging.info('Loading best model after using ModelCheckpoint callback...') 102 | self.load_model_weights(os.path.join(checkpoint_dir, f'{model_name}.hdf5')) 103 | 104 | def prepare_callback(self, 105 | callback_list: List[str], 106 | valid_data: Optional[Tuple[List[str], List[str]]] = None, 107 | valid_labels: Optional[List[str]] = None, 108 | checkpoint_dir: Optional[str] = None, 109 | model_name: Optional[str] = None, 110 | swa_model: Optional[tf.keras.models.Model] = None) -> \ 111 | Optional[List[tf.keras.callbacks.Callback]]: 112 | """ 113 | 114 | Args: 115 | callback_list: list of str, each item indicate the callback to apply during training. 116 | For example, 'earlystopping' means using 'EarlyStopping' callback. 117 | valid_data: list of tokenized (in char level) texts for evaluation 118 | valid_labels: labels string of valid data 119 | checkpoint_dir: str, directory to save spm model, must be provided when using 120 | `ModelCheckpoint` or `SWA` callback. 121 | model_name: str, prefix of spm model's weights file must be provided when using 122 | `ModelCheckpoint` or `SWA` callback. 123 | For example, if checkpoint_dir is 'ckpt' and model_name is 'model', the 124 | weights of spm model saved by `ModelCheckpoint` callback will be 125 | 'ckpt/model.hdf5' and by `SWA` callback will be 'ckpt/model_swa.hdf5' 126 | 127 | Returns: a list of `keras.callbacks.Callback` instances 128 | 129 | """ 130 | assert not isinstance(callback_list, str) 131 | callback_list = callback_list or [] 132 | callbacks = [] 133 | if valid_data is not None and valid_labels is not None: 134 | callbacks.append(SPMMetric(self.preprocessor, valid_data, valid_labels)) 135 | add_metric = True 136 | else: 137 | add_metric = False 138 | 139 | if 'modelcheckpoint' in callback_list: 140 | if not add_metric: 141 | logging.warning( 142 | 'Using `ModelCheckpoint` without validation data provided is not Recommended! ' 143 | 'We will use `loss` (of training data) as monitor.') 144 | 145 | assert checkpoint_dir is not None, \ 146 | '`checkpoint_dir` must must be provided when using "ModelCheckpoint" callback' 147 | assert model_name is not None, \ 148 | '`model_name` must must be provided when using "ModelCheckpoint" callback' 149 | callbacks.append(tf.keras.callbacks.ModelCheckpoint( 150 | filepath=os.path.join(checkpoint_dir, f'{model_name}.hdf5'), 151 | monitor='val_f1' if add_metric else 'loss', 152 | save_best_only=True, 153 | save_weights_only=True, 154 | mode='max' if add_metric else 'min', 155 | verbose=1)) 156 | logging.info('ModelCheckpoint Callback added') 157 | 158 | if 'earlystopping' in callback_list: 159 | if not add_metric: 160 | logging.warning('Using `Earlystopping` with validation data not provided is not ' 161 | 'Recommended! We will use `loss` (of training data) as monitor.') 162 | callbacks.append(tf.keras.callbacks.EarlyStopping( 163 | monitor='val_f1' if add_metric else 'loss', 164 | mode='max' if add_metric else 'min', 165 | patience=5, 166 | verbose=1)) 167 | logging.info('Earlystopping Callback added') 168 | 169 | if 'swa' in callback_list: 170 | assert checkpoint_dir is not None, \ 171 | '`checkpoint_dir` must must be provided when using "SWA" callback' 172 | assert model_name is not None, \ 173 | '`model_name` must must be provided when using "SWA" callback' 174 | assert swa_model is not None, \ 175 | '`swa_model` must must be provided when using "SWA" callback' 176 | callbacks.append(SWA(swa_model=swa_model, checkpoint_dir=checkpoint_dir, 177 | model_name=model_name, swa_start=5)) 178 | logging.info('SWA Callback added') 179 | 180 | return callbacks 181 | 182 | def load_model_weights(self, weights_file: str) -> None: 183 | self.model.load_weights(weights_file) 184 | 185 | def evaluate(self, data: Tuple[List[str], List[str]], labels: List[str]) -> float: 186 | """Evaluate the performance of spm model. 187 | 188 | Args: 189 | data: list of text pairs, like ``[['我是中国人', ...], ['我爱中国', ...]]`` 190 | labels: list of str, the corresponding label strings 191 | 192 | """ 193 | features, y = self.preprocessor.prepare_input(data, labels) 194 | pred_probs = self.model.predict(features) 195 | 196 | y_pred = np.argmax(pred_probs, axis=-1) 197 | labels = np.argmax(y, axis=-1) 198 | 199 | r = metrics.recall_score(labels, y_pred, average='macro') 200 | p = metrics.precision_score(labels, y_pred, average='macro') 201 | f1 = metrics.f1_score(labels, y_pred, average='macro') 202 | 203 | logging.info('Recall: {}, Precision: {}, F1: {}'.format(r, p, f1)) 204 | logging.info(metrics.classification_report(labels, y_pred)) 205 | return f1 206 | -------------------------------------------------------------------------------- /tests/fancy_nlp/trainer/test_spm_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_spm_data_and_labels 6 | from fancy_nlp.models.spm import SiameseCNN, BertSPM 7 | from fancy_nlp.preprocessors import SPMPreprocessor 8 | from fancy_nlp.trainers import SPMTrainer 9 | 10 | 11 | class TestSPMTrainer: 12 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/spm/webank/example.txt') 13 | bert_vocab_file = os.path.join(os.path.dirname(__file__), 14 | '../../../data/embeddings/bert_sample_model/vocab.txt') 15 | bert_config_file = os.path.join(os.path.dirname(__file__), 16 | '../../../data/embeddings/bert_sample_model/bert_config.json') 17 | bert_model_file = os.path.join(os.path.dirname(__file__), 18 | '../../../data/embeddings/bert_sample_model/bert_model.ckpt') 19 | 20 | def setup_class(self): 21 | self.train_data, self.train_labels, self.valid_data, self.valid_labels = \ 22 | load_spm_data_and_labels(self.test_file, split_mode=1) 23 | self.preprocessor = SPMPreprocessor((self.train_data[0] + self.valid_data[0], 24 | self.train_data[1] + self.valid_data[1]), 25 | self.train_labels + self.valid_labels, 26 | use_word=True, 27 | use_char=True, 28 | bert_vocab_file=self.bert_vocab_file, 29 | word_embed_type='word2vec', 30 | char_embed_type='word2vec', 31 | max_len=10) 32 | self.num_class = self.preprocessor.num_class 33 | self.char_embeddings = self.preprocessor.char_embeddings 34 | self.char_vocab_size = self.preprocessor.char_vocab_size 35 | self.char_embed_dim = self.preprocessor.char_embed_dim 36 | 37 | self.word_embeddings = self.preprocessor.word_embeddings 38 | self.word_vocab_size = self.preprocessor.word_vocab_size 39 | self.word_embed_dim = self.preprocessor.word_embed_dim 40 | self.checkpoint_dir = os.path.dirname(__file__) 41 | 42 | self.spm_model = SiameseCNN(num_class=self.num_class, 43 | use_word=True, 44 | word_embeddings=self.word_embeddings, 45 | word_vocab_size=self.word_vocab_size, 46 | word_embed_dim=self.word_embed_dim, 47 | word_embed_trainable=False, 48 | use_char=True, 49 | char_embeddings=self.char_embeddings, 50 | char_vocab_size=self.char_vocab_size, 51 | char_embed_dim=self.char_embed_dim, 52 | char_embed_trainable=False, 53 | use_bert=False, 54 | bert_config_file=self.bert_config_file, 55 | bert_checkpoint_file=self.bert_model_file, 56 | bert_trainable=True, 57 | max_len=self.preprocessor.max_len, 58 | max_word_len=self.preprocessor.max_word_len).build_model() 59 | 60 | self.swa_model = SiameseCNN(num_class=self.num_class, 61 | use_word=True, 62 | word_embeddings=self.word_embeddings, 63 | word_vocab_size=self.word_vocab_size, 64 | word_embed_dim=self.word_embed_dim, 65 | word_embed_trainable=False, 66 | use_char=True, 67 | char_embeddings=self.char_embeddings, 68 | char_vocab_size=self.char_vocab_size, 69 | char_embed_dim=self.char_embed_dim, 70 | char_embed_trainable=False, 71 | use_bert=False, 72 | bert_config_file=self.bert_config_file, 73 | bert_checkpoint_file=self.bert_model_file, 74 | bert_trainable=True, 75 | max_len=self.preprocessor.max_len, 76 | max_word_len=self.preprocessor.max_word_len).build_model() 77 | 78 | self.spm_trainer = SPMTrainer(self.spm_model, self.preprocessor) 79 | 80 | self.json_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm.json') 81 | self.weights_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm.hdf5') 82 | 83 | def test_train(self): 84 | self.spm_trainer.train(self.train_data, self.train_labels, self.valid_data, 85 | self.valid_labels, batch_size=6, epochs=2) 86 | assert not os.path.exists(self.json_file) 87 | assert not os.path.exists(self.weights_file) 88 | 89 | def test_train_no_word(self): 90 | preprocessor = SPMPreprocessor((self.train_data[0] + self.valid_data[0], 91 | self.train_data[1] + self.valid_data[1]), 92 | self.train_labels + self.valid_labels, 93 | use_word=False, 94 | use_char=True, 95 | use_bert=True, 96 | bert_vocab_file=self.bert_vocab_file, 97 | char_embed_type='word2vec', 98 | max_len=10) 99 | self.num_class = preprocessor.num_class 100 | self.char_embeddings = preprocessor.char_embeddings 101 | self.char_vocab_size = preprocessor.char_vocab_size 102 | self.char_embed_dim = preprocessor.char_embed_dim 103 | 104 | spm_model = SiameseCNN(num_class=self.num_class, 105 | use_word=False, 106 | use_char=True, 107 | char_embeddings=self.char_embeddings, 108 | char_vocab_size=self.char_vocab_size, 109 | char_embed_dim=self.char_embed_dim, 110 | char_embed_trainable=False, 111 | use_bert=True, 112 | bert_config_file=self.bert_config_file, 113 | bert_checkpoint_file=self.bert_model_file, 114 | max_len=preprocessor.max_len).build_model() 115 | 116 | spm_trainer = SPMTrainer(spm_model, preprocessor) 117 | spm_trainer.train(self.train_data, self.train_labels, self.valid_data, self.valid_labels, 118 | batch_size=6, epochs=2) 119 | assert not os.path.exists(self.json_file) 120 | assert not os.path.exists(self.weights_file) 121 | 122 | def test_train_bert_model(self): 123 | preprocessor = SPMPreprocessor((self.train_data[0] + self.valid_data[0], 124 | self.train_data[1] + self.valid_data[1]), 125 | self.train_labels + self.valid_labels, 126 | use_word=False, 127 | use_char=False, 128 | use_bert=True, 129 | use_bert_model=True, 130 | bert_vocab_file=self.bert_vocab_file, 131 | max_len=10) 132 | spm_model = BertSPM(num_class=self.num_class, 133 | bert_config_file=self.bert_config_file, 134 | bert_checkpoint_file=self.bert_model_file, 135 | bert_trainable=True, 136 | max_len=preprocessor.max_len).build_model() 137 | 138 | spm_trainer = SPMTrainer(spm_model, preprocessor) 139 | spm_trainer.train(self.train_data, self.train_labels, self.valid_data, self.valid_labels, 140 | batch_size=6, epochs=2) 141 | assert not os.path.exists(self.json_file) 142 | assert not os.path.exists(self.weights_file) 143 | 144 | def test_train_no_valid_data(self): 145 | self.spm_trainer.train(self.train_data, self.train_labels, batch_size=6, epochs=2) 146 | assert not os.path.exists(self.json_file) 147 | assert not os.path.exists(self.weights_file) 148 | 149 | def test_train_callbacks(self): 150 | self.spm_trainer.train(self.train_data, self.train_labels, self.valid_data, 151 | self.valid_labels, batch_size=6, epochs=2, 152 | callback_list=['modelcheckpoint', 'earlystopping'], 153 | checkpoint_dir=os.path.dirname(__file__), 154 | model_name='siamese_cnn_spm') 155 | 156 | assert not os.path.exists(self.json_file) 157 | assert os.path.exists(self.weights_file) 158 | os.remove(self.weights_file) 159 | assert not os.path.exists(self.weights_file) 160 | 161 | def test_train_swa(self): 162 | self.spm_trainer.train(self.train_data, self.train_labels, self.valid_data, 163 | self.valid_labels, batch_size=6, epochs=5, callback_list=['swa'], 164 | checkpoint_dir=os.path.dirname(__file__), 165 | model_name='siamese_cnn_spm', 166 | swa_model=self.swa_model, 167 | load_swa_model=True) 168 | 169 | assert not os.path.exists(self.json_file) 170 | assert not os.path.exists(self.weights_file) 171 | 172 | json_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm_swa.json') 173 | weights_file = os.path.join(self.checkpoint_dir, 'siamese_cnn_spm_swa.hdf5') 174 | assert not os.path.exists(json_file) 175 | assert os.path.exists(weights_file) 176 | os.remove(weights_file) 177 | assert not os.path.exists(weights_file) 178 | 179 | def test_generator(self): 180 | self.spm_trainer.train_generator(self.train_data, self.train_labels, 181 | self.valid_data, self.valid_labels, batch_size=6, epochs=2) 182 | 183 | assert not os.path.exists(self.json_file) 184 | assert not os.path.exists(self.weights_file) 185 | -------------------------------------------------------------------------------- /fancy_nlp/preprocessors/preprocessor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import Callable, List, Tuple, Dict, Optional 4 | 5 | from absl import logging 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | from ..utils import load_pre_trained, train_w2v, train_fasttext 10 | 11 | 12 | class Preprocessor(object): 13 | """Basic class for Fancy-NLP Preprocessor. All the preprocessor will inherit from it. 14 | 15 | Preprocessor is used to 16 | 1) build vocabulary from training data; 17 | 2) pre-trained embedding matrix using training corpus; 18 | 3) prepare feature input for model; 19 | 4) decode model predictions to label string 20 | 21 | """ 22 | 23 | def __init__(self, 24 | max_len: int = None, 25 | padding_mode: str = 'post', 26 | truncating_mode: str = 'post') -> None: 27 | """ 28 | 29 | Args: 30 | max_len: Optional int, can be None. Max length of one sequence. 31 | padding_mode: str. 'pre' or 'post': pad either before or after each sequence, used when 32 | preparing feature input for model. 33 | truncating_mode: str. pre' or 'post': remove values from sequences larger than 34 | `maxlen`, either at the beginning or at the end of the sequences, used when 35 | preparing feature input for model. 36 | """ 37 | self.pad_token = '' 38 | self.unk_token = '' 39 | self.cls_token = '' 40 | self.seq_token = '' 41 | 42 | self.max_len = max_len 43 | self.padding_mode = padding_mode 44 | self.truncating_mode = truncating_mode 45 | 46 | @staticmethod 47 | def build_corpus(untokenized_texts: List[str], 48 | cut_func: Callable[[str], List[str]]) -> List[List[str]]: 49 | """Build corpus from untokenized texts. 50 | 51 | Args: 52 | untokenized_texts: List of str. List of un-tokenized texts, like ['我是中国人', ...]. 53 | cut_func: Function to tokenize texts. For example, cut_func=lambda x: list(x) can be 54 | used for tokenize texts in char level, while cut_func=lambda x: jieba.lcut(x) can 55 | be used for tokenize Chinese texts in word level. 56 | 57 | Returns: 58 | List of List of str. List of tokenized texts, like 59 | ``[['我', '是', '中', '国', '人'], ...]`.` 60 | 61 | """ 62 | corpus = [] 63 | for text in untokenized_texts: 64 | corpus.append(cut_func(text)) 65 | return corpus 66 | 67 | def build_vocab(self, 68 | corpus: List[List[str]], 69 | min_count: int = 3, 70 | special_token: str = 'standard') \ 71 | -> Tuple[Dict[str, int], Dict[str, int], Dict[int, str]]: 72 | """Build vocabulary using corpus. 73 | 74 | Args: 75 | corpus: List of List of str. List of tokenized texts, like 76 | ``[['我', '是', '中', '国', '人'], ...]`` 77 | min_count: int. Token of which frequency is less than min_count will be ignored 78 | special_token: str. 'standard' or 'bert': determine how to handle special tokens. 79 | If special_token is 'standard', we add 2 special tokens: [('', 0), ('', 80 | 1)]. If special_token is 'bert', we add 4 special tokens: [('', 0), 81 | ('', 1), ('', 2), ('', 3)] 82 | 83 | Returns: Tuple of 3 dicts : 84 | 1. token_count: a mapping of tokens to frequencies 85 | 2. token_vocab: a mapping of tokens to indices 86 | 3. id2token: a mapping of indices to tokens 87 | 88 | """ 89 | if special_token == 'standard': 90 | token_vocab = {self.pad_token: 0, 91 | self.unk_token: 1} 92 | elif special_token == 'bert': 93 | token_vocab = {self.pad_token: 0, 94 | self.unk_token: 1, 95 | self.cls_token: 2, 96 | self.seq_token: 3} 97 | else: 98 | raise ValueError('Argument `special_token` can only be "standard" or "bert", ' 99 | 'got: {}'.format(special_token)) 100 | 101 | token_count = {} 102 | for tokenized_text in corpus: 103 | for token in tokenized_text: 104 | token_count[token] = token_count.get(token, 0) + 1 105 | # filter out low-frequency token 106 | token_count = {token: count for token, count in token_count.items() 107 | if count >= min_count} 108 | 109 | for token in token_count: 110 | token_vocab[token] = len(token_vocab) 111 | id2token = dict((idx, token) for token, idx in token_vocab.items()) 112 | 113 | logging.info('Build vocabulary finished, vocabulary size: {}'.format(len(token_vocab))) 114 | return token_count, token_vocab, id2token 115 | 116 | def build_label_vocab(self, labels): 117 | """Build label vocabulary. 118 | """ 119 | raise NotImplementedError 120 | 121 | def build_embedding(self, 122 | embed_type: Optional[str], 123 | vocab: Dict[str, int], 124 | corpus: Optional[List[List[str]]] = None, 125 | embedding_dim: int = 300, 126 | special_token: str = 'standard'): 127 | """Prepare embeddings for the tokens in vocab. 128 | We support loading external pre-trained embeddings with procided embedding file as well as 129 | training embeddings on the provied corpus. 130 | 131 | Args: 132 | embed_type: Optional str, can be None. The type of embedding, can be a 133 | pre-trained embedding filename that used to load pre-trained embedding, 134 | or a embedding training method (one of {'word2vec', 'fasttext'}) that used to 135 | train character embedding with dataset. If None, do not apply anr pre-trained 136 | embedding, and use randomly initialized embedding instead. 137 | vocab: Dict[str, int]. A mapping of tokens to indices 138 | corpus: List of tokenized texts,, like ``[['我', '是', '中', '国', '人'], ...]`. 139 | embedding_dim: int. Dimensionality of embedding 140 | special_token: str. 'standard' or 'bert': determine how to handle special tokens. 141 | If special_token is 'standard', we add 2 special tokens: [('', 0), ('', 142 | 1)]. If special_token is 'bert', we add 4 special tokens: [('', 0), 143 | ('', 1), ('', 2), ('', 3)] 144 | We will use zero-initializer for '' token and random-initializer for other 145 | special tokens. 146 | """ 147 | zero_init_indices = vocab.get(self.pad_token) 148 | if special_token == 'standard': 149 | rand_init_indices = vocab.get(self.unk_token) 150 | elif special_token == 'bert': 151 | rand_init_indices = [vocab.get(self.unk_token), 152 | vocab.get(self.cls_token), 153 | vocab.get(self.seq_token)] 154 | else: 155 | raise ValueError('Argument `special_token` can only be "standard" or "bert", ' 156 | 'got: {}'.format(special_token)) 157 | 158 | if embed_type is None: 159 | return None # do not adopt any pre-trained embeddings 160 | if embed_type == 'word2vec': 161 | return train_w2v(corpus, vocab, zero_init_indices, rand_init_indices, embedding_dim) 162 | elif embed_type == 'fasttext': 163 | return train_fasttext(corpus, vocab, zero_init_indices, rand_init_indices, 164 | embedding_dim) 165 | else: 166 | try: 167 | return load_pre_trained(embed_type, embedding_dim, vocab, 168 | zero_init_indices, rand_init_indices) 169 | except FileNotFoundError: 170 | raise ValueError('Argument `embed_type` input error: {}'.format(embed_type)) 171 | 172 | def prepare_input(self, data, label=None): 173 | """Prepare feature input for neural model training, evaluating and testing. 174 | """ 175 | raise NotImplementedError 176 | 177 | @staticmethod 178 | def build_id_sequence(tokenized_text: List[str], 179 | vocabulary: Dict[str, int], 180 | unk_idx: int = 1) -> List[int]: 181 | """Given a token list, return the corresponding id sequence. 182 | 183 | Args: 184 | tokenized_text: List of str, like `['我', '是', '中', '国', '人']`. 185 | vocabulary: Dict[str, int]. A mapping of tokens to indices. 186 | unk_idx: int. The index of tokens that do not appear in vocabulary. We usually set it 187 | to 1. 188 | 189 | Returns: 190 | List of indices. 191 | 192 | """ 193 | return [vocabulary.get(token, unk_idx) for token in tokenized_text] 194 | 195 | @staticmethod 196 | def build_id_matrix(tokenized_texts: List[List[str]], 197 | vocabulary, unk_idx=1): 198 | """Given a list, each item is a token list, return the corresponding id matrix. 199 | 200 | Args: 201 | tokenized_texts: List of List of str. List of tokenized texts, like ``[['我', '是', '中', 202 | '国', '人'], ...]``. 203 | vocabulary: Dict[str, int]. A mapping of tokens to indices 204 | unk_idx: int. The index of tokens that do not appear in vocabulary. We usually set it 205 | to 1. 206 | 207 | Returns: 208 | List of List of indices 209 | 210 | """ 211 | return [[vocabulary.get(token, unk_idx) for token in text] for text in tokenized_texts] 212 | 213 | def pad_sequence(self, 214 | sequence_list: List[List[int]]) -> np.ndarray: 215 | """Given a list, each item is a id sequence, return the padded sequence. 216 | 217 | Args: 218 | sequence_list: List of List of int, where each element is a sequence. 219 | 220 | Returns: 221 | a 2D Numpy array of shape `(num_samples, num_timesteps)` 222 | 223 | """ 224 | return tf.keras.preprocessing.sequence.pad_sequences(sequence_list, 225 | maxlen=self.max_len, 226 | padding=self.padding_mode, 227 | truncating=self.truncating_mode) 228 | 229 | def label_decode(self, predictions, label_dict): 230 | """Decode model predictions to label strings 231 | """ 232 | raise NotImplementedError 233 | -------------------------------------------------------------------------------- /data/ner/msra/example.txt: -------------------------------------------------------------------------------- 1 | 当 O 2 | 希 O 3 | 望 O 4 | 工 O 5 | 程 O 6 | 救 O 7 | 助 O 8 | 的 O 9 | 百 O 10 | 万 O 11 | 儿 O 12 | 童 O 13 | 成 O 14 | 长 O 15 | 起 O 16 | 来 O 17 | , O 18 | 科 O 19 | 教 O 20 | 兴 O 21 | 国 O 22 | 蔚 O 23 | 然 O 24 | 成 O 25 | 风 O 26 | 时 O 27 | , O 28 | 今 O 29 | 天 O 30 | 有 O 31 | 收 O 32 | 藏 O 33 | 价 O 34 | 值 O 35 | 的 O 36 | 书 O 37 | 你 O 38 | 没 O 39 | 买 O 40 | , O 41 | 明 O 42 | 日 O 43 | 就 O 44 | 叫 O 45 | 你 O 46 | 悔 O 47 | 不 O 48 | 当 O 49 | 初 O 50 | ! O 51 | 52 | 藏 O 53 | 书 O 54 | 本 O 55 | 来 O 56 | 就 O 57 | 是 O 58 | 所 O 59 | 有 O 60 | 传 O 61 | 统 O 62 | 收 O 63 | 藏 O 64 | 门 O 65 | 类 O 66 | 中 O 67 | 的 O 68 | 第 O 69 | 一 O 70 | 大 O 71 | 户 O 72 | , O 73 | 只 O 74 | 是 O 75 | 我 O 76 | 们 O 77 | 结 O 78 | 束 O 79 | 温 O 80 | 饱 O 81 | 的 O 82 | 时 O 83 | 间 O 84 | 太 O 85 | 短 O 86 | 而 O 87 | 已 O 88 | 。 O 89 | 90 | 因 O 91 | 有 O 92 | 关 O 93 | 日 B-LOC 94 | 寇 O 95 | 在 O 96 | 京 B-LOC 97 | 掠 O 98 | 夺 O 99 | 文 O 100 | 物 O 101 | 详 O 102 | 情 O 103 | , O 104 | 藏 O 105 | 界 O 106 | 较 O 107 | 为 O 108 | 重 O 109 | 视 O 110 | , O 111 | 也 O 112 | 是 O 113 | 我 O 114 | 们 O 115 | 收 O 116 | 藏 O 117 | 北 B-LOC 118 | 京 I-LOC 119 | 史 O 120 | 料 O 121 | 中 O 122 | 的 O 123 | 要 O 124 | 件 O 125 | 之 O 126 | 一 O 127 | 。 O 128 | 129 | 我 O 130 | 们 O 131 | 藏 O 132 | 有 O 133 | 一 O 134 | 册 O 135 | 1 O 136 | 9 O 137 | 4 O 138 | 5 O 139 | 年 O 140 | 6 O 141 | 月 O 142 | 油 O 143 | 印 O 144 | 的 O 145 | 《 O 146 | 北 B-LOC 147 | 京 I-LOC 148 | 文 O 149 | 物 O 150 | 保 O 151 | 存 O 152 | 保 O 153 | 管 O 154 | 状 O 155 | 态 O 156 | 之 O 157 | 调 O 158 | 查 O 159 | 报 O 160 | 告 O 161 | 》 O 162 | , O 163 | 调 O 164 | 查 O 165 | 范 O 166 | 围 O 167 | 涉 O 168 | 及 O 169 | 故 B-LOC 170 | 宫 I-LOC 171 | 、 O 172 | 历 B-LOC 173 | 博 I-LOC 174 | 、 O 175 | 古 B-ORG 176 | 研 I-ORG 177 | 所 I-ORG 178 | 、 O 179 | 北 B-LOC 180 | 大 I-LOC 181 | 清 I-LOC 182 | 华 I-LOC 183 | 图 I-LOC 184 | 书 I-LOC 185 | 馆 I-LOC 186 | 、 O 187 | 北 B-LOC 188 | 图 I-LOC 189 | 、 O 190 | 日 B-LOC 191 | 伪 O 192 | 资 O 193 | 料 O 194 | 库 O 195 | 等 O 196 | 二 O 197 | 十 O 198 | 几 O 199 | 家 O 200 | , O 201 | 言 O 202 | 及 O 203 | 文 O 204 | 物 O 205 | 二 O 206 | 十 O 207 | 万 O 208 | 件 O 209 | 以 O 210 | 上 O 211 | , O 212 | 洋 O 213 | 洋 O 214 | 三 O 215 | 万 O 216 | 余 O 217 | 言 O 218 | , O 219 | 是 O 220 | 珍 O 221 | 贵 O 222 | 的 O 223 | 北 B-LOC 224 | 京 I-LOC 225 | 史 O 226 | 料 O 227 | 。 O 228 | 229 | 以 O 230 | 家 O 231 | 乡 O 232 | 的 O 233 | 历 O 234 | 史 O 235 | 文 O 236 | 献 O 237 | 、 O 238 | 特 O 239 | 定 O 240 | 历 O 241 | 史 O 242 | 时 O 243 | 期 O 244 | 书 O 245 | 刊 O 246 | 、 O 247 | 某 O 248 | 一 O 249 | 名 O 250 | 家 O 251 | 或 O 252 | 名 O 253 | 著 O 254 | 的 O 255 | 多 O 256 | 种 O 257 | 出 O 258 | 版 O 259 | 物 O 260 | 为 O 261 | 专 O 262 | 题 O 263 | , O 264 | 注 O 265 | 意 O 266 | 精 O 267 | 品 O 268 | 、 O 269 | 非 O 270 | 卖 O 271 | 品 O 272 | 、 O 273 | 纪 O 274 | 念 O 275 | 品 O 276 | , O 277 | 集 O 278 | 成 O 279 | 系 O 280 | 列 O 281 | , O 282 | 那 O 283 | 收 O 284 | 藏 O 285 | 的 O 286 | 过 O 287 | 程 O 288 | 就 O 289 | 已 O 290 | 经 O 291 | 够 O 292 | 您 O 293 | 玩 O 294 | 味 O 295 | 无 O 296 | 穷 O 297 | 了 O 298 | 。 O 299 | 300 | 我 O 301 | 们 O 302 | 是 O 303 | 受 O 304 | 到 O 305 | 郑 B-PER 306 | 振 I-PER 307 | 铎 I-PER 308 | 先 O 309 | 生 O 310 | 、 O 311 | 阿 B-PER 312 | 英 I-PER 313 | 先 O 314 | 生 O 315 | 著 O 316 | 作 O 317 | 的 O 318 | 启 O 319 | 示 O 320 | , O 321 | 从 O 322 | 个 O 323 | 人 O 324 | 条 O 325 | 件 O 326 | 出 O 327 | 发 O 328 | , O 329 | 瞄 O 330 | 准 O 331 | 现 O 332 | 代 O 333 | 出 O 334 | 版 O 335 | 史 O 336 | 研 O 337 | 究 O 338 | 的 O 339 | 空 O 340 | 白 O 341 | , O 342 | 重 O 343 | 点 O 344 | 集 O 345 | 藏 O 346 | 解 O 347 | 放 O 348 | 区 O 349 | 、 O 350 | 国 B-ORG 351 | 民 I-ORG 352 | 党 I-ORG 353 | 毁 O 354 | 禁 O 355 | 出 O 356 | 版 O 357 | 物 O 358 | 。 O 359 | 360 | 靠 O 361 | 自 O 362 | 己 O 363 | 的 O 364 | 鉴 O 365 | 赏 O 366 | 能 O 367 | 力 O 368 | , O 369 | 将 O 370 | 某 O 371 | 一 O 372 | 专 O 373 | 题 O 374 | 尽 O 375 | 可 O 376 | 能 O 377 | 多 O 378 | 的 O 379 | 书 O 380 | 籍 O 381 | 汇 O 382 | 集 O 383 | 在 O 384 | 身 O 385 | 旁 O 386 | , O 387 | 并 O 388 | 得 O 389 | 到 O 390 | 收 O 391 | 藏 O 392 | 界 O 393 | 的 O 394 | 认 O 395 | 可 O 396 | , O 397 | 那 O 398 | 你 O 399 | 就 O 400 | 是 O 401 | 真 O 402 | 正 O 403 | 意 O 404 | 义 O 405 | 上 O 406 | 的 O 407 | 藏 O 408 | 书 O 409 | 家 O 410 | 了 O 411 | 。 O 412 | 413 | 精 O 414 | 品 O 415 | 、 O 416 | 专 O 417 | 题 O 418 | 、 O 419 | 系 O 420 | 列 O 421 | 、 O 422 | 稀 O 423 | 见 O 424 | 程 O 425 | 度 O 426 | 才 O 427 | 是 O 428 | 质 O 429 | 量 O 430 | 的 O 431 | 核 O 432 | 心 O 433 | 。 O 434 | 435 | 藏 O 436 | 书 O 437 | 的 O 438 | 数 O 439 | 量 O 440 | 多 O 441 | 少 O 442 | 不 O 443 | 能 O 444 | 反 O 445 | 映 O 446 | 收 O 447 | 藏 O 448 | 的 O 449 | 质 O 450 | 量 O 451 | , O 452 | 更 O 453 | 不 O 454 | 是 O 455 | 工 O 456 | 薪 O 457 | 层 O 458 | 的 O 459 | 承 O 460 | 受 O 461 | 范 O 462 | 围 O 463 | 。 O 464 | 465 | 书 O 466 | 籍 O 467 | 浩 O 468 | 如 O 469 | 烟 O 470 | 海 O 471 | , O 472 | 靠 O 473 | 个 O 474 | 人 O 475 | 的 O 476 | 精 O 477 | 力 O 478 | 与 O 479 | 财 O 480 | 力 O 481 | 不 O 482 | 可 O 483 | 能 O 484 | 广 O 485 | 而 O 486 | 博 O 487 | 之 O 488 | 。 O 489 | 490 | 保 O 491 | 存 O 492 | 典 O 493 | 籍 O 494 | , O 495 | 怡 O 496 | 情 O 497 | 雅 O 498 | 好 O 499 | , O 500 | 著 O 501 | 书 O 502 | 立 O 503 | 说 O 504 | , O 505 | 这 O 506 | 才 O 507 | 是 O 508 | 藏 O 509 | 书 O 510 | 家 O 511 | 的 O 512 | 全 O 513 | 部 O 514 | 。 O 515 | 516 | 闲 O 517 | 钱 O 518 | 、 O 519 | 闲 O 520 | 地 O 521 | 方 O 522 | 、 O 523 | 闲 O 524 | 工 O 525 | 夫 O 526 | 固 O 527 | 然 O 528 | 重 O 529 | 要 O 530 | , O 531 | 但 O 532 | 藏 O 533 | 书 O 534 | 要 O 535 | 有 O 536 | 文 O 537 | 化 O 538 | 底 O 539 | 蕴 O 540 | , O 541 | 对 O 542 | 于 O 543 | 书 O 544 | 所 O 545 | 传 O 546 | 达 O 547 | 出 O 548 | 来 O 549 | 的 O 550 | 美 O 551 | 感 O 552 | 和 O 553 | 情 O 554 | 趣 O 555 | , O 556 | 要 O 557 | 放 O 558 | 到 O 559 | 一 O 560 | 定 O 561 | 的 O 562 | 专 O 563 | 业 O 564 | 背 O 565 | 景 O 566 | 下 O 567 | 才 O 568 | 能 O 569 | 显 O 570 | 现 O 571 | 。 O 572 | 573 | 对 O 574 | 于 O 575 | 靠 O 576 | 藏 O 577 | 品 O 578 | 增 O 579 | 值 O 580 | 来 O 581 | 致 O 582 | 富 O 583 | 的 O 584 | 说 O 585 | 法 O 586 | , O 587 | 我 O 588 | 们 O 589 | 不 O 590 | 敢 O 591 | 非 O 592 | 议 O 593 | , O 594 | 但 O 595 | 从 O 596 | 没 O 597 | 听 O 598 | 说 O 599 | 过 O 600 | 哪 O 601 | 位 O 602 | 藏 O 603 | 书 O 604 | 家 O 605 | 靠 O 606 | 书 O 607 | 发 O 608 | 财 O 609 | 了 O 610 | 。 O 611 | 612 | 否 O 613 | 则 O 614 | , O 615 | 不 O 616 | 吃 O 617 | 不 O 618 | 喝 O 619 | 也 O 620 | 得 O 621 | 当 O 622 | 个 O 623 | 藏 O 624 | 书 O 625 | 家 O 626 | 。 O 627 | 628 | 许 O 629 | 多 O 630 | 初 O 631 | 入 O 632 | 此 O 633 | 道 O 634 | 的 O 635 | 朋 O 636 | 友 O 637 | 叹 O 638 | 惜 O 639 | 没 O 640 | 赶 O 641 | 上 O 642 | 我 O 643 | 们 O 644 | 开 O 645 | 始 O 646 | 藏 O 647 | 书 O 648 | 时 O 649 | 那 O 650 | 个 O 651 | 三 O 652 | 五 O 653 | 角 O 654 | 钱 O 655 | 买 O 656 | 本 O 657 | 线 O 658 | 装 O 659 | 书 O 660 | 的 O 661 | 时 O 662 | 代 O 663 | 。 O 664 | 665 | 近 O 666 | 年 O 667 | 来 O 668 | 兴 O 669 | 起 O 670 | 的 O 671 | 集 O 672 | 藏 O 673 | 热 O 674 | 极 O 675 | 大 O 676 | 地 O 677 | 带 O 678 | 动 O 679 | 了 O 680 | 各 O 681 | 类 O 682 | 藏 O 683 | 品 O 684 | 的 O 685 | 价 O 686 | 格 O 687 | 。 O 688 | 689 | 去 O 690 | 年 O 691 | , O 692 | 我 O 693 | 们 O 694 | 又 O 695 | 被 O 696 | 评 O 697 | 为 O 698 | “ O 699 | 北 B-LOC 700 | 京 I-LOC 701 | 市 I-LOC 702 | 首 O 703 | 届 O 704 | 家 O 705 | 庭 O 706 | 藏 O 707 | 书 O 708 | 状 O 709 | 元 O 710 | 明 O 711 | 星 O 712 | 户 O 713 | ” O 714 | 。 O 715 | 716 | 藏 O 717 | 书 O 718 | 家 O 719 | 、 O 720 | 作 O 721 | 家 O 722 | 姜 B-PER 723 | 德 I-PER 724 | 明 I-PER 725 | 先 O 726 | 生 O 727 | 在 O 728 | 1 O 729 | 9 O 730 | 9 O 731 | 7 O 732 | 年 O 733 | 出 O 734 | 版 O 735 | 的 O 736 | 书 O 737 | 话 O 738 | 专 O 739 | 集 O 740 | 《 O 741 | 文 O 742 | 林 O 743 | 枝 O 744 | 叶 O 745 | 》 O 746 | 中 O 747 | 以 O 748 | “ O 749 | 爱 O 750 | 书 O 751 | 的 O 752 | 朋 O 753 | 友 O 754 | ” O 755 | 为 O 756 | 题 O 757 | , O 758 | 详 O 759 | 细 O 760 | 介 O 761 | 绍 O 762 | 了 O 763 | 我 O 764 | 们 O 765 | 夫 O 766 | 妇 O 767 | 的 O 768 | 藏 O 769 | 品 O 770 | 及 O 771 | 三 O 772 | 口 O 773 | 之 O 774 | 家 O 775 | 以 O 776 | 书 O 777 | 为 O 778 | 友 O 779 | 、 O 780 | 好 O 781 | 乐 O 782 | 清 O 783 | 贫 O 784 | 的 O 785 | 逸 O 786 | 闻 O 787 | 趣 O 788 | 事 O 789 | 。 O 790 | 791 | 每 O 792 | 当 O 793 | 接 O 794 | 到 O 795 | 海 O 796 | 外 O 797 | 书 O 798 | 友 O 799 | 或 O 800 | 归 O 801 | 国 O 802 | 人 O 803 | 员 O 804 | 汇 O 805 | 至 O 806 | 一 O 807 | 册 O 808 | 精 O 809 | 美 O 810 | 的 O 811 | 食 O 812 | 品 O 813 | 图 O 814 | 谱 O 815 | 时 O 816 | , O 817 | 全 O 818 | 家 O 819 | 人 O 820 | 欣 O 821 | 喜 O 822 | 若 O 823 | 狂 O 824 | ; O 825 | 826 | 而 O 827 | 一 O 828 | 册 O 829 | 交 O 830 | 换 O 831 | 品 O 832 | 离 O 833 | 我 O 834 | 而 O 835 | 去 O 836 | , O 837 | 虽 O 838 | 为 O 839 | 复 O 840 | 本 O 841 | , O 842 | 也 O 843 | 大 O 844 | 有 O 845 | 李 B-PER 846 | 后 I-PER 847 | 主 I-PER 848 | “ O 849 | 挥 O 850 | 泪 O 851 | 别 O 852 | 宫 O 853 | 娥 O 854 | ” O 855 | 之 O 856 | 感 O 857 | 。 O 858 | 859 | 我 O 860 | 们 O 861 | 变 O 862 | 而 O 863 | 以 O 864 | 书 O 865 | 会 O 866 | 友 O 867 | , O 868 | 以 O 869 | 书 O 870 | 结 O 871 | 缘 O 872 | , O 873 | 把 O 874 | 欧 B-LOC 875 | 美 B-LOC 876 | 、 O 877 | 港 B-LOC 878 | 台 B-LOC 879 | 流 O 880 | 行 O 881 | 的 O 882 | 食 O 883 | 品 O 884 | 类 O 885 | 图 O 886 | 谱 O 887 | 、 O 888 | 画 O 889 | 册 O 890 | 、 O 891 | 工 O 892 | 具 O 893 | 书 O 894 | 汇 O 895 | 集 O 896 | 一 O 897 | 堂 O 898 | 。 O 899 | 900 | 然 O 901 | 而 O 902 | 海 O 903 | 外 O 904 | 图 O 905 | 书 O 906 | 价 O 907 | 格 O 908 | 奇 O 909 | 昂 O 910 | , O 911 | 非 O 912 | 工 O 913 | 薪 O 914 | 层 O 915 | 所 O 916 | 能 O 917 | 承 O 918 | 受 O 919 | 。 O 920 | 921 | 为 O 922 | 了 O 923 | 跟 O 924 | 踪 O 925 | 国 O 926 | 际 O 927 | 最 O 928 | 新 O 929 | 食 O 930 | 品 O 931 | 工 O 932 | 艺 O 933 | 、 O 934 | 流 O 935 | 行 O 936 | 趋 O 937 | 势 O 938 | , O 939 | 大 O 940 | 量 O 941 | 搜 O 942 | 集 O 943 | 海 O 944 | 外 O 945 | 专 O 946 | 业 O 947 | 书 O 948 | 刊 O 949 | 资 O 950 | 料 O 951 | 是 O 952 | 提 O 953 | 高 O 954 | 技 O 955 | 艺 O 956 | 的 O 957 | 捷 O 958 | 径 O 959 | 。 O 960 | 961 | 我 O 962 | 俩 O 963 | 都 O 964 | 是 O 965 | 从 O 966 | 事 O 967 | 食 O 968 | 品 O 969 | 生 O 970 | 产 O 971 | 营 O 972 | 销 O 973 | 工 O 974 | 作 O 975 | 的 O 976 | 。 O 977 | 978 | 每 O 979 | 当 O 980 | 夜 O 981 | 深 O 982 | 人 O 983 | 静 O 984 | , O 985 | 一 O 986 | 册 O 987 | 册 O 988 | 把 O 989 | 玩 O 990 | 那 O 991 | 土 O 992 | 纸 O 993 | 毛 O 994 | 边 O 995 | 、 O 996 | 久 O 997 | 经 O 998 | 战 O 999 | 火 O 1000 | 的 O 1001 | 藏 O 1002 | 品 O 1003 | 时 O 1004 | , O 1005 | 一 O 1006 | 幕 O 1007 | 幕 O 1008 | 可 O 1009 | 歌 O 1010 | 可 O 1011 | 泣 O 1012 | 、 O 1013 | 血 O 1014 | 雨 O 1015 | 腥 O 1016 | 风 O 1017 | 的 O 1018 | 悲 O 1019 | 壮 O 1020 | 场 O 1021 | 景 O 1022 | 便 O 1023 | 历 O 1024 | 历 O 1025 | 在 O 1026 | 目 O 1027 | , O 1028 | 怎 O 1029 | 不 O 1030 | 催 O 1031 | 人 O 1032 | 奋 O 1033 | 进 O 1034 | ! O 1035 | -------------------------------------------------------------------------------- /fancy_nlp/preprocessors/text_classification_preprocessor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | import codecs 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import jieba 9 | from absl import logging 10 | 11 | from fancy_nlp.preprocessors.preprocessor import Preprocessor 12 | from fancy_nlp.utils import get_len_from_corpus, ChineseBertTokenizer 13 | 14 | 15 | class TextClassificationPreprocessor(Preprocessor): 16 | """Text Classification preprocessor. 17 | """ 18 | def __init__(self, train_data, train_labels, min_count=2, use_char=True, use_bert=False, 19 | use_word=False, external_word_dict=None, label_dict_file=None, 20 | bert_vocab_file=None, char_embed_type=None, char_embed_dim=300, 21 | word_embed_type=None, word_embed_dim=300, max_len=None, padding_mode='post', 22 | truncating_mode='post'): 23 | """ 24 | 25 | Args: 26 | train_data: a list of tokenized (in char level) texts 27 | train_labels: list of str, train_data's labels 28 | min_count: int, token of which frequency is lower than min_count will be ignored 29 | use_char:whether to use char embedding as input 30 | use_bert: whether to use bert embedding as input 31 | use_word: whether to use word embedding as additional input 32 | external_word_dict: external word dictionary, only apply when use_word is True 33 | label_dict_file: a file with two columns separated by tab, the first column is raw 34 | label name, and the second column is the corresponding name which is 35 | meaningful 36 | bert_vocab_file: vocabulary file of pre-trained bert model, only apply when use_bert is 37 | True 38 | char_embed_type: str, can be a pre-trained embedding filename or pre-trained embedding 39 | methods (word2vec, glove, fastext) 40 | char_embed_dim: dimensionality of char embedding 41 | word_embed_type: same as char_embed_type, only apply when use_word is True 42 | word_embed_dim: dimensionality of word embedding 43 | max_len: int, max sequence len 44 | padding_mode: 45 | truncating_mode: 46 | """ 47 | super(TextClassificationPreprocessor, self).__init__(max_len, padding_mode, truncating_mode) 48 | 49 | self.train_data = train_data 50 | self.train_labels = train_labels 51 | self.min_count = min_count 52 | self.use_char = use_char 53 | self.use_bert = use_bert 54 | self.use_word = use_word 55 | self.external_word_dict = external_word_dict 56 | self.char_embed_type = char_embed_type 57 | self.word_embed_type = word_embed_type 58 | 59 | self.label_dict = self.load_label_dict(label_dict_file) 60 | 61 | assert self.use_char or self.use_bert, "must use char or bert embedding as main input" 62 | special_token = 'bert' if self.use_bert else 'standard' 63 | 64 | # build char vocabulary and char embedding 65 | if self.use_char: 66 | self.char_vocab_count, self.char_vocab, self.id2char = \ 67 | self.build_vocab(self.train_data, self.min_count, special_token) 68 | self.char_vocab_size = len(self.char_vocab) 69 | self.char_embeddings = self.build_embedding(char_embed_type, self.char_vocab, 70 | self.train_data, char_embed_dim, 71 | special_token) 72 | if self.char_embeddings is not None: 73 | self.char_embed_dim = self.char_embeddings.shape[1] 74 | else: 75 | self.char_embed_dim = char_embed_dim 76 | else: 77 | self.char_vocab_count, self.char_vocab, self.id2char = None, None, None 78 | self.char_vocab_size = -1 79 | self.char_embeddings = None 80 | self.char_embed_dim = -1 81 | 82 | # build bert vocabulary 83 | if self.use_bert: 84 | # lower case for non-chinese character 85 | self.bert_tokenizer = ChineseBertTokenizer(bert_vocab_file) 86 | 87 | # build word vocabulary and word embedding 88 | if self.use_word: 89 | self.load_word_dict() 90 | 91 | untokenized_texts = [''.join(text) for text in self.train_data] 92 | word_corpus = self.build_corpus(untokenized_texts, cut_func=lambda x: jieba.lcut(x)) 93 | 94 | self.word_vocab_count, self.word_vocab, self.id2word = \ 95 | self.build_vocab(word_corpus, self.min_count, special_token) 96 | self.word_vocab_size = len(self.word_vocab) 97 | self.word_embeddings = self.build_embedding(word_embed_type, self.word_vocab, 98 | word_corpus, word_embed_dim, 99 | special_token) 100 | if self.word_embeddings is not None: 101 | self.word_embed_dim = self.word_embeddings.shape[1] 102 | else: 103 | self.word_embed_dim = word_embed_dim 104 | else: 105 | self.word_vocab_count, self.word_vocab, self.id2word = None, None, None 106 | self.word_vocab_size = -1 107 | self.word_embeddings = None 108 | self.word_embed_dim = -1 109 | 110 | # build label vocabulary 111 | self.label_vocab, self.id2label = self.build_label_vocab(self.train_labels) 112 | self.num_class = len(self.label_vocab) 113 | 114 | if self.use_bert and self.max_len is None: 115 | # max_len must be provided when use bert as input! 116 | # We will reset max_len from train_data when max_len is not provided. 117 | self.max_len = get_len_from_corpus(self.train_data) 118 | 119 | # make sure max_len is shorted than bert's max length (512) 120 | # since there are 2 more special token: and , so add 2 121 | self.max_len = min(self.max_len + 2, 512) 122 | 123 | def load_word_dict(self): 124 | if self.external_word_dict: 125 | for word in self.external_word_dict: 126 | jieba.add_word(word, freq=1000000) 127 | 128 | @staticmethod 129 | def load_label_dict(label_dict_file): 130 | result_dict = dict() 131 | if label_dict_file: 132 | with codecs.open(label_dict_file, encoding='utf-8') as f_label_dict: 133 | for line in f_label_dict: 134 | line_items = line.strip().split('\t') 135 | result_dict[line_items[0]] = line_items[1] 136 | return result_dict 137 | else: 138 | return None 139 | 140 | def build_label_vocab(self, labels): 141 | """Build label vocabulary 142 | 143 | Args: 144 | labels: list of str, the label strings 145 | """ 146 | label_count = {} 147 | for label in labels: 148 | label_count[label] = label_count.get(label, 0) + 1 149 | 150 | # sorted by frequency, so that the label with the highest frequency will be given 151 | # id of 0, which is the default id for unknown labels 152 | sorted_label_count = sorted(label_count.items(), key=lambda x: x[1], reverse=True) 153 | sorted_label_count = dict(sorted_label_count) 154 | 155 | label_vocab = {} 156 | for label in sorted_label_count: 157 | label_vocab[label] = len(label_vocab) 158 | 159 | id2label = dict((idx, label) for label, idx in label_vocab.items()) 160 | 161 | logging.info('Build label vocabulary finished, ' 162 | 'vocabulary size: {}'.format(len(label_vocab))) 163 | return label_vocab, id2label 164 | 165 | def prepare_input(self, data, labels=None): 166 | """Prepare input (features and labels) for text classification model. 167 | Here we not only use character embeddings (or bert embeddings) as main input, but also 168 | support word embeddings and other hand-crafted features embeddings as additional input. 169 | 170 | Args: 171 | data: list of tokenized (in char level) texts, like ``[['我', '是', '中', '国', '人']]`` 172 | labels: list of str, the corresponding label strings 173 | 174 | Returns: 175 | features: id matrix 176 | y: label id matrix (only if labels is provided) 177 | 178 | """ 179 | batch_char_ids, batch_bert_ids, batch_bert_seg_ids, batch_word_ids = [], [], [], [] 180 | batch_label_ids = [] 181 | for i, char_text in enumerate(data): 182 | if self.use_char: 183 | if self.use_bert: 184 | text_for_char_input = [self.cls_token] + char_text + [self.seq_token] 185 | else: 186 | text_for_char_input = char_text 187 | char_ids = [self.char_vocab.get(token, self.char_vocab[self.unk_token]) 188 | for token in text_for_char_input] 189 | batch_char_ids.append(char_ids) 190 | 191 | if self.use_bert: 192 | indices, segments = self.bert_tokenizer.encode(first_text=''.join(char_text), 193 | max_length=self.max_len) 194 | batch_bert_ids.append(indices) 195 | batch_bert_seg_ids.append(segments) 196 | 197 | if self.use_word: 198 | word_text = jieba.lcut(''.join(char_text)) 199 | word_ids = self.get_word_ids(word_text) 200 | batch_word_ids.append(word_ids) 201 | 202 | if labels is not None: 203 | batch_label_ids = [self.label_vocab.get(l, 0) for l in labels] 204 | batch_label_ids = tf.keras.utils.to_categorical(batch_label_ids, 205 | self.num_class).astype(int) 206 | 207 | features = [] 208 | if self.use_char: 209 | features.append(self.pad_sequence(batch_char_ids)) 210 | if self.use_bert: 211 | features.append(self.pad_sequence(batch_bert_ids)) 212 | features.append(self.pad_sequence(batch_bert_seg_ids)) 213 | if self.use_word: 214 | features.append(self.pad_sequence(batch_word_ids)) 215 | 216 | if len(features) == 1: 217 | features = features[0] 218 | 219 | if not list(batch_label_ids): 220 | return features, None 221 | else: 222 | y = batch_label_ids 223 | return features, y 224 | 225 | def get_word_ids(self, word_cut): 226 | """Given a word-level tokenized text, return the corresponding word ids in char-level 227 | sequence. We add the same word id to each character in the word. 228 | 229 | Args: 230 | word_cut: list of str, like ['我', '是'. '中国人'] 231 | unk_idx: the index of words that do not appear in vocabulary, we usually set it to 1 232 | 233 | Returns: list of int, id sequence 234 | 235 | """ 236 | word_ids = [] 237 | for word in word_cut: 238 | for _ in word: 239 | word_ids.append(self.word_vocab.get(word, self.word_vocab[self.unk_token])) 240 | if self.use_bert: 241 | word_ids = [self.word_vocab[self.cls_token]] + word_ids + \ 242 | [self.word_vocab[self.seq_token]] 243 | return word_ids 244 | 245 | def label_decode(self, pred_probs, label_dict=None): 246 | pred_ids = np.argmax(pred_probs, axis=-1) 247 | pred_labels = [self.id2label[pred_id] for pred_id in pred_ids] 248 | if label_dict: 249 | pred_labels = [label_dict[raw_label] for raw_label in pred_labels] 250 | return pred_labels 251 | 252 | def save(self, preprocessor_file): 253 | pickle.dump(self, open(preprocessor_file, 'wb')) 254 | 255 | @classmethod 256 | def load(cls, preprocessor_file): 257 | p = pickle.load(open(preprocessor_file, 'rb')) 258 | p.load_word_dict() # reload external word dict into jieba 259 | return p 260 | -------------------------------------------------------------------------------- /tests/fancy_nlp/trainer/test_ner_trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | from fancy_nlp.utils import load_ner_data_and_labels 6 | from fancy_nlp.models.ner import BiLSTMCNNNER 7 | from fancy_nlp.preprocessors import NERPreprocessor 8 | from fancy_nlp.trainers import NERTrainer 9 | 10 | 11 | class TestNERTrainer: 12 | test_file = os.path.join(os.path.dirname(__file__), '../../../data/ner/msra/example.txt') 13 | bert_vocab_file = os.path.join(os.path.dirname(__file__), 14 | '../../../data/embeddings/bert_sample_model/vocab.txt') 15 | bert_config_file = os.path.join(os.path.dirname(__file__), 16 | '../../../data/embeddings/bert_sample_model/bert_config.json') 17 | bert_model_file = os.path.join(os.path.dirname(__file__), 18 | '../../../data/embeddings/bert_sample_model/bert_model.ckpt') 19 | 20 | def setup_class(self): 21 | self.train_data, self.train_labels, self.valid_data, self.valid_labels = \ 22 | load_ner_data_and_labels(self.test_file, split=True) 23 | self.preprocessor = NERPreprocessor(self.train_data+self.valid_data, 24 | self.train_labels+self.valid_labels, 25 | use_bert=True, 26 | use_word=True, 27 | bert_vocab_file=self.bert_vocab_file, 28 | char_embed_type='word2vec', 29 | word_embed_type='word2vec', 30 | max_len=16) 31 | self.num_class = self.preprocessor.num_class 32 | self.char_embeddings = self.preprocessor.char_embeddings 33 | self.char_vocab_size = self.preprocessor.char_vocab_size 34 | self.char_embed_dim = self.preprocessor.char_embed_dim 35 | 36 | self.word_embeddings = self.preprocessor.word_embeddings 37 | self.word_vocab_size = self.preprocessor.word_vocab_size 38 | self.word_embed_dim = self.preprocessor.word_embed_dim 39 | self.checkpoint_dir = os.path.dirname(__file__) 40 | 41 | self.ner_model = BiLSTMCNNNER(num_class=self.num_class, 42 | use_char=True, 43 | char_embeddings=self.char_embeddings, 44 | char_vocab_size=self.char_vocab_size, 45 | char_embed_dim=self.char_embed_dim, 46 | char_embed_trainable=False, 47 | use_bert=True, 48 | bert_config_file=self.bert_config_file, 49 | bert_checkpoint_file=self.bert_model_file, 50 | use_word=True, 51 | word_embeddings=self.word_embeddings, 52 | word_vocab_size=self.word_vocab_size, 53 | word_embed_dim=self.word_embed_dim, 54 | word_embed_trainable=False, 55 | max_len=self.preprocessor.max_len, 56 | use_crf=True).build_model() 57 | 58 | self.swa_model = BiLSTMCNNNER(num_class=self.num_class, 59 | use_char=True, 60 | char_embeddings=self.char_embeddings, 61 | char_vocab_size=self.char_vocab_size, 62 | char_embed_dim=self.char_embed_dim, 63 | char_embed_trainable=False, 64 | use_bert=True, 65 | bert_config_file=self.bert_config_file, 66 | bert_checkpoint_file=self.bert_model_file, 67 | use_word=True, 68 | word_embeddings=self.word_embeddings, 69 | word_vocab_size=self.word_vocab_size, 70 | word_embed_dim=self.word_embed_dim, 71 | word_embed_trainable=False, 72 | max_len=self.preprocessor.max_len, 73 | use_crf=True).build_model() 74 | 75 | self.ner_trainer = NERTrainer(self.ner_model, self.preprocessor) 76 | 77 | self.json_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.json') 78 | self.weights_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner.hdf5') 79 | 80 | def test_train(self): 81 | self.ner_trainer.train(self.train_data, self.train_labels, self.valid_data, 82 | self.valid_labels, batch_size=2, epochs=2) 83 | assert not os.path.exists(self.json_file) 84 | assert not os.path.exists(self.weights_file) 85 | 86 | def test_train_no_crf(self): 87 | ner_model = BiLSTMCNNNER(num_class=self.num_class, 88 | use_char=True, 89 | char_embeddings=self.char_embeddings, 90 | char_vocab_size=self.char_vocab_size, 91 | char_embed_dim=self.char_embed_dim, 92 | char_embed_trainable=False, 93 | use_bert=True, 94 | bert_config_file=self.bert_config_file, 95 | bert_checkpoint_file=self.bert_model_file, 96 | use_word=True, 97 | word_embeddings=self.word_embeddings, 98 | word_vocab_size=self.word_vocab_size, 99 | word_embed_dim=self.word_embed_dim, 100 | word_embed_trainable=False, 101 | max_len=self.preprocessor.max_len, 102 | use_crf=False).build_model() 103 | 104 | ner_trainer = NERTrainer(ner_model, self.preprocessor) 105 | ner_trainer.train(self.train_data, self.train_labels, self.valid_data, self.valid_labels, 106 | batch_size=2, epochs=2) 107 | assert not os.path.exists(self.json_file) 108 | assert not os.path.exists(self.weights_file) 109 | 110 | def test_train_no_word(self): 111 | preprocessor = NERPreprocessor(self.train_data+self.valid_data, 112 | self.train_labels+self.valid_labels, 113 | use_bert=True, 114 | use_word=False, 115 | bert_vocab_file=self.bert_vocab_file, 116 | max_len=16, 117 | char_embed_type='word2vec') 118 | ner_model = BiLSTMCNNNER(num_class=preprocessor.num_class, 119 | use_char=True, 120 | char_embeddings=preprocessor.char_embeddings, 121 | char_vocab_size=preprocessor.char_vocab_size, 122 | char_embed_dim=preprocessor.char_embed_dim, 123 | char_embed_trainable=False, 124 | use_bert=True, 125 | bert_config_file=self.bert_config_file, 126 | bert_checkpoint_file=self.bert_model_file, 127 | use_word=False, 128 | word_embeddings=preprocessor.word_embeddings, 129 | word_vocab_size=preprocessor.word_vocab_size, 130 | word_embed_dim=preprocessor.word_embed_dim, 131 | word_embed_trainable=False, 132 | max_len=preprocessor.max_len, 133 | use_crf=True).build_model() 134 | 135 | ner_trainer = NERTrainer(ner_model, preprocessor) 136 | ner_trainer.train(self.train_data, self.train_labels, self.valid_data, self.valid_labels, 137 | batch_size=2, epochs=2) 138 | assert not os.path.exists(self.json_file) 139 | assert not os.path.exists(self.weights_file) 140 | 141 | def test_train_no_bert(self): 142 | preprocessor = NERPreprocessor(self.train_data + self.valid_data, 143 | self.train_labels + self.valid_labels, 144 | use_word=True, 145 | char_embed_type='word2vec') 146 | ner_model = BiLSTMCNNNER(num_class=preprocessor.num_class, 147 | use_char=True, 148 | char_embeddings=preprocessor.char_embeddings, 149 | char_vocab_size=preprocessor.char_vocab_size, 150 | char_embed_dim=preprocessor.char_embed_dim, 151 | char_embed_trainable=False, 152 | use_word=True, 153 | word_embeddings=preprocessor.word_embeddings, 154 | word_vocab_size=preprocessor.word_vocab_size, 155 | word_embed_dim=preprocessor.word_embed_dim, 156 | word_embed_trainable=False, 157 | max_len=preprocessor.max_len, 158 | use_crf=True).build_model() 159 | 160 | ner_trainer = NERTrainer(ner_model, preprocessor) 161 | ner_trainer.train(self.train_data, self.train_labels, self.valid_data, self.valid_labels, 162 | batch_size=2, epochs=2) 163 | assert not os.path.exists(self.json_file) 164 | assert not os.path.exists(self.weights_file) 165 | 166 | def test_train_no_valid_data(self): 167 | self.ner_trainer.train(self.train_data, self.train_labels, batch_size=2, epochs=2) 168 | assert not os.path.exists(self.json_file) 169 | assert not os.path.exists(self.weights_file) 170 | 171 | def test_train_callbacks(self): 172 | self.ner_trainer.train(self.train_data, self.train_labels, self.valid_data, 173 | self.valid_labels, batch_size=2, epochs=2, 174 | callback_list=['modelcheckpoint', 'earlystopping'], 175 | checkpoint_dir=os.path.dirname(__file__), 176 | model_name='bilstm_cnn_ner') 177 | 178 | assert not os.path.exists(self.json_file) 179 | assert os.path.exists(self.weights_file) 180 | os.remove(self.weights_file) 181 | assert not os.path.exists(self.weights_file) 182 | 183 | def test_train_swa(self): 184 | self.ner_trainer.train(self.train_data, self.train_labels, self.valid_data, 185 | self.valid_labels, batch_size=2, epochs=7, callback_list=['swa'], 186 | checkpoint_dir=os.path.dirname(__file__), 187 | model_name='bilstm_cnn_ner', 188 | swa_model=self.swa_model, 189 | load_swa_model=True) 190 | 191 | assert not os.path.exists(self.json_file) 192 | assert not os.path.exists(self.weights_file) 193 | 194 | json_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner_swa.json') 195 | weights_file = os.path.join(self.checkpoint_dir, 'bilstm_cnn_ner_swa.hdf5') 196 | assert not os.path.exists(json_file) 197 | assert os.path.exists(weights_file) 198 | os.remove(weights_file) 199 | assert not os.path.exists(weights_file) 200 | 201 | def test_generator(self): 202 | self.ner_trainer.train_generator(self.train_data, self.train_labels, 203 | self.valid_data, self.valid_labels, batch_size=2, epochs=2) 204 | 205 | assert not os.path.exists(self.json_file) 206 | assert not os.path.exists(self.weights_file) 207 | --------------------------------------------------------------------------------