├── .gitignore
├── README.md
├── arguments.py
├── config.py
├── data
├── __init__.py
└── load.py
├── layers
├── bert.py
├── normalization.py
└── transformer.py
├── models
├── __init__.py
├── abstractive_summarizer.py
└── transformer.py
├── ops
├── __init__.py
├── attention.py
├── beam_search.py
├── data.py
├── encoding.py
├── masking.py
├── metrics.py
├── optimization.py
├── regularization.py
├── session.py
├── tensor.py
└── tokenization.py
├── requirements.txt
├── train.py
└── utils
├── __init__.py
├── decorators.py
└── recipes.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 |
62 | # Flask stuff:
63 | instance/
64 | .webassets-cache
65 |
66 | # Scrapy stuff:
67 | .scrapy
68 |
69 | # Sphinx documentation
70 | docs/_build/
71 |
72 | # PyBuilder
73 | target/
74 |
75 | # Jupyter Notebook
76 | .ipynb_checkpoints
77 |
78 | # IPython
79 | profile_default/
80 | ipython_config.py
81 |
82 | # pyenv
83 | .python-version
84 |
85 | # pipenv
86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
89 | # install all needed dependencies.
90 | #Pipfile.lock
91 |
92 | # celery beat schedule file
93 | celerybeat-schedule
94 |
95 | # SageMath parsed files
96 | *.sage.py
97 |
98 | # Environments
99 | .env
100 | .venv
101 | env/
102 | venv/
103 | ENV/
104 | env.bak/
105 | venv.bak/
106 |
107 | # Spyder project settings
108 | .spyderproject
109 | .spyproject
110 |
111 | # Rope project settings
112 | .ropeproject
113 |
114 | # mkdocs documentation
115 | /site
116 |
117 | # mypy
118 | .mypy_cache/
119 | .dmypy.json
120 | dmypy.json
121 |
122 | # Pyre type checker
123 | .pyre/
124 |
125 | log/
126 | checkpoint/
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Pretraining-Based Natural Language Generation for Text Summarization
2 |
3 | Implementation of a abstractive text-summarization architecture, as proposed by this [paper](https://arxiv.org/pdf/1902.09243.pdf).
4 |
5 | The solution makes use of an pre-trained language model to get contextualized representations of words; these models were training on a huge corpus of unlabelled data, e.g. BERT.
6 |
7 | This extends the sphere of possible applications where labelled data is scarce; since the model learns to summarize given the contextualized BERT representations, we have a good chance of generalization for other domains, even if training was done using the CNN/DM dataset.
8 |
9 | `tensorflowhub` to load the BERT module.
10 |
11 |
12 | #### Environment
13 |
14 | Python: `3.6`
15 |
16 | Tensorflow version: `1.13.1`
17 |
18 | `requirements.txt` exposes the library dependencies
19 |
20 | #### Run
21 |
22 | To run the training job:
23 |
24 | ```
25 | python train.py
26 | ```
27 |
28 | The arguments can be changed in `config.py`.
29 |
30 | By default, tensorboard objects are written to `log`; training checkpoints are written at the end of each epoch to `checkpoints`.
31 |
32 | ```
33 | tensorboard --logdir log/
34 | ```
35 |
36 | Where the train and evaluation metrics can be tracked;
37 |
38 | Note that `train.py` does not run the validation graph (validation set) in parallel with the training job.
39 |
40 | To run the validation job in parallel with the training:
41 |
42 | ```
43 | python train.py --eval
44 | ```
45 |
46 | takes a long time to build the inference graph and start the training; at the end of each epoch, it uses the model inference mode (autoregressive) to calculate the loss and ROUGE metrics; it also shows writes a random article/summary prediction.
47 |
48 | #### Resource Limitations
49 |
50 | Due to GPU limits, using an AWS sagemaker notebook with a `Tesla V100 (16 GB RAM)`, the batch size must be set to 2, otherwise a OOM error will be thrown.
51 | Furthermore, the target sequence length has a huge impact in the memory resources needed during training. 75 is the value used by default.
52 |
53 | To compensate this resource constraint we use gradient accumulation with `N` steps, i.e. we run the foward `N` steps with a batch size of 2 and accumulate the gradient; after the `N` steps we run the backward pass, updating the weights.
54 |
55 | This gives us an effective `2N` batch size; `N` (default=12) can be controlled by changing `GRADIENT_ACCUMULATION_N_STEPS` in `config.py`.
56 |
57 | #### Notes to the reader
58 |
59 | The core of the solution is implemented; there are however some missing pieces.
60 |
61 | Implemented:
62 |
63 | * Encoder (BERT)
64 | * Draft Decoder
65 | * Refined Decoder
66 | * Autoregressive evaluation (greedy)
67 | * Gradient accumulation
68 |
69 | Missing:
70 |
71 | * RL loss component
72 | * Beam-search mechanism for the draft-decoder
73 | * Copy mechanism
74 |
75 | Any help to implement these is appreciated.
76 |
77 | #### Configuration
78 |
79 | | **Parameter** | **Default** | **Description** |
80 | |-------------------------------|-------------|------------------------------------------------------------------|
81 | | NUM_EPOCHS | 4 | Number of epochs to train |
82 | | BATCH_SIZE | 2 | Batch size for each training step |
83 | | GRADIENT_ACCUMULATION_N_STEPS | 12 | Number of gradient accumulate steps before applying the gradient |
84 | | BUFFER_SIZE | 1000 | Buffer size for the tf.Dataset pipeline |
85 | | INITIAL_LR | 0.003 | Initial learning rate value |
86 | | WARMUP_STEPS | 4000 | |
87 | | INPUT_SEQ_LEN | 512 | Article length to truncate |
88 | | OUTPUT_SEQ_LEN | 75 | Summary length to truncate |
89 | | MAX_EXAMPLE_LEN | None | Threshold to filter examples (articles) |
90 | | VOCAB_SIZE | 30522 | Length of the vocabulary |
91 | | NUM_LAYERS | 8 | Number of layers of the Transformer Decoder |
92 | | D_MODEL | 768 | Base embedding dimensionality (as BERT) |
93 | | D_FF | 2048 | Transformer Feed Forward Layer |
94 | | NUM_HEADS | 8 | Number of heads of the transformer |
95 | | DROPOUT_RATE | 0.1 | Dropout rate to use in training |
96 | | LOGDIR | log | Location to write tensorboard objects |
97 | | CHECKPOINTDIR | checkpoint | Location to write model checkpoints |
98 |
99 |
100 | #### Data
101 |
102 | To train the model we use the [CNN/DM dataset](https://www.tensorflow.org/datasets/datasets#cnn_dailymail), directly from Tensorflow Datasets.
103 | The first time it runs, it will push the dataset from the google source (~ 500 MB).
104 |
105 | The details on how the data is pushed and prepared can be found at `data/load.py`
106 |
107 |
108 | ##### Debug
109 |
110 | Track GPU memory usage with:
111 |
112 | ```
113 | watch -n 2 nvidia-smi
114 | ```
115 |
116 | System RAM usage with:
117 |
118 | ```
119 | watch -n 2 cat /proc/meminfo
120 | ```
121 |
122 | `report_tensor_allocations_upon_oom` is set to `True` so that we can see which variables
123 | are filling up the memory.
124 |
125 | ```
126 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)
127 | ...
128 | sess.run(..., options=run_options)
129 | ```
130 |
131 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | parser = argparse.ArgumentParser()
5 |
6 | parser.add_argument("--eval", action="store_true")
7 |
8 | args = parser.parse_args()
9 |
10 |
11 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from bunch import Bunch
2 |
3 |
4 | config = {
5 | 'NUM_EPOCHS': 4,
6 | "BATCH_SIZE": 2,
7 | "GRADIENT_ACCUMULATION_N_STEPS": 18,
8 | "BUFFER_SIZE": 1,
9 | "INITIAL_LR": 0.0003,
10 | "WARMUP_STEPS": 4000,
11 | "INPUT_SEQ_LEN": 512,
12 | "OUTPUT_SEQ_LEN": 72,
13 | "MAX_EXAMPLE_LEN": None,
14 | "VOCAB_SIZE": 30522,
15 | "NUM_LAYERS": 8,
16 | "D_MODEL": 768,
17 | "D_FF": 2048,
18 | "NUM_HEADS": 8,
19 | "DROPOUT_RATE": 0.1,
20 | "LOGDIR": 'log',
21 | "CHECKPOINTDIR": 'checkpoint2'
22 | }
23 |
24 | config = Bunch(config)
25 |
26 |
27 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/data/__init__.py
--------------------------------------------------------------------------------
/data/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import tensorflow as tf
5 | import tensorflow_hub as hub
6 | import tensorflow_datasets as tfds
7 |
8 | from functools import partial
9 | from tensorflow.keras import backend as K
10 | from ops.tokenization import tokenizer
11 |
12 | from config import config
13 |
14 |
15 | # Special Tokens
16 | UNK_ID = 100
17 | CLS_ID = 101
18 | SEP_ID = 102
19 | MASK_ID = 103
20 |
21 |
22 | def pad(l, n, pad=0):
23 | """
24 | Pad the list 'l' to have size 'n' using 'padding_element'
25 | """
26 | pad_with = (0, max(0, n - len(l)))
27 | return np.pad(l, pad_with, mode='constant', constant_values=pad)
28 |
29 |
30 | def encode(sent_1, sent_2, tokenizer, input_seq_len, output_seq_len):
31 | """
32 | Encode the text to the BERT expected format
33 |
34 | 'input_seq_len' is used to truncate the the article length
35 | 'output_seq_len' is used to truncate the the summary length
36 |
37 | BERT has the following special tokens:
38 |
39 | [CLS] : The first token of every sequence. A classification token
40 | which is normally used in conjunction with a softmax layer for classification
41 | tasks. For anything else, it can be safely ignored.
42 |
43 | [SEP] : A sequence delimiter token which was used at pre-training for
44 | sequence-pair tasks (i.e. Next sentence prediction). Must be used when
45 | sequence pair tasks are required. When a single sequence is used it is just appended at the end.
46 |
47 | [MASK] : Token used for masked words. Only used for pre-training.
48 |
49 | Additionally BERT requires additional inputs to work correctly:
50 | - Mask IDs
51 | - Segment IDs
52 |
53 | The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
54 | Sentence Embeddings is just a numeric class to distinguish between pairs of sentences.
55 | """
56 | tokens_1 = tokenizer.tokenize(sent_1.numpy())
57 | tokens_2 = tokenizer.tokenize(sent_2.numpy())
58 |
59 | # Account for [CLS] and [SEP] with "- 2"
60 | if len(tokens_1) > input_seq_len - 2:
61 | tokens_1 = tokens_1[0:(input_seq_len - 2)]
62 | if len(tokens_2) > (output_seq_len + 1) - 2:
63 | tokens_2 = tokens_2[0:((output_seq_len + 1) - 2)]
64 |
65 | tokens_1 = ["[CLS]"] + tokens_1 + ["[SEP]"]
66 | tokens_2 = ["[CLS]"] + tokens_2 + ["[SEP]"]
67 |
68 | input_ids_1 = tokenizer.convert_tokens_to_ids(tokens_1)
69 | input_ids_2 = tokenizer.convert_tokens_to_ids(tokens_2)
70 |
71 | input_mask_1 = [1] * len(input_ids_1)
72 | input_mask_2 = [1] * len(input_ids_2)
73 |
74 | input_ids_1 = pad(input_ids_1, input_seq_len, 0)
75 | input_ids_2 = pad(input_ids_2, output_seq_len + 1, 0)
76 | input_mask_1 = pad(input_mask_1, input_seq_len, 0)
77 | input_mask_2 = pad(input_mask_2, output_seq_len + 1, 0)
78 |
79 | input_type_ids_1 = [0] * len(input_ids_1)
80 | input_type_ids_2 = [0] * len(input_ids_2)
81 |
82 | return input_ids_1, input_mask_1, input_type_ids_1, input_ids_2, input_mask_2, input_type_ids_2
83 |
84 |
85 | def tf_encode(tokenizer, input_seq_len, output_seq_len):
86 | """
87 | Operations inside `.map()` run in graph mode and receive a graph
88 | tensor that do not have a `numpy` attribute.
89 | The tokenizer expects a string or Unicode symbol to encode it into integers.
90 | Hence, you need to run the encoding inside a `tf.py_function`,
91 | which receives an eager tensor having a numpy attribute that contains the string value.
92 | """
93 | def f(s1, s2):
94 | encode_ = partial(encode, tokenizer=tokenizer, input_seq_len=input_seq_len, output_seq_len=output_seq_len)
95 | return tf.py_function(encode_, [s1, s2], [tf.int32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32])
96 |
97 | return f
98 |
99 |
100 | def filter_max_length(x, x1, x2, y, y1, y2, max_length=config.MAX_EXAMPLE_LEN):
101 | predicate = tf.logical_and(
102 | tf.size(x[0]) <= max_length,
103 | tf.size(y[0]) <= max_length
104 | )
105 | return predicate
106 |
107 |
108 | def pipeline(examples, tokenizer, cache=False):
109 | """
110 | Prepare a Dataset to return the following elements
111 | x_ids, x_mask, x_segments, y_ids, y_maks, y_segments
112 | """
113 |
114 | dataset = examples.map(tf_encode(tokenizer, config.INPUT_SEQ_LEN, config.OUTPUT_SEQ_LEN), num_parallel_calls=tf.data.experimental.AUTOTUNE)
115 |
116 | if config.MAX_EXAMPLE_LEN is not None:
117 | dataset = dataset.filter(filter_max_length)
118 |
119 | if cache:
120 | dataset = dataset.cache()
121 |
122 | dataset = dataset.shuffle(config.BUFFER_SIZE).padded_batch(config.BATCH_SIZE, padded_shapes=([-1], [-1], [-1], [-1], [-1], [-1]))
123 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
124 |
125 | return dataset
126 |
127 |
128 | def load_cnn_dailymail(tokenizer=tokenizer):
129 | """
130 | Load the CNN/DM data from tensorflow Datasets
131 | """
132 | examples, metadata = tfds.load('cnn_dailymail', with_info=True, as_supervised=True)
133 | train, val, test = examples['train'], examples['validation'], examples['test']
134 | train_dataset = pipeline(train, tokenizer)
135 | val_dataset = pipeline(val, tokenizer)
136 | test_dataset = pipeline(test, tokenizer)
137 |
138 | # grab information regarding the number of examples
139 | metadata = json.loads(metadata.as_json)
140 |
141 | n_test_examples = int(metadata['splits'][0]['statistics']['numExamples'])
142 | n_train_examples = int(metadata['splits'][1]['statistics']['numExamples'])
143 | n_val_examples = int(metadata['splits'][2]['statistics']['numExamples'])
144 |
145 | return train_dataset, val_dataset, test_dataset, n_train_examples, n_val_examples, n_test_examples
146 |
147 |
--------------------------------------------------------------------------------
/layers/bert.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_hub as hub
3 |
4 | from tensorflow.keras import backend as K
5 |
6 |
7 | BERT_MODEL_URL = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
8 |
9 |
10 |
11 | class BertLayer(tf.keras.layers.Layer):
12 | """
13 | Custom Keras layer, integrating BERT from tf-hub
14 | """
15 | def __init__(self, url=BERT_MODEL_URL, d_embedding=768, n_fine_tune_layers=0, **kwargs):
16 | self.url = url
17 | self.n_fine_tune_layers = n_fine_tune_layers
18 | self.d_embedding = d_embedding
19 |
20 | super(BertLayer, self).__init__(**kwargs)
21 |
22 | def build(self, input_shape):
23 |
24 | self.bert = hub.Module(
25 | self.url,
26 | trainable=False,
27 | name="{}_bert_module".format(self.name)
28 | )
29 |
30 | trainable_vars = self.bert.variables
31 |
32 | # Remove unused layers
33 | trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name]
34 |
35 | # Select how many layers to fine tune
36 | trainable_vars = []
37 |
38 | # Add to trainable weights
39 | for var in trainable_vars:
40 | self._trainable_weights.append(var)
41 |
42 | for var in self.bert.variables:
43 | if var not in self._trainable_weights:
44 | self._non_trainable_weights.append(var)
45 |
46 | super(BertLayer, self).build(input_shape)
47 |
48 | def call(self, inputs):
49 | inputs = [K.cast(x, dtype="int32") for x in inputs]
50 |
51 | input_ids, input_mask, segment_ids = inputs
52 |
53 | bert_inputs = dict(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids)
54 | result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)["sequence_output"]
55 | return result
56 |
57 | def compute_output_shape(self, input_shape):
58 | return (input_shape[0], input_shape[1], self.d_embedding)
59 |
60 |
--------------------------------------------------------------------------------
/layers/normalization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import tensorflow as tf
5 | import tensorflow_hub as hub
6 |
7 | from tensorflow.keras import backend as K
8 |
9 | from tensorflow.python.keras import constraints
10 | from tensorflow.python.keras import initializers
11 | from tensorflow.python.keras import regularizers
12 | from tensorflow.python.ops import nn
13 |
14 |
15 | class LayerNormalization(tf.keras.layers.Layer):
16 | """
17 | Layer normalization layer (Ba et al., 2016).
18 | Normalize the activations of the previous layer for each given example in a
19 | batch independently, rather than across a batch like Batch Normalization.
20 | i.e. applies a transformation that maintains the mean activation within each
21 | example close to 0 and the activation standard deviation close to 1.
22 | Arguments:
23 | axis: Integer or List/Tuple. The axis that should be normalized
24 | (typically the features axis).
25 | epsilon: Small float added to variance to avoid dividing by zero.
26 | center: If True, add offset of `beta` to normalized tensor.
27 | If False, `beta` is ignored.
28 | scale: If True, multiply by `gamma`.
29 | If False, `gamma` is not used.
30 | When the next layer is linear (also e.g. `nn.relu`),
31 | this can be disabled since the scaling
32 | will be done by the next layer.
33 | beta_initializer: Initializer for the beta weight.
34 | gamma_initializer: Initializer for the gamma weight.
35 | beta_regularizer: Optional regularizer for the beta weight.
36 | gamma_regularizer: Optional regularizer for the gamma weight.
37 | beta_constraint: Optional constraint for the beta weight.
38 | gamma_constraint: Optional constraint for the gamma weight.
39 | trainable: Boolean, if `True` the variables will be marked as trainable.
40 | Input shape:
41 | Arbitrary. Use the keyword argument `input_shape`
42 | (tuple of integers, does not include the samples axis)
43 | when using this layer as the first layer in a model.
44 | Output shape:
45 | Same shape as input.
46 | References:
47 | - [Layer Normalization](https://arxiv.org/abs/1607.06450)
48 | """
49 |
50 | def __init__(self,
51 | axis=-1,
52 | epsilon=1e-3,
53 | center=True,
54 | scale=True,
55 | beta_initializer='zeros',
56 | gamma_initializer='ones',
57 | beta_regularizer=None,
58 | gamma_regularizer=None,
59 | beta_constraint=None,
60 | gamma_constraint=None,
61 | trainable=True,
62 | name=None,
63 | **kwargs):
64 | super(LayerNormalization, self).__init__(
65 | name=name, trainable=trainable, **kwargs)
66 | if isinstance(axis, (list, tuple)):
67 | self.axis = axis[:]
68 | elif isinstance(axis, int):
69 | self.axis = axis
70 | else:
71 | raise ValueError('Expected an int or a list/tuple of ints for the argument \'axis\', but received instead: %s' % axis)
72 |
73 | self.epsilon = epsilon
74 | self.center = center
75 | self.scale = scale
76 | self.beta_initializer = initializers.get(beta_initializer)
77 | self.gamma_initializer = initializers.get(gamma_initializer)
78 | self.beta_regularizer = regularizers.get(beta_regularizer)
79 | self.gamma_regularizer = regularizers.get(gamma_regularizer)
80 | self.beta_constraint = constraints.get(beta_constraint)
81 | self.gamma_constraint = constraints.get(gamma_constraint)
82 |
83 | self.supports_masking = True
84 |
85 | def build(self, input_shape):
86 | ndims = len(input_shape)
87 | if ndims is None:
88 | raise ValueError('Input shape %s has undefined rank.' % input_shape)
89 |
90 | # Convert axis to list and resolve negatives
91 | if isinstance(self.axis, int):
92 | self.axis = [self.axis]
93 | for idx, x in enumerate(self.axis):
94 | if x < 0:
95 | self.axis[idx] = ndims + x
96 |
97 | # Validate axes
98 | for x in self.axis:
99 | if x < 0 or x >= ndims:
100 | raise ValueError('Invalid axis: %d' % x)
101 | if len(self.axis) != len(set(self.axis)):
102 | raise ValueError('Duplicate axis: {}'.format(tuple(self.axis)))
103 |
104 | param_shape = [input_shape[dim] for dim in self.axis]
105 | if self.scale:
106 | self.gamma = self.add_weight(
107 | name='gamma',
108 | shape=param_shape,
109 | initializer=self.gamma_initializer,
110 | regularizer=self.gamma_regularizer,
111 | constraint=self.gamma_constraint,
112 | trainable=True)
113 | else:
114 | self.gamma = None
115 |
116 | if self.center:
117 | self.beta = self.add_weight(
118 | name='beta',
119 | shape=param_shape,
120 | initializer=self.beta_initializer,
121 | regularizer=self.beta_regularizer,
122 | constraint=self.beta_constraint,
123 | trainable=True)
124 | else:
125 | self.beta = None
126 |
127 | def call(self, inputs):
128 | # Compute the axes along which to reduce the mean / variance
129 | input_shape = inputs.shape
130 | ndims = len(input_shape)
131 |
132 | # Calculate the moments on the last axis (layer activations).
133 | mean, variance = nn.moments(inputs, self.axis, keep_dims=True)
134 |
135 | # Broadcasting only necessary for norm where the axis is not just
136 | # the last dimension
137 | broadcast_shape = [1] * ndims
138 | for dim in self.axis:
139 | broadcast_shape[dim] = input_shape.dims[dim].value
140 | def _broadcast(v):
141 | if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]):
142 | return array_ops.reshape(v, broadcast_shape)
143 | return v
144 | scale, offset = _broadcast(self.gamma), _broadcast(self.beta)
145 |
146 | # Compute layer normalization using the batch_normalization function.
147 | outputs = nn.batch_normalization(
148 | inputs,
149 | mean,
150 | variance,
151 | offset=offset,
152 | scale=scale,
153 | variance_epsilon=self.epsilon)
154 |
155 | # If some components of the shape got lost due to adjustments, fix that.
156 | outputs.set_shape(input_shape)
157 |
158 | return outputs
159 |
160 | def compute_output_shape(self, input_shape):
161 | return input_shape
162 |
163 | def get_config(self):
164 | config = {
165 | 'axis': self.axis,
166 | 'epsilon': self.epsilon,
167 | 'center': self.center,
168 | 'scale': self.scale,
169 | 'beta_initializer': initializers.serialize(self.beta_initializer),
170 | 'gamma_initializer': initializers.serialize(self.gamma_initializer),
171 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer),
172 | 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
173 | 'beta_constraint': constraints.serialize(self.beta_constraint),
174 | 'gamma_constraint': constraints.serialize(self.gamma_constraint)
175 | }
176 | base_config = super(LayerNormalization, self).get_config()
177 | return dict(list(base_config.items()) + list(config.items()))
178 |
179 |
--------------------------------------------------------------------------------
/layers/transformer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import tensorflow as tf
5 | import tensorflow_hub as hub
6 |
7 | from tensorflow.keras import backend as K
8 | from tensorflow.python.keras.initializers import Constant
9 |
10 | from layers.normalization import LayerNormalization
11 | from ops.encoding import positional_encoding
12 | from ops.attention import scaled_dot_product_attention
13 |
14 |
15 | class MultiHeadAttention(tf.keras.layers.Layer):
16 | """
17 | Multi-head attention consists of four parts:
18 | * Linear layers and split into heads.
19 | * Scaled dot-product attention.
20 | * Concatenation of heads.
21 | * Final linear layer.
22 |
23 | Each multi-head attention block gets three inputs;
24 | Q (query), K (key), V (value).
25 | These are put through linear (Dense) layers and split up into multiple heads.
26 |
27 | Instead of one single attention head, Q, K, and V
28 | are split into multiple heads because it allows the
29 | model to jointly attend to information at different
30 | positions from different representational spaces.
31 |
32 | After the split each head has a reduced dimensionality,
33 | so the total computation cost is the same as a single head
34 | attention with full dimensionality.
35 | """
36 | def __init__(self, d_model, num_heads):
37 | super(MultiHeadAttention, self).__init__()
38 | self.num_heads = num_heads
39 | self.d_model = d_model
40 |
41 | assert d_model % self.num_heads == 0
42 |
43 | self.depth = d_model // self.num_heads
44 |
45 | def build(self, input_shape):
46 |
47 | self.wq = tf.keras.layers.Dense(self.d_model)
48 | self.wk = tf.keras.layers.Dense(self.d_model)
49 | self.wv = tf.keras.layers.Dense(self.d_model)
50 |
51 | self.dense = tf.keras.layers.Dense(self.d_model)
52 |
53 | for i in [self.wq, self.wk, self.wv, self.dense]:
54 | i.build(input_shape)
55 | for weight in i.trainable_weights:
56 | if weight not in self._trainable_weights:
57 | self._trainable_weights.append(weight)
58 | for weight in i.non_trainable_weights:
59 | if weight not in self._non_trainable_weights:
60 | self._non_trainable_weights.append(weight)
61 |
62 | super(MultiHeadAttention, self).build(input_shape)
63 |
64 | def split_heads(self, x, batch_size):
65 | """Split the last dimension into (num_heads, depth).
66 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
67 | """
68 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
69 | return tf.transpose(x, perm=[0, 2, 1, 3])
70 |
71 | def call(self, v, k, q, mask):
72 | batch_size = tf.shape(q)[0]
73 |
74 | q = self.wq(q) # (batch_size, seq_len, d_model)
75 | k = self.wk(k) # (batch_size, seq_len, d_model)
76 | v = self.wv(v) # (batch_size, seq_len, d_model)
77 |
78 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth)
79 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth)
80 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth)
81 |
82 | # scaled_attention.shape == (batch_size, num_heads, seq_len_v, depth)
83 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
84 | scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
85 |
86 | scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_v, num_heads, depth)
87 |
88 | concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_v, d_model)
89 |
90 | output = self.dense(concat_attention) # (batch_size, seq_len_v, d_model)
91 |
92 | return output, attention_weights
93 |
94 |
95 | def point_wise_feed_forward_network(d_model, dff):
96 | """
97 | Point wise feed forward network consists of two
98 | fully-connected layers with a ReLU activation in between.
99 | """
100 | return tf.keras.Sequential([
101 | tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff)
102 | tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model)
103 | ])
104 |
105 |
106 |
107 | class EncoderLayer(tf.keras.layers.Layer):
108 | """
109 | Each encoder layer consists of sublayers:
110 |
111 | * Multi-head attention (with padding mask)
112 | * Point wise feed forward networks.
113 |
114 | Each of these sublayers has a residual connection around it followed by a layer normalization.
115 | Residual connections help in avoiding the vanishing gradient problem in deep networks.
116 |
117 | The output of each sublayer is `LayerNorm(x + Sublayer(x))`.
118 | The normalization is done on the d_model (last) axis. There are N encoder layers in the transformer.
119 | """
120 | def __init__(self, d_model, num_heads, dff, rate=0.1):
121 | self.d_model = d_model
122 | self.num_heads = num_heads
123 | self.dff = dff
124 | self.rate = rate
125 | super(EncoderLayer, self).__init__()
126 |
127 | def build(self, input_shape):
128 |
129 | self.mha = MultiHeadAttention(self.d_model, self.num_heads)
130 | self.ffn = point_wise_feed_forward_network(self.d_model, self.dff)
131 |
132 | self.layernorm1 = LayerNormalization(epsilon=1e-6)
133 | self.layernorm2 = LayerNormalization(epsilon=1e-6)
134 |
135 | self.dropout1 = tf.keras.layers.Dropout(self.rate)
136 | self.dropout2 = tf.keras.layers.Dropout(self.rate)
137 |
138 | for i in [self.mha, self.ffn, self.layernorm1, self.layernorm2, self.dropout1, self.dropout2]:
139 | i.build(input_shape)
140 | for weight in i.trainable_weights:
141 | if weight not in self._trainable_weights:
142 | self._trainable_weights.append(weight)
143 | for weight in i.non_trainable_weights:
144 | if weight not in self._non_trainable_weights:
145 | self._non_trainable_weights.append(weight)
146 |
147 | super(EncoderLayer, self).build(input_shape)
148 |
149 | def call(self, x, training, mask):
150 |
151 | attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model)
152 | attn_output = self.dropout1(attn_output, training=training)
153 | out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model)
154 |
155 | ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model)
156 | ffn_output = self.dropout2(ffn_output, training=training)
157 | out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model)
158 |
159 | return out2
160 |
161 |
162 | class DecoderLayer(tf.keras.layers.Layer):
163 | """
164 | Each decoder layer consists of sublayers:
165 |
166 | * Masked multi-head attention (with look ahead mask and padding mask)
167 | * Multi-head attention (with padding mask). V (value) and K (key)
168 | receive the encoder output as inputs. Q (query) receives the output from the masked multi-head attention sublayer.
169 | * Point wise feed forward networks
170 |
171 | Each of these sublayers has a residual connection around it followed by a layer normalization.
172 | The output of each sublayer is LayerNorm(x + Sublayer(x)). The normalization is done on the d_model (last) axis.
173 |
174 | As Q receives the output from decoder's first attention block,
175 | and K receives the encoder output, the attention weights represent
176 | the importance given to the decoder's input based on the encoder's output.
177 | In other words, the decoder predicts the next word by looking at the encoder
178 | output and self-attending to its own output.
179 | """
180 | def __init__(self, d_model, num_heads, dff, rate=0.1):
181 | self.d_model = d_model
182 | self.num_heads = num_heads
183 | self.dff = dff
184 | self.rate = rate
185 | super(DecoderLayer, self).__init__()
186 |
187 | def build(self, input_shape):
188 |
189 | self.mha1 = MultiHeadAttention(self.d_model, self.num_heads)
190 | self.mha2 = MultiHeadAttention(self.d_model, self.num_heads)
191 |
192 | self.ffn = point_wise_feed_forward_network(self.d_model, self.dff)
193 |
194 | self.layernorm1 = LayerNormalization(epsilon=1e-6)
195 | self.layernorm2 = LayerNormalization(epsilon=1e-6)
196 | self.layernorm3 = LayerNormalization(epsilon=1e-6)
197 |
198 | self.dropout1 = tf.keras.layers.Dropout(self.rate)
199 | self.dropout2 = tf.keras.layers.Dropout(self.rate)
200 | self.dropout3 = tf.keras.layers.Dropout(self.rate)
201 |
202 | for i in [self.mha1, self.mha2, self.ffn, self.layernorm1, self.layernorm2, self.layernorm3, self.dropout1, self.dropout2, self.dropout3]:
203 | i.build(input_shape)
204 | for weight in i.trainable_weights:
205 | if weight not in self._trainable_weights:
206 | self._trainable_weights.append(weight)
207 | for weight in i.non_trainable_weights:
208 | if weight not in self._non_trainable_weights:
209 | self._non_trainable_weights.append(weight)
210 |
211 | super(DecoderLayer, self).build(input_shape)
212 |
213 |
214 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
215 | # enc_output.shape == (batch_size, input_seq_len, d_model)
216 |
217 | attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model)
218 | attn1 = self.dropout1(attn1, training=training)
219 | out1 = self.layernorm1(attn1 + x)
220 |
221 | attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model)
222 | attn2 = self.dropout2(attn2, training=training)
223 | out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model)
224 |
225 | ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model)
226 | ffn_output = self.dropout3(ffn_output, training=training)
227 | out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model)
228 |
229 | return out3, attn_weights_block1, attn_weights_block2
230 |
231 |
232 | class Encoder(tf.keras.layers.Layer):
233 | """
234 | The Encoder consists of:
235 |
236 | 1. Input Embedding
237 | 2. Positional Encoding
238 | 3. N encoder layers
239 |
240 | The input is put through an embedding which is summed with the positional encoding.
241 | The output of this summation is the input to the encoder layers.
242 | The output of the encoder is the input to the decoder.
243 | """
244 | def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, rate=0.1):
245 | super(Encoder, self).__init__()
246 |
247 | self.d_model = d_model
248 | self.num_layers = num_layers
249 | self.num_heads = num_heads
250 | self.dff = dff
251 | self.input_vocab_size = input_vocab_size
252 | self.rate = rate
253 |
254 | def build(self, input_shape):
255 |
256 | self.embedding = tf.keras.layers.Embedding(self.input_vocab_size, self.d_model)
257 | self.pos_encoding = positional_encoding(self.input_vocab_size, self.d_model)
258 |
259 |
260 | self.enc_layers = [
261 | EncoderLayer(self.d_model, self.num_heads, self.dff, self.rate)
262 | for _ in range(self.num_layers)
263 | ]
264 |
265 | self.dropout = tf.keras.layers.Dropout(self.rate)
266 |
267 | for i in [self.embedding] + self.enc_layers + [self.dropout]:
268 | i.build(input_shape)
269 | for weight in i.trainable_weights:
270 | if weight not in self._trainable_weights:
271 | self._trainable_weights.append(weight)
272 | for weight in i.non_trainable_weights:
273 | if weight not in self._non_trainable_weights:
274 | self._non_trainable_weights.append(weight)
275 |
276 | super(Encoder, self).build(input_shape)
277 |
278 | def call(self, x, training, mask):
279 |
280 | seq_len = tf.shape(x)[1]
281 |
282 | # adding embedding and position encoding.
283 | x = self.embedding(x) # (batch_size, input_seq_len, d_model)
284 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
285 | x += self.pos_encoding[:, :seq_len, :]
286 |
287 | x = self.dropout(x, training=training)
288 |
289 | for i in range(self.num_layers):
290 | x = self.enc_layers[i](x, training, mask)
291 |
292 | return x # (batch_size, input_seq_len, d_model)
293 |
294 |
295 | class Decoder(tf.keras.layers.Layer):
296 | """
297 | The Decoder consists of:
298 | 1. Output Embedding
299 | 2. Positional Encoding
300 | 3. N decoder layers
301 |
302 | The target is put through an embedding which is summed with the positional encoding.
303 | The output of this summation is the input to the decoder layers.
304 | The output of the decoder is the input to the final linear layer.
305 |
306 | If `pretrained_embeddings` are available, we use it as a word embedding matrix and do not perform further training
307 | """
308 | def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, rate=0.1):
309 | super(Decoder, self).__init__()
310 |
311 | self.d_model = d_model
312 | self.num_layers = num_layers
313 | self.num_heads = num_heads
314 | self.dff = dff
315 | self.target_vocab_size = target_vocab_size
316 | self.rate = rate
317 |
318 | # if pretrained_embeddings is None:
319 | # self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
320 | # else:
321 | # self.embedding = tf.keras.layers.Embedding(
322 | # target_vocab_size, d_model, trainable=False,
323 | # embeddings_initializer=Constant(pretrained_embeddings)
324 | # )
325 |
326 | def build(self, input_shape):
327 |
328 | self.pos_encoding = positional_encoding(self.target_vocab_size, self.d_model)
329 |
330 | self.dec_layers = [
331 | DecoderLayer(self.d_model, self.num_heads, self.dff, self.rate)
332 | for _ in range(self.num_layers)
333 | ]
334 | self.dropout = tf.keras.layers.Dropout(self.rate)
335 |
336 | for i in self.dec_layers + [self.dropout]:
337 | i.build(input_shape)
338 | for weight in i.trainable_weights:
339 | if weight not in self._trainable_weights:
340 | self._trainable_weights.append(weight)
341 | for weight in i.non_trainable_weights:
342 | if weight not in self._non_trainable_weights:
343 | self._non_trainable_weights.append(weight)
344 |
345 | super(Decoder, self).build(input_shape)
346 |
347 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
348 |
349 | seq_len = tf.shape(x)[1]
350 | attention_weights = {}
351 |
352 | # if not input_alreay_embedded:
353 | # x = self.embedding(x) # (batch_size, target_seq_len, d_model)
354 |
355 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
356 | x += self.pos_encoding[:, :seq_len, :]
357 |
358 | x = self.dropout(x, training=training)
359 |
360 | for i in range(self.num_layers):
361 |
362 | # dv = f"/device:GPU:{str(next(selector))}"
363 | # print(f"With device )
364 | # with tf.device():
365 |
366 | x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
367 |
368 | attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
369 | attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
370 |
371 | # x.shape == (batch_size, target_seq_len, d_model)
372 | return x, attention_weights
373 |
374 |
375 |
376 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/models/__init__.py
--------------------------------------------------------------------------------
/models/abstractive_summarizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import logging
3 | import tensorflow as tf
4 | import tensorflow_hub as hub
5 |
6 | from tqdm import tqdm
7 | from random import randint
8 | from tensorflow.keras import backend as K
9 | from tensorflow.python.keras.initializers import Constant
10 |
11 | from layers.transformer import Encoder
12 | from layers.transformer import Decoder
13 | from layers.bert import BertLayer, BERT_MODEL_URL
14 | from layers.transformer import Decoder
15 |
16 | from ops.masking import create_masks
17 | from ops.masking import create_look_ahead_mask
18 | from ops.masking import create_padding_mask
19 | from ops.masking import mask_timestamp
20 | from ops.masking import tile_and_mask_diagonal
21 |
22 | from ops.session import initialize_vars
23 | from ops.metrics import calculate_rouge
24 | from ops.tensor import with_column
25 | from ops.regularization import label_smoothing
26 | from ops.optimization import noam_scheme
27 | from ops.tokenization import tokenizer
28 | from ops.tokenization import convert_idx_to_token_tensor
29 |
30 | from data.load import UNK_ID
31 | from data.load import CLS_ID
32 | from data.load import SEP_ID
33 | from data.load import MASK_ID
34 |
35 | from config import config
36 |
37 |
38 | logger = logging.getLogger()
39 | logger.setLevel(logging.INFO)
40 |
41 |
42 | warmup_steps = config.WARMUP_STEPS
43 | initial_lr = config.INITIAL_LR
44 |
45 |
46 |
47 | def _embedding_from_bert():
48 | """
49 | Extract the preratined word embeddings from a BERT model
50 | Returns a numpy matrix with the embeddings
51 | """
52 | logger.info("Extracting pretrained word embeddings weights from BERT")
53 |
54 | with tf.device("/device:CPU:0"):
55 | bert = hub.Module(BERT_MODEL_URL, trainable=False, name="embeddings_from_bert_module")
56 |
57 | with tf.Session() as sess:
58 | initialize_vars(sess)
59 | embedding_matrix = sess.run(bert.variable_map['bert/embeddings/word_embeddings'])
60 |
61 | tf.reset_default_graph()
62 |
63 | logger.info(f"Embedding matrix shape '{embedding_matrix.shape}'")
64 | return embedding_matrix
65 |
66 |
67 | class AbstractiveSummarization(tf.keras.Model):
68 | """
69 | Pretraining-Based Natural Language Generation for Text Summarization
70 | https://arxiv.org/pdf/1902.09243.pdf
71 | """
72 | def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, input_seq_len, output_seq_len, rate=0.1):
73 | super(AbstractiveSummarization, self).__init__()
74 |
75 | self.input_seq_len = input_seq_len
76 | self.output_seq_len = output_seq_len
77 |
78 | self.vocab_size = vocab_size
79 |
80 | self.bert = BertLayer(d_embedding=d_model, trainable=False)
81 |
82 | embedding_matrix = _embedding_from_bert()
83 |
84 | self.embedding = tf.keras.layers.Embedding(
85 | vocab_size, d_model, trainable=False,
86 | embeddings_initializer=Constant(embedding_matrix)
87 | )
88 |
89 | self.decoder = Decoder(num_layers, d_model, num_heads, dff, vocab_size, rate)
90 |
91 | self.final_layer = tf.keras.layers.Dense(vocab_size)
92 |
93 | def encode(self, ids, mask, segment_ids):
94 | # (batch_size, seq_len, d_bert)
95 | return self.bert((ids, mask, segment_ids))
96 |
97 | def draft_summary(self, enc_output, look_ahead_mask, padding_mask, target_ids=None, training=True):
98 |
99 | logging.info("Building:'Draft summary'")
100 |
101 | # (batch_size, seq_len)
102 | dec_input = target_ids
103 |
104 | # (batch_size, seq_len, d_bert)
105 | embeddings = self.embedding(target_ids)
106 |
107 | # (batch_size, seq_len, d_bert), (_)
108 | dec_output, attention_dist = self.decoder(embeddings, enc_output, training, look_ahead_mask, padding_mask)
109 |
110 | # (batch_size, seq_len, vocab_len)
111 | logits = self.final_layer(dec_output)
112 |
113 | # (batch_size, seq_len)
114 | preds = tf.to_int32(tf.argmax(logits, axis=-1))
115 |
116 | return logits, preds, attention_dist
117 |
118 | def draft_summary_greedy(self, enc_output, look_ahead_mask, padding_mask, training=False):
119 | """
120 | Inference call, builds a draft summary autoregressively
121 | """
122 |
123 | logging.info("Building: 'Greedy Draft Summary'")
124 |
125 | N = tf.shape(enc_output)[0]
126 | T = tf.shape(enc_output)[1]
127 |
128 | # (batch_size, 1)
129 | dec_input = tf.ones([N, 1], dtype=tf.int32) * CLS_ID
130 |
131 | summary, dec_outputs, dec_logits, attention_dists = [], [], [], []
132 |
133 | summary += [dec_input]
134 | dec_logits += [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1])]
135 |
136 | for i in tqdm(range(0, self.output_seq_len - 1)):
137 |
138 | # (batch_size, i+1, d_bert)
139 | embeddings = self.embedding(dec_input)
140 |
141 | # (batch_size, i+1, d_bert), (_)
142 | dec_output, attention_dist = self.decoder(embeddings, enc_output, training, look_ahead_mask, padding_mask)
143 |
144 | # (batch_size, 1, d_bert)
145 | dec_output_i = dec_output[:, -1: ,:]
146 |
147 | # (batch_size, 1, d_bert)
148 | dec_outputs += [dec_output_i]
149 | attention_dists += [{k: v[:, -1:, :] for k, v in attention_dist.items()}]
150 |
151 | # (batch_size, 1, vocab_len)
152 | logits = self.final_layer(dec_output_i)
153 |
154 | dec_logits += [logits]
155 |
156 | # (batch_size, 1)
157 | preds = tf.to_int32(tf.argmax(logits, axis=-1))
158 |
159 | summary += [preds]
160 |
161 | # (batch_size, i+2)
162 | dec_input = with_column(dec_input, i+1, preds)
163 |
164 | # (batch_size, seq_len, vocab_len)
165 | dec_logits = tf.concat(dec_logits, axis=1)
166 |
167 | # (batch_size, seq_len)
168 | summary = tf.concat(summary, axis=1)
169 |
170 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
171 | return dec_logits, summary, attention_dists
172 |
173 | def refined_summary_iter(self, enc_output, target, padding_mask, training=True):
174 | """
175 | Iterative version of refined summary using teacher forcing
176 | """
177 |
178 | logging.info("Building: 'Refined Summary'")
179 |
180 | N = tf.shape(enc_output)[0]
181 |
182 | dec_inp_ids, dec_inp_mask, dec_inp_segment_ids = target
183 |
184 | dec_outputs, attention_dists = [], []
185 |
186 | for i in tqdm(range(1, self.output_seq_len)):
187 |
188 | # (batch_size, seq_len)
189 | dec_inp_ids_ = mask_timestamp(dec_inp_ids, i, MASK_ID)
190 |
191 | # (batch_size, seq_len, d_bert)
192 | context_vectors = self.bert((dec_inp_ids_, dec_inp_mask, dec_inp_segment_ids))
193 |
194 | # (batch_size, seq_len, d_bert), (_)
195 | dec_output, attention_dist = self.decoder(
196 | context_vectors,
197 | enc_output,
198 | training,
199 | look_ahead_mask=None,
200 | padding_mask=padding_mask
201 | )
202 |
203 | # (batch_size, 1, seq_len)
204 | dec_outputs += [dec_output[:,i:i+1,:]]
205 | attention_dists += [{k: v[:, i:i+1, :] for k, v in attention_dist.items()}]
206 |
207 | # (batch_size, seq_len - 1, d_bert)
208 | dec_outputs = tf.concat(dec_outputs, axis=1)
209 |
210 | # (batch_size, seq_len - 1, vocab_len)
211 | logits = self.final_layer(dec_outputs)
212 |
213 | # (batch_size, seq_len, vocab_len), accommodate for initial [CLS]
214 | logits = tf.concat(
215 | [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1]), logits],
216 | axis=1
217 | )
218 |
219 | # (batch_size, seq_len)
220 | preds = tf.to_int32(tf.argmax(logits, axis=-1))
221 |
222 | return logits, preds, attention_dists
223 |
224 | def refined_summary(self, enc_output, target, padding_mask, training=True):
225 |
226 | logging.info("Building: 'Refined Summary'")
227 |
228 | N = tf.shape(enc_output)[0]
229 | T = self.output_seq_len
230 |
231 | # (batch_size, seq_len) x3
232 | dec_inp_ids, dec_inp_mask, dec_inp_segment_ids = target
233 |
234 | # since we are using teacher forcing we do not need an autoregressice mechanism here
235 |
236 | # (batch_size x (seq_len - 1), seq_len)
237 | dec_inp_ids = tile_and_mask_diagonal(dec_inp_ids, mask_with=MASK_ID)
238 |
239 | # (batch_size x (seq_len - 1), seq_len)
240 | dec_inp_mask = tf.tile(dec_inp_mask, [T-1, 1])
241 |
242 | # (batch_size x (seq_len - 1), seq_len)
243 | dec_inp_segment_ids = tf.tile(dec_inp_segment_ids, [T-1, 1])
244 |
245 | # (batch_size x (seq_len - 1), seq_len, d_bert)
246 | enc_output = tf.tile(enc_output, [T-1, 1, 1])
247 |
248 | # (batch_size x (seq_len - 1), 1, 1, seq_len)
249 | padding_mask = tf.tile(padding_mask, [T-1, 1, 1, 1])
250 |
251 | # (batch_size x (seq_len - 1), seq_len, d_bert)
252 | context_vectors = self.bert((dec_inp_ids, dec_inp_mask, dec_inp_segment_ids))
253 |
254 | # (batch_size x (seq_len - 1), seq_len, d_bert), (_)
255 | dec_outputs, attention_dists = self.decoder(
256 | context_vectors,
257 | enc_output,
258 | training,
259 | look_ahead_mask=None,
260 | padding_mask=padding_mask
261 | )
262 |
263 | # (batch_size x (seq_len - 1), seq_len - 1, d_bert)
264 | dec_outputs = dec_outputs[:, 1:, :]
265 |
266 | # (batch_size x (seq_len - 1), (seq_len - 1))
267 | diag = tf.linalg.set_diag(tf.zeros([T-1, T-1]), tf.ones([T-1]))
268 | diag = tf.tile(diag, [N, 1])
269 |
270 | where = tf.not_equal(diag, 0)
271 | indices = tf.where(where)
272 |
273 | # (batch_size x (seq_len - 1), d_bert)
274 | dec_outputs = tf.gather_nd(dec_outputs, indices)
275 |
276 | # (batch_size, seq_len - 1, d_bert)
277 | dec_outputs = tf.reshape(dec_outputs, [N, T-1, -1])
278 |
279 | # (batch_size, seq_len - 1, vocab_len)
280 | logits = self.final_layer(dec_outputs)
281 |
282 | # (batch_size, seq_len, vocab_len), accommodate for initial [CLS]
283 | logits = tf.concat(
284 | [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1]), logits],
285 | axis=1
286 | )
287 |
288 | # (batch_size, seq_len)
289 | preds = tf.to_int32(tf.argmax(logits, axis=-1))
290 |
291 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
292 | return logits, preds, attention_dists
293 |
294 | def refined_summary_greedy(self, enc_output, draft_summary, padding_mask, training=False):
295 | """
296 | Inference call, builds a refined summary
297 |
298 | It first masks each word in the summary draft one by one,
299 | then feeds the draft to BERT to generate context vectors.
300 | """
301 |
302 | logging.info("Building: 'Greedy Refined Summary'")
303 |
304 | refined_summary = draft_summary
305 | refined_summary_mask = tf.cast(tf.math.equal(draft_summary, 0), tf.float32)
306 | refined_summary_segment_ids = tf.zeros(tf.shape(draft_summary))
307 |
308 | N = tf.shape(draft_summary)[0]
309 | T = tf.shape(draft_summary)[1]
310 |
311 | dec_outputs, dec_logits, attention_dists = [], [], []
312 |
313 | dec_logits += [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1])]
314 |
315 | for i in tqdm(range(1, self.output_seq_len)):
316 |
317 | # (batch_size, seq_len)
318 | refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID)
319 |
320 | # (batch_size, seq_len, d_bert)
321 | context_vectors = self.bert((refined_summary_, refined_summary_mask, refined_summary_segment_ids))
322 |
323 | # (batch_size, seq_len, d_bert), (_)
324 | dec_output, attention_dist = self.decoder(
325 | context_vectors,
326 | enc_output,
327 | training=training,
328 | look_ahead_mask=None,
329 | padding_mask=padding_mask
330 | )
331 |
332 | # (batch_size, 1, vocab_len)
333 | dec_output_i = dec_output[:, i:i+1 ,:]
334 |
335 | dec_outputs += [dec_output_i]
336 | attention_dists += [{k: v[:, i:i+1, :] for k, v in attention_dist.items()}]
337 |
338 | # (batch_size, 1, vocab_len)
339 | logits = self.final_layer(dec_output_i)
340 |
341 | dec_logits += [logits]
342 |
343 | # (batch_size, 1)
344 | preds = tf.to_int32(tf.argmax(logits, axis=-1))
345 |
346 | # (batch_size, seq_len)
347 | refined_summary = with_column(refined_summary, i, preds)
348 |
349 | # (batch_size, seq_len, vocab_len)
350 | dec_logits = tf.concat(dec_logits, axis=1)
351 |
352 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
353 | return dec_logits, refined_summary, attention_dists
354 |
355 | def call(self, inp, tar=None, training=False):
356 | """
357 | Run the model for training/inference
358 | For training, the target is needed (ids, mask, segments)
359 | """
360 |
361 | if training:
362 | return self.fit(inp, tar)
363 |
364 | else:
365 | return self.predict(inp)
366 |
367 |
368 | def fit(self, inp, tar):
369 | """
370 | __call__ for training; uses teacher forcing for both the draft
371 | and the defined decoder
372 | """
373 | # (batch_size, seq_len) x3
374 | input_ids, input_mask, input_segment_ids = inp
375 |
376 | # (batch_size, seq_len + 1) x3
377 | target_ids, target_mask, target_segment_ids = tar
378 |
379 | # (batch_size, 1, 1, seq_len), (_), (batch_size, 1, 1, seq_len)
380 | combined_mask, dec_padding_mask = create_masks(input_ids, target_ids[:, :-1])
381 |
382 | # (batch_size, seq_len, d_bert)
383 | enc_output = self.encode(input_ids, input_mask, input_segment_ids)
384 |
385 | # (batch_size, seq_len , vocab_len), (batch_size, seq_len), (_)
386 | logits_draft_summary, preds_draft_summary, draft_attention_dist = self.draft_summary(
387 | enc_output=enc_output,
388 | look_ahead_mask=combined_mask,
389 | padding_mask=dec_padding_mask,
390 | target_ids=target_ids[:, :-1],
391 | training=True
392 | )
393 |
394 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
395 | logits_refined_summary, preds_refined_summary, refined_attention_dist = self.refined_summary(
396 | enc_output=enc_output,
397 | target=(target_ids[:, :-1], target_mask[:, :-1], target_segment_ids[:, :-1]),
398 | padding_mask=dec_padding_mask,
399 | training=True
400 | )
401 |
402 | return logits_draft_summary, logits_refined_summary
403 |
404 |
405 | def predict(self, inp):
406 | """
407 | __call__ for inference; uses teacher forcing for both the draft
408 | and the defined decoder
409 | """
410 | # (batch_size, seq_len) x3
411 | input_ids, input_mask, input_segment_ids = inp
412 |
413 | dec_padding_mask = create_padding_mask(input_ids)
414 |
415 | # (batch_size, seq_len, d_bert)
416 | enc_output = self.encode(input_ids, input_mask, input_segment_ids)
417 |
418 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
419 | logits_draft_summary, preds_draft_summary, draft_attention_dist = self.draft_summary_greedy(
420 | enc_output=enc_output,
421 | look_ahead_mask=None,
422 | padding_mask=dec_padding_mask
423 | )
424 |
425 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
426 | logits_refined_summary, preds_refined_summary, refined_attention_dist = self.refined_summary_greedy(
427 | enc_output=enc_output,
428 | padding_mask=dec_padding_mask,
429 | draft_summary=preds_draft_summary
430 | )
431 |
432 | return logits_draft_summary, preds_draft_summary, draft_attention_dist, logits_refined_summary, preds_refined_summary, refined_attention_dist
433 |
434 |
435 | def train(model, xs, ys, gradient_accumulation=False):
436 |
437 | logging.info("Building Training Graph")
438 | logging.info(f"w/ Gradient Accumulation: {str(gradient_accumulation)}")
439 |
440 | # (batch_size, seq_len + 1) x3
441 | target_ids, _, _ = ys
442 |
443 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len, vocab_len)
444 | logits_draft_summary, logits_refined_summary = model(xs, ys, True)
445 |
446 | target_ids_ = label_smoothing(tf.one_hot(target_ids, depth=model.vocab_size))
447 |
448 | # use right shifted target, (batch_size, seq_len)
449 | loss_draft = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_draft_summary, labels=target_ids_[:, 1:, :])
450 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, 1:], 0))
451 | mask = tf.cast(mask, dtype=loss_draft.dtype)
452 | loss_draft *= mask
453 |
454 | # use non-shifted target (we want to predict the masked word), (batch_size, seq_len)
455 | loss_refined = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_refined_summary, labels=target_ids_[:, :-1, :])
456 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, :-1], 0))
457 | mask = tf.cast(mask, dtype=loss_refined.dtype)
458 | loss_refined *= mask
459 |
460 | # (batch_size, seq_len)
461 | loss = loss_draft + loss_refined
462 | # scalar
463 | loss = tf.reduce_mean(loss)
464 |
465 | global_step = tf.train.get_or_create_global_step()
466 | learning_rate = noam_scheme(initial_lr, global_step, warmup_steps)
467 | optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.98, epsilon=1e-9)
468 |
469 | tf.summary.scalar('learning_rate', learning_rate, family='train')
470 | tf.summary.scalar('loss_draft', tf.reduce_mean(loss_draft * mask), family='train')
471 | tf.summary.scalar('loss_refined', tf.reduce_mean(loss_refined * mask), family='train')
472 | tf.summary.scalar("loss", loss, family='train')
473 | tf.summary.scalar("global_step", global_step, family='train')
474 |
475 | summaries = tf.summary.merge_all()
476 |
477 | if gradient_accumulation:
478 |
479 | tvs = tf.trainable_variables()
480 |
481 | accumulation_variables = [
482 | tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False)
483 | for tv in tvs
484 | ]
485 |
486 | zero_op = [tv.assign(tf.zeros_like(tv)) for tv in accumulation_variables]
487 |
488 | gradients_vs = optimizer.compute_gradients(
489 | loss=loss,
490 | var_list=tvs,
491 | colocate_gradients_with_ops=True,
492 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
493 | )
494 |
495 | accumlation_op = [accumulation_variables[i].assign_add(gv[0]) for i, gv in enumerate(gradients_vs) if gv[0] is not None]
496 |
497 | # pass list of (gradient, variable) pairs
498 | train_op = optimizer.apply_gradients([(accumulation_variables[i], gv[1]) for i, gv in enumerate(gradients_vs)], global_step)
499 |
500 | return loss, zero_op, accumlation_op, train_op, global_step, summaries
501 |
502 | else:
503 |
504 | train_op = optimizer.minimize(
505 | loss,
506 | global_step=global_step,
507 | colocate_gradients_with_ops=True,
508 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
509 | )
510 |
511 | return loss, train_op, global_step, summaries
512 |
513 |
514 | def eval(model, xs, ys):
515 |
516 | logging.info("Building Evaluation Graph")
517 |
518 | # (batch_size, seq_len + 1) x3
519 | target_ids, _, _ = ys
520 |
521 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_), (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_)
522 | logits_draft_summary, preds_draft_summary, _, logits_refined_summary, preds_refined_summary, _ = model(xs)
523 |
524 | target_ids_ = label_smoothing(tf.one_hot(target_ids, depth=model.vocab_size))
525 |
526 | # use right shifted target, (batch_size, seq_len)
527 | loss_draft = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_draft_summary, labels=target_ids_[:, 1:, :])
528 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, 1:], 0))
529 | mask = tf.cast(mask, dtype=loss_draft.dtype)
530 | loss_draft *= mask
531 |
532 | # use non-shifted target (we want to predict the masked word), (batch_size, seq_len)
533 | loss_refined = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_refined_summary, labels=target_ids_[:, :-1, :])
534 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, :-1], 0))
535 | mask = tf.cast(mask, dtype=loss_refined.dtype)
536 | loss_refined *= mask
537 |
538 | # (batch_size, seq_len)
539 | loss = loss_draft + loss_refined
540 |
541 | # scalar
542 | loss = tf.reduce_mean(loss)
543 |
544 | # monitor a random sample
545 | n = tf.random_uniform((), 0, tf.shape(xs[0])[0] - 1, tf.int32)
546 |
547 | x_rnd = convert_idx_to_token_tensor(xs[0][n])
548 | y_rnd = convert_idx_to_token_tensor(target_ids[n, :-1])
549 | y_hat_rnd = convert_idx_to_token_tensor(preds_refined_summary[n])
550 |
551 | r1_val, r2_val, rl_val, r_vag = calculate_rouge(y_rnd, y_hat_rnd)
552 |
553 | tf.summary.text("input", x_rnd)
554 | tf.summary.text("target", y_rnd)
555 | tf.summary.text("prediction", y_hat_rnd)
556 |
557 | tf.summary.scalar('ROUGE-1', r1_val, family='eval')
558 | tf.summary.scalar('ROUGE-2', r2_val, family='eval')
559 | tf.summary.scalar("ROUGE-L", rl_val, family='eval')
560 | tf.summary.scalar("R-AVG", r_vag, family='eval')
561 |
562 | tf.summary.scalar('loss_draft', tf.reduce_mean(loss_draft * mask), family='eval')
563 | tf.summary.scalar('loss_refined', tf.reduce_mean(loss_refined * mask), family='eval')
564 | tf.summary.scalar("loss", loss, family='eval')
565 |
566 | summaries = tf.summary.merge_all()
567 |
568 | # (batch_size, seq_len), (batch_size, seq_len), scalar, object
569 | return target_ids[:, :-1], preds_refined_summary, loss, summaries
570 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | import tensorflow as tf
5 | import tensorflow_hub as hub
6 |
7 | from tensorflow.keras import backend as K
8 | from layers.transformer import Encoder
9 | from layers.transformer import Decoder
10 |
11 |
12 | class Transformer(tf.keras.Model):
13 | """
14 | Transformer consists of the encoder, decoder and a final linear layer.
15 | The output of the decoder is the input to the linear layer and its output is returned.
16 | """
17 | def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, rate=0.1):
18 | super(Transformer, self).__init__()
19 |
20 | self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, rate)
21 |
22 | self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, rate)
23 |
24 | self.final_layer = tf.keras.layers.Dense(target_vocab_size)
25 |
26 | def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask):
27 |
28 | enc_output = self.encoder(inp, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model)
29 |
30 | # dec_output.shape == (batch_size, tar_seq_len, d_model)
31 | dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask)
32 |
33 | final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size)
34 |
35 | return final_output, attention_weights
36 |
--------------------------------------------------------------------------------
/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/ops/__init__.py
--------------------------------------------------------------------------------
/ops/attention.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def scaled_dot_product_attention(q, k, v, mask):
5 | """Calculate the attention weights.
6 | q, k, v must have matching leading dimensions.
7 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
8 | The mask has different shapes depending on its type(padding or look ahead)
9 | but it must be broadcastable for addition.
10 |
11 | The mask is multiplied with -1e9 (close to negative infinity).
12 | This is done because the mask is summed with the scaled matrix
13 | multiplication of Q and K and is applied immediately before a softmax.
14 | The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.
15 |
16 | Args:
17 | q: query shape == (..., seq_len_q, depth)
18 | k: key shape == (..., seq_len_k, depth)
19 | v: value shape == (..., seq_len_v, depth_v)
20 | mask: Float tensor with shape broadcastable
21 | to (..., seq_len_q, seq_len_k). Defaults to None.
22 |
23 | Returns:
24 | output, attention_weights
25 | """
26 |
27 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
28 |
29 | # scale matmul_qk
30 | dk = tf.cast(tf.shape(k)[-1], tf.float32)
31 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
32 |
33 | # add the mask to the scaled tensor.
34 | if mask is not None:
35 | scaled_attention_logits += (mask * -1e9)
36 |
37 | # softmax is normalized on the last axis (seq_len_k) so that the scores
38 | # add up to 1.
39 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
40 |
41 | output = tf.matmul(attention_weights, v) # (..., seq_len_v, depth_v)
42 |
43 | return output, attention_weights
44 |
45 |
--------------------------------------------------------------------------------
/ops/beam_search.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.util import nest
3 |
4 | # Default value for INF
5 | INF = 1. * 1e7
6 |
7 |
8 | class _StateKeys(object):
9 | """Keys to dictionary storing the state of the beam search loop."""
10 |
11 | # Variable storing the loop index.
12 | CUR_INDEX = "CUR_INDEX"
13 |
14 | # Top sequences that are alive for each batch item. Alive sequences are ones
15 | # that have not generated an EOS token. Sequences that reach EOS are marked as
16 | # finished and moved to the FINISHED_SEQ tensor.
17 | # Has shape [batch_size, beam_size, CUR_INDEX + 1]
18 | ALIVE_SEQ = "ALIVE_SEQ"
19 | # Log probabilities of each alive sequence. Shape [batch_size, beam_size]
20 | ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS"
21 | # Dictionary of cached values for each alive sequence. The cache stores
22 | # the encoder output, attention bias, and the decoder attention output from
23 | # the previous iteration.
24 | ALIVE_CACHE = "ALIVE_CACHE"
25 |
26 | # Top finished sequences for each batch item.
27 | # Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are
28 | # shorter than CUR_INDEX + 1 are padded with 0s.
29 | FINISHED_SEQ = "FINISHED_SEQ"
30 | # Scores for each finished sequence. Score = log probability / length norm
31 | # Shape [batch_size, beam_size]
32 | FINISHED_SCORES = "FINISHED_SCORES"
33 | # Flags indicating which sequences in the finished sequences are finished.
34 | # At the beginning, all of the sequences in FINISHED_SEQ are filler values.
35 | # True -> finished sequence, False -> filler. Shape [batch_size, beam_size]
36 | FINISHED_FLAGS = "FINISHED_FLAGS"
37 |
38 |
39 | class SequenceBeamSearch(object):
40 | """Implementation of beam search loop."""
41 |
42 | def __init__(self, symbols_to_logits_fn, vocab_size, batch_size,
43 | beam_size, alpha, max_decode_length, eos_id):
44 | self.symbols_to_logits_fn = symbols_to_logits_fn
45 | self.vocab_size = vocab_size
46 | self.batch_size = batch_size
47 | self.beam_size = beam_size
48 | self.alpha = alpha
49 | self.max_decode_length = max_decode_length
50 | self.eos_id = eos_id
51 |
52 | def search(self, initial_ids, initial_cache):
53 | """Beam search for sequences with highest scores."""
54 | state, state_shapes = self._create_initial_state(initial_ids, initial_cache)
55 |
56 | finished_state = tf.while_loop(
57 | self._continue_search, self._search_step, loop_vars=[state],
58 | shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False)
59 | finished_state = finished_state[0]
60 |
61 | alive_seq = finished_state[_StateKeys.ALIVE_SEQ]
62 | alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS]
63 | finished_seq = finished_state[_StateKeys.FINISHED_SEQ]
64 | finished_scores = finished_state[_StateKeys.FINISHED_SCORES]
65 | finished_flags = finished_state[_StateKeys.FINISHED_FLAGS]
66 |
67 | # Account for corner case where there are no finished sequences for a
68 | # particular batch item. In that case, return alive sequences for that batch
69 | # item.
70 | finished_seq = tf.where(
71 | tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
72 | finished_scores = tf.where(
73 | tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
74 | return finished_seq, finished_scores
75 |
76 | def _create_initial_state(self, initial_ids, initial_cache):
77 | """Return initial state dictionary and its shape invariants.
78 |
79 | Args:
80 | initial_ids: initial ids to pass into the symbols_to_logits_fn.
81 | int tensor with shape [batch_size, 1]
82 | initial_cache: dictionary storing values to be passed into the
83 | symbols_to_logits_fn.
84 |
85 | Returns:
86 | state and shape invariant dictionaries with keys from _StateKeys
87 | """
88 | # Current loop index (starts at 0)
89 | cur_index = tf.constant(0)
90 |
91 | # Create alive sequence with shape [batch_size, beam_size, 1]
92 | alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
93 | alive_seq = tf.expand_dims(alive_seq, axis=2)
94 |
95 | # Create tensor for storing initial log probabilities.
96 | # Assume initial_ids are prob 1.0
97 | initial_log_probs = tf.constant(
98 | [[0.] + [-float("inf")] * (self.beam_size - 1)])
99 | alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1])
100 |
101 | # Expand all values stored in the dictionary to the beam size, so that each
102 | # beam has a separate cache.
103 | alive_cache = nest.map_structure(
104 | lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache)
105 |
106 | # Initialize tensor storing finished sequences with filler values.
107 | finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
108 |
109 | # Set scores of the initial finished seqs to negative infinity.
110 | finished_scores = tf.ones([self.batch_size, self.beam_size]) * -INF
111 |
112 | # Initialize finished flags with all False values.
113 | finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool)
114 |
115 | # Create state dictionary
116 | state = {
117 | _StateKeys.CUR_INDEX: cur_index,
118 | _StateKeys.ALIVE_SEQ: alive_seq,
119 | _StateKeys.ALIVE_LOG_PROBS: alive_log_probs,
120 | _StateKeys.ALIVE_CACHE: alive_cache,
121 | _StateKeys.FINISHED_SEQ: finished_seq,
122 | _StateKeys.FINISHED_SCORES: finished_scores,
123 | _StateKeys.FINISHED_FLAGS: finished_flags
124 | }
125 |
126 | # Create state invariants for each value in the state dictionary. Each
127 | # dimension must be a constant or None. A None dimension means either:
128 | # 1) the dimension's value is a tensor that remains the same but may
129 | # depend on the input sequence to the model (e.g. batch size).
130 | # 2) the dimension may have different values on different iterations.
131 | state_shape_invariants = {
132 | _StateKeys.CUR_INDEX: tf.TensorShape([]),
133 | _StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]),
134 | _StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]),
135 | _StateKeys.ALIVE_CACHE: nest.map_structure(
136 | _get_shape_keep_last_dim, alive_cache),
137 | _StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]),
138 | _StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]),
139 | _StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size])
140 | }
141 |
142 | return state, state_shape_invariants
143 |
144 | def _continue_search(self, state):
145 | """Return whether to continue the search loop.
146 |
147 | The loops should terminate when
148 | 1) when decode length has been reached, or
149 | 2) when the worst score in the finished sequences is better than the best
150 | score in the alive sequences (i.e. the finished sequences are provably
151 | unchanging)
152 |
153 | Args:
154 | state: A dictionary with the current loop state.
155 |
156 | Returns:
157 | Bool tensor with value True if loop should continue, False if loop should
158 | terminate.
159 | """
160 | i = state[_StateKeys.CUR_INDEX]
161 | alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
162 | finished_scores = state[_StateKeys.FINISHED_SCORES]
163 | finished_flags = state[_StateKeys.FINISHED_FLAGS]
164 |
165 | not_at_max_decode_length = tf.less(i, self.max_decode_length)
166 |
167 | # Calculate largest length penalty (the larger penalty, the better score).
168 | max_length_norm = _length_normalization(self.alpha, self.max_decode_length)
169 | # Get the best possible scores from alive sequences.
170 | best_alive_scores = alive_log_probs[:, 0] / max_length_norm
171 |
172 | # Compute worst score in finished sequences for each batch element
173 | finished_scores *= tf.cast(finished_flags,
174 | tf.float32) # set filler scores to zero
175 | lowest_finished_scores = tf.reduce_min(finished_scores, axis=1)
176 |
177 | # If there are no finished sequences in a batch element, then set the lowest
178 | # finished score to -INF for that element.
179 | finished_batches = tf.reduce_any(finished_flags, 1)
180 | lowest_finished_scores += (1.0 -
181 | tf.cast(finished_batches, tf.float32)) * -INF
182 |
183 | worst_finished_score_better_than_best_alive_score = tf.reduce_all(
184 | tf.greater(lowest_finished_scores, best_alive_scores)
185 | )
186 |
187 | return tf.logical_and(
188 | not_at_max_decode_length,
189 | tf.logical_not(worst_finished_score_better_than_best_alive_score)
190 | )
191 |
192 | def _search_step(self, state):
193 | """Beam search loop body.
194 |
195 | Grow alive sequences by a single ID. Sequences that have reached the EOS
196 | token are marked as finished. The alive and finished sequences with the
197 | highest log probabilities and scores are returned.
198 |
199 | A sequence's finished score is calculating by dividing the log probability
200 | by the length normalization factor. Without length normalization, the
201 | search is more likely to return shorter sequences.
202 |
203 | Args:
204 | state: A dictionary with the current loop state.
205 |
206 | Returns:
207 | new state dictionary.
208 | """
209 | # Grow alive sequences by one token.
210 | new_seq, new_log_probs, new_cache = self._grow_alive_seq(state)
211 | # Collect top beam_size alive sequences
212 | alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache)
213 |
214 | # Combine newly finished sequences with existing finished sequences, and
215 | # collect the top k scoring sequences.
216 | finished_state = self._get_new_finished_state(state, new_seq, new_log_probs)
217 |
218 | # Increment loop index and create new state dictionary
219 | new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1}
220 | new_state.update(alive_state)
221 | new_state.update(finished_state)
222 | return [new_state]
223 |
224 | def _grow_alive_seq(self, state):
225 | """Grow alive sequences by one token, and collect top 2*beam_size sequences.
226 |
227 | 2*beam_size sequences are collected because some sequences may have reached
228 | the EOS token. 2*beam_size ensures that at least beam_size sequences are
229 | still alive.
230 |
231 | Args:
232 | state: A dictionary with the current loop state.
233 | Returns:
234 | Tuple of
235 | (Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1],
236 | Scores of returned sequences [batch_size, 2 * beam_size],
237 | New alive cache, for each of the 2 * beam_size sequences)
238 | """
239 | i = state[_StateKeys.CUR_INDEX]
240 | alive_seq = state[_StateKeys.ALIVE_SEQ]
241 | alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS]
242 | alive_cache = state[_StateKeys.ALIVE_CACHE]
243 |
244 | beams_to_keep = 2 * self.beam_size
245 |
246 | # Get logits for the next candidate IDs for the alive sequences. Get the new
247 | # cache values at the same time.
248 | flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size]
249 | flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache)
250 |
251 | flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i)
252 |
253 | # Unflatten logits to shape [batch_size, beam_size, vocab_size]
254 | logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size)
255 | new_cache = nest.map_structure(
256 | lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size),
257 | flat_cache)
258 |
259 | # Convert logits to normalized log probs
260 | candidate_log_probs = _log_prob_from_logits(logits)
261 |
262 | # Calculate new log probabilities if each of the alive sequences were
263 | # extended # by the the candidate IDs.
264 | # Shape [batch_size, beam_size, vocab_size]
265 | log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)
266 |
267 | # Each batch item has beam_size * vocab_size candidate sequences. For each
268 | # batch item, get the k candidates with the highest log probabilities.
269 | flat_log_probs = tf.reshape(log_probs,
270 | [-1, self.beam_size * self.vocab_size])
271 | topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep)
272 |
273 | # Extract the alive sequences that generate the highest log probabilities
274 | # after being extended.
275 | topk_beam_indices = topk_indices // self.vocab_size
276 | topk_seq, new_cache = _gather_beams(
277 | [alive_seq, new_cache], topk_beam_indices, self.batch_size,
278 | beams_to_keep)
279 |
280 | # Append the most probable IDs to the topk sequences
281 | topk_ids = topk_indices % self.vocab_size
282 | topk_ids = tf.expand_dims(topk_ids, axis=2)
283 | topk_seq = tf.concat([topk_seq, topk_ids], axis=2)
284 | return topk_seq, topk_log_probs, new_cache
285 |
286 | def _get_new_alive_state(self, new_seq, new_log_probs, new_cache):
287 | """Gather the top k sequences that are still alive.
288 |
289 | Args:
290 | new_seq: New sequences generated by growing the current alive sequences
291 | int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1]
292 | new_log_probs: Log probabilities of new sequences
293 | float32 tensor with shape [batch_size, beam_size]
294 | new_cache: Dict of cached values for each sequence.
295 |
296 | Returns:
297 | Dictionary with alive keys from _StateKeys:
298 | {Top beam_size sequences that are still alive (don't end with eos_id)
299 | Log probabilities of top alive sequences
300 | Dict cache storing decoder states for top alive sequences}
301 | """
302 | # To prevent finished sequences from being considered, set log probs to -INF
303 | new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
304 | new_log_probs += tf.cast(new_finished_flags, tf.float32) * -INF
305 |
306 | top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams(
307 | [new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size,
308 | self.beam_size)
309 |
310 | return {
311 | _StateKeys.ALIVE_SEQ: top_alive_seq,
312 | _StateKeys.ALIVE_LOG_PROBS: top_alive_log_probs,
313 | _StateKeys.ALIVE_CACHE: top_alive_cache
314 | }
315 |
316 | def _get_new_finished_state(self, state, new_seq, new_log_probs):
317 | """Combine new and old finished sequences, and gather the top k sequences.
318 |
319 | Args:
320 | state: A dictionary with the current loop state.
321 | new_seq: New sequences generated by growing the current alive sequences
322 | int32 tensor with shape [batch_size, beam_size, i + 1]
323 | new_log_probs: Log probabilities of new sequences
324 | float32 tensor with shape [batch_size, beam_size]
325 |
326 | Returns:
327 | Dictionary with finished keys from _StateKeys:
328 | {Top beam_size finished sequences based on score,
329 | Scores of finished sequences,
330 | Finished flags of finished sequences}
331 | """
332 | i = state[_StateKeys.CUR_INDEX]
333 | finished_seq = state[_StateKeys.FINISHED_SEQ]
334 | finished_scores = state[_StateKeys.FINISHED_SCORES]
335 | finished_flags = state[_StateKeys.FINISHED_FLAGS]
336 |
337 | # First append a column of 0-ids to finished_seq to increment the length.
338 | # New shape of finished_seq: [batch_size, beam_size, i + 1]
339 | finished_seq = tf.concat(
340 | [finished_seq,
341 | tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2)
342 |
343 | # Calculate new seq scores from log probabilities.
344 | length_norm = _length_normalization(self.alpha, i + 1)
345 | new_scores = new_log_probs / length_norm
346 |
347 | # Set the scores of the still-alive seq in new_seq to large negative values.
348 | new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id)
349 | new_scores += (1. - tf.cast(new_finished_flags, tf.float32)) * -INF
350 |
351 | # Combine sequences, scores, and flags.
352 | finished_seq = tf.concat([finished_seq, new_seq], axis=1)
353 | finished_scores = tf.concat([finished_scores, new_scores], axis=1)
354 | finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1)
355 |
356 | # Return the finished sequences with the best scores.
357 | top_finished_seq, top_finished_scores, top_finished_flags = (
358 | _gather_topk_beams([finished_seq, finished_scores, finished_flags],
359 | finished_scores, self.batch_size, self.beam_size))
360 |
361 | return {
362 | _StateKeys.FINISHED_SEQ: top_finished_seq,
363 | _StateKeys.FINISHED_SCORES: top_finished_scores,
364 | _StateKeys.FINISHED_FLAGS: top_finished_flags
365 | }
366 |
367 |
368 | def sequence_beam_search(
369 | symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size,
370 | alpha, max_decode_length, eos_id):
371 | """Search for sequence of subtoken ids with the largest probability.
372 |
373 | Args:
374 | symbols_to_logits_fn: A function that takes in ids, index, and cache as
375 | arguments. The passed in arguments will have shape:
376 | ids -> [batch_size * beam_size, index]
377 | index -> [] (scalar)
378 | cache -> nested dictionary of tensors [batch_size * beam_size, ...]
379 | The function must return logits and new cache.
380 | logits -> [batch * beam_size, vocab_size]
381 | new cache -> same shape/structure as inputted cache
382 | initial_ids: Starting ids for each batch item.
383 | int32 tensor with shape [batch_size]
384 | initial_cache: dict containing starting decoder variables information
385 | vocab_size: int size of tokens
386 | beam_size: int number of beams
387 | alpha: float defining the strength of length normalization
388 | max_decode_length: maximum length to decoded sequence
389 | eos_id: int id of eos token, used to determine when a sequence has finished
390 |
391 | Returns:
392 | Top decoded sequences [batch_size, beam_size, max_decode_length]
393 | sequence scores [batch_size, beam_size]
394 | """
395 | batch_size = tf.shape(initial_ids)[0]
396 | sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size,
397 | beam_size, alpha, max_decode_length, eos_id)
398 | return sbs.search(initial_ids, initial_cache)
399 |
400 |
401 | def _log_prob_from_logits(logits):
402 | return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True)
403 |
404 |
405 | def _length_normalization(alpha, length):
406 | """Return length normalization factor."""
407 | return tf.pow(((5. + tf.cast(length, tf.float32)) / 6.), alpha)
408 |
409 |
410 | def _expand_to_beam_size(tensor, beam_size):
411 | """Tiles a given tensor by beam_size.
412 |
413 | Args:
414 | tensor: tensor to tile [batch_size, ...]
415 | beam_size: How much to tile the tensor by.
416 |
417 | Returns:
418 | Tiled tensor [batch_size, beam_size, ...]
419 | """
420 | tensor = tf.expand_dims(tensor, axis=1)
421 | tile_dims = [1] * tensor.shape.ndims
422 | tile_dims[1] = beam_size
423 |
424 | return tf.tile(tensor, tile_dims)
425 |
426 |
427 | def _shape_list(tensor):
428 | """Return a list of the tensor's shape, and ensure no None values in list."""
429 | # Get statically known shape (may contain None's for unknown dimensions)
430 | shape = tensor.get_shape().as_list()
431 |
432 | # Ensure that the shape values are not None
433 | dynamic_shape = tf.shape(tensor)
434 | for i in range(len(shape)): # pylint: disable=consider-using-enumerate
435 | if shape[i] is None:
436 | shape[i] = dynamic_shape[i]
437 | return shape
438 |
439 |
440 | def _get_shape_keep_last_dim(tensor):
441 | shape_list = _shape_list(tensor)
442 |
443 | # Only the last
444 | for i in range(len(shape_list) - 1):
445 | shape_list[i] = None
446 |
447 | if isinstance(shape_list[-1], tf.Tensor):
448 | shape_list[-1] = None
449 | return tf.TensorShape(shape_list)
450 |
451 |
452 | def _flatten_beam_dim(tensor):
453 | """Reshapes first two dimensions in to single dimension.
454 |
455 | Args:
456 | tensor: Tensor to reshape of shape [A, B, ...]
457 |
458 | Returns:
459 | Reshaped tensor of shape [A*B, ...]
460 | """
461 | shape = _shape_list(tensor)
462 | shape[0] *= shape[1]
463 | shape.pop(1) # Remove beam dim
464 | return tf.reshape(tensor, shape)
465 |
466 |
467 | def _unflatten_beam_dim(tensor, batch_size, beam_size):
468 | """Reshapes first dimension back to [batch_size, beam_size].
469 |
470 | Args:
471 | tensor: Tensor to reshape of shape [batch_size*beam_size, ...]
472 | batch_size: Tensor, original batch size.
473 | beam_size: int, original beam size.
474 |
475 | Returns:
476 | Reshaped tensor of shape [batch_size, beam_size, ...]
477 | """
478 | shape = _shape_list(tensor)
479 | new_shape = [batch_size, beam_size] + shape[1:]
480 | return tf.reshape(tensor, new_shape)
481 |
482 |
483 | def _gather_beams(nested, beam_indices, batch_size, new_beam_size):
484 | """Gather beams from nested structure of tensors.
485 |
486 | Each tensor in nested represents a batch of beams, where beam refers to a
487 | single search state (beam search involves searching through multiple states
488 | in parallel).
489 |
490 | This function is used to gather the top beams, specified by
491 | beam_indices, from the nested tensors.
492 |
493 | Args:
494 | nested: Nested structure (tensor, list, tuple or dict) containing tensors
495 | with shape [batch_size, beam_size, ...].
496 | beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each
497 | value in beam_indices must be between [0, beam_size), and are not
498 | necessarily unique.
499 | batch_size: int size of batch
500 | new_beam_size: int number of beams to be pulled from the nested tensors.
501 |
502 | Returns:
503 | Nested structure containing tensors with shape
504 | [batch_size, new_beam_size, ...]
505 | """
506 | # Computes the i'th coodinate that contains the batch index for gather_nd.
507 | # Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..].
508 | batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size
509 | batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size])
510 |
511 | # Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor
512 | # with shape [batch_size, beam_size, 2], where the last dimension contains
513 | # the (i, j) gathering coordinates.
514 | coordinates = tf.stack([batch_pos, beam_indices], axis=2)
515 |
516 | return nest.map_structure(
517 | lambda state: tf.gather_nd(state, coordinates), nested)
518 |
519 |
520 | def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size):
521 | """Gather top beams from nested structure."""
522 | _, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size)
523 | return _gather_beams(nested, topk_indexes, batch_size, beam_size)
524 |
--------------------------------------------------------------------------------
/ops/data.py:
--------------------------------------------------------------------------------
1 | import re
2 | import time
3 | import glob
4 | import struct
5 | import shutil
6 | import logging
7 | import pickle
8 | import numpy as np
9 | import tensorflow as tf
10 | import tensorflow_hub as hub
11 |
12 | from collections import namedtuple
13 | from itertools import count
14 | from functools import partial
15 | from functools import reduce
16 | from functools import wraps
17 |
18 | from tensorflow.core.example import example_pb2
19 | from bert.tokenization import FullTokenizer
20 | from tqdm import tqdm
21 |
22 | from utils.decorators import timeit
23 | from config import config
24 |
25 |
26 | logging.basicConfig(level=logging.INFO)
27 |
28 | logger = logging.getLogger()
29 |
30 |
31 |
32 | SENTENCE_START = ''
33 | SENTENCE_END = ''
34 |
35 | BERT_MODEL_URL = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
36 |
37 |
38 | InputExample = namedtuple('InputExample', ['guid', 'text_a', 'text_b'])
39 |
40 | InputFeatures = namedtuple('InputFeatures', ['guid', 'tokens', 'input_ids', 'input_mask', 'input_type_ids'])
41 |
42 |
43 | def pad(l, n, pad):
44 | """
45 | Pad the list 'l' to have size 'n' using 'padding_element'
46 | """
47 | return l + [pad] * (n - len(l))
48 |
49 |
50 | def calc_num_batches(total_num, batch_size):
51 | """
52 | Calculates the number of batches, allowing for remainders.
53 | """
54 | return total_num // batch_size + int(total_num % batch_size != 0)
55 |
56 |
57 | def convert_single_example(tokenizer, example, max_seq_len=config.SEQ_LEN):
58 | """
59 | Convert `text` to the Bert input format
60 | """
61 | tokens = tokenizer.tokenize(example.text_a)
62 |
63 | if len(tokens) > max_seq_len - 2:
64 | tokens = tokens[0:(max_seq_len - 2)]
65 |
66 | tokens = ['[CLS]'] + tokens + ['[SEP]']
67 |
68 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
69 | input_type_ids = [0] * len(input_ids)
70 | input_mask = [1] * len(input_ids)
71 |
72 | input_ids = pad(input_ids, max_seq_len, 0)
73 | input_type_ids = pad(input_type_ids, max_seq_len, 0)
74 | input_mask = pad(input_mask, max_seq_len, 0)
75 |
76 | return tokens, input_ids, input_mask, input_type_ids
77 |
78 |
79 | def convert_examples_to_features(tokenizer, examples, max_seq_len=config.SEQ_LEN):
80 | """
81 | Convert raw features to Bert specific representation
82 | """
83 | converter = partial(convert_single_example, tokenizer=tokenizer, max_seq_len=max_seq_len)
84 | examples = [converter(example=example) for example in examples]
85 | return examples
86 |
87 |
88 | def create_tokenizer_from_hub_module(bert_hub_url):
89 | """
90 | Get the vocab file and casing info from the Hub module.
91 | """
92 | bert_module = hub.Module(bert_hub_url)
93 | tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
94 |
95 | with tf.Session() as sess:
96 | vocab_file, do_lower_case = sess.run([
97 | tokenization_info["vocab_file"],
98 | tokenization_info["do_lower_case"]
99 | ])
100 |
101 | return FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
102 |
103 |
104 | def abstract2sents(abstract):
105 | """
106 | # Use the and tags in abstract to get a list of sentences.
107 | """
108 | sentences_pattern = re.compile(r"(.+?)<\/s>")
109 | sentences = sentences_pattern.findall(abstract)
110 | return sentences
111 |
112 |
113 | def _load_single(file):
114 | """
115 | Opens and prepares a single chunked file
116 | """
117 | article_texts = []
118 | abstract_texts = []
119 |
120 | with open(file, 'rb') as f:
121 |
122 | while True:
123 |
124 | len_bytes = f.read(8)
125 |
126 | if not len_bytes:
127 | break
128 |
129 | str_len = struct.unpack('q', len_bytes)[0]
130 | str_bytes = struct.unpack('%ds' % str_len, f.read(str_len))[0]
131 | example = example_pb2.Example.FromString(str_bytes)
132 |
133 | article_text = example.features.feature['article'].bytes_list.value[0].decode('unicode_escape').strip()
134 | abstract_text = example.features.feature['abstract'].bytes_list.value[0].decode('unicode_escape').strip()
135 | abstract_text = ' '.join([sent.strip() for sent in abstract2sents(abstract_text)])
136 |
137 | article_texts.append(article_text)
138 | abstract_texts.append(abstract_text)
139 |
140 | return article_texts, abstract_texts
141 |
142 |
143 | def load_data(files):
144 | """
145 | Reads binary data and returns chuncks of [(articles, summaries)]
146 | """
147 | logger.info(f"'{len(files)}' files found")
148 | data = [_load_single(file) for file in files]
149 | articles = sum([a for a, _ in data], [])
150 | summaries = sum([b for _, b in data], [])
151 | return articles, summaries
152 |
153 |
154 | def convert_single_example(example, tokenizer, max_seq_len=MAX_SEQ_LEN):
155 | """
156 | Convert `text` to the Bert input format
157 | """
158 | tokens = tokenizer.tokenize(example)
159 |
160 | if len(tokens) > max_seq_len - 2:
161 | tokens = tokens[0:(max_seq_len - 2)]
162 |
163 | tokens = ['[CLS]'] + tokens + ['[SEP]']
164 |
165 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
166 | input_type_ids = [0] * len(input_ids)
167 | input_mask = [1] * len(input_ids)
168 |
169 | input_ids = pad(input_ids, max_seq_len, 0)
170 | input_type_ids = pad(input_type_ids, max_seq_len, 0)
171 | input_mask = pad(input_mask, max_seq_len, 0)
172 |
173 | return tokens, input_ids, input_mask, input_type_ids
174 |
175 |
176 | def generator_fn(sents1, sents2, tokenizer):
177 | """
178 | Generates training / evaluation data
179 | raw_data: list of (abstracts, raw_data)
180 | tokenizer: tokenizer to separate the different tokens
181 | yields
182 | xs: tuple of
183 | x: list of source token ids in a sent
184 | x_seqlen: int. sequence length of x
185 | sent1: str. raw source (=input) sentence
186 | ys: tuple of
187 | decoder_input: decoder_input: list of encoded decoder inputs
188 | y: list of target token ids in a sent
189 | y_seqlen: int. sequence length of y
190 | sent2: str. target sentence
191 | """
192 | for article, summary in zip(sents1, sents2):
193 | tokens_x, input_ids_x, input_mask_x, input_type_ids_x = convert_single_example(article, tokenizer)
194 | tokens_y, input_ids_y, input_mask_y, input_type_ids_y = convert_single_example(summary, tokenizer)
195 |
196 | x_seqlen, y_seqlen = len(tokens_x), len(tokens_y)
197 |
198 | yield (input_ids_x, input_mask_x, input_type_ids_x, x_seqlen, article), (input_ids_y, input_mask_y, input_type_ids_y, y_seqlen, summary)
199 |
200 |
201 |
202 | def input_fn(sents1, sents2, tokenizer, batch_size, shuffle=False):
203 | """
204 | Batchify data
205 | raw_data [(artciles, abstracts)]
206 | batch_size: scalar
207 | shuffle: boolean
208 |
209 | Returns
210 | xs: tuple of
211 | x: int32 tensor. (N, T1)
212 | x_seqlens: int32 tensor. (N,)
213 | sents1: str tensor. (N,)
214 | ys: tuple of
215 | decoder_input: int32 tensor. (N, T2)
216 | y: int32 tensor. (N, T2)
217 | y_seqlen: int32 tensor. (N, )
218 | sents2: str tensor. (N,)
219 | """
220 | shapes = (
221 | ([None], [None], [None], (), ()),
222 | ([None], [None], [None], (), ())
223 | )
224 |
225 | types = (
226 | (tf.int32, tf.int32, tf.int32, tf.int32, tf.string),
227 | (tf.int32, tf.int32, tf.int32, tf.int32, tf.string)
228 | )
229 |
230 | paddings = (
231 | (0, 0, 0, 0, ''),
232 | (0, 0, 0, 0, '')
233 | )
234 |
235 | dataset = tf.data.Dataset.from_generator(
236 | partial(generator_fn, tokenizer=tokenizer),
237 | output_shapes=shapes,
238 | output_types=types,
239 | args=(sents1, sents2)) # <- arguments for generator_fn. converted to np string arrays
240 |
241 | if shuffle: # for training
242 | dataset = dataset.shuffle(128*batch_size)
243 |
244 | dataset = dataset.repeat() # iterate forever
245 | dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)
246 |
247 | return dataset
248 |
249 |
250 | @timeit
251 | def prepare_data(inputpath, batch_size, shuffle=False):
252 | """
253 | """
254 | files = glob.glob(inputpath + '*.bin')[:2]
255 |
256 | sents1, sents2 = load_data(files)
257 |
258 | tokenizer = create_tokenizer_from_hub_module(BERT_MODEL_URL)
259 |
260 | batches = input_fn(sents1, sents2, tokenizer, batch_size, shuffle=shuffle)
261 |
262 | num_batches = calc_num_batches(len(sents1), batch_size)
263 |
264 | return batches, num_batches, len(sents1)
265 |
266 |
267 |
--------------------------------------------------------------------------------
/ops/encoding.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | def get_angles(pos, i, d_model):
6 | angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
7 | return pos * angle_rates
8 |
9 |
10 | def positional_encoding(position, d_model):
11 | """
12 | The positional encoding vector is added to the embedding vector.
13 | Embeddings represent a token in a d-dimensional space where tokens
14 | with similar meaning will be closer to each other.
15 | But the embeddings do not encode the relative position of words in a sentence.
16 |
17 | So after adding the positional encoding, words will be closer to each other
18 | based on the similarity of their meaning and their position in the sentence,
19 | in the d-dimensional space.
20 |
21 | >>> pos_encoding = positional_encoding(50, 512)
22 | >>> print (pos_encoding.shape)
23 |
24 | >>> plt.pcolormesh(pos_encoding[0].eval(), cmap='RdBu')
25 | >>> plt.xlabel('Depth')
26 | >>> plt.xlim((0, 512))
27 | >>> plt.ylabel('Position')
28 | >>> plt.colorbar()
29 | >>> plt.show()
30 | """
31 | angle_rads = get_angles(
32 | np.arange(position)[:, np.newaxis],
33 | np.arange(d_model)[np.newaxis, :],
34 | d_model
35 | )
36 |
37 | # apply sin to even indices in the array; 2i
38 | sines = np.sin(angle_rads[:, 0::2])
39 |
40 | # apply cos to odd indices in the array; 2i+1
41 | cosines = np.cos(angle_rads[:, 1::2])
42 |
43 | pos_encoding = np.concatenate([sines, cosines], axis=-1)
44 |
45 | pos_encoding = pos_encoding[np.newaxis, ...]
46 |
47 | return tf.cast(pos_encoding, dtype=tf.float32)
48 |
49 |
--------------------------------------------------------------------------------
/ops/masking.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def create_padding_mask(seq):
5 | """
6 | Mask all the pad tokens in the batch of sequence.
7 | It ensures that the model does not treat padding as the input.
8 | The mask indicates where pad value 0 is present:
9 | it outputs a 1 at those locations, and a 0 otherwise.
10 | """
11 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
12 |
13 | # add extra dimensions so that we can add the padding
14 | # to the attention logits.
15 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len)
16 |
17 |
18 | def create_look_ahead_mask(size):
19 | """
20 | The look-ahead mask is used to mask the future tokens in a sequence.
21 | In other words, the mask indicates which entries should not be used.
22 |
23 | This means that to predict the third word, only the first and second word will be used.
24 | Similarly to predict the fourth word, only the first, second and the third word will be used and so on.
25 | """
26 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
27 | return mask # (seq_len, seq_len)
28 |
29 |
30 | def create_masks(inp, tar):
31 | # Encoder padding mask
32 | # enc_padding_mask = create_padding_mask(inp)
33 |
34 | # Used in the 2nd attention block in the decoder.
35 | # This padding mask is used to mask the encoder outputs.
36 | dec_padding_mask = create_padding_mask(inp)
37 |
38 | # Used in the 1st attention block in the decoder.
39 | # It is used to pad and mask future tokens in the input received by
40 | # the decoder.
41 | look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
42 | dec_target_padding_mask = create_padding_mask(tar)
43 | combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
44 |
45 | # return enc_padding_mask, combined_mask, dec_padding_mask
46 | return combined_mask, dec_padding_mask
47 |
48 |
49 | def mask_timestamp(x, i, mask_with):
50 | """
51 | Masks each word in the summary draft one by one with the [MASK] token
52 | At t-th time step the t-th word of input summary is
53 | masked, and the decoder predicts the refined word given other
54 | words of the summary.
55 |
56 | x :: (N, T)
57 | return :: (N, T)
58 | """
59 |
60 | N, T = tf.shape(x)[0], tf.shape(x)[1]
61 |
62 | left = x[:, :i]
63 | right = x[:, i+1:]
64 |
65 | mask = tf.ones([N, 1], dtype=x.dtype) * mask_with
66 |
67 | masked = tf.concat([left, mask, right], axis=1)
68 |
69 | return masked
70 |
71 |
72 | def tile_and_mask_diagonal(x, mask_with):
73 | """
74 | Masks each word in the summary draft one by one with the [MASK] token
75 | At t-th time step the t-th word of input summary is
76 | masked, and the decoder predicts the refined word given other
77 | words of the summary.
78 |
79 | x :: (N, T)
80 | returrn :: (N, T-1, T)
81 |
82 | We do not mask the first and last postition (corresponding to [CLS]
83 | """
84 |
85 | N, T = tf.shape(x)[0], tf.shape(x)[1]
86 |
87 | first = tf.reshape(tf.tile(x[:, 0], [T-1]), [N, T-1, 1])
88 |
89 | x = x[:, 1:]
90 | T = T - 1
91 |
92 | masked = tf.reshape(tf.tile(x, [1, T]), [N, T, T])
93 |
94 | diag = tf.ones([N, T], dtype=masked.dtype) * mask_with
95 | masked = tf.linalg.set_diag(masked, diag)
96 |
97 | masked = tf.concat([first, masked], axis=2)
98 |
99 | masked = tf.reshape(masked, [N*T, T+1])
100 |
101 | return masked
102 |
103 |
104 |
--------------------------------------------------------------------------------
/ops/metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 |
4 | from rouge import Rouge
5 |
6 |
7 | rouge = Rouge()
8 |
9 |
10 | def calculate_rouge(y, y_hat):
11 | """
12 | Calculate ROUGE scores between the target 'y' and
13 | the model prediction 'y_hat'
14 | """
15 |
16 | def f(a, b):
17 | rouges = rouge.get_scores(a.decode("utf-8") , b.decode("utf-8") )[0]
18 | r1_val, r2_val, rl_val = rouges['rouge-1']["f"], rouges['rouge-2']["f"], rouges['rouge-l']["f"]
19 | r_avg = np.mean([r1_val, r2_val, rl_val], dtype=np.float64)
20 | return r1_val, r2_val, rl_val, r_avg
21 |
22 | return tf.py_func(f, [y, y_hat], [tf.float64, tf.float64, tf.float64, tf.float64])
23 |
--------------------------------------------------------------------------------
/ops/optimization.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def noam_scheme(init_lr, global_step, warmup_steps=4000.):
5 | '''Noam scheme learning rate decay
6 | init_lr: initial learning rate. scalar.
7 | global_step: scalar.
8 | warmup_steps: scalar. During warmup_steps, learning rate increases
9 | until it reaches init_lr.
10 | '''
11 | step = tf.cast(global_step + 1, dtype=tf.float32)
12 | return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)
--------------------------------------------------------------------------------
/ops/regularization.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def label_smoothing(inputs, epsilon=0.1):
5 | '''Applies label smoothing. See 5.4 and https://arxiv.org/abs/1512.00567.
6 | inputs: 3d tensor. [N, T, V], where V is the number of vocabulary.
7 | epsilon: Smoothing rate.
8 |
9 | For example,
10 |
11 | ```
12 | import tensorflow as tf
13 | inputs = tf.convert_to_tensor([[[0, 0, 1],
14 | [0, 1, 0],
15 | [1, 0, 0]],
16 | [[1, 0, 0],
17 | [1, 0, 0],
18 | [0, 1, 0]]], tf.float32)
19 |
20 | outputs = label_smoothing(inputs)
21 |
22 | with tf.Session() as sess:
23 | print(sess.run([outputs]))
24 |
25 | >>
26 | [array([[[ 0.03333334, 0.03333334, 0.93333334],
27 | [ 0.03333334, 0.93333334, 0.03333334],
28 | [ 0.93333334, 0.03333334, 0.03333334]],
29 | [[ 0.93333334, 0.03333334, 0.03333334],
30 | [ 0.93333334, 0.03333334, 0.03333334],
31 | [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)]
32 | ```
33 | '''
34 | V = inputs.get_shape().as_list()[-1] # number of channels
35 | return ((1-epsilon) * inputs) + (epsilon / V)
36 |
37 |
--------------------------------------------------------------------------------
/ops/session.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import logging
3 |
4 | from tensorflow.keras import backend as K
5 |
6 |
7 | logging.basicConfig(level=logging.INFO)
8 |
9 |
10 | def initialize_vars(sess):
11 | sess.run(tf.local_variables_initializer())
12 | sess.run(tf.global_variables_initializer())
13 | sess.run(tf.tables_initializer())
14 | K.set_session(sess)
15 |
16 |
17 | def save_variable_specs(fpath):
18 | '''
19 | Saves information about variables such as
20 | their name, shape, and total parameter number
21 | fpath: string. output file path
22 | Writes
23 | a text file named fpath.
24 | '''
25 | def _get_size(shp):
26 | '''Gets size of tensor shape
27 | shp: TensorShape
28 | Returns
29 | size
30 | '''
31 | size = 1
32 | for d in range(len(shp)):
33 | size *=shp[d]
34 | return size
35 |
36 | params, num_params = [], 0
37 | for v in tf.global_variables():
38 | params.append("{}==={}".format(v.name, v.shape))
39 | num_params += _get_size(v.shape)
40 | print("num_params: ", num_params)
41 | with open(fpath, 'w') as fout:
42 | fout.write("num_params: {}\n".format(num_params))
43 | fout.write("\n".join(params))
44 | logging.info("Variables info has been saved.")
--------------------------------------------------------------------------------
/ops/tensor.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def with_column(x, i, column):
5 | """
6 | Given a tensor `x`, change its i-th column with `column`
7 | x :: (N, T)
8 | return :: (N, T)
9 | """
10 |
11 | N, T = tf.shape(x)[0], tf.shape(x)[1]
12 |
13 | left = x[:, :i]
14 | right = x[:, i+1:]
15 |
16 | return tf.concat([left, column, right], axis=1)
17 |
--------------------------------------------------------------------------------
/ops/tokenization.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import tensorflow_hub as hub
3 |
4 | from bert.tokenization import FullTokenizer
5 |
6 |
7 | BERT_MODEL_URL = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1"
8 |
9 |
10 | def create_bert_tokenizer(vocab_file, do_lower_case=True):
11 | """
12 | Return a BERT FullTokenizer
13 | """
14 | return FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
15 |
16 |
17 | def create_tokenizer_from_hub_module(bert_hub_url, bert_module=None):
18 | """
19 | Get the vocab file and casing info from the Hub module.
20 | """
21 | if bert_module is None:
22 | bert_module = hub.Module(bert_hub_url)
23 |
24 | tokenization_info = bert_module(signature="tokenization_info", as_dict=True)
25 |
26 | with tf.Session() as sess:
27 | vocab_file, do_lower_case = sess.run([
28 | tokenization_info["vocab_file"],
29 | tokenization_info["do_lower_case"]
30 | ])
31 |
32 | return FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case)
33 |
34 |
35 | tokenizer = create_tokenizer_from_hub_module(BERT_MODEL_URL)
36 |
37 |
38 | def convert_idx_to_token_tensor(inputs, tokenizer=tokenizer):
39 | '''
40 | Converts int32 tensor to string tensor.
41 | inputs: 1d int32 tensor. indices.
42 | tokenizer :: [int] -> str
43 | Returns
44 | 1d string tensor.
45 | '''
46 | def f(inputs):
47 | return ' '.join(tokenizer.convert_ids_to_tokens(inputs))
48 |
49 | return tf.py_func(f, [inputs], tf.string)
50 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-hub
2 | tensorflow-datasets==1.0.2
3 | tqdm==4.31.1
4 | bert-tensorflow==1.0.1
5 | bunch
6 | rouge==0.3.2
7 | pyyaml
8 |
9 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import math
4 | import shutil
5 | import logging
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 | import tensorflow as tf
10 | import tensorflow_hub as hub
11 | import tensorflow_datasets as tfds
12 |
13 | from tensorflow.keras import backend as K
14 | from tensorflow.python.keras.initializers import Constant
15 |
16 | from tqdm import tqdm
17 | from config import config
18 | from arguments import args
19 |
20 | from data.load import load_cnn_dailymail
21 |
22 | from random import randint
23 | from rouge import Rouge
24 |
25 | from ops.tokenization import tokenizer
26 | from ops.tokenization import convert_idx_to_token_tensor
27 |
28 | from ops.session import initialize_vars
29 | from ops.session import save_variable_specs
30 |
31 | from ops.metrics import calculate_rouge
32 | from ops.tensor import with_column
33 | from ops.regularization import label_smoothing
34 | from ops.optimization import noam_scheme
35 |
36 | from models.abstractive_summarizer import AbstractiveSummarization
37 | from models.abstractive_summarizer import train
38 | from models.abstractive_summarizer import eval
39 |
40 |
41 | logger = logging.getLogger()
42 | logger.setLevel(logging.INFO)
43 |
44 | tf.logging.set_verbosity(tf.logging.INFO)
45 | tf.enable_resource_variables()
46 |
47 | logging.info('Job Configuration:\n' + str(config))
48 |
49 |
50 | model = AbstractiveSummarization(
51 | num_layers=config.NUM_LAYERS,
52 | d_model=config.D_MODEL,
53 | num_heads=config.NUM_HEADS,
54 | dff=config.D_FF,
55 | vocab_size=config.VOCAB_SIZE,
56 | input_seq_len=config.INPUT_SEQ_LEN,
57 | output_seq_len=config.OUTPUT_SEQ_LEN,
58 | rate=config.DROPOUT_RATE
59 | )
60 |
61 |
62 | train_dataset, val_dataset, test_dataset, n_train_examples, n_val_examples, n_test_examples = load_cnn_dailymail()
63 |
64 | n_train_batches = n_train_examples // config.BATCH_SIZE
65 | n_val_batches = n_val_examples // config.BATCH_SIZE
66 | n_test_batches = n_test_examples // config.BATCH_SIZE
67 |
68 | logging.info(f"'{n_train_examples}' training examples, '{n_train_batches}' batches")
69 | logging.info(f"'{n_val_examples}' validation examples, '{n_val_batches}' batches")
70 | logging.info(f"'{n_test_examples}' testing examples, '{n_test_batches}' batches")
71 |
72 |
73 | train_iterator = train_dataset.make_initializable_iterator()
74 | train_stream = train_iterator.get_next()
75 |
76 | xs, ys = train_stream[:3], train_stream[3:]
77 | train_loss, zero_op, accumlation_op, train_op, global_step, train_summaries = train(model, xs, ys, gradient_accumulation=True)
78 |
79 | if args.eval:
80 | val_iterator = val_dataset.make_initializable_iterator()
81 | val_stream = val_iterator.get_next()
82 |
83 | xs, ys = val_stream[:3], val_stream[3:]
84 | y, y_hat, eval_loss, eval_summaries = eval(model, xs, ys)
85 |
86 | saver = tf.train.Saver(max_to_keep=config.NUM_EPOCHS)
87 |
88 | # config_tf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
89 | config_tf = tf.ConfigProto(allow_soft_placement=True)
90 | config_tf.gpu_options.allow_growth=True
91 |
92 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)
93 |
94 |
95 | with tf.Session(config=config_tf) as sess:
96 |
97 | if os.path.isdir(config.LOGDIR):
98 | shutil.rmtree(config.LOGDIR)
99 |
100 | os.mkdir(config.LOGDIR)
101 |
102 | ckpt = tf.train.latest_checkpoint(config.CHECKPOINTDIR)
103 |
104 | rouge = Rouge()
105 |
106 | if ckpt is None:
107 | logging.info("Initializing from scratch")
108 | sess.run(tf.global_variables_initializer())
109 | save_variable_specs(os.path.join(config.LOGDIR, "specs"))
110 | else:
111 | saver.restore(sess, ckpt)
112 |
113 | summary_writer_train = tf.summary.FileWriter(os.path.join(config.LOGDIR), sess.graph)
114 | summary_writer_eval = tf.summary.FileWriter(os.path.join(config.LOGDIR, 'eval'), sess.graph)
115 |
116 | initialize_vars(sess)
117 |
118 | _gs = sess.run(global_step)
119 |
120 | sess.run(train_iterator.initializer)
121 | if args.eval:
122 | sess.run(val_iterator.initializer)
123 |
124 | total_steps = int(config.NUM_EPOCHS * (n_train_batches / config.GRADIENT_ACCUMULATION_N_STEPS))
125 |
126 | logger.info(f"Running Training Job for '{total_steps}' steps")
127 |
128 | for i in tqdm(range(_gs, total_steps+1)):
129 |
130 | # gradient accumulation mechanism
131 | sess.run(zero_op)
132 |
133 | for i in range(config.GRADIENT_ACCUMULATION_N_STEPS):
134 | sess.run(accumlation_op)
135 |
136 | _loss, _, _gs, _summary = sess.run([train_loss, train_op, global_step, train_summaries], options=run_options)
137 |
138 | epoch = math.ceil(_gs / n_train_batches)
139 |
140 | summary_writer_train.add_summary(_summary, _gs)
141 | summary_writer_train.flush()
142 |
143 | if (_gs % n_train_batches == 0):
144 |
145 | if args.eval:
146 |
147 | logger.info(f"Epoch '{epoch}' done")
148 | logger.info(f"Current training step: '{_gs}")
149 |
150 | _y, _y_hat, _eval_summary = sess.run([y, y_hat, eval_summaries])
151 |
152 | summary_writer_eval.add_summary(_eval_summary, 0)
153 | summary_writer_eval.flush()
154 |
155 | # monitor a random sample
156 | rnd = randint(0, _y.shape[0] - 1)
157 |
158 | y_rnd = ' '.join(tokenizer.convert_ids_to_tokens(_y[rnd]))
159 | y_hat_rnd = ' '.join(tokenizer.convert_ids_to_tokens(_y_hat[rnd]))
160 |
161 | rouges = rouge.get_scores(y_rnd, y_hat_rnd)[0]
162 | r1_val, r2_val, rl_val = rouges['rouge-1']["f"], rouges['rouge-2']["f"], rouges['rouge-l']["f"]
163 |
164 | print('Target:')
165 | print(y_rnd)
166 | print('Prediction:')
167 | print(y_hat_rnd)
168 |
169 | print(f"ROUGE-1 '{r1_val}'")
170 | print(f"ROUGE-2 '{r2_val}'")
171 | print(f"ROUGE-L '{rl_val}'")
172 | print(f"ROUGE-AVG '{np.mean([r1_val, r2_val, rl_val])}'", '\n--\n')
173 |
174 | logging.info("Checkpoint: Saving Model")
175 |
176 | model_output = f"abstractive_summarization_2019_epoch_{epoch}_loss_{str(round(_loss, 4))}"
177 |
178 | ckpt_name = os.path.join(config.CHECKPOINTDIR, model_output)
179 |
180 | saver.save(sess, ckpt_name, global_step=_gs)
181 |
182 | logging.info(f"After training '{_gs}' steps, '{ckpt_name}' has been saved.")
183 |
184 | model_output = f"abstractive_summarization_2019_final"
185 | ckpt_name = os.path.join(config.CHECKPOINTDIR, model_output)
186 | saver.save(sess, ckpt_name, global_step=_gs)
187 |
188 | summary_writer_train.close()
189 | summary_writer_eval.close()
190 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/utils/__init__.py
--------------------------------------------------------------------------------
/utils/decorators.py:
--------------------------------------------------------------------------------
1 | import time
2 | from functools import wraps
3 |
4 |
5 | def timeit(f):
6 | @wraps(f)
7 | def timed(*args, **kw):
8 | ts = time.time()
9 | result = f(*args, **kw)
10 | te = time.time()
11 | print(f"'{f.__name__}' {round(te - ts, 2)} s")
12 | return result
13 | return timed
14 |
15 |
--------------------------------------------------------------------------------
/utils/recipes.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 |
3 |
4 | def flatten(ls):
5 | """
6 | Flatten one level of nesting
7 | """
8 | return chain.from_iterable(ls)
9 |
10 |
--------------------------------------------------------------------------------