├── .gitignore ├── README.md ├── arguments.py ├── config.py ├── data ├── __init__.py └── load.py ├── layers ├── bert.py ├── normalization.py └── transformer.py ├── models ├── __init__.py ├── abstractive_summarizer.py └── transformer.py ├── ops ├── __init__.py ├── attention.py ├── beam_search.py ├── data.py ├── encoding.py ├── masking.py ├── metrics.py ├── optimization.py ├── regularization.py ├── session.py ├── tensor.py └── tokenization.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── decorators.py └── recipes.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | log/ 126 | checkpoint/ 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pretraining-Based Natural Language Generation for Text Summarization 2 | 3 | Implementation of a abstractive text-summarization architecture, as proposed by this [paper](https://arxiv.org/pdf/1902.09243.pdf). 4 | 5 | The solution makes use of an pre-trained language model to get contextualized representations of words; these models were training on a huge corpus of unlabelled data, e.g. BERT. 6 | 7 | This extends the sphere of possible applications where labelled data is scarce; since the model learns to summarize given the contextualized BERT representations, we have a good chance of generalization for other domains, even if training was done using the CNN/DM dataset. 8 | 9 | `tensorflowhub` to load the BERT module. 10 | 11 | 12 | #### Environment 13 | 14 | Python: `3.6` 15 | 16 | Tensorflow version: `1.13.1` 17 | 18 | `requirements.txt` exposes the library dependencies 19 | 20 | #### Run 21 | 22 | To run the training job: 23 | 24 | ``` 25 | python train.py 26 | ``` 27 | 28 | The arguments can be changed in `config.py`. 29 | 30 | By default, tensorboard objects are written to `log`; training checkpoints are written at the end of each epoch to `checkpoints`. 31 | 32 | ``` 33 | tensorboard --logdir log/ 34 | ``` 35 | 36 | Where the train and evaluation metrics can be tracked; 37 | 38 | Note that `train.py` does not run the validation graph (validation set) in parallel with the training job. 39 | 40 | To run the validation job in parallel with the training: 41 | 42 | ``` 43 | python train.py --eval 44 | ``` 45 | 46 | takes a long time to build the inference graph and start the training; at the end of each epoch, it uses the model inference mode (autoregressive) to calculate the loss and ROUGE metrics; it also shows writes a random article/summary prediction. 47 | 48 | #### Resource Limitations 49 | 50 | Due to GPU limits, using an AWS sagemaker notebook with a `Tesla V100 (16 GB RAM)`, the batch size must be set to 2, otherwise a OOM error will be thrown. 51 | Furthermore, the target sequence length has a huge impact in the memory resources needed during training. 75 is the value used by default. 52 | 53 | To compensate this resource constraint we use gradient accumulation with `N` steps, i.e. we run the foward `N` steps with a batch size of 2 and accumulate the gradient; after the `N` steps we run the backward pass, updating the weights. 54 | 55 | This gives us an effective `2N` batch size; `N` (default=12) can be controlled by changing `GRADIENT_ACCUMULATION_N_STEPS` in `config.py`. 56 | 57 | #### Notes to the reader 58 | 59 | The core of the solution is implemented; there are however some missing pieces. 60 | 61 | Implemented: 62 | 63 | * Encoder (BERT) 64 | * Draft Decoder 65 | * Refined Decoder 66 | * Autoregressive evaluation (greedy) 67 | * Gradient accumulation 68 | 69 | Missing: 70 | 71 | * RL loss component 72 | * Beam-search mechanism for the draft-decoder 73 | * Copy mechanism 74 | 75 | Any help to implement these is appreciated. 76 | 77 | #### Configuration 78 | 79 | | **Parameter** | **Default** | **Description** | 80 | |-------------------------------|-------------|------------------------------------------------------------------| 81 | | NUM_EPOCHS | 4 | Number of epochs to train | 82 | | BATCH_SIZE | 2 | Batch size for each training step | 83 | | GRADIENT_ACCUMULATION_N_STEPS | 12 | Number of gradient accumulate steps before applying the gradient | 84 | | BUFFER_SIZE | 1000 | Buffer size for the tf.Dataset pipeline | 85 | | INITIAL_LR | 0.003 | Initial learning rate value | 86 | | WARMUP_STEPS | 4000 | | 87 | | INPUT_SEQ_LEN | 512 | Article length to truncate | 88 | | OUTPUT_SEQ_LEN | 75 | Summary length to truncate | 89 | | MAX_EXAMPLE_LEN | None | Threshold to filter examples (articles) | 90 | | VOCAB_SIZE | 30522 | Length of the vocabulary | 91 | | NUM_LAYERS | 8 | Number of layers of the Transformer Decoder | 92 | | D_MODEL | 768 | Base embedding dimensionality (as BERT) | 93 | | D_FF | 2048 | Transformer Feed Forward Layer | 94 | | NUM_HEADS | 8 | Number of heads of the transformer | 95 | | DROPOUT_RATE | 0.1 | Dropout rate to use in training | 96 | | LOGDIR | log | Location to write tensorboard objects | 97 | | CHECKPOINTDIR | checkpoint | Location to write model checkpoints | 98 | 99 | 100 | #### Data 101 | 102 | To train the model we use the [CNN/DM dataset](https://www.tensorflow.org/datasets/datasets#cnn_dailymail), directly from Tensorflow Datasets. 103 | The first time it runs, it will push the dataset from the google source (~ 500 MB). 104 | 105 | The details on how the data is pushed and prepared can be found at `data/load.py` 106 | 107 | 108 | ##### Debug 109 | 110 | Track GPU memory usage with: 111 | 112 | ``` 113 | watch -n 2 nvidia-smi 114 | ``` 115 | 116 | System RAM usage with: 117 | 118 | ``` 119 | watch -n 2 cat /proc/meminfo 120 | ``` 121 | 122 | `report_tensor_allocations_upon_oom` is set to `True` so that we can see which variables 123 | are filling up the memory. 124 | 125 | ``` 126 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True) 127 | ... 128 | sess.run(..., options=run_options) 129 | ``` 130 | 131 | 132 | 133 | 134 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument("--eval", action="store_true") 7 | 8 | args = parser.parse_args() 9 | 10 | 11 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from bunch import Bunch 2 | 3 | 4 | config = { 5 | 'NUM_EPOCHS': 4, 6 | "BATCH_SIZE": 2, 7 | "GRADIENT_ACCUMULATION_N_STEPS": 18, 8 | "BUFFER_SIZE": 1, 9 | "INITIAL_LR": 0.0003, 10 | "WARMUP_STEPS": 4000, 11 | "INPUT_SEQ_LEN": 512, 12 | "OUTPUT_SEQ_LEN": 72, 13 | "MAX_EXAMPLE_LEN": None, 14 | "VOCAB_SIZE": 30522, 15 | "NUM_LAYERS": 8, 16 | "D_MODEL": 768, 17 | "D_FF": 2048, 18 | "NUM_HEADS": 8, 19 | "DROPOUT_RATE": 0.1, 20 | "LOGDIR": 'log', 21 | "CHECKPOINTDIR": 'checkpoint2' 22 | } 23 | 24 | config = Bunch(config) 25 | 26 | 27 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/data/__init__.py -------------------------------------------------------------------------------- /data/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | import tensorflow_datasets as tfds 7 | 8 | from functools import partial 9 | from tensorflow.keras import backend as K 10 | from ops.tokenization import tokenizer 11 | 12 | from config import config 13 | 14 | 15 | # Special Tokens 16 | UNK_ID = 100 17 | CLS_ID = 101 18 | SEP_ID = 102 19 | MASK_ID = 103 20 | 21 | 22 | def pad(l, n, pad=0): 23 | """ 24 | Pad the list 'l' to have size 'n' using 'padding_element' 25 | """ 26 | pad_with = (0, max(0, n - len(l))) 27 | return np.pad(l, pad_with, mode='constant', constant_values=pad) 28 | 29 | 30 | def encode(sent_1, sent_2, tokenizer, input_seq_len, output_seq_len): 31 | """ 32 | Encode the text to the BERT expected format 33 | 34 | 'input_seq_len' is used to truncate the the article length 35 | 'output_seq_len' is used to truncate the the summary length 36 | 37 | BERT has the following special tokens: 38 | 39 | [CLS] : The first token of every sequence. A classification token 40 | which is normally used in conjunction with a softmax layer for classification 41 | tasks. For anything else, it can be safely ignored. 42 | 43 | [SEP] : A sequence delimiter token which was used at pre-training for 44 | sequence-pair tasks (i.e. Next sentence prediction). Must be used when 45 | sequence pair tasks are required. When a single sequence is used it is just appended at the end. 46 | 47 | [MASK] : Token used for masked words. Only used for pre-training. 48 | 49 | Additionally BERT requires additional inputs to work correctly: 50 | - Mask IDs 51 | - Segment IDs 52 | 53 | The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to. 54 | Sentence Embeddings is just a numeric class to distinguish between pairs of sentences. 55 | """ 56 | tokens_1 = tokenizer.tokenize(sent_1.numpy()) 57 | tokens_2 = tokenizer.tokenize(sent_2.numpy()) 58 | 59 | # Account for [CLS] and [SEP] with "- 2" 60 | if len(tokens_1) > input_seq_len - 2: 61 | tokens_1 = tokens_1[0:(input_seq_len - 2)] 62 | if len(tokens_2) > (output_seq_len + 1) - 2: 63 | tokens_2 = tokens_2[0:((output_seq_len + 1) - 2)] 64 | 65 | tokens_1 = ["[CLS]"] + tokens_1 + ["[SEP]"] 66 | tokens_2 = ["[CLS]"] + tokens_2 + ["[SEP]"] 67 | 68 | input_ids_1 = tokenizer.convert_tokens_to_ids(tokens_1) 69 | input_ids_2 = tokenizer.convert_tokens_to_ids(tokens_2) 70 | 71 | input_mask_1 = [1] * len(input_ids_1) 72 | input_mask_2 = [1] * len(input_ids_2) 73 | 74 | input_ids_1 = pad(input_ids_1, input_seq_len, 0) 75 | input_ids_2 = pad(input_ids_2, output_seq_len + 1, 0) 76 | input_mask_1 = pad(input_mask_1, input_seq_len, 0) 77 | input_mask_2 = pad(input_mask_2, output_seq_len + 1, 0) 78 | 79 | input_type_ids_1 = [0] * len(input_ids_1) 80 | input_type_ids_2 = [0] * len(input_ids_2) 81 | 82 | return input_ids_1, input_mask_1, input_type_ids_1, input_ids_2, input_mask_2, input_type_ids_2 83 | 84 | 85 | def tf_encode(tokenizer, input_seq_len, output_seq_len): 86 | """ 87 | Operations inside `.map()` run in graph mode and receive a graph 88 | tensor that do not have a `numpy` attribute. 89 | The tokenizer expects a string or Unicode symbol to encode it into integers. 90 | Hence, you need to run the encoding inside a `tf.py_function`, 91 | which receives an eager tensor having a numpy attribute that contains the string value. 92 | """ 93 | def f(s1, s2): 94 | encode_ = partial(encode, tokenizer=tokenizer, input_seq_len=input_seq_len, output_seq_len=output_seq_len) 95 | return tf.py_function(encode_, [s1, s2], [tf.int32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32]) 96 | 97 | return f 98 | 99 | 100 | def filter_max_length(x, x1, x2, y, y1, y2, max_length=config.MAX_EXAMPLE_LEN): 101 | predicate = tf.logical_and( 102 | tf.size(x[0]) <= max_length, 103 | tf.size(y[0]) <= max_length 104 | ) 105 | return predicate 106 | 107 | 108 | def pipeline(examples, tokenizer, cache=False): 109 | """ 110 | Prepare a Dataset to return the following elements 111 | x_ids, x_mask, x_segments, y_ids, y_maks, y_segments 112 | """ 113 | 114 | dataset = examples.map(tf_encode(tokenizer, config.INPUT_SEQ_LEN, config.OUTPUT_SEQ_LEN), num_parallel_calls=tf.data.experimental.AUTOTUNE) 115 | 116 | if config.MAX_EXAMPLE_LEN is not None: 117 | dataset = dataset.filter(filter_max_length) 118 | 119 | if cache: 120 | dataset = dataset.cache() 121 | 122 | dataset = dataset.shuffle(config.BUFFER_SIZE).padded_batch(config.BATCH_SIZE, padded_shapes=([-1], [-1], [-1], [-1], [-1], [-1])) 123 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 124 | 125 | return dataset 126 | 127 | 128 | def load_cnn_dailymail(tokenizer=tokenizer): 129 | """ 130 | Load the CNN/DM data from tensorflow Datasets 131 | """ 132 | examples, metadata = tfds.load('cnn_dailymail', with_info=True, as_supervised=True) 133 | train, val, test = examples['train'], examples['validation'], examples['test'] 134 | train_dataset = pipeline(train, tokenizer) 135 | val_dataset = pipeline(val, tokenizer) 136 | test_dataset = pipeline(test, tokenizer) 137 | 138 | # grab information regarding the number of examples 139 | metadata = json.loads(metadata.as_json) 140 | 141 | n_test_examples = int(metadata['splits'][0]['statistics']['numExamples']) 142 | n_train_examples = int(metadata['splits'][1]['statistics']['numExamples']) 143 | n_val_examples = int(metadata['splits'][2]['statistics']['numExamples']) 144 | 145 | return train_dataset, val_dataset, test_dataset, n_train_examples, n_val_examples, n_test_examples 146 | 147 | -------------------------------------------------------------------------------- /layers/bert.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_hub as hub 3 | 4 | from tensorflow.keras import backend as K 5 | 6 | 7 | BERT_MODEL_URL = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1" 8 | 9 | 10 | 11 | class BertLayer(tf.keras.layers.Layer): 12 | """ 13 | Custom Keras layer, integrating BERT from tf-hub 14 | """ 15 | def __init__(self, url=BERT_MODEL_URL, d_embedding=768, n_fine_tune_layers=0, **kwargs): 16 | self.url = url 17 | self.n_fine_tune_layers = n_fine_tune_layers 18 | self.d_embedding = d_embedding 19 | 20 | super(BertLayer, self).__init__(**kwargs) 21 | 22 | def build(self, input_shape): 23 | 24 | self.bert = hub.Module( 25 | self.url, 26 | trainable=False, 27 | name="{}_bert_module".format(self.name) 28 | ) 29 | 30 | trainable_vars = self.bert.variables 31 | 32 | # Remove unused layers 33 | trainable_vars = [var for var in trainable_vars if not "/cls/" in var.name] 34 | 35 | # Select how many layers to fine tune 36 | trainable_vars = [] 37 | 38 | # Add to trainable weights 39 | for var in trainable_vars: 40 | self._trainable_weights.append(var) 41 | 42 | for var in self.bert.variables: 43 | if var not in self._trainable_weights: 44 | self._non_trainable_weights.append(var) 45 | 46 | super(BertLayer, self).build(input_shape) 47 | 48 | def call(self, inputs): 49 | inputs = [K.cast(x, dtype="int32") for x in inputs] 50 | 51 | input_ids, input_mask, segment_ids = inputs 52 | 53 | bert_inputs = dict(input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids) 54 | result = self.bert(inputs=bert_inputs, signature="tokens", as_dict=True)["sequence_output"] 55 | return result 56 | 57 | def compute_output_shape(self, input_shape): 58 | return (input_shape[0], input_shape[1], self.d_embedding) 59 | 60 | -------------------------------------------------------------------------------- /layers/normalization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | 7 | from tensorflow.keras import backend as K 8 | 9 | from tensorflow.python.keras import constraints 10 | from tensorflow.python.keras import initializers 11 | from tensorflow.python.keras import regularizers 12 | from tensorflow.python.ops import nn 13 | 14 | 15 | class LayerNormalization(tf.keras.layers.Layer): 16 | """ 17 | Layer normalization layer (Ba et al., 2016). 18 | Normalize the activations of the previous layer for each given example in a 19 | batch independently, rather than across a batch like Batch Normalization. 20 | i.e. applies a transformation that maintains the mean activation within each 21 | example close to 0 and the activation standard deviation close to 1. 22 | Arguments: 23 | axis: Integer or List/Tuple. The axis that should be normalized 24 | (typically the features axis). 25 | epsilon: Small float added to variance to avoid dividing by zero. 26 | center: If True, add offset of `beta` to normalized tensor. 27 | If False, `beta` is ignored. 28 | scale: If True, multiply by `gamma`. 29 | If False, `gamma` is not used. 30 | When the next layer is linear (also e.g. `nn.relu`), 31 | this can be disabled since the scaling 32 | will be done by the next layer. 33 | beta_initializer: Initializer for the beta weight. 34 | gamma_initializer: Initializer for the gamma weight. 35 | beta_regularizer: Optional regularizer for the beta weight. 36 | gamma_regularizer: Optional regularizer for the gamma weight. 37 | beta_constraint: Optional constraint for the beta weight. 38 | gamma_constraint: Optional constraint for the gamma weight. 39 | trainable: Boolean, if `True` the variables will be marked as trainable. 40 | Input shape: 41 | Arbitrary. Use the keyword argument `input_shape` 42 | (tuple of integers, does not include the samples axis) 43 | when using this layer as the first layer in a model. 44 | Output shape: 45 | Same shape as input. 46 | References: 47 | - [Layer Normalization](https://arxiv.org/abs/1607.06450) 48 | """ 49 | 50 | def __init__(self, 51 | axis=-1, 52 | epsilon=1e-3, 53 | center=True, 54 | scale=True, 55 | beta_initializer='zeros', 56 | gamma_initializer='ones', 57 | beta_regularizer=None, 58 | gamma_regularizer=None, 59 | beta_constraint=None, 60 | gamma_constraint=None, 61 | trainable=True, 62 | name=None, 63 | **kwargs): 64 | super(LayerNormalization, self).__init__( 65 | name=name, trainable=trainable, **kwargs) 66 | if isinstance(axis, (list, tuple)): 67 | self.axis = axis[:] 68 | elif isinstance(axis, int): 69 | self.axis = axis 70 | else: 71 | raise ValueError('Expected an int or a list/tuple of ints for the argument \'axis\', but received instead: %s' % axis) 72 | 73 | self.epsilon = epsilon 74 | self.center = center 75 | self.scale = scale 76 | self.beta_initializer = initializers.get(beta_initializer) 77 | self.gamma_initializer = initializers.get(gamma_initializer) 78 | self.beta_regularizer = regularizers.get(beta_regularizer) 79 | self.gamma_regularizer = regularizers.get(gamma_regularizer) 80 | self.beta_constraint = constraints.get(beta_constraint) 81 | self.gamma_constraint = constraints.get(gamma_constraint) 82 | 83 | self.supports_masking = True 84 | 85 | def build(self, input_shape): 86 | ndims = len(input_shape) 87 | if ndims is None: 88 | raise ValueError('Input shape %s has undefined rank.' % input_shape) 89 | 90 | # Convert axis to list and resolve negatives 91 | if isinstance(self.axis, int): 92 | self.axis = [self.axis] 93 | for idx, x in enumerate(self.axis): 94 | if x < 0: 95 | self.axis[idx] = ndims + x 96 | 97 | # Validate axes 98 | for x in self.axis: 99 | if x < 0 or x >= ndims: 100 | raise ValueError('Invalid axis: %d' % x) 101 | if len(self.axis) != len(set(self.axis)): 102 | raise ValueError('Duplicate axis: {}'.format(tuple(self.axis))) 103 | 104 | param_shape = [input_shape[dim] for dim in self.axis] 105 | if self.scale: 106 | self.gamma = self.add_weight( 107 | name='gamma', 108 | shape=param_shape, 109 | initializer=self.gamma_initializer, 110 | regularizer=self.gamma_regularizer, 111 | constraint=self.gamma_constraint, 112 | trainable=True) 113 | else: 114 | self.gamma = None 115 | 116 | if self.center: 117 | self.beta = self.add_weight( 118 | name='beta', 119 | shape=param_shape, 120 | initializer=self.beta_initializer, 121 | regularizer=self.beta_regularizer, 122 | constraint=self.beta_constraint, 123 | trainable=True) 124 | else: 125 | self.beta = None 126 | 127 | def call(self, inputs): 128 | # Compute the axes along which to reduce the mean / variance 129 | input_shape = inputs.shape 130 | ndims = len(input_shape) 131 | 132 | # Calculate the moments on the last axis (layer activations). 133 | mean, variance = nn.moments(inputs, self.axis, keep_dims=True) 134 | 135 | # Broadcasting only necessary for norm where the axis is not just 136 | # the last dimension 137 | broadcast_shape = [1] * ndims 138 | for dim in self.axis: 139 | broadcast_shape[dim] = input_shape.dims[dim].value 140 | def _broadcast(v): 141 | if (v is not None and len(v.shape) != ndims and self.axis != [ndims - 1]): 142 | return array_ops.reshape(v, broadcast_shape) 143 | return v 144 | scale, offset = _broadcast(self.gamma), _broadcast(self.beta) 145 | 146 | # Compute layer normalization using the batch_normalization function. 147 | outputs = nn.batch_normalization( 148 | inputs, 149 | mean, 150 | variance, 151 | offset=offset, 152 | scale=scale, 153 | variance_epsilon=self.epsilon) 154 | 155 | # If some components of the shape got lost due to adjustments, fix that. 156 | outputs.set_shape(input_shape) 157 | 158 | return outputs 159 | 160 | def compute_output_shape(self, input_shape): 161 | return input_shape 162 | 163 | def get_config(self): 164 | config = { 165 | 'axis': self.axis, 166 | 'epsilon': self.epsilon, 167 | 'center': self.center, 168 | 'scale': self.scale, 169 | 'beta_initializer': initializers.serialize(self.beta_initializer), 170 | 'gamma_initializer': initializers.serialize(self.gamma_initializer), 171 | 'beta_regularizer': regularizers.serialize(self.beta_regularizer), 172 | 'gamma_regularizer': regularizers.serialize(self.gamma_regularizer), 173 | 'beta_constraint': constraints.serialize(self.beta_constraint), 174 | 'gamma_constraint': constraints.serialize(self.gamma_constraint) 175 | } 176 | base_config = super(LayerNormalization, self).get_config() 177 | return dict(list(base_config.items()) + list(config.items())) 178 | 179 | -------------------------------------------------------------------------------- /layers/transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | 7 | from tensorflow.keras import backend as K 8 | from tensorflow.python.keras.initializers import Constant 9 | 10 | from layers.normalization import LayerNormalization 11 | from ops.encoding import positional_encoding 12 | from ops.attention import scaled_dot_product_attention 13 | 14 | 15 | class MultiHeadAttention(tf.keras.layers.Layer): 16 | """ 17 | Multi-head attention consists of four parts: 18 | * Linear layers and split into heads. 19 | * Scaled dot-product attention. 20 | * Concatenation of heads. 21 | * Final linear layer. 22 | 23 | Each multi-head attention block gets three inputs; 24 | Q (query), K (key), V (value). 25 | These are put through linear (Dense) layers and split up into multiple heads. 26 | 27 | Instead of one single attention head, Q, K, and V 28 | are split into multiple heads because it allows the 29 | model to jointly attend to information at different 30 | positions from different representational spaces. 31 | 32 | After the split each head has a reduced dimensionality, 33 | so the total computation cost is the same as a single head 34 | attention with full dimensionality. 35 | """ 36 | def __init__(self, d_model, num_heads): 37 | super(MultiHeadAttention, self).__init__() 38 | self.num_heads = num_heads 39 | self.d_model = d_model 40 | 41 | assert d_model % self.num_heads == 0 42 | 43 | self.depth = d_model // self.num_heads 44 | 45 | def build(self, input_shape): 46 | 47 | self.wq = tf.keras.layers.Dense(self.d_model) 48 | self.wk = tf.keras.layers.Dense(self.d_model) 49 | self.wv = tf.keras.layers.Dense(self.d_model) 50 | 51 | self.dense = tf.keras.layers.Dense(self.d_model) 52 | 53 | for i in [self.wq, self.wk, self.wv, self.dense]: 54 | i.build(input_shape) 55 | for weight in i.trainable_weights: 56 | if weight not in self._trainable_weights: 57 | self._trainable_weights.append(weight) 58 | for weight in i.non_trainable_weights: 59 | if weight not in self._non_trainable_weights: 60 | self._non_trainable_weights.append(weight) 61 | 62 | super(MultiHeadAttention, self).build(input_shape) 63 | 64 | def split_heads(self, x, batch_size): 65 | """Split the last dimension into (num_heads, depth). 66 | Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) 67 | """ 68 | x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) 69 | return tf.transpose(x, perm=[0, 2, 1, 3]) 70 | 71 | def call(self, v, k, q, mask): 72 | batch_size = tf.shape(q)[0] 73 | 74 | q = self.wq(q) # (batch_size, seq_len, d_model) 75 | k = self.wk(k) # (batch_size, seq_len, d_model) 76 | v = self.wv(v) # (batch_size, seq_len, d_model) 77 | 78 | q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) 79 | k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) 80 | v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) 81 | 82 | # scaled_attention.shape == (batch_size, num_heads, seq_len_v, depth) 83 | # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) 84 | scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask) 85 | 86 | scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch_size, seq_len_v, num_heads, depth) 87 | 88 | concat_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model)) # (batch_size, seq_len_v, d_model) 89 | 90 | output = self.dense(concat_attention) # (batch_size, seq_len_v, d_model) 91 | 92 | return output, attention_weights 93 | 94 | 95 | def point_wise_feed_forward_network(d_model, dff): 96 | """ 97 | Point wise feed forward network consists of two 98 | fully-connected layers with a ReLU activation in between. 99 | """ 100 | return tf.keras.Sequential([ 101 | tf.keras.layers.Dense(dff, activation='relu'), # (batch_size, seq_len, dff) 102 | tf.keras.layers.Dense(d_model) # (batch_size, seq_len, d_model) 103 | ]) 104 | 105 | 106 | 107 | class EncoderLayer(tf.keras.layers.Layer): 108 | """ 109 | Each encoder layer consists of sublayers: 110 | 111 | * Multi-head attention (with padding mask) 112 | * Point wise feed forward networks. 113 | 114 | Each of these sublayers has a residual connection around it followed by a layer normalization. 115 | Residual connections help in avoiding the vanishing gradient problem in deep networks. 116 | 117 | The output of each sublayer is `LayerNorm(x + Sublayer(x))`. 118 | The normalization is done on the d_model (last) axis. There are N encoder layers in the transformer. 119 | """ 120 | def __init__(self, d_model, num_heads, dff, rate=0.1): 121 | self.d_model = d_model 122 | self.num_heads = num_heads 123 | self.dff = dff 124 | self.rate = rate 125 | super(EncoderLayer, self).__init__() 126 | 127 | def build(self, input_shape): 128 | 129 | self.mha = MultiHeadAttention(self.d_model, self.num_heads) 130 | self.ffn = point_wise_feed_forward_network(self.d_model, self.dff) 131 | 132 | self.layernorm1 = LayerNormalization(epsilon=1e-6) 133 | self.layernorm2 = LayerNormalization(epsilon=1e-6) 134 | 135 | self.dropout1 = tf.keras.layers.Dropout(self.rate) 136 | self.dropout2 = tf.keras.layers.Dropout(self.rate) 137 | 138 | for i in [self.mha, self.ffn, self.layernorm1, self.layernorm2, self.dropout1, self.dropout2]: 139 | i.build(input_shape) 140 | for weight in i.trainable_weights: 141 | if weight not in self._trainable_weights: 142 | self._trainable_weights.append(weight) 143 | for weight in i.non_trainable_weights: 144 | if weight not in self._non_trainable_weights: 145 | self._non_trainable_weights.append(weight) 146 | 147 | super(EncoderLayer, self).build(input_shape) 148 | 149 | def call(self, x, training, mask): 150 | 151 | attn_output, _ = self.mha(x, x, x, mask) # (batch_size, input_seq_len, d_model) 152 | attn_output = self.dropout1(attn_output, training=training) 153 | out1 = self.layernorm1(x + attn_output) # (batch_size, input_seq_len, d_model) 154 | 155 | ffn_output = self.ffn(out1) # (batch_size, input_seq_len, d_model) 156 | ffn_output = self.dropout2(ffn_output, training=training) 157 | out2 = self.layernorm2(out1 + ffn_output) # (batch_size, input_seq_len, d_model) 158 | 159 | return out2 160 | 161 | 162 | class DecoderLayer(tf.keras.layers.Layer): 163 | """ 164 | Each decoder layer consists of sublayers: 165 | 166 | * Masked multi-head attention (with look ahead mask and padding mask) 167 | * Multi-head attention (with padding mask). V (value) and K (key) 168 | receive the encoder output as inputs. Q (query) receives the output from the masked multi-head attention sublayer. 169 | * Point wise feed forward networks 170 | 171 | Each of these sublayers has a residual connection around it followed by a layer normalization. 172 | The output of each sublayer is LayerNorm(x + Sublayer(x)). The normalization is done on the d_model (last) axis. 173 | 174 | As Q receives the output from decoder's first attention block, 175 | and K receives the encoder output, the attention weights represent 176 | the importance given to the decoder's input based on the encoder's output. 177 | In other words, the decoder predicts the next word by looking at the encoder 178 | output and self-attending to its own output. 179 | """ 180 | def __init__(self, d_model, num_heads, dff, rate=0.1): 181 | self.d_model = d_model 182 | self.num_heads = num_heads 183 | self.dff = dff 184 | self.rate = rate 185 | super(DecoderLayer, self).__init__() 186 | 187 | def build(self, input_shape): 188 | 189 | self.mha1 = MultiHeadAttention(self.d_model, self.num_heads) 190 | self.mha2 = MultiHeadAttention(self.d_model, self.num_heads) 191 | 192 | self.ffn = point_wise_feed_forward_network(self.d_model, self.dff) 193 | 194 | self.layernorm1 = LayerNormalization(epsilon=1e-6) 195 | self.layernorm2 = LayerNormalization(epsilon=1e-6) 196 | self.layernorm3 = LayerNormalization(epsilon=1e-6) 197 | 198 | self.dropout1 = tf.keras.layers.Dropout(self.rate) 199 | self.dropout2 = tf.keras.layers.Dropout(self.rate) 200 | self.dropout3 = tf.keras.layers.Dropout(self.rate) 201 | 202 | for i in [self.mha1, self.mha2, self.ffn, self.layernorm1, self.layernorm2, self.layernorm3, self.dropout1, self.dropout2, self.dropout3]: 203 | i.build(input_shape) 204 | for weight in i.trainable_weights: 205 | if weight not in self._trainable_weights: 206 | self._trainable_weights.append(weight) 207 | for weight in i.non_trainable_weights: 208 | if weight not in self._non_trainable_weights: 209 | self._non_trainable_weights.append(weight) 210 | 211 | super(DecoderLayer, self).build(input_shape) 212 | 213 | 214 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask): 215 | # enc_output.shape == (batch_size, input_seq_len, d_model) 216 | 217 | attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask) # (batch_size, target_seq_len, d_model) 218 | attn1 = self.dropout1(attn1, training=training) 219 | out1 = self.layernorm1(attn1 + x) 220 | 221 | attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask) # (batch_size, target_seq_len, d_model) 222 | attn2 = self.dropout2(attn2, training=training) 223 | out2 = self.layernorm2(attn2 + out1) # (batch_size, target_seq_len, d_model) 224 | 225 | ffn_output = self.ffn(out2) # (batch_size, target_seq_len, d_model) 226 | ffn_output = self.dropout3(ffn_output, training=training) 227 | out3 = self.layernorm3(ffn_output + out2) # (batch_size, target_seq_len, d_model) 228 | 229 | return out3, attn_weights_block1, attn_weights_block2 230 | 231 | 232 | class Encoder(tf.keras.layers.Layer): 233 | """ 234 | The Encoder consists of: 235 | 236 | 1. Input Embedding 237 | 2. Positional Encoding 238 | 3. N encoder layers 239 | 240 | The input is put through an embedding which is summed with the positional encoding. 241 | The output of this summation is the input to the encoder layers. 242 | The output of the encoder is the input to the decoder. 243 | """ 244 | def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, rate=0.1): 245 | super(Encoder, self).__init__() 246 | 247 | self.d_model = d_model 248 | self.num_layers = num_layers 249 | self.num_heads = num_heads 250 | self.dff = dff 251 | self.input_vocab_size = input_vocab_size 252 | self.rate = rate 253 | 254 | def build(self, input_shape): 255 | 256 | self.embedding = tf.keras.layers.Embedding(self.input_vocab_size, self.d_model) 257 | self.pos_encoding = positional_encoding(self.input_vocab_size, self.d_model) 258 | 259 | 260 | self.enc_layers = [ 261 | EncoderLayer(self.d_model, self.num_heads, self.dff, self.rate) 262 | for _ in range(self.num_layers) 263 | ] 264 | 265 | self.dropout = tf.keras.layers.Dropout(self.rate) 266 | 267 | for i in [self.embedding] + self.enc_layers + [self.dropout]: 268 | i.build(input_shape) 269 | for weight in i.trainable_weights: 270 | if weight not in self._trainable_weights: 271 | self._trainable_weights.append(weight) 272 | for weight in i.non_trainable_weights: 273 | if weight not in self._non_trainable_weights: 274 | self._non_trainable_weights.append(weight) 275 | 276 | super(Encoder, self).build(input_shape) 277 | 278 | def call(self, x, training, mask): 279 | 280 | seq_len = tf.shape(x)[1] 281 | 282 | # adding embedding and position encoding. 283 | x = self.embedding(x) # (batch_size, input_seq_len, d_model) 284 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) 285 | x += self.pos_encoding[:, :seq_len, :] 286 | 287 | x = self.dropout(x, training=training) 288 | 289 | for i in range(self.num_layers): 290 | x = self.enc_layers[i](x, training, mask) 291 | 292 | return x # (batch_size, input_seq_len, d_model) 293 | 294 | 295 | class Decoder(tf.keras.layers.Layer): 296 | """ 297 | The Decoder consists of: 298 | 1. Output Embedding 299 | 2. Positional Encoding 300 | 3. N decoder layers 301 | 302 | The target is put through an embedding which is summed with the positional encoding. 303 | The output of this summation is the input to the decoder layers. 304 | The output of the decoder is the input to the final linear layer. 305 | 306 | If `pretrained_embeddings` are available, we use it as a word embedding matrix and do not perform further training 307 | """ 308 | def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, rate=0.1): 309 | super(Decoder, self).__init__() 310 | 311 | self.d_model = d_model 312 | self.num_layers = num_layers 313 | self.num_heads = num_heads 314 | self.dff = dff 315 | self.target_vocab_size = target_vocab_size 316 | self.rate = rate 317 | 318 | # if pretrained_embeddings is None: 319 | # self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model) 320 | # else: 321 | # self.embedding = tf.keras.layers.Embedding( 322 | # target_vocab_size, d_model, trainable=False, 323 | # embeddings_initializer=Constant(pretrained_embeddings) 324 | # ) 325 | 326 | def build(self, input_shape): 327 | 328 | self.pos_encoding = positional_encoding(self.target_vocab_size, self.d_model) 329 | 330 | self.dec_layers = [ 331 | DecoderLayer(self.d_model, self.num_heads, self.dff, self.rate) 332 | for _ in range(self.num_layers) 333 | ] 334 | self.dropout = tf.keras.layers.Dropout(self.rate) 335 | 336 | for i in self.dec_layers + [self.dropout]: 337 | i.build(input_shape) 338 | for weight in i.trainable_weights: 339 | if weight not in self._trainable_weights: 340 | self._trainable_weights.append(weight) 341 | for weight in i.non_trainable_weights: 342 | if weight not in self._non_trainable_weights: 343 | self._non_trainable_weights.append(weight) 344 | 345 | super(Decoder, self).build(input_shape) 346 | 347 | def call(self, x, enc_output, training, look_ahead_mask, padding_mask): 348 | 349 | seq_len = tf.shape(x)[1] 350 | attention_weights = {} 351 | 352 | # if not input_alreay_embedded: 353 | # x = self.embedding(x) # (batch_size, target_seq_len, d_model) 354 | 355 | x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) 356 | x += self.pos_encoding[:, :seq_len, :] 357 | 358 | x = self.dropout(x, training=training) 359 | 360 | for i in range(self.num_layers): 361 | 362 | # dv = f"/device:GPU:{str(next(selector))}" 363 | # print(f"With device ) 364 | # with tf.device(): 365 | 366 | x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask) 367 | 368 | attention_weights['decoder_layer{}_block1'.format(i+1)] = block1 369 | attention_weights['decoder_layer{}_block2'.format(i+1)] = block2 370 | 371 | # x.shape == (batch_size, target_seq_len, d_model) 372 | return x, attention_weights 373 | 374 | 375 | 376 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/models/__init__.py -------------------------------------------------------------------------------- /models/abstractive_summarizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import tensorflow as tf 4 | import tensorflow_hub as hub 5 | 6 | from tqdm import tqdm 7 | from random import randint 8 | from tensorflow.keras import backend as K 9 | from tensorflow.python.keras.initializers import Constant 10 | 11 | from layers.transformer import Encoder 12 | from layers.transformer import Decoder 13 | from layers.bert import BertLayer, BERT_MODEL_URL 14 | from layers.transformer import Decoder 15 | 16 | from ops.masking import create_masks 17 | from ops.masking import create_look_ahead_mask 18 | from ops.masking import create_padding_mask 19 | from ops.masking import mask_timestamp 20 | from ops.masking import tile_and_mask_diagonal 21 | 22 | from ops.session import initialize_vars 23 | from ops.metrics import calculate_rouge 24 | from ops.tensor import with_column 25 | from ops.regularization import label_smoothing 26 | from ops.optimization import noam_scheme 27 | from ops.tokenization import tokenizer 28 | from ops.tokenization import convert_idx_to_token_tensor 29 | 30 | from data.load import UNK_ID 31 | from data.load import CLS_ID 32 | from data.load import SEP_ID 33 | from data.load import MASK_ID 34 | 35 | from config import config 36 | 37 | 38 | logger = logging.getLogger() 39 | logger.setLevel(logging.INFO) 40 | 41 | 42 | warmup_steps = config.WARMUP_STEPS 43 | initial_lr = config.INITIAL_LR 44 | 45 | 46 | 47 | def _embedding_from_bert(): 48 | """ 49 | Extract the preratined word embeddings from a BERT model 50 | Returns a numpy matrix with the embeddings 51 | """ 52 | logger.info("Extracting pretrained word embeddings weights from BERT") 53 | 54 | with tf.device("/device:CPU:0"): 55 | bert = hub.Module(BERT_MODEL_URL, trainable=False, name="embeddings_from_bert_module") 56 | 57 | with tf.Session() as sess: 58 | initialize_vars(sess) 59 | embedding_matrix = sess.run(bert.variable_map['bert/embeddings/word_embeddings']) 60 | 61 | tf.reset_default_graph() 62 | 63 | logger.info(f"Embedding matrix shape '{embedding_matrix.shape}'") 64 | return embedding_matrix 65 | 66 | 67 | class AbstractiveSummarization(tf.keras.Model): 68 | """ 69 | Pretraining-Based Natural Language Generation for Text Summarization 70 | https://arxiv.org/pdf/1902.09243.pdf 71 | """ 72 | def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, input_seq_len, output_seq_len, rate=0.1): 73 | super(AbstractiveSummarization, self).__init__() 74 | 75 | self.input_seq_len = input_seq_len 76 | self.output_seq_len = output_seq_len 77 | 78 | self.vocab_size = vocab_size 79 | 80 | self.bert = BertLayer(d_embedding=d_model, trainable=False) 81 | 82 | embedding_matrix = _embedding_from_bert() 83 | 84 | self.embedding = tf.keras.layers.Embedding( 85 | vocab_size, d_model, trainable=False, 86 | embeddings_initializer=Constant(embedding_matrix) 87 | ) 88 | 89 | self.decoder = Decoder(num_layers, d_model, num_heads, dff, vocab_size, rate) 90 | 91 | self.final_layer = tf.keras.layers.Dense(vocab_size) 92 | 93 | def encode(self, ids, mask, segment_ids): 94 | # (batch_size, seq_len, d_bert) 95 | return self.bert((ids, mask, segment_ids)) 96 | 97 | def draft_summary(self, enc_output, look_ahead_mask, padding_mask, target_ids=None, training=True): 98 | 99 | logging.info("Building:'Draft summary'") 100 | 101 | # (batch_size, seq_len) 102 | dec_input = target_ids 103 | 104 | # (batch_size, seq_len, d_bert) 105 | embeddings = self.embedding(target_ids) 106 | 107 | # (batch_size, seq_len, d_bert), (_) 108 | dec_output, attention_dist = self.decoder(embeddings, enc_output, training, look_ahead_mask, padding_mask) 109 | 110 | # (batch_size, seq_len, vocab_len) 111 | logits = self.final_layer(dec_output) 112 | 113 | # (batch_size, seq_len) 114 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 115 | 116 | return logits, preds, attention_dist 117 | 118 | def draft_summary_greedy(self, enc_output, look_ahead_mask, padding_mask, training=False): 119 | """ 120 | Inference call, builds a draft summary autoregressively 121 | """ 122 | 123 | logging.info("Building: 'Greedy Draft Summary'") 124 | 125 | N = tf.shape(enc_output)[0] 126 | T = tf.shape(enc_output)[1] 127 | 128 | # (batch_size, 1) 129 | dec_input = tf.ones([N, 1], dtype=tf.int32) * CLS_ID 130 | 131 | summary, dec_outputs, dec_logits, attention_dists = [], [], [], [] 132 | 133 | summary += [dec_input] 134 | dec_logits += [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1])] 135 | 136 | for i in tqdm(range(0, self.output_seq_len - 1)): 137 | 138 | # (batch_size, i+1, d_bert) 139 | embeddings = self.embedding(dec_input) 140 | 141 | # (batch_size, i+1, d_bert), (_) 142 | dec_output, attention_dist = self.decoder(embeddings, enc_output, training, look_ahead_mask, padding_mask) 143 | 144 | # (batch_size, 1, d_bert) 145 | dec_output_i = dec_output[:, -1: ,:] 146 | 147 | # (batch_size, 1, d_bert) 148 | dec_outputs += [dec_output_i] 149 | attention_dists += [{k: v[:, -1:, :] for k, v in attention_dist.items()}] 150 | 151 | # (batch_size, 1, vocab_len) 152 | logits = self.final_layer(dec_output_i) 153 | 154 | dec_logits += [logits] 155 | 156 | # (batch_size, 1) 157 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 158 | 159 | summary += [preds] 160 | 161 | # (batch_size, i+2) 162 | dec_input = with_column(dec_input, i+1, preds) 163 | 164 | # (batch_size, seq_len, vocab_len) 165 | dec_logits = tf.concat(dec_logits, axis=1) 166 | 167 | # (batch_size, seq_len) 168 | summary = tf.concat(summary, axis=1) 169 | 170 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 171 | return dec_logits, summary, attention_dists 172 | 173 | def refined_summary_iter(self, enc_output, target, padding_mask, training=True): 174 | """ 175 | Iterative version of refined summary using teacher forcing 176 | """ 177 | 178 | logging.info("Building: 'Refined Summary'") 179 | 180 | N = tf.shape(enc_output)[0] 181 | 182 | dec_inp_ids, dec_inp_mask, dec_inp_segment_ids = target 183 | 184 | dec_outputs, attention_dists = [], [] 185 | 186 | for i in tqdm(range(1, self.output_seq_len)): 187 | 188 | # (batch_size, seq_len) 189 | dec_inp_ids_ = mask_timestamp(dec_inp_ids, i, MASK_ID) 190 | 191 | # (batch_size, seq_len, d_bert) 192 | context_vectors = self.bert((dec_inp_ids_, dec_inp_mask, dec_inp_segment_ids)) 193 | 194 | # (batch_size, seq_len, d_bert), (_) 195 | dec_output, attention_dist = self.decoder( 196 | context_vectors, 197 | enc_output, 198 | training, 199 | look_ahead_mask=None, 200 | padding_mask=padding_mask 201 | ) 202 | 203 | # (batch_size, 1, seq_len) 204 | dec_outputs += [dec_output[:,i:i+1,:]] 205 | attention_dists += [{k: v[:, i:i+1, :] for k, v in attention_dist.items()}] 206 | 207 | # (batch_size, seq_len - 1, d_bert) 208 | dec_outputs = tf.concat(dec_outputs, axis=1) 209 | 210 | # (batch_size, seq_len - 1, vocab_len) 211 | logits = self.final_layer(dec_outputs) 212 | 213 | # (batch_size, seq_len, vocab_len), accommodate for initial [CLS] 214 | logits = tf.concat( 215 | [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1]), logits], 216 | axis=1 217 | ) 218 | 219 | # (batch_size, seq_len) 220 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 221 | 222 | return logits, preds, attention_dists 223 | 224 | def refined_summary(self, enc_output, target, padding_mask, training=True): 225 | 226 | logging.info("Building: 'Refined Summary'") 227 | 228 | N = tf.shape(enc_output)[0] 229 | T = self.output_seq_len 230 | 231 | # (batch_size, seq_len) x3 232 | dec_inp_ids, dec_inp_mask, dec_inp_segment_ids = target 233 | 234 | # since we are using teacher forcing we do not need an autoregressice mechanism here 235 | 236 | # (batch_size x (seq_len - 1), seq_len) 237 | dec_inp_ids = tile_and_mask_diagonal(dec_inp_ids, mask_with=MASK_ID) 238 | 239 | # (batch_size x (seq_len - 1), seq_len) 240 | dec_inp_mask = tf.tile(dec_inp_mask, [T-1, 1]) 241 | 242 | # (batch_size x (seq_len - 1), seq_len) 243 | dec_inp_segment_ids = tf.tile(dec_inp_segment_ids, [T-1, 1]) 244 | 245 | # (batch_size x (seq_len - 1), seq_len, d_bert) 246 | enc_output = tf.tile(enc_output, [T-1, 1, 1]) 247 | 248 | # (batch_size x (seq_len - 1), 1, 1, seq_len) 249 | padding_mask = tf.tile(padding_mask, [T-1, 1, 1, 1]) 250 | 251 | # (batch_size x (seq_len - 1), seq_len, d_bert) 252 | context_vectors = self.bert((dec_inp_ids, dec_inp_mask, dec_inp_segment_ids)) 253 | 254 | # (batch_size x (seq_len - 1), seq_len, d_bert), (_) 255 | dec_outputs, attention_dists = self.decoder( 256 | context_vectors, 257 | enc_output, 258 | training, 259 | look_ahead_mask=None, 260 | padding_mask=padding_mask 261 | ) 262 | 263 | # (batch_size x (seq_len - 1), seq_len - 1, d_bert) 264 | dec_outputs = dec_outputs[:, 1:, :] 265 | 266 | # (batch_size x (seq_len - 1), (seq_len - 1)) 267 | diag = tf.linalg.set_diag(tf.zeros([T-1, T-1]), tf.ones([T-1])) 268 | diag = tf.tile(diag, [N, 1]) 269 | 270 | where = tf.not_equal(diag, 0) 271 | indices = tf.where(where) 272 | 273 | # (batch_size x (seq_len - 1), d_bert) 274 | dec_outputs = tf.gather_nd(dec_outputs, indices) 275 | 276 | # (batch_size, seq_len - 1, d_bert) 277 | dec_outputs = tf.reshape(dec_outputs, [N, T-1, -1]) 278 | 279 | # (batch_size, seq_len - 1, vocab_len) 280 | logits = self.final_layer(dec_outputs) 281 | 282 | # (batch_size, seq_len, vocab_len), accommodate for initial [CLS] 283 | logits = tf.concat( 284 | [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1]), logits], 285 | axis=1 286 | ) 287 | 288 | # (batch_size, seq_len) 289 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 290 | 291 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 292 | return logits, preds, attention_dists 293 | 294 | def refined_summary_greedy(self, enc_output, draft_summary, padding_mask, training=False): 295 | """ 296 | Inference call, builds a refined summary 297 | 298 | It first masks each word in the summary draft one by one, 299 | then feeds the draft to BERT to generate context vectors. 300 | """ 301 | 302 | logging.info("Building: 'Greedy Refined Summary'") 303 | 304 | refined_summary = draft_summary 305 | refined_summary_mask = tf.cast(tf.math.equal(draft_summary, 0), tf.float32) 306 | refined_summary_segment_ids = tf.zeros(tf.shape(draft_summary)) 307 | 308 | N = tf.shape(draft_summary)[0] 309 | T = tf.shape(draft_summary)[1] 310 | 311 | dec_outputs, dec_logits, attention_dists = [], [], [] 312 | 313 | dec_logits += [tf.tile(tf.expand_dims(tf.one_hot([CLS_ID], self.vocab_size), axis=0), [N, 1, 1])] 314 | 315 | for i in tqdm(range(1, self.output_seq_len)): 316 | 317 | # (batch_size, seq_len) 318 | refined_summary_ = mask_timestamp(refined_summary, i, MASK_ID) 319 | 320 | # (batch_size, seq_len, d_bert) 321 | context_vectors = self.bert((refined_summary_, refined_summary_mask, refined_summary_segment_ids)) 322 | 323 | # (batch_size, seq_len, d_bert), (_) 324 | dec_output, attention_dist = self.decoder( 325 | context_vectors, 326 | enc_output, 327 | training=training, 328 | look_ahead_mask=None, 329 | padding_mask=padding_mask 330 | ) 331 | 332 | # (batch_size, 1, vocab_len) 333 | dec_output_i = dec_output[:, i:i+1 ,:] 334 | 335 | dec_outputs += [dec_output_i] 336 | attention_dists += [{k: v[:, i:i+1, :] for k, v in attention_dist.items()}] 337 | 338 | # (batch_size, 1, vocab_len) 339 | logits = self.final_layer(dec_output_i) 340 | 341 | dec_logits += [logits] 342 | 343 | # (batch_size, 1) 344 | preds = tf.to_int32(tf.argmax(logits, axis=-1)) 345 | 346 | # (batch_size, seq_len) 347 | refined_summary = with_column(refined_summary, i, preds) 348 | 349 | # (batch_size, seq_len, vocab_len) 350 | dec_logits = tf.concat(dec_logits, axis=1) 351 | 352 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 353 | return dec_logits, refined_summary, attention_dists 354 | 355 | def call(self, inp, tar=None, training=False): 356 | """ 357 | Run the model for training/inference 358 | For training, the target is needed (ids, mask, segments) 359 | """ 360 | 361 | if training: 362 | return self.fit(inp, tar) 363 | 364 | else: 365 | return self.predict(inp) 366 | 367 | 368 | def fit(self, inp, tar): 369 | """ 370 | __call__ for training; uses teacher forcing for both the draft 371 | and the defined decoder 372 | """ 373 | # (batch_size, seq_len) x3 374 | input_ids, input_mask, input_segment_ids = inp 375 | 376 | # (batch_size, seq_len + 1) x3 377 | target_ids, target_mask, target_segment_ids = tar 378 | 379 | # (batch_size, 1, 1, seq_len), (_), (batch_size, 1, 1, seq_len) 380 | combined_mask, dec_padding_mask = create_masks(input_ids, target_ids[:, :-1]) 381 | 382 | # (batch_size, seq_len, d_bert) 383 | enc_output = self.encode(input_ids, input_mask, input_segment_ids) 384 | 385 | # (batch_size, seq_len , vocab_len), (batch_size, seq_len), (_) 386 | logits_draft_summary, preds_draft_summary, draft_attention_dist = self.draft_summary( 387 | enc_output=enc_output, 388 | look_ahead_mask=combined_mask, 389 | padding_mask=dec_padding_mask, 390 | target_ids=target_ids[:, :-1], 391 | training=True 392 | ) 393 | 394 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 395 | logits_refined_summary, preds_refined_summary, refined_attention_dist = self.refined_summary( 396 | enc_output=enc_output, 397 | target=(target_ids[:, :-1], target_mask[:, :-1], target_segment_ids[:, :-1]), 398 | padding_mask=dec_padding_mask, 399 | training=True 400 | ) 401 | 402 | return logits_draft_summary, logits_refined_summary 403 | 404 | 405 | def predict(self, inp): 406 | """ 407 | __call__ for inference; uses teacher forcing for both the draft 408 | and the defined decoder 409 | """ 410 | # (batch_size, seq_len) x3 411 | input_ids, input_mask, input_segment_ids = inp 412 | 413 | dec_padding_mask = create_padding_mask(input_ids) 414 | 415 | # (batch_size, seq_len, d_bert) 416 | enc_output = self.encode(input_ids, input_mask, input_segment_ids) 417 | 418 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 419 | logits_draft_summary, preds_draft_summary, draft_attention_dist = self.draft_summary_greedy( 420 | enc_output=enc_output, 421 | look_ahead_mask=None, 422 | padding_mask=dec_padding_mask 423 | ) 424 | 425 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 426 | logits_refined_summary, preds_refined_summary, refined_attention_dist = self.refined_summary_greedy( 427 | enc_output=enc_output, 428 | padding_mask=dec_padding_mask, 429 | draft_summary=preds_draft_summary 430 | ) 431 | 432 | return logits_draft_summary, preds_draft_summary, draft_attention_dist, logits_refined_summary, preds_refined_summary, refined_attention_dist 433 | 434 | 435 | def train(model, xs, ys, gradient_accumulation=False): 436 | 437 | logging.info("Building Training Graph") 438 | logging.info(f"w/ Gradient Accumulation: {str(gradient_accumulation)}") 439 | 440 | # (batch_size, seq_len + 1) x3 441 | target_ids, _, _ = ys 442 | 443 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len, vocab_len) 444 | logits_draft_summary, logits_refined_summary = model(xs, ys, True) 445 | 446 | target_ids_ = label_smoothing(tf.one_hot(target_ids, depth=model.vocab_size)) 447 | 448 | # use right shifted target, (batch_size, seq_len) 449 | loss_draft = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_draft_summary, labels=target_ids_[:, 1:, :]) 450 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, 1:], 0)) 451 | mask = tf.cast(mask, dtype=loss_draft.dtype) 452 | loss_draft *= mask 453 | 454 | # use non-shifted target (we want to predict the masked word), (batch_size, seq_len) 455 | loss_refined = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_refined_summary, labels=target_ids_[:, :-1, :]) 456 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, :-1], 0)) 457 | mask = tf.cast(mask, dtype=loss_refined.dtype) 458 | loss_refined *= mask 459 | 460 | # (batch_size, seq_len) 461 | loss = loss_draft + loss_refined 462 | # scalar 463 | loss = tf.reduce_mean(loss) 464 | 465 | global_step = tf.train.get_or_create_global_step() 466 | learning_rate = noam_scheme(initial_lr, global_step, warmup_steps) 467 | optimizer = tf.train.AdamOptimizer(learning_rate, beta1=0.9, beta2=0.98, epsilon=1e-9) 468 | 469 | tf.summary.scalar('learning_rate', learning_rate, family='train') 470 | tf.summary.scalar('loss_draft', tf.reduce_mean(loss_draft * mask), family='train') 471 | tf.summary.scalar('loss_refined', tf.reduce_mean(loss_refined * mask), family='train') 472 | tf.summary.scalar("loss", loss, family='train') 473 | tf.summary.scalar("global_step", global_step, family='train') 474 | 475 | summaries = tf.summary.merge_all() 476 | 477 | if gradient_accumulation: 478 | 479 | tvs = tf.trainable_variables() 480 | 481 | accumulation_variables = [ 482 | tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) 483 | for tv in tvs 484 | ] 485 | 486 | zero_op = [tv.assign(tf.zeros_like(tv)) for tv in accumulation_variables] 487 | 488 | gradients_vs = optimizer.compute_gradients( 489 | loss=loss, 490 | var_list=tvs, 491 | colocate_gradients_with_ops=True, 492 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 493 | ) 494 | 495 | accumlation_op = [accumulation_variables[i].assign_add(gv[0]) for i, gv in enumerate(gradients_vs) if gv[0] is not None] 496 | 497 | # pass list of (gradient, variable) pairs 498 | train_op = optimizer.apply_gradients([(accumulation_variables[i], gv[1]) for i, gv in enumerate(gradients_vs)], global_step) 499 | 500 | return loss, zero_op, accumlation_op, train_op, global_step, summaries 501 | 502 | else: 503 | 504 | train_op = optimizer.minimize( 505 | loss, 506 | global_step=global_step, 507 | colocate_gradients_with_ops=True, 508 | aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N 509 | ) 510 | 511 | return loss, train_op, global_step, summaries 512 | 513 | 514 | def eval(model, xs, ys): 515 | 516 | logging.info("Building Evaluation Graph") 517 | 518 | # (batch_size, seq_len + 1) x3 519 | target_ids, _, _ = ys 520 | 521 | # (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_), (batch_size, seq_len, vocab_len), (batch_size, seq_len), (_) 522 | logits_draft_summary, preds_draft_summary, _, logits_refined_summary, preds_refined_summary, _ = model(xs) 523 | 524 | target_ids_ = label_smoothing(tf.one_hot(target_ids, depth=model.vocab_size)) 525 | 526 | # use right shifted target, (batch_size, seq_len) 527 | loss_draft = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_draft_summary, labels=target_ids_[:, 1:, :]) 528 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, 1:], 0)) 529 | mask = tf.cast(mask, dtype=loss_draft.dtype) 530 | loss_draft *= mask 531 | 532 | # use non-shifted target (we want to predict the masked word), (batch_size, seq_len) 533 | loss_refined = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits_refined_summary, labels=target_ids_[:, :-1, :]) 534 | mask = tf.math.logical_not(tf.math.equal(target_ids[:, :-1], 0)) 535 | mask = tf.cast(mask, dtype=loss_refined.dtype) 536 | loss_refined *= mask 537 | 538 | # (batch_size, seq_len) 539 | loss = loss_draft + loss_refined 540 | 541 | # scalar 542 | loss = tf.reduce_mean(loss) 543 | 544 | # monitor a random sample 545 | n = tf.random_uniform((), 0, tf.shape(xs[0])[0] - 1, tf.int32) 546 | 547 | x_rnd = convert_idx_to_token_tensor(xs[0][n]) 548 | y_rnd = convert_idx_to_token_tensor(target_ids[n, :-1]) 549 | y_hat_rnd = convert_idx_to_token_tensor(preds_refined_summary[n]) 550 | 551 | r1_val, r2_val, rl_val, r_vag = calculate_rouge(y_rnd, y_hat_rnd) 552 | 553 | tf.summary.text("input", x_rnd) 554 | tf.summary.text("target", y_rnd) 555 | tf.summary.text("prediction", y_hat_rnd) 556 | 557 | tf.summary.scalar('ROUGE-1', r1_val, family='eval') 558 | tf.summary.scalar('ROUGE-2', r2_val, family='eval') 559 | tf.summary.scalar("ROUGE-L", rl_val, family='eval') 560 | tf.summary.scalar("R-AVG", r_vag, family='eval') 561 | 562 | tf.summary.scalar('loss_draft', tf.reduce_mean(loss_draft * mask), family='eval') 563 | tf.summary.scalar('loss_refined', tf.reduce_mean(loss_refined * mask), family='eval') 564 | tf.summary.scalar("loss", loss, family='eval') 565 | 566 | summaries = tf.summary.merge_all() 567 | 568 | # (batch_size, seq_len), (batch_size, seq_len), scalar, object 569 | return target_ids[:, :-1], preds_refined_summary, loss, summaries 570 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow_hub as hub 6 | 7 | from tensorflow.keras import backend as K 8 | from layers.transformer import Encoder 9 | from layers.transformer import Decoder 10 | 11 | 12 | class Transformer(tf.keras.Model): 13 | """ 14 | Transformer consists of the encoder, decoder and a final linear layer. 15 | The output of the decoder is the input to the linear layer and its output is returned. 16 | """ 17 | def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, rate=0.1): 18 | super(Transformer, self).__init__() 19 | 20 | self.encoder = Encoder(num_layers, d_model, num_heads, dff, input_vocab_size, rate) 21 | 22 | self.decoder = Decoder(num_layers, d_model, num_heads, dff, target_vocab_size, rate) 23 | 24 | self.final_layer = tf.keras.layers.Dense(target_vocab_size) 25 | 26 | def call(self, inp, tar, training, enc_padding_mask, look_ahead_mask, dec_padding_mask): 27 | 28 | enc_output = self.encoder(inp, training, enc_padding_mask) # (batch_size, inp_seq_len, d_model) 29 | 30 | # dec_output.shape == (batch_size, tar_seq_len, d_model) 31 | dec_output, attention_weights = self.decoder(tar, enc_output, training, look_ahead_mask, dec_padding_mask) 32 | 33 | final_output = self.final_layer(dec_output) # (batch_size, tar_seq_len, target_vocab_size) 34 | 35 | return final_output, attention_weights 36 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/ops/__init__.py -------------------------------------------------------------------------------- /ops/attention.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def scaled_dot_product_attention(q, k, v, mask): 5 | """Calculate the attention weights. 6 | q, k, v must have matching leading dimensions. 7 | k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. 8 | The mask has different shapes depending on its type(padding or look ahead) 9 | but it must be broadcastable for addition. 10 | 11 | The mask is multiplied with -1e9 (close to negative infinity). 12 | This is done because the mask is summed with the scaled matrix 13 | multiplication of Q and K and is applied immediately before a softmax. 14 | The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output. 15 | 16 | Args: 17 | q: query shape == (..., seq_len_q, depth) 18 | k: key shape == (..., seq_len_k, depth) 19 | v: value shape == (..., seq_len_v, depth_v) 20 | mask: Float tensor with shape broadcastable 21 | to (..., seq_len_q, seq_len_k). Defaults to None. 22 | 23 | Returns: 24 | output, attention_weights 25 | """ 26 | 27 | matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) 28 | 29 | # scale matmul_qk 30 | dk = tf.cast(tf.shape(k)[-1], tf.float32) 31 | scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) 32 | 33 | # add the mask to the scaled tensor. 34 | if mask is not None: 35 | scaled_attention_logits += (mask * -1e9) 36 | 37 | # softmax is normalized on the last axis (seq_len_k) so that the scores 38 | # add up to 1. 39 | attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) 40 | 41 | output = tf.matmul(attention_weights, v) # (..., seq_len_v, depth_v) 42 | 43 | return output, attention_weights 44 | 45 | -------------------------------------------------------------------------------- /ops/beam_search.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.util import nest 3 | 4 | # Default value for INF 5 | INF = 1. * 1e7 6 | 7 | 8 | class _StateKeys(object): 9 | """Keys to dictionary storing the state of the beam search loop.""" 10 | 11 | # Variable storing the loop index. 12 | CUR_INDEX = "CUR_INDEX" 13 | 14 | # Top sequences that are alive for each batch item. Alive sequences are ones 15 | # that have not generated an EOS token. Sequences that reach EOS are marked as 16 | # finished and moved to the FINISHED_SEQ tensor. 17 | # Has shape [batch_size, beam_size, CUR_INDEX + 1] 18 | ALIVE_SEQ = "ALIVE_SEQ" 19 | # Log probabilities of each alive sequence. Shape [batch_size, beam_size] 20 | ALIVE_LOG_PROBS = "ALIVE_LOG_PROBS" 21 | # Dictionary of cached values for each alive sequence. The cache stores 22 | # the encoder output, attention bias, and the decoder attention output from 23 | # the previous iteration. 24 | ALIVE_CACHE = "ALIVE_CACHE" 25 | 26 | # Top finished sequences for each batch item. 27 | # Has shape [batch_size, beam_size, CUR_INDEX + 1]. Sequences that are 28 | # shorter than CUR_INDEX + 1 are padded with 0s. 29 | FINISHED_SEQ = "FINISHED_SEQ" 30 | # Scores for each finished sequence. Score = log probability / length norm 31 | # Shape [batch_size, beam_size] 32 | FINISHED_SCORES = "FINISHED_SCORES" 33 | # Flags indicating which sequences in the finished sequences are finished. 34 | # At the beginning, all of the sequences in FINISHED_SEQ are filler values. 35 | # True -> finished sequence, False -> filler. Shape [batch_size, beam_size] 36 | FINISHED_FLAGS = "FINISHED_FLAGS" 37 | 38 | 39 | class SequenceBeamSearch(object): 40 | """Implementation of beam search loop.""" 41 | 42 | def __init__(self, symbols_to_logits_fn, vocab_size, batch_size, 43 | beam_size, alpha, max_decode_length, eos_id): 44 | self.symbols_to_logits_fn = symbols_to_logits_fn 45 | self.vocab_size = vocab_size 46 | self.batch_size = batch_size 47 | self.beam_size = beam_size 48 | self.alpha = alpha 49 | self.max_decode_length = max_decode_length 50 | self.eos_id = eos_id 51 | 52 | def search(self, initial_ids, initial_cache): 53 | """Beam search for sequences with highest scores.""" 54 | state, state_shapes = self._create_initial_state(initial_ids, initial_cache) 55 | 56 | finished_state = tf.while_loop( 57 | self._continue_search, self._search_step, loop_vars=[state], 58 | shape_invariants=[state_shapes], parallel_iterations=1, back_prop=False) 59 | finished_state = finished_state[0] 60 | 61 | alive_seq = finished_state[_StateKeys.ALIVE_SEQ] 62 | alive_log_probs = finished_state[_StateKeys.ALIVE_LOG_PROBS] 63 | finished_seq = finished_state[_StateKeys.FINISHED_SEQ] 64 | finished_scores = finished_state[_StateKeys.FINISHED_SCORES] 65 | finished_flags = finished_state[_StateKeys.FINISHED_FLAGS] 66 | 67 | # Account for corner case where there are no finished sequences for a 68 | # particular batch item. In that case, return alive sequences for that batch 69 | # item. 70 | finished_seq = tf.where( 71 | tf.reduce_any(finished_flags, 1), finished_seq, alive_seq) 72 | finished_scores = tf.where( 73 | tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs) 74 | return finished_seq, finished_scores 75 | 76 | def _create_initial_state(self, initial_ids, initial_cache): 77 | """Return initial state dictionary and its shape invariants. 78 | 79 | Args: 80 | initial_ids: initial ids to pass into the symbols_to_logits_fn. 81 | int tensor with shape [batch_size, 1] 82 | initial_cache: dictionary storing values to be passed into the 83 | symbols_to_logits_fn. 84 | 85 | Returns: 86 | state and shape invariant dictionaries with keys from _StateKeys 87 | """ 88 | # Current loop index (starts at 0) 89 | cur_index = tf.constant(0) 90 | 91 | # Create alive sequence with shape [batch_size, beam_size, 1] 92 | alive_seq = _expand_to_beam_size(initial_ids, self.beam_size) 93 | alive_seq = tf.expand_dims(alive_seq, axis=2) 94 | 95 | # Create tensor for storing initial log probabilities. 96 | # Assume initial_ids are prob 1.0 97 | initial_log_probs = tf.constant( 98 | [[0.] + [-float("inf")] * (self.beam_size - 1)]) 99 | alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1]) 100 | 101 | # Expand all values stored in the dictionary to the beam size, so that each 102 | # beam has a separate cache. 103 | alive_cache = nest.map_structure( 104 | lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache) 105 | 106 | # Initialize tensor storing finished sequences with filler values. 107 | finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32) 108 | 109 | # Set scores of the initial finished seqs to negative infinity. 110 | finished_scores = tf.ones([self.batch_size, self.beam_size]) * -INF 111 | 112 | # Initialize finished flags with all False values. 113 | finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool) 114 | 115 | # Create state dictionary 116 | state = { 117 | _StateKeys.CUR_INDEX: cur_index, 118 | _StateKeys.ALIVE_SEQ: alive_seq, 119 | _StateKeys.ALIVE_LOG_PROBS: alive_log_probs, 120 | _StateKeys.ALIVE_CACHE: alive_cache, 121 | _StateKeys.FINISHED_SEQ: finished_seq, 122 | _StateKeys.FINISHED_SCORES: finished_scores, 123 | _StateKeys.FINISHED_FLAGS: finished_flags 124 | } 125 | 126 | # Create state invariants for each value in the state dictionary. Each 127 | # dimension must be a constant or None. A None dimension means either: 128 | # 1) the dimension's value is a tensor that remains the same but may 129 | # depend on the input sequence to the model (e.g. batch size). 130 | # 2) the dimension may have different values on different iterations. 131 | state_shape_invariants = { 132 | _StateKeys.CUR_INDEX: tf.TensorShape([]), 133 | _StateKeys.ALIVE_SEQ: tf.TensorShape([None, self.beam_size, None]), 134 | _StateKeys.ALIVE_LOG_PROBS: tf.TensorShape([None, self.beam_size]), 135 | _StateKeys.ALIVE_CACHE: nest.map_structure( 136 | _get_shape_keep_last_dim, alive_cache), 137 | _StateKeys.FINISHED_SEQ: tf.TensorShape([None, self.beam_size, None]), 138 | _StateKeys.FINISHED_SCORES: tf.TensorShape([None, self.beam_size]), 139 | _StateKeys.FINISHED_FLAGS: tf.TensorShape([None, self.beam_size]) 140 | } 141 | 142 | return state, state_shape_invariants 143 | 144 | def _continue_search(self, state): 145 | """Return whether to continue the search loop. 146 | 147 | The loops should terminate when 148 | 1) when decode length has been reached, or 149 | 2) when the worst score in the finished sequences is better than the best 150 | score in the alive sequences (i.e. the finished sequences are provably 151 | unchanging) 152 | 153 | Args: 154 | state: A dictionary with the current loop state. 155 | 156 | Returns: 157 | Bool tensor with value True if loop should continue, False if loop should 158 | terminate. 159 | """ 160 | i = state[_StateKeys.CUR_INDEX] 161 | alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS] 162 | finished_scores = state[_StateKeys.FINISHED_SCORES] 163 | finished_flags = state[_StateKeys.FINISHED_FLAGS] 164 | 165 | not_at_max_decode_length = tf.less(i, self.max_decode_length) 166 | 167 | # Calculate largest length penalty (the larger penalty, the better score). 168 | max_length_norm = _length_normalization(self.alpha, self.max_decode_length) 169 | # Get the best possible scores from alive sequences. 170 | best_alive_scores = alive_log_probs[:, 0] / max_length_norm 171 | 172 | # Compute worst score in finished sequences for each batch element 173 | finished_scores *= tf.cast(finished_flags, 174 | tf.float32) # set filler scores to zero 175 | lowest_finished_scores = tf.reduce_min(finished_scores, axis=1) 176 | 177 | # If there are no finished sequences in a batch element, then set the lowest 178 | # finished score to -INF for that element. 179 | finished_batches = tf.reduce_any(finished_flags, 1) 180 | lowest_finished_scores += (1.0 - 181 | tf.cast(finished_batches, tf.float32)) * -INF 182 | 183 | worst_finished_score_better_than_best_alive_score = tf.reduce_all( 184 | tf.greater(lowest_finished_scores, best_alive_scores) 185 | ) 186 | 187 | return tf.logical_and( 188 | not_at_max_decode_length, 189 | tf.logical_not(worst_finished_score_better_than_best_alive_score) 190 | ) 191 | 192 | def _search_step(self, state): 193 | """Beam search loop body. 194 | 195 | Grow alive sequences by a single ID. Sequences that have reached the EOS 196 | token are marked as finished. The alive and finished sequences with the 197 | highest log probabilities and scores are returned. 198 | 199 | A sequence's finished score is calculating by dividing the log probability 200 | by the length normalization factor. Without length normalization, the 201 | search is more likely to return shorter sequences. 202 | 203 | Args: 204 | state: A dictionary with the current loop state. 205 | 206 | Returns: 207 | new state dictionary. 208 | """ 209 | # Grow alive sequences by one token. 210 | new_seq, new_log_probs, new_cache = self._grow_alive_seq(state) 211 | # Collect top beam_size alive sequences 212 | alive_state = self._get_new_alive_state(new_seq, new_log_probs, new_cache) 213 | 214 | # Combine newly finished sequences with existing finished sequences, and 215 | # collect the top k scoring sequences. 216 | finished_state = self._get_new_finished_state(state, new_seq, new_log_probs) 217 | 218 | # Increment loop index and create new state dictionary 219 | new_state = {_StateKeys.CUR_INDEX: state[_StateKeys.CUR_INDEX] + 1} 220 | new_state.update(alive_state) 221 | new_state.update(finished_state) 222 | return [new_state] 223 | 224 | def _grow_alive_seq(self, state): 225 | """Grow alive sequences by one token, and collect top 2*beam_size sequences. 226 | 227 | 2*beam_size sequences are collected because some sequences may have reached 228 | the EOS token. 2*beam_size ensures that at least beam_size sequences are 229 | still alive. 230 | 231 | Args: 232 | state: A dictionary with the current loop state. 233 | Returns: 234 | Tuple of 235 | (Top 2*beam_size sequences [batch_size, 2 * beam_size, cur_index + 1], 236 | Scores of returned sequences [batch_size, 2 * beam_size], 237 | New alive cache, for each of the 2 * beam_size sequences) 238 | """ 239 | i = state[_StateKeys.CUR_INDEX] 240 | alive_seq = state[_StateKeys.ALIVE_SEQ] 241 | alive_log_probs = state[_StateKeys.ALIVE_LOG_PROBS] 242 | alive_cache = state[_StateKeys.ALIVE_CACHE] 243 | 244 | beams_to_keep = 2 * self.beam_size 245 | 246 | # Get logits for the next candidate IDs for the alive sequences. Get the new 247 | # cache values at the same time. 248 | flat_ids = _flatten_beam_dim(alive_seq) # [batch_size * beam_size] 249 | flat_cache = nest.map_structure(_flatten_beam_dim, alive_cache) 250 | 251 | flat_logits, flat_cache = self.symbols_to_logits_fn(flat_ids, i) 252 | 253 | # Unflatten logits to shape [batch_size, beam_size, vocab_size] 254 | logits = _unflatten_beam_dim(flat_logits, self.batch_size, self.beam_size) 255 | new_cache = nest.map_structure( 256 | lambda t: _unflatten_beam_dim(t, self.batch_size, self.beam_size), 257 | flat_cache) 258 | 259 | # Convert logits to normalized log probs 260 | candidate_log_probs = _log_prob_from_logits(logits) 261 | 262 | # Calculate new log probabilities if each of the alive sequences were 263 | # extended # by the the candidate IDs. 264 | # Shape [batch_size, beam_size, vocab_size] 265 | log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2) 266 | 267 | # Each batch item has beam_size * vocab_size candidate sequences. For each 268 | # batch item, get the k candidates with the highest log probabilities. 269 | flat_log_probs = tf.reshape(log_probs, 270 | [-1, self.beam_size * self.vocab_size]) 271 | topk_log_probs, topk_indices = tf.nn.top_k(flat_log_probs, k=beams_to_keep) 272 | 273 | # Extract the alive sequences that generate the highest log probabilities 274 | # after being extended. 275 | topk_beam_indices = topk_indices // self.vocab_size 276 | topk_seq, new_cache = _gather_beams( 277 | [alive_seq, new_cache], topk_beam_indices, self.batch_size, 278 | beams_to_keep) 279 | 280 | # Append the most probable IDs to the topk sequences 281 | topk_ids = topk_indices % self.vocab_size 282 | topk_ids = tf.expand_dims(topk_ids, axis=2) 283 | topk_seq = tf.concat([topk_seq, topk_ids], axis=2) 284 | return topk_seq, topk_log_probs, new_cache 285 | 286 | def _get_new_alive_state(self, new_seq, new_log_probs, new_cache): 287 | """Gather the top k sequences that are still alive. 288 | 289 | Args: 290 | new_seq: New sequences generated by growing the current alive sequences 291 | int32 tensor with shape [batch_size, 2 * beam_size, cur_index + 1] 292 | new_log_probs: Log probabilities of new sequences 293 | float32 tensor with shape [batch_size, beam_size] 294 | new_cache: Dict of cached values for each sequence. 295 | 296 | Returns: 297 | Dictionary with alive keys from _StateKeys: 298 | {Top beam_size sequences that are still alive (don't end with eos_id) 299 | Log probabilities of top alive sequences 300 | Dict cache storing decoder states for top alive sequences} 301 | """ 302 | # To prevent finished sequences from being considered, set log probs to -INF 303 | new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id) 304 | new_log_probs += tf.cast(new_finished_flags, tf.float32) * -INF 305 | 306 | top_alive_seq, top_alive_log_probs, top_alive_cache = _gather_topk_beams( 307 | [new_seq, new_log_probs, new_cache], new_log_probs, self.batch_size, 308 | self.beam_size) 309 | 310 | return { 311 | _StateKeys.ALIVE_SEQ: top_alive_seq, 312 | _StateKeys.ALIVE_LOG_PROBS: top_alive_log_probs, 313 | _StateKeys.ALIVE_CACHE: top_alive_cache 314 | } 315 | 316 | def _get_new_finished_state(self, state, new_seq, new_log_probs): 317 | """Combine new and old finished sequences, and gather the top k sequences. 318 | 319 | Args: 320 | state: A dictionary with the current loop state. 321 | new_seq: New sequences generated by growing the current alive sequences 322 | int32 tensor with shape [batch_size, beam_size, i + 1] 323 | new_log_probs: Log probabilities of new sequences 324 | float32 tensor with shape [batch_size, beam_size] 325 | 326 | Returns: 327 | Dictionary with finished keys from _StateKeys: 328 | {Top beam_size finished sequences based on score, 329 | Scores of finished sequences, 330 | Finished flags of finished sequences} 331 | """ 332 | i = state[_StateKeys.CUR_INDEX] 333 | finished_seq = state[_StateKeys.FINISHED_SEQ] 334 | finished_scores = state[_StateKeys.FINISHED_SCORES] 335 | finished_flags = state[_StateKeys.FINISHED_FLAGS] 336 | 337 | # First append a column of 0-ids to finished_seq to increment the length. 338 | # New shape of finished_seq: [batch_size, beam_size, i + 1] 339 | finished_seq = tf.concat( 340 | [finished_seq, 341 | tf.zeros([self.batch_size, self.beam_size, 1], tf.int32)], axis=2) 342 | 343 | # Calculate new seq scores from log probabilities. 344 | length_norm = _length_normalization(self.alpha, i + 1) 345 | new_scores = new_log_probs / length_norm 346 | 347 | # Set the scores of the still-alive seq in new_seq to large negative values. 348 | new_finished_flags = tf.equal(new_seq[:, :, -1], self.eos_id) 349 | new_scores += (1. - tf.cast(new_finished_flags, tf.float32)) * -INF 350 | 351 | # Combine sequences, scores, and flags. 352 | finished_seq = tf.concat([finished_seq, new_seq], axis=1) 353 | finished_scores = tf.concat([finished_scores, new_scores], axis=1) 354 | finished_flags = tf.concat([finished_flags, new_finished_flags], axis=1) 355 | 356 | # Return the finished sequences with the best scores. 357 | top_finished_seq, top_finished_scores, top_finished_flags = ( 358 | _gather_topk_beams([finished_seq, finished_scores, finished_flags], 359 | finished_scores, self.batch_size, self.beam_size)) 360 | 361 | return { 362 | _StateKeys.FINISHED_SEQ: top_finished_seq, 363 | _StateKeys.FINISHED_SCORES: top_finished_scores, 364 | _StateKeys.FINISHED_FLAGS: top_finished_flags 365 | } 366 | 367 | 368 | def sequence_beam_search( 369 | symbols_to_logits_fn, initial_ids, initial_cache, vocab_size, beam_size, 370 | alpha, max_decode_length, eos_id): 371 | """Search for sequence of subtoken ids with the largest probability. 372 | 373 | Args: 374 | symbols_to_logits_fn: A function that takes in ids, index, and cache as 375 | arguments. The passed in arguments will have shape: 376 | ids -> [batch_size * beam_size, index] 377 | index -> [] (scalar) 378 | cache -> nested dictionary of tensors [batch_size * beam_size, ...] 379 | The function must return logits and new cache. 380 | logits -> [batch * beam_size, vocab_size] 381 | new cache -> same shape/structure as inputted cache 382 | initial_ids: Starting ids for each batch item. 383 | int32 tensor with shape [batch_size] 384 | initial_cache: dict containing starting decoder variables information 385 | vocab_size: int size of tokens 386 | beam_size: int number of beams 387 | alpha: float defining the strength of length normalization 388 | max_decode_length: maximum length to decoded sequence 389 | eos_id: int id of eos token, used to determine when a sequence has finished 390 | 391 | Returns: 392 | Top decoded sequences [batch_size, beam_size, max_decode_length] 393 | sequence scores [batch_size, beam_size] 394 | """ 395 | batch_size = tf.shape(initial_ids)[0] 396 | sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, batch_size, 397 | beam_size, alpha, max_decode_length, eos_id) 398 | return sbs.search(initial_ids, initial_cache) 399 | 400 | 401 | def _log_prob_from_logits(logits): 402 | return logits - tf.reduce_logsumexp(logits, axis=2, keepdims=True) 403 | 404 | 405 | def _length_normalization(alpha, length): 406 | """Return length normalization factor.""" 407 | return tf.pow(((5. + tf.cast(length, tf.float32)) / 6.), alpha) 408 | 409 | 410 | def _expand_to_beam_size(tensor, beam_size): 411 | """Tiles a given tensor by beam_size. 412 | 413 | Args: 414 | tensor: tensor to tile [batch_size, ...] 415 | beam_size: How much to tile the tensor by. 416 | 417 | Returns: 418 | Tiled tensor [batch_size, beam_size, ...] 419 | """ 420 | tensor = tf.expand_dims(tensor, axis=1) 421 | tile_dims = [1] * tensor.shape.ndims 422 | tile_dims[1] = beam_size 423 | 424 | return tf.tile(tensor, tile_dims) 425 | 426 | 427 | def _shape_list(tensor): 428 | """Return a list of the tensor's shape, and ensure no None values in list.""" 429 | # Get statically known shape (may contain None's for unknown dimensions) 430 | shape = tensor.get_shape().as_list() 431 | 432 | # Ensure that the shape values are not None 433 | dynamic_shape = tf.shape(tensor) 434 | for i in range(len(shape)): # pylint: disable=consider-using-enumerate 435 | if shape[i] is None: 436 | shape[i] = dynamic_shape[i] 437 | return shape 438 | 439 | 440 | def _get_shape_keep_last_dim(tensor): 441 | shape_list = _shape_list(tensor) 442 | 443 | # Only the last 444 | for i in range(len(shape_list) - 1): 445 | shape_list[i] = None 446 | 447 | if isinstance(shape_list[-1], tf.Tensor): 448 | shape_list[-1] = None 449 | return tf.TensorShape(shape_list) 450 | 451 | 452 | def _flatten_beam_dim(tensor): 453 | """Reshapes first two dimensions in to single dimension. 454 | 455 | Args: 456 | tensor: Tensor to reshape of shape [A, B, ...] 457 | 458 | Returns: 459 | Reshaped tensor of shape [A*B, ...] 460 | """ 461 | shape = _shape_list(tensor) 462 | shape[0] *= shape[1] 463 | shape.pop(1) # Remove beam dim 464 | return tf.reshape(tensor, shape) 465 | 466 | 467 | def _unflatten_beam_dim(tensor, batch_size, beam_size): 468 | """Reshapes first dimension back to [batch_size, beam_size]. 469 | 470 | Args: 471 | tensor: Tensor to reshape of shape [batch_size*beam_size, ...] 472 | batch_size: Tensor, original batch size. 473 | beam_size: int, original beam size. 474 | 475 | Returns: 476 | Reshaped tensor of shape [batch_size, beam_size, ...] 477 | """ 478 | shape = _shape_list(tensor) 479 | new_shape = [batch_size, beam_size] + shape[1:] 480 | return tf.reshape(tensor, new_shape) 481 | 482 | 483 | def _gather_beams(nested, beam_indices, batch_size, new_beam_size): 484 | """Gather beams from nested structure of tensors. 485 | 486 | Each tensor in nested represents a batch of beams, where beam refers to a 487 | single search state (beam search involves searching through multiple states 488 | in parallel). 489 | 490 | This function is used to gather the top beams, specified by 491 | beam_indices, from the nested tensors. 492 | 493 | Args: 494 | nested: Nested structure (tensor, list, tuple or dict) containing tensors 495 | with shape [batch_size, beam_size, ...]. 496 | beam_indices: int32 tensor with shape [batch_size, new_beam_size]. Each 497 | value in beam_indices must be between [0, beam_size), and are not 498 | necessarily unique. 499 | batch_size: int size of batch 500 | new_beam_size: int number of beams to be pulled from the nested tensors. 501 | 502 | Returns: 503 | Nested structure containing tensors with shape 504 | [batch_size, new_beam_size, ...] 505 | """ 506 | # Computes the i'th coodinate that contains the batch index for gather_nd. 507 | # Batch pos is a tensor like [[0,0,0,0,],[1,1,1,1],..]. 508 | batch_pos = tf.range(batch_size * new_beam_size) // new_beam_size 509 | batch_pos = tf.reshape(batch_pos, [batch_size, new_beam_size]) 510 | 511 | # Create coordinates to be passed to tf.gather_nd. Stacking creates a tensor 512 | # with shape [batch_size, beam_size, 2], where the last dimension contains 513 | # the (i, j) gathering coordinates. 514 | coordinates = tf.stack([batch_pos, beam_indices], axis=2) 515 | 516 | return nest.map_structure( 517 | lambda state: tf.gather_nd(state, coordinates), nested) 518 | 519 | 520 | def _gather_topk_beams(nested, score_or_log_prob, batch_size, beam_size): 521 | """Gather top beams from nested structure.""" 522 | _, topk_indexes = tf.nn.top_k(score_or_log_prob, k=beam_size) 523 | return _gather_beams(nested, topk_indexes, batch_size, beam_size) 524 | -------------------------------------------------------------------------------- /ops/data.py: -------------------------------------------------------------------------------- 1 | import re 2 | import time 3 | import glob 4 | import struct 5 | import shutil 6 | import logging 7 | import pickle 8 | import numpy as np 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | 12 | from collections import namedtuple 13 | from itertools import count 14 | from functools import partial 15 | from functools import reduce 16 | from functools import wraps 17 | 18 | from tensorflow.core.example import example_pb2 19 | from bert.tokenization import FullTokenizer 20 | from tqdm import tqdm 21 | 22 | from utils.decorators import timeit 23 | from config import config 24 | 25 | 26 | logging.basicConfig(level=logging.INFO) 27 | 28 | logger = logging.getLogger() 29 | 30 | 31 | 32 | SENTENCE_START = '' 33 | SENTENCE_END = '' 34 | 35 | BERT_MODEL_URL = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1" 36 | 37 | 38 | InputExample = namedtuple('InputExample', ['guid', 'text_a', 'text_b']) 39 | 40 | InputFeatures = namedtuple('InputFeatures', ['guid', 'tokens', 'input_ids', 'input_mask', 'input_type_ids']) 41 | 42 | 43 | def pad(l, n, pad): 44 | """ 45 | Pad the list 'l' to have size 'n' using 'padding_element' 46 | """ 47 | return l + [pad] * (n - len(l)) 48 | 49 | 50 | def calc_num_batches(total_num, batch_size): 51 | """ 52 | Calculates the number of batches, allowing for remainders. 53 | """ 54 | return total_num // batch_size + int(total_num % batch_size != 0) 55 | 56 | 57 | def convert_single_example(tokenizer, example, max_seq_len=config.SEQ_LEN): 58 | """ 59 | Convert `text` to the Bert input format 60 | """ 61 | tokens = tokenizer.tokenize(example.text_a) 62 | 63 | if len(tokens) > max_seq_len - 2: 64 | tokens = tokens[0:(max_seq_len - 2)] 65 | 66 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 67 | 68 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 69 | input_type_ids = [0] * len(input_ids) 70 | input_mask = [1] * len(input_ids) 71 | 72 | input_ids = pad(input_ids, max_seq_len, 0) 73 | input_type_ids = pad(input_type_ids, max_seq_len, 0) 74 | input_mask = pad(input_mask, max_seq_len, 0) 75 | 76 | return tokens, input_ids, input_mask, input_type_ids 77 | 78 | 79 | def convert_examples_to_features(tokenizer, examples, max_seq_len=config.SEQ_LEN): 80 | """ 81 | Convert raw features to Bert specific representation 82 | """ 83 | converter = partial(convert_single_example, tokenizer=tokenizer, max_seq_len=max_seq_len) 84 | examples = [converter(example=example) for example in examples] 85 | return examples 86 | 87 | 88 | def create_tokenizer_from_hub_module(bert_hub_url): 89 | """ 90 | Get the vocab file and casing info from the Hub module. 91 | """ 92 | bert_module = hub.Module(bert_hub_url) 93 | tokenization_info = bert_module(signature="tokenization_info", as_dict=True) 94 | 95 | with tf.Session() as sess: 96 | vocab_file, do_lower_case = sess.run([ 97 | tokenization_info["vocab_file"], 98 | tokenization_info["do_lower_case"] 99 | ]) 100 | 101 | return FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 102 | 103 | 104 | def abstract2sents(abstract): 105 | """ 106 | # Use the and tags in abstract to get a list of sentences. 107 | """ 108 | sentences_pattern = re.compile(r"(.+?)<\/s>") 109 | sentences = sentences_pattern.findall(abstract) 110 | return sentences 111 | 112 | 113 | def _load_single(file): 114 | """ 115 | Opens and prepares a single chunked file 116 | """ 117 | article_texts = [] 118 | abstract_texts = [] 119 | 120 | with open(file, 'rb') as f: 121 | 122 | while True: 123 | 124 | len_bytes = f.read(8) 125 | 126 | if not len_bytes: 127 | break 128 | 129 | str_len = struct.unpack('q', len_bytes)[0] 130 | str_bytes = struct.unpack('%ds' % str_len, f.read(str_len))[0] 131 | example = example_pb2.Example.FromString(str_bytes) 132 | 133 | article_text = example.features.feature['article'].bytes_list.value[0].decode('unicode_escape').strip() 134 | abstract_text = example.features.feature['abstract'].bytes_list.value[0].decode('unicode_escape').strip() 135 | abstract_text = ' '.join([sent.strip() for sent in abstract2sents(abstract_text)]) 136 | 137 | article_texts.append(article_text) 138 | abstract_texts.append(abstract_text) 139 | 140 | return article_texts, abstract_texts 141 | 142 | 143 | def load_data(files): 144 | """ 145 | Reads binary data and returns chuncks of [(articles, summaries)] 146 | """ 147 | logger.info(f"'{len(files)}' files found") 148 | data = [_load_single(file) for file in files] 149 | articles = sum([a for a, _ in data], []) 150 | summaries = sum([b for _, b in data], []) 151 | return articles, summaries 152 | 153 | 154 | def convert_single_example(example, tokenizer, max_seq_len=MAX_SEQ_LEN): 155 | """ 156 | Convert `text` to the Bert input format 157 | """ 158 | tokens = tokenizer.tokenize(example) 159 | 160 | if len(tokens) > max_seq_len - 2: 161 | tokens = tokens[0:(max_seq_len - 2)] 162 | 163 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 164 | 165 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 166 | input_type_ids = [0] * len(input_ids) 167 | input_mask = [1] * len(input_ids) 168 | 169 | input_ids = pad(input_ids, max_seq_len, 0) 170 | input_type_ids = pad(input_type_ids, max_seq_len, 0) 171 | input_mask = pad(input_mask, max_seq_len, 0) 172 | 173 | return tokens, input_ids, input_mask, input_type_ids 174 | 175 | 176 | def generator_fn(sents1, sents2, tokenizer): 177 | """ 178 | Generates training / evaluation data 179 | raw_data: list of (abstracts, raw_data) 180 | tokenizer: tokenizer to separate the different tokens 181 | yields 182 | xs: tuple of 183 | x: list of source token ids in a sent 184 | x_seqlen: int. sequence length of x 185 | sent1: str. raw source (=input) sentence 186 | ys: tuple of 187 | decoder_input: decoder_input: list of encoded decoder inputs 188 | y: list of target token ids in a sent 189 | y_seqlen: int. sequence length of y 190 | sent2: str. target sentence 191 | """ 192 | for article, summary in zip(sents1, sents2): 193 | tokens_x, input_ids_x, input_mask_x, input_type_ids_x = convert_single_example(article, tokenizer) 194 | tokens_y, input_ids_y, input_mask_y, input_type_ids_y = convert_single_example(summary, tokenizer) 195 | 196 | x_seqlen, y_seqlen = len(tokens_x), len(tokens_y) 197 | 198 | yield (input_ids_x, input_mask_x, input_type_ids_x, x_seqlen, article), (input_ids_y, input_mask_y, input_type_ids_y, y_seqlen, summary) 199 | 200 | 201 | 202 | def input_fn(sents1, sents2, tokenizer, batch_size, shuffle=False): 203 | """ 204 | Batchify data 205 | raw_data [(artciles, abstracts)] 206 | batch_size: scalar 207 | shuffle: boolean 208 | 209 | Returns 210 | xs: tuple of 211 | x: int32 tensor. (N, T1) 212 | x_seqlens: int32 tensor. (N,) 213 | sents1: str tensor. (N,) 214 | ys: tuple of 215 | decoder_input: int32 tensor. (N, T2) 216 | y: int32 tensor. (N, T2) 217 | y_seqlen: int32 tensor. (N, ) 218 | sents2: str tensor. (N,) 219 | """ 220 | shapes = ( 221 | ([None], [None], [None], (), ()), 222 | ([None], [None], [None], (), ()) 223 | ) 224 | 225 | types = ( 226 | (tf.int32, tf.int32, tf.int32, tf.int32, tf.string), 227 | (tf.int32, tf.int32, tf.int32, tf.int32, tf.string) 228 | ) 229 | 230 | paddings = ( 231 | (0, 0, 0, 0, ''), 232 | (0, 0, 0, 0, '') 233 | ) 234 | 235 | dataset = tf.data.Dataset.from_generator( 236 | partial(generator_fn, tokenizer=tokenizer), 237 | output_shapes=shapes, 238 | output_types=types, 239 | args=(sents1, sents2)) # <- arguments for generator_fn. converted to np string arrays 240 | 241 | if shuffle: # for training 242 | dataset = dataset.shuffle(128*batch_size) 243 | 244 | dataset = dataset.repeat() # iterate forever 245 | dataset = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1) 246 | 247 | return dataset 248 | 249 | 250 | @timeit 251 | def prepare_data(inputpath, batch_size, shuffle=False): 252 | """ 253 | """ 254 | files = glob.glob(inputpath + '*.bin')[:2] 255 | 256 | sents1, sents2 = load_data(files) 257 | 258 | tokenizer = create_tokenizer_from_hub_module(BERT_MODEL_URL) 259 | 260 | batches = input_fn(sents1, sents2, tokenizer, batch_size, shuffle=shuffle) 261 | 262 | num_batches = calc_num_batches(len(sents1), batch_size) 263 | 264 | return batches, num_batches, len(sents1) 265 | 266 | 267 | -------------------------------------------------------------------------------- /ops/encoding.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def get_angles(pos, i, d_model): 6 | angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model)) 7 | return pos * angle_rates 8 | 9 | 10 | def positional_encoding(position, d_model): 11 | """ 12 | The positional encoding vector is added to the embedding vector. 13 | Embeddings represent a token in a d-dimensional space where tokens 14 | with similar meaning will be closer to each other. 15 | But the embeddings do not encode the relative position of words in a sentence. 16 | 17 | So after adding the positional encoding, words will be closer to each other 18 | based on the similarity of their meaning and their position in the sentence, 19 | in the d-dimensional space. 20 | 21 | >>> pos_encoding = positional_encoding(50, 512) 22 | >>> print (pos_encoding.shape) 23 | 24 | >>> plt.pcolormesh(pos_encoding[0].eval(), cmap='RdBu') 25 | >>> plt.xlabel('Depth') 26 | >>> plt.xlim((0, 512)) 27 | >>> plt.ylabel('Position') 28 | >>> plt.colorbar() 29 | >>> plt.show() 30 | """ 31 | angle_rads = get_angles( 32 | np.arange(position)[:, np.newaxis], 33 | np.arange(d_model)[np.newaxis, :], 34 | d_model 35 | ) 36 | 37 | # apply sin to even indices in the array; 2i 38 | sines = np.sin(angle_rads[:, 0::2]) 39 | 40 | # apply cos to odd indices in the array; 2i+1 41 | cosines = np.cos(angle_rads[:, 1::2]) 42 | 43 | pos_encoding = np.concatenate([sines, cosines], axis=-1) 44 | 45 | pos_encoding = pos_encoding[np.newaxis, ...] 46 | 47 | return tf.cast(pos_encoding, dtype=tf.float32) 48 | 49 | -------------------------------------------------------------------------------- /ops/masking.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def create_padding_mask(seq): 5 | """ 6 | Mask all the pad tokens in the batch of sequence. 7 | It ensures that the model does not treat padding as the input. 8 | The mask indicates where pad value 0 is present: 9 | it outputs a 1 at those locations, and a 0 otherwise. 10 | """ 11 | seq = tf.cast(tf.math.equal(seq, 0), tf.float32) 12 | 13 | # add extra dimensions so that we can add the padding 14 | # to the attention logits. 15 | return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) 16 | 17 | 18 | def create_look_ahead_mask(size): 19 | """ 20 | The look-ahead mask is used to mask the future tokens in a sequence. 21 | In other words, the mask indicates which entries should not be used. 22 | 23 | This means that to predict the third word, only the first and second word will be used. 24 | Similarly to predict the fourth word, only the first, second and the third word will be used and so on. 25 | """ 26 | mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) 27 | return mask # (seq_len, seq_len) 28 | 29 | 30 | def create_masks(inp, tar): 31 | # Encoder padding mask 32 | # enc_padding_mask = create_padding_mask(inp) 33 | 34 | # Used in the 2nd attention block in the decoder. 35 | # This padding mask is used to mask the encoder outputs. 36 | dec_padding_mask = create_padding_mask(inp) 37 | 38 | # Used in the 1st attention block in the decoder. 39 | # It is used to pad and mask future tokens in the input received by 40 | # the decoder. 41 | look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1]) 42 | dec_target_padding_mask = create_padding_mask(tar) 43 | combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask) 44 | 45 | # return enc_padding_mask, combined_mask, dec_padding_mask 46 | return combined_mask, dec_padding_mask 47 | 48 | 49 | def mask_timestamp(x, i, mask_with): 50 | """ 51 | Masks each word in the summary draft one by one with the [MASK] token 52 | At t-th time step the t-th word of input summary is 53 | masked, and the decoder predicts the refined word given other 54 | words of the summary. 55 | 56 | x :: (N, T) 57 | return :: (N, T) 58 | """ 59 | 60 | N, T = tf.shape(x)[0], tf.shape(x)[1] 61 | 62 | left = x[:, :i] 63 | right = x[:, i+1:] 64 | 65 | mask = tf.ones([N, 1], dtype=x.dtype) * mask_with 66 | 67 | masked = tf.concat([left, mask, right], axis=1) 68 | 69 | return masked 70 | 71 | 72 | def tile_and_mask_diagonal(x, mask_with): 73 | """ 74 | Masks each word in the summary draft one by one with the [MASK] token 75 | At t-th time step the t-th word of input summary is 76 | masked, and the decoder predicts the refined word given other 77 | words of the summary. 78 | 79 | x :: (N, T) 80 | returrn :: (N, T-1, T) 81 | 82 | We do not mask the first and last postition (corresponding to [CLS] 83 | """ 84 | 85 | N, T = tf.shape(x)[0], tf.shape(x)[1] 86 | 87 | first = tf.reshape(tf.tile(x[:, 0], [T-1]), [N, T-1, 1]) 88 | 89 | x = x[:, 1:] 90 | T = T - 1 91 | 92 | masked = tf.reshape(tf.tile(x, [1, T]), [N, T, T]) 93 | 94 | diag = tf.ones([N, T], dtype=masked.dtype) * mask_with 95 | masked = tf.linalg.set_diag(masked, diag) 96 | 97 | masked = tf.concat([first, masked], axis=2) 98 | 99 | masked = tf.reshape(masked, [N*T, T+1]) 100 | 101 | return masked 102 | 103 | 104 | -------------------------------------------------------------------------------- /ops/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | from rouge import Rouge 5 | 6 | 7 | rouge = Rouge() 8 | 9 | 10 | def calculate_rouge(y, y_hat): 11 | """ 12 | Calculate ROUGE scores between the target 'y' and 13 | the model prediction 'y_hat' 14 | """ 15 | 16 | def f(a, b): 17 | rouges = rouge.get_scores(a.decode("utf-8") , b.decode("utf-8") )[0] 18 | r1_val, r2_val, rl_val = rouges['rouge-1']["f"], rouges['rouge-2']["f"], rouges['rouge-l']["f"] 19 | r_avg = np.mean([r1_val, r2_val, rl_val], dtype=np.float64) 20 | return r1_val, r2_val, rl_val, r_avg 21 | 22 | return tf.py_func(f, [y, y_hat], [tf.float64, tf.float64, tf.float64, tf.float64]) 23 | -------------------------------------------------------------------------------- /ops/optimization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def noam_scheme(init_lr, global_step, warmup_steps=4000.): 5 | '''Noam scheme learning rate decay 6 | init_lr: initial learning rate. scalar. 7 | global_step: scalar. 8 | warmup_steps: scalar. During warmup_steps, learning rate increases 9 | until it reaches init_lr. 10 | ''' 11 | step = tf.cast(global_step + 1, dtype=tf.float32) 12 | return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5) -------------------------------------------------------------------------------- /ops/regularization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def label_smoothing(inputs, epsilon=0.1): 5 | '''Applies label smoothing. See 5.4 and https://arxiv.org/abs/1512.00567. 6 | inputs: 3d tensor. [N, T, V], where V is the number of vocabulary. 7 | epsilon: Smoothing rate. 8 | 9 | For example, 10 | 11 | ``` 12 | import tensorflow as tf 13 | inputs = tf.convert_to_tensor([[[0, 0, 1], 14 | [0, 1, 0], 15 | [1, 0, 0]], 16 | [[1, 0, 0], 17 | [1, 0, 0], 18 | [0, 1, 0]]], tf.float32) 19 | 20 | outputs = label_smoothing(inputs) 21 | 22 | with tf.Session() as sess: 23 | print(sess.run([outputs])) 24 | 25 | >> 26 | [array([[[ 0.03333334, 0.03333334, 0.93333334], 27 | [ 0.03333334, 0.93333334, 0.03333334], 28 | [ 0.93333334, 0.03333334, 0.03333334]], 29 | [[ 0.93333334, 0.03333334, 0.03333334], 30 | [ 0.93333334, 0.03333334, 0.03333334], 31 | [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)] 32 | ``` 33 | ''' 34 | V = inputs.get_shape().as_list()[-1] # number of channels 35 | return ((1-epsilon) * inputs) + (epsilon / V) 36 | 37 | -------------------------------------------------------------------------------- /ops/session.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import logging 3 | 4 | from tensorflow.keras import backend as K 5 | 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | 9 | 10 | def initialize_vars(sess): 11 | sess.run(tf.local_variables_initializer()) 12 | sess.run(tf.global_variables_initializer()) 13 | sess.run(tf.tables_initializer()) 14 | K.set_session(sess) 15 | 16 | 17 | def save_variable_specs(fpath): 18 | ''' 19 | Saves information about variables such as 20 | their name, shape, and total parameter number 21 | fpath: string. output file path 22 | Writes 23 | a text file named fpath. 24 | ''' 25 | def _get_size(shp): 26 | '''Gets size of tensor shape 27 | shp: TensorShape 28 | Returns 29 | size 30 | ''' 31 | size = 1 32 | for d in range(len(shp)): 33 | size *=shp[d] 34 | return size 35 | 36 | params, num_params = [], 0 37 | for v in tf.global_variables(): 38 | params.append("{}==={}".format(v.name, v.shape)) 39 | num_params += _get_size(v.shape) 40 | print("num_params: ", num_params) 41 | with open(fpath, 'w') as fout: 42 | fout.write("num_params: {}\n".format(num_params)) 43 | fout.write("\n".join(params)) 44 | logging.info("Variables info has been saved.") -------------------------------------------------------------------------------- /ops/tensor.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def with_column(x, i, column): 5 | """ 6 | Given a tensor `x`, change its i-th column with `column` 7 | x :: (N, T) 8 | return :: (N, T) 9 | """ 10 | 11 | N, T = tf.shape(x)[0], tf.shape(x)[1] 12 | 13 | left = x[:, :i] 14 | right = x[:, i+1:] 15 | 16 | return tf.concat([left, column, right], axis=1) 17 | -------------------------------------------------------------------------------- /ops/tokenization.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_hub as hub 3 | 4 | from bert.tokenization import FullTokenizer 5 | 6 | 7 | BERT_MODEL_URL = "https://tfhub.dev/google/bert_uncased_L-12_H-768_A-12/1" 8 | 9 | 10 | def create_bert_tokenizer(vocab_file, do_lower_case=True): 11 | """ 12 | Return a BERT FullTokenizer 13 | """ 14 | return FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 15 | 16 | 17 | def create_tokenizer_from_hub_module(bert_hub_url, bert_module=None): 18 | """ 19 | Get the vocab file and casing info from the Hub module. 20 | """ 21 | if bert_module is None: 22 | bert_module = hub.Module(bert_hub_url) 23 | 24 | tokenization_info = bert_module(signature="tokenization_info", as_dict=True) 25 | 26 | with tf.Session() as sess: 27 | vocab_file, do_lower_case = sess.run([ 28 | tokenization_info["vocab_file"], 29 | tokenization_info["do_lower_case"] 30 | ]) 31 | 32 | return FullTokenizer(vocab_file=vocab_file, do_lower_case=do_lower_case) 33 | 34 | 35 | tokenizer = create_tokenizer_from_hub_module(BERT_MODEL_URL) 36 | 37 | 38 | def convert_idx_to_token_tensor(inputs, tokenizer=tokenizer): 39 | ''' 40 | Converts int32 tensor to string tensor. 41 | inputs: 1d int32 tensor. indices. 42 | tokenizer :: [int] -> str 43 | Returns 44 | 1d string tensor. 45 | ''' 46 | def f(inputs): 47 | return ' '.join(tokenizer.convert_ids_to_tokens(inputs)) 48 | 49 | return tf.py_func(f, [inputs], tf.string) 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-hub 2 | tensorflow-datasets==1.0.2 3 | tqdm==4.31.1 4 | bert-tensorflow==1.0.1 5 | bunch 6 | rouge==0.3.2 7 | pyyaml 8 | 9 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import math 4 | import shutil 5 | import logging 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import tensorflow as tf 10 | import tensorflow_hub as hub 11 | import tensorflow_datasets as tfds 12 | 13 | from tensorflow.keras import backend as K 14 | from tensorflow.python.keras.initializers import Constant 15 | 16 | from tqdm import tqdm 17 | from config import config 18 | from arguments import args 19 | 20 | from data.load import load_cnn_dailymail 21 | 22 | from random import randint 23 | from rouge import Rouge 24 | 25 | from ops.tokenization import tokenizer 26 | from ops.tokenization import convert_idx_to_token_tensor 27 | 28 | from ops.session import initialize_vars 29 | from ops.session import save_variable_specs 30 | 31 | from ops.metrics import calculate_rouge 32 | from ops.tensor import with_column 33 | from ops.regularization import label_smoothing 34 | from ops.optimization import noam_scheme 35 | 36 | from models.abstractive_summarizer import AbstractiveSummarization 37 | from models.abstractive_summarizer import train 38 | from models.abstractive_summarizer import eval 39 | 40 | 41 | logger = logging.getLogger() 42 | logger.setLevel(logging.INFO) 43 | 44 | tf.logging.set_verbosity(tf.logging.INFO) 45 | tf.enable_resource_variables() 46 | 47 | logging.info('Job Configuration:\n' + str(config)) 48 | 49 | 50 | model = AbstractiveSummarization( 51 | num_layers=config.NUM_LAYERS, 52 | d_model=config.D_MODEL, 53 | num_heads=config.NUM_HEADS, 54 | dff=config.D_FF, 55 | vocab_size=config.VOCAB_SIZE, 56 | input_seq_len=config.INPUT_SEQ_LEN, 57 | output_seq_len=config.OUTPUT_SEQ_LEN, 58 | rate=config.DROPOUT_RATE 59 | ) 60 | 61 | 62 | train_dataset, val_dataset, test_dataset, n_train_examples, n_val_examples, n_test_examples = load_cnn_dailymail() 63 | 64 | n_train_batches = n_train_examples // config.BATCH_SIZE 65 | n_val_batches = n_val_examples // config.BATCH_SIZE 66 | n_test_batches = n_test_examples // config.BATCH_SIZE 67 | 68 | logging.info(f"'{n_train_examples}' training examples, '{n_train_batches}' batches") 69 | logging.info(f"'{n_val_examples}' validation examples, '{n_val_batches}' batches") 70 | logging.info(f"'{n_test_examples}' testing examples, '{n_test_batches}' batches") 71 | 72 | 73 | train_iterator = train_dataset.make_initializable_iterator() 74 | train_stream = train_iterator.get_next() 75 | 76 | xs, ys = train_stream[:3], train_stream[3:] 77 | train_loss, zero_op, accumlation_op, train_op, global_step, train_summaries = train(model, xs, ys, gradient_accumulation=True) 78 | 79 | if args.eval: 80 | val_iterator = val_dataset.make_initializable_iterator() 81 | val_stream = val_iterator.get_next() 82 | 83 | xs, ys = val_stream[:3], val_stream[3:] 84 | y, y_hat, eval_loss, eval_summaries = eval(model, xs, ys) 85 | 86 | saver = tf.train.Saver(max_to_keep=config.NUM_EPOCHS) 87 | 88 | # config_tf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) 89 | config_tf = tf.ConfigProto(allow_soft_placement=True) 90 | config_tf.gpu_options.allow_growth=True 91 | 92 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True) 93 | 94 | 95 | with tf.Session(config=config_tf) as sess: 96 | 97 | if os.path.isdir(config.LOGDIR): 98 | shutil.rmtree(config.LOGDIR) 99 | 100 | os.mkdir(config.LOGDIR) 101 | 102 | ckpt = tf.train.latest_checkpoint(config.CHECKPOINTDIR) 103 | 104 | rouge = Rouge() 105 | 106 | if ckpt is None: 107 | logging.info("Initializing from scratch") 108 | sess.run(tf.global_variables_initializer()) 109 | save_variable_specs(os.path.join(config.LOGDIR, "specs")) 110 | else: 111 | saver.restore(sess, ckpt) 112 | 113 | summary_writer_train = tf.summary.FileWriter(os.path.join(config.LOGDIR), sess.graph) 114 | summary_writer_eval = tf.summary.FileWriter(os.path.join(config.LOGDIR, 'eval'), sess.graph) 115 | 116 | initialize_vars(sess) 117 | 118 | _gs = sess.run(global_step) 119 | 120 | sess.run(train_iterator.initializer) 121 | if args.eval: 122 | sess.run(val_iterator.initializer) 123 | 124 | total_steps = int(config.NUM_EPOCHS * (n_train_batches / config.GRADIENT_ACCUMULATION_N_STEPS)) 125 | 126 | logger.info(f"Running Training Job for '{total_steps}' steps") 127 | 128 | for i in tqdm(range(_gs, total_steps+1)): 129 | 130 | # gradient accumulation mechanism 131 | sess.run(zero_op) 132 | 133 | for i in range(config.GRADIENT_ACCUMULATION_N_STEPS): 134 | sess.run(accumlation_op) 135 | 136 | _loss, _, _gs, _summary = sess.run([train_loss, train_op, global_step, train_summaries], options=run_options) 137 | 138 | epoch = math.ceil(_gs / n_train_batches) 139 | 140 | summary_writer_train.add_summary(_summary, _gs) 141 | summary_writer_train.flush() 142 | 143 | if (_gs % n_train_batches == 0): 144 | 145 | if args.eval: 146 | 147 | logger.info(f"Epoch '{epoch}' done") 148 | logger.info(f"Current training step: '{_gs}") 149 | 150 | _y, _y_hat, _eval_summary = sess.run([y, y_hat, eval_summaries]) 151 | 152 | summary_writer_eval.add_summary(_eval_summary, 0) 153 | summary_writer_eval.flush() 154 | 155 | # monitor a random sample 156 | rnd = randint(0, _y.shape[0] - 1) 157 | 158 | y_rnd = ' '.join(tokenizer.convert_ids_to_tokens(_y[rnd])) 159 | y_hat_rnd = ' '.join(tokenizer.convert_ids_to_tokens(_y_hat[rnd])) 160 | 161 | rouges = rouge.get_scores(y_rnd, y_hat_rnd)[0] 162 | r1_val, r2_val, rl_val = rouges['rouge-1']["f"], rouges['rouge-2']["f"], rouges['rouge-l']["f"] 163 | 164 | print('Target:') 165 | print(y_rnd) 166 | print('Prediction:') 167 | print(y_hat_rnd) 168 | 169 | print(f"ROUGE-1 '{r1_val}'") 170 | print(f"ROUGE-2 '{r2_val}'") 171 | print(f"ROUGE-L '{rl_val}'") 172 | print(f"ROUGE-AVG '{np.mean([r1_val, r2_val, rl_val])}'", '\n--\n') 173 | 174 | logging.info("Checkpoint: Saving Model") 175 | 176 | model_output = f"abstractive_summarization_2019_epoch_{epoch}_loss_{str(round(_loss, 4))}" 177 | 178 | ckpt_name = os.path.join(config.CHECKPOINTDIR, model_output) 179 | 180 | saver.save(sess, ckpt_name, global_step=_gs) 181 | 182 | logging.info(f"After training '{_gs}' steps, '{ckpt_name}' has been saved.") 183 | 184 | model_output = f"abstractive_summarization_2019_final" 185 | ckpt_name = os.path.join(config.CHECKPOINTDIR, model_output) 186 | saver.save(sess, ckpt_name, global_step=_gs) 187 | 188 | summary_writer_train.close() 189 | summary_writer_eval.close() 190 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/raufer/bert-summarization/2302fc8c4117070d234b21e02e51e20dd66c4f6f/utils/__init__.py -------------------------------------------------------------------------------- /utils/decorators.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import wraps 3 | 4 | 5 | def timeit(f): 6 | @wraps(f) 7 | def timed(*args, **kw): 8 | ts = time.time() 9 | result = f(*args, **kw) 10 | te = time.time() 11 | print(f"'{f.__name__}' {round(te - ts, 2)} s") 12 | return result 13 | return timed 14 | 15 | -------------------------------------------------------------------------------- /utils/recipes.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | 4 | def flatten(ls): 5 | """ 6 | Flatten one level of nesting 7 | """ 8 | return chain.from_iterable(ls) 9 | 10 | --------------------------------------------------------------------------------