├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── requirements.txt ├── setup.py └── src ├── __init__.py ├── configs.py ├── embeddings.py ├── feature_generation.py └── tfrecords.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.csv 2 | .pyc 3 | *.txt 4 | *.ipynb -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Carted 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Processing text data at scale with Apache Beam and Cloud Dataflow 2 | 3 | Presents an optimized [Apache Beam](https://beam.apache.org/) pipeline for generating sentence embeddings (runnable on [Cloud Dataflow](https://cloud.google.com/dataflow)). This repository 4 | accompanies our blog post: [Improving Dataflow Pipelines for Text Data Processing](https://www.carted.com/blog/improving-dataflow-pipelines-for-text-data-processing/). 5 | 6 | We assume that you already have a billing enabled [Google Cloud Platform (GCP)](https://cloud.google.com/) project in case 7 | you wanted to run the pipeline on Cloud Dataflow. 8 | 9 | ## Running the code locally 10 | 11 | To run the code locally, first install the dependencies: `pip install -r requirements`. If you cannot 12 | create a [Google Cloud Storage (GCS)](https://cloud.google.com/storage) Bucket then download the data using from 13 | [here](https://www.kaggle.com/rohitganji13/film-genre-classification-using-nlp). We just need the 14 | `train_data.txt` file for our purpose. Also, note that without a GCS Bucket, one cannot 15 | run the pipeline on Cloud Dataflow which is the main objective of this repository. 16 | 17 | After downloading the dataset, make changes to the respective paths and command-line 18 | arguments that use GCS in `main.py`. 19 | 20 | Then execute `python main.py -r DirectRunner`. 21 | 22 | ## Running the code on Cloud Dataflow 23 | 24 | 1. Create a GCS Bucket and note its name. 25 | 2. Then create a folder called `data` inside the Bucket. 26 | 3. Copy over the `train_data.txt` file to the `data` folder: `gsutil cp train_data.txt gs:///data`. 27 | 4. Then run the following from the terminal: 28 | 29 | ```shell 30 | python main.py \ 31 | --project \ 32 | --gcs-bucket 33 | --runner DataflowRunner 34 | ``` 35 | 36 | For more details please refer to our blog post: [Improving Dataflow Pipelines for Text Data Processing](https://www.carted.com/blog/improving-dataflow-pipelines-for-text-data-processing/). 37 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Carted. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from apache_beam.io.tfrecordio import WriteToTFRecord 9 | from datetime import datetime 10 | import apache_beam as beam 11 | import pandas as pd 12 | import functools 13 | import argparse 14 | import logging 15 | import pprint 16 | import math 17 | 18 | logging.getLogger().setLevel(logging.INFO) 19 | 20 | from src.feature_generation import generate_features, DecodeFromTextLineDoFn 21 | from src.configs import get_bert_encoder_config 22 | from src.tfrecords import FeaturesToSerializedExampleFn 23 | from src.embeddings import get_text_encodings 24 | 25 | 26 | # Initialize encoder configuration. 27 | PREPROCESSOR_PATH = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3" 28 | ENCODER_PATH = "https://tfhub.dev/google/universal-sentence-encoder-cmlm/en-base/1" 29 | EMBEDDING_DIM = 768 30 | MAX_SEQ_LEN = 512 31 | 32 | 33 | def main( 34 | project: str, 35 | gcs_bucket: str, 36 | region: str, 37 | machine_type: str, 38 | max_num_workers: str, 39 | runner: str, 40 | chunk_size: int, 41 | ): 42 | job_timestamp = datetime.utcnow().strftime("%y%m%d-%H%M%S") 43 | pipeline_args_dict = { 44 | "job_name": f"dataflow-text-processing-{job_timestamp}", 45 | "machine_type": machine_type, 46 | "num_workers": "1", 47 | "max_num_workers": max_num_workers, 48 | "runner": runner, 49 | "setup_file": "./setup.py", 50 | "project": project, 51 | "region": region, 52 | "gcs_location": f"gs://{gcs_bucket}", 53 | "temp_location": f"gs://{gcs_bucket}/temp", 54 | "staging_location": f"gs://{gcs_bucket}/staging", 55 | "save_main_session": "True", 56 | } 57 | 58 | # Convert the dictionary to a list of (argument, value) tuples and flatten the list. 59 | pipeline_args = [(f"--{k}", v) for k, v in pipeline_args_dict.items()] 60 | pipeline_args = [x for y in pipeline_args for x in y] 61 | 62 | # Load the dataframe for counting the total number of samples it has. For larger 63 | # datasets, this should be performed separately. 64 | train_df = pd.read_csv( 65 | f"gs://{gcs_bucket}/data/train_data.txt", 66 | engine="python", 67 | sep=" ::: ", 68 | names=["id", "movie", "genre", "summary"], 69 | ) 70 | total_examples = len(train_df) 71 | 72 | logging.info( 73 | f"Executing beam pipeline with args:\n{pprint.pformat(pipeline_args_dict)}" 74 | ) 75 | 76 | with beam.Pipeline(argv=pipeline_args) as pipeline: 77 | encoding_config = get_bert_encoder_config( 78 | PREPROCESSOR_PATH, ENCODER_PATH, EMBEDDING_DIM, MAX_SEQ_LEN 79 | ) 80 | configured_encode_examples = functools.partial( 81 | get_text_encodings, config=encoding_config, chunk_size=chunk_size 82 | ) 83 | _ = ( 84 | pipeline 85 | | "Read file" 86 | >> beam.io.ReadFromText( 87 | f"gs://{gcs_bucket}/data/train_data.txt", skip_header_lines=True 88 | ) 89 | | "Parse the file and yield dictionaries" 90 | >> beam.ParDo(DecodeFromTextLineDoFn()) 91 | | "Generate features" 92 | >> beam.ParDo(generate_features, config=encoding_config) 93 | | "Intelligently Batch examples" 94 | >> beam.BatchElements(min_batch_size=chunk_size, max_batch_size=1000) 95 | | "Encode the text features" >> beam.ParDo(configured_encode_examples) 96 | | "Create TF Train example" >> beam.ParDo(FeaturesToSerializedExampleFn()) 97 | | "Write TFRecord to GS Bucket" 98 | >> WriteToTFRecord( 99 | file_path_prefix=f"gs://{gcs_bucket}/tfrecords/", 100 | file_name_suffix=f"{job_timestamp}.tfrecord", 101 | num_shards=math.ceil(total_examples / 50), 102 | ) 103 | ) 104 | 105 | 106 | def parse_arguments(): 107 | parser = argparse.ArgumentParser( 108 | description="Beam pipeline for generating TFRecords from a pandas dataframe." 109 | ) 110 | parser.add_argument( 111 | "-p", 112 | "--project", 113 | default="carted-gcp", 114 | type=str, 115 | help="The name of the GCP project.", 116 | ) 117 | parser.add_argument( 118 | "-b", 119 | "--gcs-bucket", 120 | default="processing-text-data", 121 | type=str, 122 | help="The Google Cloud Storage bucket name.", 123 | ) 124 | parser.add_argument( 125 | "-reg", "--region", default="us-central1", type=str, help="The GCP region.", 126 | ) 127 | parser.add_argument( 128 | "-m", 129 | "--machine-type", 130 | type=str, 131 | default="n1-standard-1", 132 | help="Machine type for the Dataflow workers.", 133 | ) 134 | parser.add_argument( 135 | "-w", 136 | "--max-num-workers", 137 | default="25", 138 | type=str, 139 | help="Number of maximum workers for Dataflow", 140 | ) 141 | parser.add_argument( 142 | "-r", 143 | "--runner", 144 | type=str, 145 | choices=["DirectRunner", "DataflowRunner"], 146 | help="The runner for the pipeline.", 147 | ) 148 | parser.add_argument( 149 | "-cs", 150 | "--chunk-size", 151 | type=int, 152 | default=50, 153 | help="Chunk size to use during BERT encoding.", 154 | ) 155 | args = parser.parse_args() 156 | return vars(args) 157 | 158 | 159 | if __name__ == "__main__": 160 | args = parse_arguments() 161 | main(**args) 162 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apache-beam[gcp]==2.31.0 2 | tensorflow==2.6.0 3 | tensorflow-estimator==2.6.0 4 | keras==2.6.0 5 | sentence_splitter==1.4 6 | seaborn==0.11.2 7 | pandas==1.3.2 8 | tensorflow_hub==0.12.0 9 | tensorflow_text==2.6.0 10 | ml_collections==0.1.0 11 | protobuf==3.18.0 12 | python-snappy==0.6.0 13 | google-apitools==0.5.31 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Carted. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import setuptools 9 | 10 | 11 | NAME = "processing_text_data" 12 | VERSION = "0.1.0" 13 | REQUIRED_PACKAGES = [ 14 | "apache-beam[gcp]==2.31.0", 15 | "tensorflow==2.6.0", 16 | "tensorflow-estimator==2.6.0", 17 | "keras==2.6.0", 18 | "sentence_splitter==1.4", 19 | "seaborn==0.11.2", 20 | "pandas==1.3.2", 21 | "tensorflow_hub==0.12.0", 22 | "tensorflow_text==2.6.0", 23 | "ml_collections==0.1.0", 24 | "protobuf==3.18.0", 25 | "python-snappy==0.6.0", 26 | "google-apitools==0.5.31", 27 | ] 28 | 29 | 30 | setuptools.setup( 31 | name=NAME, 32 | version=VERSION, 33 | install_requires=REQUIRED_PACKAGES, 34 | packages=setuptools.find_packages(), 35 | include_package_data=True, 36 | ) 37 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carted/processing-text-data/8e968f2817077be27ebc4aea6431841fe4b1f62d/src/__init__.py -------------------------------------------------------------------------------- /src/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Carted. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import ml_collections 9 | import tensorflow as tf 10 | 11 | 12 | def get_bert_encoder_config( 13 | preprocessor_path: str, 14 | encoder_path: str, 15 | embedding_dim: int, 16 | max_seq_len: int, 17 | trainable: bool = False, 18 | ) -> ml_collections.ConfigDict: 19 | config = ml_collections.ConfigDict() 20 | config.name = "bert" 21 | config.input_type = tf.int32 22 | config.preprocessor_path = preprocessor_path 23 | config.encoder_path = encoder_path 24 | config.encoder_inputs = ["input_word_ids", "input_type_ids", "input_mask"] 25 | config.embedding_dim = embedding_dim 26 | config.max_seq_len = max_seq_len 27 | config.trainable = trainable 28 | config.output_dim = embedding_dim 29 | return config 30 | -------------------------------------------------------------------------------- /src/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Carted. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | from ml_collections import ConfigDict 9 | from typing import List, Dict, Any 10 | 11 | import tensorflow_text 12 | import tensorflow_hub as hub 13 | import tensorflow as tf 14 | 15 | 16 | def contiguous_group_average_vectors(vectors, groups): 17 | """Works iff sum(groups) == len(vectors) 18 | 19 | Example: 20 | Inputs: vectors: A dense 2D tensor of shape = (13, 3) 21 | groups : A dense 1D tensor with values [2, 5, 1, 4, 1] 22 | indicating that there are 5 groups. 23 | 24 | Objective: Compute a 5x3 matrix where the first row 25 | is the average of the rows 0-1 of `vectors`, 26 | the second row is the average of rows 2-6 of 27 | vectors, the third row is the row 7 of vectors, 28 | the fourth row is the average of rows 8-11 of 29 | vectors and the fifth and final row is the row 30 | 12 of vectors. 31 | 32 | Logic: A selection mask matrix is generated 33 | mask = [[1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.] 34 | [0. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.] 35 | [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.] 36 | [0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0.] 37 | [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]] 38 | 39 | This mask is then multiplied with `vectors` to get a 40 | matrix of shape (5, 3) called `summed_vectors` where 41 | each row contains the group sums. 42 | 43 | `summed_vectors` is then devided by `groups` to 44 | obtain the averages. 45 | Author: Nilabhra Roy Chowdhury (@Nilabhra) 46 | """ 47 | groups = tf.expand_dims(tf.cast(groups, dtype=tf.int32), axis=1) 48 | group_cumsum = tf.cumsum(groups) 49 | 50 | mask = tf.repeat( 51 | tf.expand_dims(tf.range(tf.shape(vectors)[0]), axis=0), 52 | repeats=tf.shape(groups)[0], 53 | axis=0, 54 | ) 55 | mask = tf.cast(mask < group_cumsum, dtype=tf.float32) 56 | 57 | def complete_mask(mask): 58 | neg_mask = tf.concat( 59 | (tf.expand_dims(tf.ones_like(mask[0]), axis=0), 1 - mask[:-1]), axis=0 60 | ) 61 | return mask * neg_mask 62 | 63 | mask = tf.cond( 64 | tf.greater(tf.shape(groups)[0], 1), 65 | true_fn=lambda: complete_mask(mask), 66 | false_fn=lambda: mask, 67 | ) 68 | 69 | summed_vectors = tf.matmul(mask, vectors) 70 | averaged_vectors = summed_vectors / tf.cast(groups, dtype=tf.float32) 71 | 72 | return averaged_vectors 73 | 74 | 75 | def get_text_encodings( 76 | examples: List[Dict[str, Any]], config: ConfigDict, chunk_size=50 77 | ): 78 | """Generates average text encodings from text descriptions. 79 | 80 | Many of the utilities used in this function were written by Nilabhra 81 | Roy Chowdhury (@Nilabhra). 82 | """ 83 | # Loading the text preprocessor and setting it as an attribute on the 84 | # first invocation. 85 | if not hasattr(get_text_encodings, "preprocessor"): 86 | get_text_encodings.preprocessor = hub.load(config.preprocessor_path) 87 | 88 | # Loading the encoder and setting it as an attribute on the first invocation. 89 | if not hasattr(get_text_encodings, "encoder"): 90 | get_text_encodings.encoder = hub.load(config.encoder_path) 91 | 92 | def prepare_bert_inputs(tokens: List[tf.RaggedTensor], token_lens: List[tf.Tensor]): 93 | """Pack the tokens w.r.t BERT inputs.""" 94 | max_token_len = tf.reduce_max(token_lens) 95 | packer = hub.KerasLayer( 96 | get_text_encodings.preprocessor.bert_pack_inputs, 97 | arguments={ 98 | "seq_length": tf.math.minimum( 99 | # +2 to consider the [CLS] and [SEP] tokens. 100 | max_token_len + 2, 101 | config.max_seq_len, 102 | ) 103 | }, 104 | ) 105 | bert_inputs = packer([tokens]) 106 | return bert_inputs 107 | 108 | def encode_with_bert(inputs: Dict[str, tf.Tensor]): 109 | """Computes encodings with BERT.""" 110 | bert_outputs = get_text_encodings.encoder(inputs) 111 | return bert_outputs["pooled_output"] 112 | 113 | def compute_text_encoding( 114 | tokens: List[tf.RaggedTensor], token_lens: List[tf.Tensor] 115 | ): 116 | """Packs BERT inputs and then computes text encodings.""" 117 | bert_inputs = prepare_bert_inputs(tokens, token_lens) 118 | bert_outputs = encode_with_bert(bert_inputs) 119 | return bert_outputs 120 | 121 | # Gather text related features. 122 | text_tokens, text_token_lens, text_num_sentences = [], [], [] 123 | for example in examples: 124 | text_tokens.extend(example["summary_tokens"]) 125 | text_token_lens.extend(example["summary_token_lens"]) 126 | text_num_sentences.append(example["summary_num_sentences"]) 127 | 128 | text_tokens = tf.stack(text_tokens) 129 | text_token_lens = tf.stack(text_token_lens) 130 | text_num_sentences = tf.stack(text_num_sentences) 131 | 132 | # Encode with BERT in a batch-wise manner to prevent OOM. 133 | len_idx = len(text_token_lens) 134 | all_bert_encodings = [] 135 | 136 | # Sort sequences to reduce compute waste. 137 | sort_idx = tf.argsort(text_token_lens, direction="DESCENDING", axis=0) 138 | unsort_idx = tf.argsort( 139 | sort_idx, direction="ASCENDING", axis=0 140 | ) # indices to unsort the sorted embeddings 141 | 142 | sorted_all_text_tokens = tf.gather(text_tokens, sort_idx, axis=0) 143 | sorted_all_text_token_lens = tf.gather(text_token_lens, sort_idx, axis=0) 144 | 145 | for idx in range(0, len_idx, chunk_size): 146 | bert_encodings = compute_text_encoding( 147 | sorted_all_text_tokens[idx : idx + chunk_size], 148 | sorted_all_text_token_lens[idx : idx + chunk_size], 149 | ) 150 | all_bert_encodings.append(bert_encodings) 151 | 152 | all_bert_encodings = tf.concat(all_bert_encodings, axis=0) 153 | 154 | # Unsort the encodings. 155 | all_bert_encodings = tf.gather(all_bert_encodings, unsort_idx, axis=0) 156 | 157 | # Perform averaging. 158 | averaged_encodings = contiguous_group_average_vectors( 159 | all_bert_encodings, text_num_sentences 160 | ) 161 | 162 | # Add the encodings to the individual examples. 163 | for i, encoding in enumerate(averaged_encodings): 164 | examples[i]["summary_average_embeddings"] = encoding 165 | 166 | return examples 167 | -------------------------------------------------------------------------------- /src/feature_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Carted. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import tensorflow as tf 9 | import tensorflow_hub as hub 10 | import apache_beam as beam 11 | import pandas as pd 12 | 13 | from typing import Union, Dict, Any, List, Tuple 14 | from sentence_splitter import split_text_into_sentences 15 | from ml_collections import ConfigDict 16 | 17 | 18 | def generate_features(example: Union[pd.Series, Dict[str, Any]], config: ConfigDict): 19 | """Generates embeddings and labels from the dataset examples. 20 | 21 | Many of the utilities used in this function were written by Nilabhra 22 | Roy Chowdhury (@Nilabhra). 23 | """ 24 | if not hasattr(generate_features, "tokenizer"): 25 | generate_features.tokenizer = hub.load(config.preprocessor_path).tokenize 26 | 27 | def _tokenize_text(text: List[str],) -> Tuple[tf.RaggedTensor, List[int]]: 28 | """Tokenizes a list of sentences. 29 | Args: 30 | text (List[str]): A list of sentences. 31 | Returns: 32 | Tuple[tf.RaggedTensor, List[int]]: Tokenized and indexed sentences, list 33 | containing the number of tokens per sentence. 34 | """ 35 | token_list = generate_features.tokenizer(tf.constant(text)) 36 | token_lens = [tokens.flat_values.shape[-1] for tokens in token_list] 37 | return token_list, token_lens 38 | 39 | text_features = {} 40 | split_sentences = split_text_into_sentences(example["summary"], language="en") 41 | 42 | text_features["summary"] = example["summary"] 43 | tokenized_text_feature, token_lens = _tokenize_text(split_sentences) 44 | text_features["summary_tokens"] = tokenized_text_feature 45 | text_features["summary_token_lens"] = token_lens 46 | text_features["summary_num_sentences"] = len(token_lens) 47 | 48 | text_features["label"] = example["genre"] 49 | return [text_features] 50 | 51 | 52 | class DecodeFromTextLineDoFn(beam.DoFn): 53 | # Mind the space. 54 | def __init__(self, delimiter=" ::: "): 55 | self.delimiter = delimiter 56 | 57 | def process(self, text_line): 58 | splits = text_line.split(self.delimiter) 59 | genre = splits[-2] 60 | summary = splits[-1] 61 | packed_dict = {"summary": summary, "genre": genre} 62 | yield packed_dict 63 | -------------------------------------------------------------------------------- /src/tfrecords.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Carted. 2 | 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import tensorflow as tf 9 | import apache_beam as beam 10 | 11 | from typing import Dict, Union 12 | 13 | 14 | def _bytes_feature(bytes_input: bytes) -> tf.train.Feature: 15 | """Encodes given data as a byte feature.""" 16 | bytes_list = tf.train.BytesList(value=[bytes_input]) 17 | return tf.train.Feature(bytes_list=bytes_list) 18 | 19 | 20 | def _floats_feature(value): 21 | """Returns a float_list from a float / double.""" 22 | if not isinstance(value, list): 23 | value = [value] 24 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 25 | 26 | 27 | def create_tfr_example( 28 | raw_features: Dict[str, Union[Dict[str, str], tf.Tensor, str]] 29 | ) -> tf.Tensor: 30 | """Creates a tf.train.Example instance from high level features.""" 31 | feature = { 32 | "summary": _bytes_feature(raw_features["summary"].encode("utf-8")), 33 | "summary_average_embeddings": _floats_feature( 34 | raw_features["summary_average_embeddings"].numpy().tolist() 35 | ), 36 | "label": _bytes_feature(raw_features["label"].encode("utf-8")), 37 | } 38 | # Wrap as a training example. 39 | feature = tf.train.Features(feature=feature) 40 | example = tf.train.Example(features=feature) 41 | return example 42 | 43 | 44 | class FeaturesToSerializedExampleFn(beam.DoFn): 45 | """DoFn class to create a tf.train.Example from high level features.""" 46 | 47 | def process(self, features): 48 | example = create_tfr_example(features) 49 | yield example.SerializeToString() 50 | --------------------------------------------------------------------------------