├── requirements.txt ├── assets ├── 1.png ├── lstm_after.png ├── attention_1.png ├── lstm_before.png ├── graph_multi_attention.png └── graph_single_attention.png ├── .gitignore ├── attention_dense.py ├── attention_utils.py ├── attention_lstm.py ├── attention_lstm_todimensions.py ├── README.md └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | Keras==2.0.2 2 | matplotlib==2.0.0 3 | numpy==1.13.0 4 | pandas==0.18.1 5 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatientEz/keras-attention-mechanism/HEAD/assets/1.png -------------------------------------------------------------------------------- /assets/lstm_after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatientEz/keras-attention-mechanism/HEAD/assets/lstm_after.png -------------------------------------------------------------------------------- /assets/attention_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatientEz/keras-attention-mechanism/HEAD/assets/attention_1.png -------------------------------------------------------------------------------- /assets/lstm_before.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatientEz/keras-attention-mechanism/HEAD/assets/lstm_before.png -------------------------------------------------------------------------------- /assets/graph_multi_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatientEz/keras-attention-mechanism/HEAD/assets/graph_multi_attention.png -------------------------------------------------------------------------------- /assets/graph_single_attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PatientEz/keras-attention-mechanism/HEAD/assets/graph_single_attention.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | .idea/ 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /attention_dense.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from attention_utils import get_activations, get_data 4 | 5 | np.random.seed(1337) # for reproducibility 6 | from keras.models import * 7 | from keras.layers import Input, Dense, merge 8 | 9 | input_dim = 32 10 | 11 | 12 | def build_model(): 13 | inputs = Input(shape=(input_dim,)) 14 | 15 | # ATTENTION PART STARTS HERE 16 | attention_probs = Dense(input_dim, activation='softmax', name='attention_vec')(inputs) 17 | attention_mul = merge([inputs, attention_probs], output_shape=32, name='attention_mul', mode='mul') 18 | # ATTENTION PART FINISHES HERE 19 | 20 | attention_mul = Dense(64)(attention_mul) 21 | output = Dense(1, activation='sigmoid')(attention_mul) 22 | model = Model(input=[inputs], output=output) 23 | return model 24 | 25 | 26 | if __name__ == '__main__': 27 | N = 10000 28 | inputs_1, outputs = get_data(N, input_dim) 29 | 30 | m = build_model() 31 | m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) 32 | print(m.summary()) 33 | 34 | m.fit([inputs_1], outputs, epochs=20, batch_size=64, validation_split=0.5) 35 | 36 | testing_inputs_1, testing_outputs = get_data(1, input_dim) 37 | 38 | # Attention vector corresponds to the second matrix. 39 | # The first one is the Inputs output. 40 | attention_vector = get_activations(m, testing_inputs_1, 41 | print_shape_only=True, 42 | layer_name='attention_vec')[0].flatten() 43 | print('attention =', attention_vector) 44 | 45 | # plot part. 46 | import matplotlib.pyplot as plt 47 | import pandas as pd 48 | 49 | pd.DataFrame(attention_vector, columns=['attention (%)']).plot(kind='bar', 50 | title='Attention Mechanism as ' 51 | 'a function of input' 52 | ' dimensions.') 53 | plt.show() 54 | -------------------------------------------------------------------------------- /attention_utils.py: -------------------------------------------------------------------------------- 1 | import keras.backend as K 2 | import numpy as np 3 | 4 | 5 | def get_activations(model, inputs, print_shape_only=False, layer_name=None): 6 | # Documentation is available online on Github at the address below. 7 | # From: https://github.com/philipperemy/keras-visualize-activations 8 | print('----- activations -----') 9 | activations = [] 10 | inp = model.input 11 | if layer_name is None: 12 | outputs = [layer.output for layer in model.layers] 13 | else: 14 | outputs = [layer.output for layer in model.layers if layer.name == layer_name] # all layer outputs 15 | funcs = [K.function([inp] + [K.learning_phase()], [out]) for out in outputs] # evaluation functions 16 | layer_outputs = [func([inputs, 1.])[0] for func in funcs] 17 | for layer_activations in layer_outputs: 18 | activations.append(layer_activations) 19 | if print_shape_only: 20 | print(layer_activations.shape) 21 | else: 22 | print('shape为',layer_activations.shape) 23 | print(layer_activations) 24 | return activations 25 | 26 | 27 | def get_data(n, input_dim, attention_column=1): 28 | """ 29 | Data generation. x is purely random except that it's first value equals the target y. 30 | In practice, the network should learn that the target = x[attention_column]. 31 | Therefore, most of its attention should be focused on the value addressed by attention_column. 32 | :param n: the number of samples to retrieve. 33 | :param input_dim: the number of dimensions of each element in the series. 34 | :param attention_column: the column linked to the target. Everything else is purely random. 35 | :return: x: model inputs, y: model targets 36 | """ 37 | x = np.random.standard_normal(size=(n, input_dim)) 38 | y = np.random.randint(low=0, high=2, size=(n, 1)) 39 | x[:, attention_column] = y[:, 0] 40 | return x, y 41 | 42 | 43 | def get_data_recurrent(n, time_steps, input_dim, attention_column=10): 44 | """ 45 | Data generation. x is purely random except that it's first value equals the target y. 46 | In practice, the network should learn that the target = x[attention_column]. 47 | Therefore, most of its attention should be focused on the value addressed by attention_column. 48 | :param n: the number of samples to retrieve. 49 | :param time_steps: the number of time steps of your series. 50 | :param input_dim: the number of dimensions of each element in the series. 51 | :param attention_column: the column linked to the target. Everything else is purely random. 52 | :return: x: model inputs, y: model targets 53 | """ 54 | x = np.random.standard_normal(size=(n, time_steps, input_dim)) 55 | y = np.random.randint(low=0, high=2, size=(n, 1)) 56 | x[:, attention_column, :] = np.tile(y[:], (1, input_dim)) 57 | return x, y 58 | 59 | 60 | def get_data_recurrent2(n, time_steps, input_dim, attention_dim=5): 61 | """ 62 | 假设 input_dim = 10 time_steps = 6 63 | 产生一个 x 6 x 10 的数据 其中每步的第 6 维 与 y相同 64 | 65 | """ 66 | x = np.random.standard_normal(size=(n, time_steps, input_dim)) 67 | y = np.random.randint(low=0, high=2, size=(n, 1)) 68 | x[:,:,attention_dim] = np.tile(y[:], (1, time_steps)) 69 | 70 | 71 | return x,y 72 | 73 | print( get_data_recurrent2(1,6,10)[0]) 74 | 75 | print( get_data_recurrent2(1,6,10)[1]) 76 | -------------------------------------------------------------------------------- /attention_lstm.py: -------------------------------------------------------------------------------- 1 | from keras.layers import merge 2 | from keras.layers.core import * 3 | from keras.layers.recurrent import LSTM 4 | from keras.models import * 5 | 6 | from attention_utils import get_activations, get_data_recurrent 7 | 8 | INPUT_DIM = 2 9 | TIME_STEPS = 20 10 | # if True, the attention vector is shared across the input_dimensions where the attention is applied. 11 | SINGLE_ATTENTION_VECTOR = False 12 | APPLY_ATTENTION_BEFORE_LSTM = False 13 | 14 | 15 | def attention_3d_block(inputs): 16 | # inputs.shape = (batch_size, time_steps, input_dim) 17 | input_dim = int(inputs.shape[2]) 18 | a = Permute((2, 1))(inputs) 19 | a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what. 20 | a = Dense(TIME_STEPS, activation='softmax')(a) 21 | if SINGLE_ATTENTION_VECTOR: 22 | a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a) 23 | a = RepeatVector(input_dim)(a) 24 | a_probs = Permute((2, 1), name='attention_vec')(a) 25 | output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul') 26 | return output_attention_mul 27 | 28 | 29 | def model_attention_applied_after_lstm(): 30 | inputs = Input(shape=(TIME_STEPS, INPUT_DIM,)) 31 | lstm_units = 32 32 | lstm_out = LSTM(lstm_units, return_sequences=True)(inputs) 33 | attention_mul = attention_3d_block(lstm_out) 34 | attention_mul = Flatten()(attention_mul) 35 | output = Dense(1, activation='sigmoid')(attention_mul) 36 | model = Model(input=[inputs], output=output) 37 | return model 38 | 39 | 40 | def model_attention_applied_before_lstm(): 41 | inputs = Input(shape=(TIME_STEPS, INPUT_DIM,)) 42 | attention_mul = attention_3d_block(inputs) 43 | lstm_units = 32 44 | attention_mul = LSTM(lstm_units, return_sequences=False)(attention_mul) 45 | output = Dense(1, activation='sigmoid')(attention_mul) 46 | model = Model(input=[inputs], output=output) 47 | return model 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | N = 300000 53 | # N = 300 -> too few = no training 54 | inputs_1, outputs = get_data_recurrent(N, TIME_STEPS, INPUT_DIM) 55 | 56 | if APPLY_ATTENTION_BEFORE_LSTM: 57 | m = model_attention_applied_before_lstm() 58 | else: 59 | m = model_attention_applied_after_lstm() 60 | 61 | m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) 62 | print(m.summary()) 63 | 64 | m.fit([inputs_1], outputs, epochs=3, batch_size=64, validation_split=0.1) 65 | 66 | attention_vectors = [] 67 | for i in range(300): 68 | testing_inputs_1, testing_outputs = get_data_recurrent(1, TIME_STEPS, INPUT_DIM) 69 | attention_vector = np.mean(get_activations(m, 70 | testing_inputs_1, 71 | print_shape_only=True, 72 | layer_name='attention_vec')[0], axis=2).squeeze() 73 | print('attention =', attention_vector) 74 | assert (np.sum(attention_vector) - 1.0) < 1e-5 75 | attention_vectors.append(attention_vector) 76 | 77 | attention_vector_final = np.mean(np.array(attention_vectors), axis=0) 78 | # plot part. 79 | import matplotlib.pyplot as plt 80 | import pandas as pd 81 | 82 | pd.DataFrame(attention_vector_final, columns=['attention (%)']).plot(kind='bar', 83 | title='Attention Mechanism as ' 84 | 'a function of input' 85 | ' dimensions.') 86 | plt.show() 87 | -------------------------------------------------------------------------------- /attention_lstm_todimensions.py: -------------------------------------------------------------------------------- 1 | from keras.layers import merge 2 | from keras.layers.core import * 3 | from keras.layers.recurrent import LSTM 4 | from keras.models import * 5 | 6 | from attention_utils import get_activations, get_data_recurrent , get_data_recurrent2 7 | 8 | INPUT_DIM = 10 9 | TIME_STEPS = 6 10 | # if True, the attention vector is shared across the input_dimensions where the attention is applied. 11 | SINGLE_ATTENTION_VECTOR = False 12 | APPLY_ATTENTION_BEFORE_LSTM = True 13 | 14 | 15 | def attention_3d_block(inputs): 16 | # inputs.shape = (batch_size, time_steps, input_dim) 17 | input_dim = int(inputs.shape[2]) 18 | a = inputs 19 | #a = Permute((2, 1))(inputs) 20 | #a = Reshape((input_dim, TIME_STEPS))(a) # this line is not useful. It's just to know which dimension is what. 21 | a = Dense(input_dim, activation='softmax')(a) 22 | if SINGLE_ATTENTION_VECTOR: 23 | a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a) 24 | a = RepeatVector(input_dim)(a) 25 | a_probs = Permute((1, 2), name='attention_vec')(a) 26 | #a_probs = a 27 | output_attention_mul = merge([inputs, a_probs], name='attention_mul', mode='mul') 28 | return output_attention_mul 29 | 30 | 31 | def model_attention_applied_after_lstm(): 32 | inputs = Input(shape=(TIME_STEPS, INPUT_DIM,)) 33 | lstm_units = 32 34 | lstm_out = LSTM(lstm_units, return_sequences=True)(inputs) 35 | attention_mul = attention_3d_block(lstm_out) 36 | attention_mul = Flatten()(attention_mul) 37 | output = Dense(1, activation='sigmoid')(attention_mul) 38 | model = Model(input=[inputs], output=output) 39 | return model 40 | 41 | 42 | def model_attention_applied_before_lstm(): 43 | inputs = Input(shape=(TIME_STEPS, INPUT_DIM,)) 44 | attention_mul = attention_3d_block(inputs) 45 | lstm_units = 32 46 | attention_mul = LSTM(lstm_units, return_sequences=False)(attention_mul) 47 | output = Dense(1, activation='sigmoid')(attention_mul) 48 | model = Model(input=[inputs], output=output) 49 | return model 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | N = 300000 55 | # N = 300 -> too few = no training 56 | inputs_1, outputs = get_data_recurrent2(N, TIME_STEPS, INPUT_DIM) 57 | 58 | if APPLY_ATTENTION_BEFORE_LSTM: 59 | m = model_attention_applied_before_lstm() 60 | else: 61 | m = model_attention_applied_after_lstm() 62 | 63 | m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy']) 64 | print(m.summary()) 65 | 66 | m.fit([inputs_1], outputs, epochs=10, batch_size=64, validation_split=0.1) 67 | 68 | attention_vectors = [] 69 | for i in range(300): 70 | testing_inputs_1, testing_outputs = get_data_recurrent2(1, TIME_STEPS, INPUT_DIM) 71 | attention_vector = np.mean(get_activations(m, 72 | testing_inputs_1, 73 | print_shape_only=False, 74 | layer_name='attention_vec')[0], axis=2).squeeze() 75 | print('attention =', attention_vector) 76 | assert (np.sum(attention_vector) - 1.0) < 1e-5 77 | attention_vectors.append(attention_vector) 78 | 79 | attention_vector_final = np.mean(np.array(attention_vectors), axis=0) 80 | # plot part. 81 | import matplotlib.pyplot as plt 82 | import pandas as pd 83 | 84 | pd.DataFrame(attention_vector_final, columns=['attention (%)']).plot(kind='bar', 85 | title='Attention Mechanism as ' 86 | 'a function of input' 87 | ' timesteps.') 88 | plt.show() 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras Attention Mechanism 2 | 3 | Simple attention mechanism implemented in Keras for the following layers: 4 | 5 | - [x] **Dense (attention 2D block)** 6 | - [x] **LSTM, GRU (attention 3D block)** 7 | 8 |
9 |
10 |
Example: Attention block
11 |
31 | Attention Mechanism explained
32 |
33 |
41 | Attention Mechanism explained
42 |
43 |
77 | Attention vector applied on the inputs (before)
78 |
79 |
86 | Attention vector applied on the output of the LSTM layer (after)
87 |
88 |
99 | Attention defined per time series (each TS has its own attention)
100 |
101 |
106 | Attention shared across all the time series
107 |
108 |