├── requirements.txt ├── LICENSE ├── README.md ├── crf_test.py ├── .gitignore └── crf.py /requirements.txt: -------------------------------------------------------------------------------- 1 | bleach==1.5.0 2 | h5py==2.7.1 3 | html5lib==0.9999999 4 | Keras==2.0.8 5 | Markdown==2.6.9 6 | numpy==1.13.1 7 | protobuf==3.4.0 8 | PyYAML==3.12 9 | scipy==0.19.1 10 | six==1.10.0 11 | tensorflow==1.3.0 12 | tensorflow-tensorboard==0.1.5 13 | Werkzeug==0.12.2 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Hiroki Nakayama 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras-CRF-Layer 2 | The Keras-CRF-Layer module implements a linear-chain CRF layer for learning to predict tag sequences. 3 | This variant of the CRF is factored into unary potentials for every element in the sequence and binary potentials for every transition between output tags. 4 | 5 | ## Usage 6 | Below is an example of the API, which learns a CRF for some random data. 7 | The linear layer in the example can be replaced by any neural network. 8 | 9 | ```python 10 | import numpy as np 11 | from keras.layers import Embedding, Input 12 | from keras.models import Model 13 | 14 | from crf import CRFLayer 15 | 16 | # Hyperparameter settings. 17 | vocab_size = 20 18 | n_classes = 11 19 | batch_size = 2 20 | maxlen = 2 21 | 22 | # Random features. 23 | x = np.random.randint(1, vocab_size, size=(batch_size, maxlen)) 24 | 25 | # Random tag indices representing the gold sequence. 26 | y = np.random.randint(n_classes, size=(batch_size, maxlen)) 27 | y = np.eye(n_classes)[y] 28 | 29 | # All sequences in this example have the same length, but they can be variable in a real model. 30 | s = np.asarray([maxlen] * batch_size, dtype='int32') 31 | 32 | # Build an example model. 33 | word_ids = Input(batch_shape=(batch_size, maxlen), dtype='int32') 34 | sequence_lengths = Input(batch_shape=[batch_size, 1], dtype='int32') 35 | 36 | word_embeddings = Embedding(vocab_size, n_classes)(word_ids) 37 | crf = CRFLayer() 38 | pred = crf(inputs=[word_embeddings, sequence_lengths]) 39 | model = Model(inputs=[word_ids, sequence_lengths], outputs=[pred]) 40 | model.compile(loss=crf.loss, optimizer='sgd') 41 | 42 | # Train first 1 batch. 43 | model.train_on_batch([x, s], y) 44 | 45 | # Save the model 46 | model.save('model.h5') 47 | ``` 48 | 49 | ### Model loading 50 | When you want to load a saved model that has a crf output, then loading 51 | the model with 'keras.models.load_model' won't work properly because 52 | the reference of the loss function to the transition parameters is lost. To 53 | fix this, you need to use the parameter 'custom_objects' as follows: 54 | 55 | ```python 56 | from keras.models import load_model 57 | 58 | from crf import create_custom_objects 59 | 60 | model = load_model('model.h5', custom_objects=create_custom_objects()) 61 | ``` 62 | -------------------------------------------------------------------------------- /crf_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from keras.layers import Embedding, Input, LSTM, Dropout, Dense, Bidirectional 5 | from keras.models import Model, load_model 6 | 7 | from crf import CRFLayer, create_custom_objects 8 | 9 | 10 | class LayerTest(unittest.TestCase): 11 | 12 | def setUp(self): 13 | self.filename = 'test.h5' 14 | 15 | def test_crf_layer(self): 16 | 17 | # Hyperparameter settings. 18 | vocab_size = 20 19 | n_classes = 11 20 | batch_size = 2 21 | maxlen = 2 22 | 23 | # Random features. 24 | x = np.random.randint(1, vocab_size, size=(batch_size, maxlen)) 25 | 26 | # Random tag indices representing the gold sequence. 27 | y = np.random.randint(n_classes, size=(batch_size, maxlen)) 28 | y = np.eye(n_classes)[y] 29 | 30 | # All sequences in this example have the same length, but they can be variable in a real model. 31 | s = np.asarray([maxlen] * batch_size, dtype='int32') 32 | 33 | # Build a model. 34 | word_ids = Input(batch_shape=(batch_size, maxlen), dtype='int32') 35 | word_embeddings = Embedding(vocab_size, n_classes)(word_ids) 36 | sequence_lengths = Input(batch_shape=[batch_size, 1], dtype='int32') 37 | crf = CRFLayer() 38 | pred = crf([word_embeddings, sequence_lengths]) 39 | model = Model(inputs=[word_ids, sequence_lengths], outputs=[pred]) 40 | model.compile(loss=crf.loss, optimizer='sgd') 41 | 42 | # Train first 1 batch. 43 | model.train_on_batch([x, s], y) 44 | 45 | # Save the model. 46 | model.save(self.filename) 47 | 48 | def test_load_model(self): 49 | model = load_model(self.filename, custom_objects=create_custom_objects()) 50 | 51 | def test_bilstm_crf(self): 52 | 53 | # Hyperparameter settings. 54 | vocab_size = 10000 55 | word_embedding_size = 100 56 | num_word_lstm_units = 100 57 | dropout = 0.5 58 | ntags = 10 59 | 60 | # Build bidirectional lstm-crf model. 61 | word_ids = Input(batch_shape=(None, None), dtype='int32') 62 | word_embeddings = Embedding(input_dim=vocab_size, 63 | output_dim=word_embedding_size, 64 | mask_zero=True)(word_ids) 65 | 66 | x = Bidirectional(LSTM(units=num_word_lstm_units, return_sequences=True))(word_embeddings) 67 | x = Dropout(dropout)(x) 68 | x = Dense(ntags)(x) 69 | 70 | sequence_lengths = Input(batch_shape=(None, 1), dtype='int32') 71 | 72 | crf = CRFLayer() 73 | pred = crf([x, sequence_lengths]) 74 | 75 | model = Model(inputs=[word_ids, sequence_lengths], outputs=[pred]) 76 | model.compile(loss=crf.loss, optimizer='sgd') 77 | 78 | model.save(self.filename) 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff: 7 | .idea/workspace.xml 8 | .idea/tasks.xml 9 | 10 | # Sensitive or high-churn files: 11 | .idea/dataSources/ 12 | .idea/dataSources.ids 13 | .idea/dataSources.xml 14 | .idea/dataSources.local.xml 15 | .idea/sqlDataSources.xml 16 | .idea/dynamic.xml 17 | .idea/uiDesigner.xml 18 | 19 | # Gradle: 20 | .idea/gradle.xml 21 | .idea/libraries 22 | 23 | # Mongo Explorer plugin: 24 | .idea/mongoSettings.xml 25 | 26 | ## File-based project format: 27 | *.iws 28 | 29 | ## Plugin-specific files: 30 | 31 | # IntelliJ 32 | /out/ 33 | 34 | # mpeltonen/sbt-idea plugin 35 | .idea_modules/ 36 | 37 | # JIRA plugin 38 | atlassian-ide-plugin.xml 39 | 40 | # Crashlytics plugin (for Android Studio and IntelliJ) 41 | com_crashlytics_export_strings.xml 42 | crashlytics.properties 43 | crashlytics-build.properties 44 | fabric.properties 45 | ### VirtualEnv template 46 | # Virtualenv 47 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 48 | .Python 49 | [Bb]in 50 | [Ii]nclude 51 | [Ll]ib 52 | [Ll]ib64 53 | [Ll]ocal 54 | [Ss]cripts 55 | pyvenv.cfg 56 | .venv 57 | pip-selfcheck.json 58 | ### Python template 59 | # Byte-compiled / optimized / DLL files 60 | __pycache__/ 61 | *.py[cod] 62 | *$py.class 63 | 64 | # C extensions 65 | *.so 66 | 67 | # Distribution / packaging 68 | env/ 69 | build/ 70 | develop-eggs/ 71 | dist/ 72 | downloads/ 73 | eggs/ 74 | .eggs/ 75 | lib/ 76 | lib64/ 77 | parts/ 78 | sdist/ 79 | var/ 80 | *.egg-info/ 81 | .installed.cfg 82 | *.egg 83 | 84 | # PyInstaller 85 | # Usually these files are written by a python script from a template 86 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 87 | *.manifest 88 | *.spec 89 | 90 | # Installer logs 91 | pip-log.txt 92 | pip-delete-this-directory.txt 93 | 94 | # Unit test / coverage reports 95 | htmlcov/ 96 | .tox/ 97 | .coverage 98 | .coverage.* 99 | .cache 100 | nosetests.xml 101 | coverage.xml 102 | *,cover 103 | .hypothesis/ 104 | 105 | # Translations 106 | *.mo 107 | *.pot 108 | 109 | # Django stuff: 110 | *.log 111 | local_settings.py 112 | 113 | # Flask stuff: 114 | instance/ 115 | .webassets-cache 116 | 117 | # Scrapy stuff: 118 | .scrapy 119 | 120 | # Sphinx documentation 121 | docs/_build/ 122 | 123 | # PyBuilder 124 | target/ 125 | 126 | # Jupyter Notebook 127 | .ipynb_checkpoints 128 | 129 | # pyenv 130 | .python-version 131 | 132 | # celery beat schedule file 133 | celerybeat-schedule 134 | 135 | # dotenv 136 | .env 137 | 138 | # virtualenv 139 | .venv/ 140 | venv/ 141 | ENV/ 142 | 143 | # Spyder project settings 144 | .spyderproject 145 | 146 | # Rope project settings 147 | .ropeproject 148 | ### macOS template 149 | *.DS_Store 150 | .AppleDouble 151 | .LSOverride 152 | 153 | # Icon must end with two \r 154 | Icon 155 | 156 | 157 | # Thumbnails 158 | ._* 159 | 160 | # Files that might appear in the root of a volume 161 | .DocumentRevisions-V100 162 | .fseventsd 163 | .Spotlight-V100 164 | .TemporaryItems 165 | .Trashes 166 | .VolumeIcon.icns 167 | .com.apple.timemachine.donotpresent 168 | 169 | # Directories potentially created on remote AFP share 170 | .AppleDB 171 | .AppleDesktop 172 | Network Trash Folder 173 | Temporary Items 174 | .apdisk 175 | -------------------------------------------------------------------------------- /crf.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from keras import backend as K 3 | from keras.engine import Layer, InputSpec 4 | 5 | try: 6 | from tensorflow.contrib.crf import crf_decode 7 | except ImportError: 8 | from tensorflow.python.framework import dtypes 9 | from tensorflow.python.ops import array_ops, gen_array_ops, math_ops, rnn, rnn_cell 10 | 11 | 12 | class CrfDecodeForwardRnnCell(rnn_cell.RNNCell): 13 | """Computes the forward decoding in a linear-chain CRF. 14 | """ 15 | 16 | def __init__(self, transition_params): 17 | """Initialize the CrfDecodeForwardRnnCell. 18 | Args: 19 | transition_params: A [num_tags, num_tags] matrix of binary 20 | potentials. This matrix is expanded into a 21 | [1, num_tags, num_tags] in preparation for the broadcast 22 | summation occurring within the cell. 23 | """ 24 | self._transition_params = array_ops.expand_dims(transition_params, 0) 25 | self._num_tags = transition_params.get_shape()[0].value 26 | 27 | @property 28 | def state_size(self): 29 | return self._num_tags 30 | 31 | @property 32 | def output_size(self): 33 | return self._num_tags 34 | 35 | def __call__(self, inputs, state, scope=None): 36 | """Build the CrfDecodeForwardRnnCell. 37 | Args: 38 | inputs: A [batch_size, num_tags] matrix of unary potentials. 39 | state: A [batch_size, num_tags] matrix containing the previous step's 40 | score values. 41 | scope: Unused variable scope of this cell. 42 | Returns: 43 | backpointers: [batch_size, num_tags], containing backpointers. 44 | new_state: [batch_size, num_tags], containing new score values. 45 | """ 46 | # For simplicity, in shape comments, denote: 47 | # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). 48 | state = array_ops.expand_dims(state, 2) # [B, O, 1] 49 | 50 | # This addition op broadcasts self._transitions_params along the zeroth 51 | # dimension and state along the second dimension. 52 | # [B, O, 1] + [1, O, O] -> [B, O, O] 53 | transition_scores = state + self._transition_params # [B, O, O] 54 | new_state = inputs + math_ops.reduce_max(transition_scores, [1]) # [B, O] 55 | backpointers = math_ops.argmax(transition_scores, 1) 56 | backpointers = math_ops.cast(backpointers, dtype=dtypes.int32) # [B, O] 57 | return backpointers, new_state 58 | 59 | 60 | class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell): 61 | """Computes backward decoding in a linear-chain CRF. 62 | """ 63 | 64 | def __init__(self, num_tags): 65 | """Initialize the CrfDecodeBackwardRnnCell. 66 | Args: 67 | num_tags 68 | """ 69 | self._num_tags = num_tags 70 | 71 | @property 72 | def state_size(self): 73 | return 1 74 | 75 | @property 76 | def output_size(self): 77 | return 1 78 | 79 | def __call__(self, inputs, state, scope=None): 80 | """Build the CrfDecodeBackwardRnnCell. 81 | Args: 82 | inputs: [batch_size, num_tags], backpointer of next step (in time order). 83 | state: [batch_size, 1], next position's tag index. 84 | scope: Unused variable scope of this cell. 85 | Returns: 86 | new_tags, new_tags: A pair of [batch_size, num_tags] 87 | tensors containing the new tag indices. 88 | """ 89 | state = array_ops.squeeze(state, axis=[1]) # [B] 90 | batch_size = array_ops.shape(inputs)[0] 91 | b_indices = math_ops.range(batch_size) # [B] 92 | indices = array_ops.stack([b_indices, state], axis=1) # [B, 2] 93 | new_tags = array_ops.expand_dims( 94 | gen_array_ops.gather_nd(inputs, indices), # [B] 95 | axis=-1) # [B, 1] 96 | 97 | return new_tags, new_tags 98 | 99 | 100 | def crf_decode(potentials, transition_params, sequence_length): 101 | """Decode the highest scoring sequence of tags in TensorFlow. 102 | This is a function for tensor. 103 | Args: 104 | potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of 105 | unary potentials. 106 | transition_params: A [num_tags, num_tags] tensor, matrix of 107 | binary potentials. 108 | sequence_length: A [batch_size] tensor, containing sequence lengths. 109 | Returns: 110 | decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. 111 | Contains the highest scoring tag indicies. 112 | best_score: A [batch_size] tensor, containing the score of decode_tags. 113 | """ 114 | # For simplicity, in shape comments, denote: 115 | # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output). 116 | num_tags = potentials.get_shape()[2].value 117 | 118 | # Computes forward decoding. Get last score and backpointers. 119 | crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params) 120 | initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1]) 121 | initial_state = array_ops.squeeze(initial_state, axis=[1]) # [B, O] 122 | inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1]) # [B, T-1, O] 123 | backpointers, last_score = rnn.dynamic_rnn( 124 | crf_fwd_cell, 125 | inputs=inputs, 126 | sequence_length=sequence_length - 1, 127 | initial_state=initial_state, 128 | time_major=False, 129 | dtype=dtypes.int32) # [B, T - 1, O], [B, O] 130 | backpointers = gen_array_ops.reverse_sequence(backpointers, sequence_length - 1, seq_dim=1) # [B, T-1, O] 131 | 132 | # Computes backward decoding. Extract tag indices from backpointers. 133 | crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags) 134 | initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1), dtype=dtypes.int32) # [B] 135 | initial_state = array_ops.expand_dims(initial_state, axis=-1) # [B, 1] 136 | decode_tags, _ = rnn.dynamic_rnn( 137 | crf_bwd_cell, 138 | inputs=backpointers, 139 | sequence_length=sequence_length - 1, 140 | initial_state=initial_state, 141 | time_major=False, 142 | dtype=dtypes.int32) # [B, T - 1, 1] 143 | decode_tags = array_ops.squeeze(decode_tags, axis=[2]) # [B, T - 1] 144 | decode_tags = array_ops.concat([initial_state, decode_tags], axis=1) # [B, T] 145 | decode_tags = gen_array_ops.reverse_sequence(decode_tags, sequence_length, seq_dim=1) # [B, T] 146 | 147 | best_score = math_ops.reduce_max(last_score, axis=1) # [B] 148 | return decode_tags, best_score 149 | 150 | 151 | class CRFLayer(Layer): 152 | 153 | def __init__(self, transition_params=None, **kwargs): 154 | super(CRFLayer, self).__init__(**kwargs) 155 | self.transition_params = transition_params 156 | self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=2)] 157 | self.supports_masking = True 158 | 159 | def compute_output_shape(self, input_shape): 160 | assert input_shape and len(input_shape[0]) == 3 161 | 162 | return input_shape[0] 163 | 164 | def build(self, input_shape): 165 | """Creates the layer weights. 166 | 167 | Args: 168 | input_shape (list(tuple, tuple)): [(batch_size, n_steps, n_classes), (batch_size, 1)] 169 | """ 170 | assert len(input_shape) == 2 171 | assert len(input_shape[0]) == 3 172 | assert len(input_shape[1]) == 2 173 | n_steps = input_shape[0][1] 174 | n_classes = input_shape[0][2] 175 | assert n_steps is None or n_steps >= 2 176 | 177 | self.transition_params = self.add_weight(shape=(n_classes, n_classes), 178 | initializer='uniform', 179 | name='transition') 180 | self.input_spec = [InputSpec(dtype=K.floatx(), shape=(None, n_steps, n_classes)), 181 | InputSpec(dtype='int32', shape=(None, 1))] 182 | self.built = True 183 | 184 | def viterbi_decode(self, potentials, sequence_length): 185 | """Decode the highest scoring sequence of tags in TensorFlow. 186 | 187 | This is a function for tensor. 188 | 189 | Args: 190 | potentials: A [batch_size, max_seq_len, num_tags] tensor, matrix of unary potentials. 191 | sequence_length: A [batch_size] tensor, containing sequence lengths. 192 | 193 | Returns: 194 | decode_tags: A [batch_size, max_seq_len] tensor, with dtype tf.int32. 195 | Contains the highest scoring tag indicies. 196 | """ 197 | decode_tags, best_score = crf_decode(potentials, self.transition_params, sequence_length) 198 | 199 | return decode_tags 200 | 201 | def call(self, inputs, mask=None, **kwargs): 202 | inputs, sequence_lengths = inputs 203 | self.sequence_lengths = K.flatten(sequence_lengths) 204 | y_pred = self.viterbi_decode(inputs, self.sequence_lengths) 205 | nb_classes = self.input_spec[0].shape[2] 206 | y_pred_one_hot = K.one_hot(y_pred, nb_classes) 207 | 208 | return K.in_train_phase(inputs, y_pred_one_hot) 209 | 210 | def loss(self, y_true, y_pred): 211 | """Computes the log-likelihood of tag sequences in a CRF. 212 | 213 | Args: 214 | y_true : A (batch_size, n_steps, n_classes) tensor. 215 | y_pred : A (batch_size, n_steps, n_classes) tensor. 216 | 217 | Returns: 218 | loss: A scalar containing the log-likelihood of the given sequence of tag indices. 219 | """ 220 | y_true = K.cast(K.argmax(y_true, axis=-1), dtype='int32') 221 | log_likelihood, self.transition_params = tf.contrib.crf.crf_log_likelihood( 222 | y_pred, y_true, self.sequence_lengths, self.transition_params) 223 | loss = tf.reduce_mean(-log_likelihood) 224 | 225 | return loss 226 | 227 | def get_config(self): 228 | config = { 229 | 'transition_params': K.eval(self.transition_params), 230 | } 231 | base_config = super(CRFLayer, self).get_config() 232 | 233 | return dict(list(base_config.items()) + list(config.items())) 234 | 235 | 236 | def create_custom_objects(): 237 | """Returns the custom objects, needed for loading a persisted model.""" 238 | instanceHolder = {'instance': None} 239 | 240 | class ClassWrapper(CRFLayer): 241 | def __init__(self, *args, **kwargs): 242 | instanceHolder['instance'] = self 243 | super(ClassWrapper, self).__init__(*args, **kwargs) 244 | 245 | def loss(*args): 246 | method = getattr(instanceHolder['instance'], 'loss') 247 | return method(*args) 248 | 249 | return {'CRFLayer': ClassWrapper, 'loss': loss} 250 | --------------------------------------------------------------------------------