├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── config └── model.config └── src ├── aggregators.py ├── bert_core.py ├── create_tf_record_data.py ├── data_preprocessing.py ├── encoders.py ├── metric.py ├── tokenizer.py ├── train.py ├── trainer.py └── twinbertgnn.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improve Text Encoder via Graph Neural Network 2 | Code for the BERT version implementation of the TextGNN model in WWW 2021 paper: [TextGNN: Improve Text Encoder via Graph Neural Network](https://arxiv.org/abs/2101.06323) 3 | 4 | 5 | ## Requirements: 6 | * **Tensorflow 2.2.0** 7 | * Python 3.7 8 | * CUDA 10.1+ (For GPU) 9 | * HuggingFace transformers 10 | * HuggingFace wandb (For logging) 11 | 12 | ## Example Training Command 13 | $ python train.py --do_train --do_eval --train_data_size 400000000 --train_data_path ../data/QK_Neighbor/Teacher/ --eval_train_data_path ../data/QK_Neighbor/Teacher_Eval/ --eval_data_path ../data/QK_Neighbor/Validation/ --config_path ../config/model.config --output_dir ../outputs/model --logging_dir ../logging/model --per_device_train_batch_size 512 --per_device_eval_batch_size 512 --evaluate_during_training --overwrite_output_dir --learning_rate 1e-4 --warmup_steps 2000 --num_train_epochs 2.0 --pretrained_bert_name bert-base-uncased --eval_steps 10000 --logging_steps 10000 --save_steps 10000 14 | 15 | ## Example Inference Command 16 | $ python train.py --do_predict --test_data_path ../data/QK_Neighbor/Test/ --config_path ../config/model.config --output_dir ../outputs/model --logging_dir ../logging/model 17 | 18 | ## Acknowledgements: 19 | This code base was heavily adapted from the HuggingFace Transformers repository: https://github.com/huggingface/transformers. 20 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /config/model.config: -------------------------------------------------------------------------------- 1 | hidden_size:int 768 2 | num_hidden_layers:int 3 3 | num_attention_heads:int 8 4 | intermediate_size:int 3072 5 | max_position_embeddings:int 512 6 | type_vocab_size:int 2 7 | hidden_dropout_prob:float 0.1 8 | attention_probs_dropout_prob:float 0.1 9 | initializer_range:float 0.02 10 | hidden_act:str gelu 11 | pooler_type:str weightpooler 12 | sim_type:str feedforward 13 | crossing_res:bool True 14 | res_size:int 2048 15 | res_bn:bool False 16 | comb_type:str concat 17 | use_two_bert:bool False 18 | use_two_gnn:bool True 19 | vocab_size:int 30522 20 | embedding_type:str bpe 21 | triletter_max_letters_in_word:int 20 22 | quantization_side:str neither 23 | tanh_pooler:bool False 24 | downscale:int 0 25 | max_seq_len:int 16 26 | max_seq_len_doc:int 16 27 | post_processing:bool False 28 | loss_type:str mse 29 | ckpt_layer_mapping:str 0,1,2 30 | a_fanouts:str 3 31 | b_fanouts:str 3 32 | gnn_model:str gat 33 | weighted_gnn_type:str click,impression,ctr 34 | hidden_dims:str 768 35 | gnn_acts:str leaky_relu,leaky_relu 36 | gnn_add_residual:bool False 37 | aggregator:str meanpool 38 | gnn_concat_residual:bool True 39 | use_residual:bool False 40 | head_nums:str 4 41 | agg_dropout:float 0.0 42 | agg_concat:bool True 43 | agg_bias:bool True 44 | weight_decay:float 0.0 45 | agg_model_size:str small 46 | num_classes:int 2 47 | bert_trainable:bool True 48 | tb_loss:bool False 49 | use_two_crossings:bool False -------------------------------------------------------------------------------- /src/aggregators.py: -------------------------------------------------------------------------------- 1 | ## GraphSage aggregators 2 | import logging 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from transformers.modeling_tf_utils import shape_list 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class MeanAggregator(tf.keras.layers.Layer): 12 | 13 | def __init__(self, config, out_dim, activation='relu', identity_act=False, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 17 | self.concat = config.agg_concat 18 | self.add_bias = config.agg_bias 19 | self.out_dim = out_dim 20 | self.identity_act = identity_act 21 | 22 | self.self_weights = tf.keras.layers.Dense( 23 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="self_weight" 24 | ) 25 | self.neigh_weights = tf.keras.layers.Dense( 26 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="neigh_weights" 27 | ) 28 | 29 | if not self.identity_act: 30 | self.act = tf.keras.layers.Activation(activation) 31 | 32 | def build(self, input_shape): 33 | if self.add_bias: 34 | with tf.name_scope('bias'): 35 | self.bias = self.add_weight( 36 | "weight", 37 | shape=[self.out_dim * 2 if self.concat else self.out_dim], 38 | initializer='zeros', 39 | ) 40 | super().build(input_shape) 41 | 42 | def call(self, inputs, training=False): 43 | self_vecs, neigh_vecs = inputs 44 | 45 | neigh_vecs = self.dropout(neigh_vecs, training=training) 46 | self_vecs = self.dropout(self_vecs, training=training) 47 | neigh_means = tf.reduce_mean(neigh_vecs, axis=1) 48 | 49 | from_neighs = self.neigh_weights(neigh_means) 50 | from_self = self.self_weights(self_vecs) 51 | 52 | if not self.concat: 53 | output = tf.add_n([from_self, from_neighs]) 54 | else: 55 | output = tf.concat([from_self, from_neighs], axis=1) 56 | 57 | if self.add_bias: 58 | output += self.bias 59 | 60 | if self.identity_act: return output 61 | return self.act(output) 62 | 63 | 64 | class GCNAggregator(tf.keras.layers.Layer): 65 | 66 | def __init__(self, config, out_dim, activation='relu', identity_act=False, **kwargs): 67 | super().__init__(**kwargs) 68 | 69 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 70 | self.out_dim = out_dim 71 | 72 | if identity_act: 73 | self.neigh_weights = tf.keras.layers.Dense( 74 | self.out_dim, use_bias=config.agg_bias, kernel_initializer='glorot_uniform', 75 | name="neigh_weights" 76 | ) 77 | else: 78 | self.neigh_weights = tf.keras.layers.Dense( 79 | self.out_dim, use_bias=config.agg_bias, activation=activation, kernel_initializer='glorot_uniform', 80 | name="neigh_weights" 81 | ) 82 | 83 | def call(self, inputs, training=False): 84 | self_vecs, neigh_vecs = inputs 85 | 86 | neigh_vecs = self.dropout(neigh_vecs, training=training) 87 | self_vecs = self.dropout(self_vecs, training=training) 88 | means = tf.reduce_mean(tf.concat([neigh_vecs, 89 | tf.expand_dims(self_vecs, axis=1)], axis=1), axis=1) 90 | 91 | return self.neigh_weights(means) 92 | 93 | 94 | class MaxPoolingAggregator(tf.keras.layers.Layer): 95 | 96 | def __init__(self, config, out_dim, activation='relu', identity_act=False, **kwargs): 97 | super().__init__(**kwargs) 98 | 99 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 100 | self.concat = config.agg_concat 101 | self.add_bias = config.agg_bias 102 | self.out_dim = out_dim 103 | self.identity_act = identity_act 104 | self.hidden_dim = 512 if config.agg_model_size == 'small' else 1024 105 | 106 | self.mlp_layers = [] 107 | self.mlp_layers.append(tf.keras.layers.Dense( 108 | self.hidden_dim, activation='relu', kernel_initializer='glorot_uniform', 109 | kernel_regularizer=tf.keras.regularizers.l2(config.weight_decay), name="neigh_mlp" 110 | )) 111 | 112 | self.self_weights = tf.keras.layers.Dense( 113 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="self_weight" 114 | ) 115 | self.neigh_weights = tf.keras.layers.Dense( 116 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="neigh_weights" 117 | ) 118 | 119 | if not self.identity_act: 120 | self.act = tf.keras.layers.Activation(activation) 121 | 122 | def build(self, input_shape): 123 | if self.add_bias: 124 | with tf.name_scope('bias'): 125 | self.bias = self.add_weight( 126 | "weight", 127 | shape=[self.out_dim * 2 if self.concat else self.out_dim], 128 | initializer='zeros', 129 | ) 130 | super().build(input_shape) 131 | 132 | def call(self, inputs, training=False): 133 | self_vecs, neigh_vecs = inputs 134 | 135 | for l in self.mlp_layers: 136 | neigh_vecs = self.dropout(neigh_vecs, training=training) 137 | neigh_vecs = l(neigh_vecs) 138 | neigh_vecs = tf.reduce_max(neigh_vecs, axis=1) 139 | 140 | from_neighs = self.neigh_weights(neigh_vecs) 141 | from_self = self.self_weights(self_vecs) 142 | 143 | if not self.concat: 144 | output = tf.add_n([from_self, from_neighs]) 145 | else: 146 | output = tf.concat([from_self, from_neighs], axis=1) 147 | 148 | if self.add_bias: 149 | output += self.bias 150 | 151 | if self.identity_act: return output 152 | return self.act(output) 153 | 154 | 155 | class MeanPoolingAggregator(tf.keras.layers.Layer): 156 | 157 | def __init__(self, config, out_dim, activation='relu', identity_act=False, **kwargs): 158 | super().__init__(**kwargs) 159 | 160 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 161 | self.concat = config.agg_concat 162 | self.add_bias = config.agg_bias 163 | self.out_dim = out_dim 164 | self.identity_act = identity_act 165 | self.hidden_dim = 512 if config.agg_model_size == 'small' else 1024 166 | 167 | self.mlp_layers = [] 168 | self.mlp_layers.append(tf.keras.layers.Dense( 169 | self.hidden_dim, activation='relu', kernel_initializer='glorot_uniform', 170 | kernel_regularizer=tf.keras.regularizers.l2(config.weight_decay), name="neigh_mlp" 171 | )) 172 | 173 | self.self_weights = tf.keras.layers.Dense( 174 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="self_weight" 175 | ) 176 | self.neigh_weights = tf.keras.layers.Dense( 177 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="neigh_weights" 178 | ) 179 | 180 | if not self.identity_act: 181 | self.act = tf.keras.layers.Activation(activation) 182 | 183 | def build(self, input_shape): 184 | if self.add_bias: 185 | with tf.name_scope('bias'): 186 | self.bias = self.add_weight( 187 | "weight", 188 | shape=[self.out_dim * 2 if self.concat else self.out_dim], 189 | initializer='zeros', 190 | ) 191 | super().build(input_shape) 192 | 193 | def call(self, inputs, training=False): 194 | self_vecs, neigh_vecs = inputs 195 | 196 | for l in self.mlp_layers: 197 | neigh_vecs = self.dropout(neigh_vecs, training=training) 198 | neigh_vecs = l(neigh_vecs) 199 | neigh_vecs = tf.reduce_mean(neigh_vecs, axis=1) 200 | 201 | from_neighs = self.neigh_weights(neigh_vecs) 202 | from_self = self.self_weights(self_vecs) 203 | 204 | if not self.concat: 205 | output = tf.add_n([from_self, from_neighs]) 206 | else: 207 | output = tf.concat([from_self, from_neighs], axis=1) 208 | 209 | if self.add_bias: 210 | output += self.bias 211 | 212 | if self.identity_act: return output 213 | return self.act(output) 214 | 215 | 216 | class TwoMaxLayerPoolingAggregator(tf.keras.layers.Layer): 217 | 218 | def __init__(self, config, out_dim, activation='relu', identity_act=False, **kwargs): 219 | super().__init__(**kwargs) 220 | 221 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 222 | self.concat = config.agg_concat 223 | self.add_bias = config.agg_bias 224 | self.out_dim = out_dim 225 | self.identity_act = identity_act 226 | self.hidden_dim_1 = 512 if config.agg_model_size == 'small' else 1024 227 | self.hidden_dim_1 = 256 if config.agg_model_size == 'small' else 512 228 | 229 | self.mlp_layers = [] 230 | self.mlp_layers.append(tf.keras.layers.Dense( 231 | self.hidden_dim1, activation='relu', kernel_initializer='glorot_uniform', 232 | kernel_regularizer=tf.keras.regularizers.l2(config.weight_decay), name="neigh_mlp_1" 233 | )) 234 | self.mlp_layers.append(tf.keras.layers.Dense( 235 | self.hidden_dim2, activation='relu', kernel_initializer='glorot_uniform', 236 | kernel_regularizer=tf.keras.regularizers.l2(config.weight_decay), name="neigh_mlp_2" 237 | )) 238 | 239 | self.self_weights = tf.keras.layers.Dense( 240 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="self_weight" 241 | ) 242 | self.neigh_weights = tf.keras.layers.Dense( 243 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="neigh_weights" 244 | ) 245 | 246 | if not self.identity_act: 247 | self.act = tf.keras.layers.Activation(activation) 248 | 249 | def build(self, input_shape): 250 | if self.add_bias: 251 | with tf.name_scope('bias'): 252 | self.bias = self.add_weight( 253 | "weight", 254 | shape=[self.out_dim * 2 if self.concat else self.out_dim], 255 | initializer='zeros', 256 | ) 257 | super().build(input_shape) 258 | 259 | def call(self, inputs, training=False): 260 | self_vecs, neigh_vecs = inputs 261 | 262 | for l in self.mlp_layers: 263 | neigh_vecs = self.dropout(neigh_vecs, training=training) 264 | neigh_vecs = l(neigh_vecs) 265 | neigh_vecs = tf.reduce_max(neigh_vecs, axis=1) 266 | 267 | from_neighs = self.neigh_weights(neigh_vecs) 268 | from_self = self.self_weights(self_vecs) 269 | 270 | if not self.concat: 271 | output = tf.add_n([from_self, from_neighs]) 272 | else: 273 | output = tf.concat([from_self, from_neighs], axis=1) 274 | 275 | if self.add_bias: 276 | output += self.bias 277 | 278 | if self.identity_act: return output 279 | return self.act(output) 280 | 281 | 282 | class SeqAggregator(tf.keras.layers.Layer): 283 | 284 | def __init__(self, config, out_dim, activation='relu', identity_act=False, **kwargs): 285 | super().__init__(**kwargs) 286 | 287 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 288 | self.concat = config.agg_concat 289 | self.add_bias = config.agg_bias 290 | self.out_dim = out_dim 291 | self.identity_act = identity_act 292 | self.hidden_dim = 128 if config.agg_model_size == 'small' else 256 293 | 294 | self.lstm = tf.keras.layers.LSTM(self.hidden_dim) 295 | 296 | self.self_weights = tf.keras.layers.Dense( 297 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="self_weight" 298 | ) 299 | self.neigh_weights = tf.keras.layers.Dense( 300 | self.out_dim, use_bias=False, kernel_initializer='glorot_uniform', name="neigh_weights" 301 | ) 302 | 303 | if not self.identity_act: 304 | self.act = tf.keras.layers.Activation(activation) 305 | 306 | def build(self, input_shape): 307 | if self.add_bias: 308 | with tf.name_scope('bias'): 309 | self.bias = self.add_weight( 310 | "weight", 311 | shape=[self.out_dim * 2 if self.concat else self.out_dim], 312 | initializer='zeros', 313 | ) 314 | super().build(input_shape) 315 | 316 | def call(self, inputs, training=False): 317 | self_vecs, neigh_vecs = inputs 318 | 319 | mask = tf.cast(tf.sign(tf.reduce_max(tf.abs(x), axis=2)), dtype=tf.bool) 320 | batch_size = shape_list(mask)[0] 321 | mask = tf.concat([tf.constant(np.ones([batch_size, 1]), dtype=tf.bool), mask[:, 1:]], axis=1) 322 | 323 | rnn_outputs = self.lstm(inputs=neigh_vecs, mask=mask) 324 | 325 | from_neighs = self.neigh_weights(rnn_outputs) 326 | from_self = self.self_weights(self_vecs) 327 | 328 | if not self.concat: 329 | output = tf.add_n([from_self, from_neighs]) 330 | else: 331 | output = tf.concat([from_self, from_neighs], axis=1) 332 | 333 | if self.add_bias: 334 | output += self.bias 335 | 336 | if self.identity_act: return output 337 | return self.act(output) 338 | 339 | 340 | class NodePredict(tf.keras.layers.Layer): 341 | def __init__(self, config, **kwargs): 342 | super().__init__(**kwargs) 343 | 344 | self.dense = tf.keras.layers.Dense( 345 | config.num_classes, kernel_initializer='glorot_uniform', name="dense" 346 | ) 347 | self.dropout = tf.keras.layers.Dropout(config.agg_dropout) 348 | 349 | def call(self, inputs, training=False): 350 | node_preds = self.dense(inputs) 351 | node_preds = self.dropout(node_preds, training=training) 352 | return node_preds 353 | 354 | 355 | aggregators = { 356 | 'gcn': GCNAggregator, 357 | 'mean': MeanAggregator, 358 | 'meanpool': MeanPoolingAggregator, 359 | 'maxpool': MaxPoolingAggregator, 360 | 'twomaxpool': TwoMaxLayerPoolingAggregator, 361 | 'seq': SeqAggregator, 362 | 'nodepred': NodePredict 363 | } 364 | 365 | 366 | def get(aggregator): 367 | return aggregators.get(aggregator) -------------------------------------------------------------------------------- /src/bert_core.py: -------------------------------------------------------------------------------- 1 | """Generalize BERT model with Triletter or Simpiied BPE encoder""" 2 | 3 | import tensorflow as tf 4 | from transformers import BertConfig, TFBertMainLayer, TFBertPreTrainedModel 5 | from transformers.modeling_tf_bert import TFBertEncoder, TFBertPooler, TFBertEmbeddings, TFBertPredictionHeadTransform 6 | from transformers.modeling_tf_utils import keras_serializable, shape_list, get_initializer 7 | from transformers.tokenization_utils import BatchEncoding 8 | 9 | 10 | class TFBertEmbeddingsSimple(tf.keras.layers.Layer): 11 | """Construct the embeddings with only word embeddings. 12 | """ 13 | 14 | def __init__(self, config, **kwargs): 15 | super(TFBertEmbeddingsSimple, self).__init__(**kwargs) 16 | self.vocab_size = config.vocab_size 17 | self.hidden_size = config.hidden_size 18 | self.initializer_range = config.initializer_range 19 | 20 | self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings, config.hidden_size, embeddings_initializer=get_initializer(self.initializer_range), name='position_embeddings') 21 | 22 | self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name='LayerNorm') 23 | self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) 24 | 25 | def build(self, input_shape): 26 | with tf.name_scope("word_embeddings"): 27 | self.word_embeddings = self.add_weight( 28 | "weight", 29 | shape=[self.vocab_size, self.hidden_size], 30 | initializer=get_initializer(self.initializer_range) 31 | ) 32 | 33 | def call(self, inputs, training=False): 34 | input_ids, position_ids, token_type_ids, inputs_embeds = inputs 35 | input_shape = shape_list(input_ids) 36 | 37 | if inputs_embeds is None: 38 | inputs_embeds = tf.gather(self.word_embeddings, input_ids) 39 | position_embeddings = self.position_embeddings(position_ids) 40 | 41 | embeddings = inputs_embeds + position_embeddings 42 | embeddings = self.LayerNorm(embeddings) 43 | embeddings = self.dropout(embeddings, training=training) 44 | return embeddings 45 | 46 | 47 | class TriletterEmbeddings(tf.keras.layers.Layer): 48 | """Comparing with TriletterEmbeddingsSimple, this one has position encoding, name is a little bit ugly, but not breaking anything 49 | """ 50 | 51 | def __init__(self, config, **kwargs): 52 | super(TriletterEmbeddings, self).__init__(**kwargs) 53 | self.triletter_max_letters_in_word = config.triletter_max_letters_in_word # 20, so 20 triletters 54 | self.triletter_embeddings = tf.keras.layers.Embedding(config.vocab_size + 1, config.hidden_size, 55 | mask_zero=True, name='triletter_embeddings') 56 | self.position_embeddings = tf.keras.layers.Embedding(config.max_position_embeddings + 1, 57 | config.hidden_size, mask_zero=True, 58 | name='position_embeddings') 59 | 60 | def call(self, inputs, training=False): 61 | input_ids, position_ids, token_type_ids, inputs_embeds = inputs 62 | triletter_max_seq_len = shape_list(input_ids)[1] // self.triletter_max_letters_in_word 63 | position_embeddings = self.position_embeddings(position_ids) 64 | 65 | embeddings = self.triletter_embeddings(input_ids) # [N, 12*[20], hidden_size] 66 | 67 | embeddings = tf.reshape(embeddings, [-1, triletter_max_seq_len, self.triletter_max_letters_in_word, 68 | shape_list(embeddings)[-1]]) 69 | embeddings = tf.reshape(tf.reduce_sum(embeddings, axis=2), 70 | [-1, triletter_max_seq_len, shape_list(embeddings)[-1]]) 71 | 72 | embeddings = embeddings + position_embeddings 73 | 74 | return embeddings 75 | 76 | 77 | class BERTCore(tf.keras.layers.Layer): 78 | config_class = BertConfig 79 | 80 | def __init__(self, config, **kwargs): 81 | super(BERTCore, self).__init__(**kwargs) 82 | self.config = config 83 | 84 | self.num_hidden_layers = self.config.num_hidden_layers 85 | self.initializer_range = self.config.initializer_range 86 | self.output_attentions = self.config.output_attentions 87 | self.output_hidden_states = self.config.output_hidden_states 88 | 89 | if self.config.embedding_type == 'triletter': 90 | self.embeddings = TriletterEmbeddings(self.config) 91 | elif self.config.embedding_type == 'bpe_simple': 92 | self.embeddings = TFBertEmbeddingsSimple(self.config, name="embeddings") 93 | else: 94 | self.embeddings = TFBertEmbeddings(self.config, name="embeddings") 95 | 96 | self.encoder = TFBertEncoder(self.config, name="encoder") 97 | self.pooler = TFBertPooler(self.config, name='pooler') 98 | 99 | def get_input_embeddings(self): 100 | return self.embeddings 101 | 102 | def _prune_heads(self, heads_to_prune): 103 | """ Prunes heads of the model. 104 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 105 | See base class PreTrainedModel 106 | """ 107 | raise NotImplementedError 108 | 109 | def call( 110 | self, 111 | inputs, 112 | attention_mask=None, 113 | token_type_ids=None, 114 | position_ids=None, 115 | head_mask=None, 116 | inputs_embeds=None, 117 | output_attentions=None, 118 | output_hidden_states=None, 119 | training=False, 120 | ): 121 | if isinstance(inputs, (tuple, list)): 122 | input_ids = inputs[0] 123 | attention_mask = inputs[1] if len(inputs) > 1 else attention_mask 124 | token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids 125 | position_ids = inputs[3] if len(inputs) > 3 else position_ids 126 | head_mask = inputs[4] if len(inputs) > 4 else head_mask 127 | inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds 128 | output_attentions = inputs[6] if len(inputs) > 6 else output_attentions 129 | output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states 130 | assert len(inputs) <= 8, "Too many inputs." 131 | elif isinstance(inputs, (dict, BatchEncoding)): 132 | input_ids = inputs.get("input_ids") 133 | attention_mask = inputs.get("attention_mask", attention_mask) 134 | token_type_ids = inputs.get("token_type_ids", token_type_ids) 135 | position_ids = inputs.get("position_ids", position_ids) 136 | head_mask = inputs.get("head_mask", head_mask) 137 | inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) 138 | output_attentions = inputs.get("output_attentions", output_attentions) 139 | output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) 140 | assert len(inputs) <= 8, "Too many inputs." 141 | else: 142 | input_ids = inputs 143 | 144 | output_attentions = output_attentions if output_attentions is not None else self.output_attentions 145 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states 146 | 147 | if input_ids is not None and inputs_embeds is not None: 148 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 149 | elif input_ids is not None: 150 | input_shape = shape_list(input_ids) 151 | elif inputs_embeds is not None: 152 | input_shape = shape_list(inputs_embeds)[:-1] 153 | else: 154 | raise ValueError("You have to specify either input_ids or inputs_embeds") 155 | 156 | if attention_mask is None: 157 | if type(self.embeddings) == TriletterEmbeddings: 158 | attention_mask = tf.ones( 159 | [input_shape[0], input_shape[1] // self.embeddings.triletter_max_letters_in_word]) 160 | else: 161 | attention_mask = tf.fill(input_shape, 1) 162 | if token_type_ids is None: 163 | if type(self.embeddings) == TriletterEmbeddings: 164 | token_type_ids = tf.zeros( 165 | [input_shape[0], input_shape[1] // self.embeddings.triletter_max_letters_in_word]) 166 | else: 167 | token_type_ids = tf.fill(input_shape, 0) 168 | if position_ids is None: 169 | if type(self.embeddings) == TriletterEmbeddings: 170 | position_ids = (tf.range(input_shape[1] // self.embeddings.triletter_max_letters_in_word, 171 | dtype=tf.int32) + 1)[tf.newaxis, :] 172 | else: 173 | position_ids = tf.range(int(input_shape[1]), dtype=tf.int32)[tf.newaxis, :] 174 | position_ids = tf.where(attention_mask == 0, tf.zeros_like(position_ids), position_ids) 175 | 176 | extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :] 177 | extended_attention_mask = tf.cast(extended_attention_mask, tf.float32) 178 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 179 | 180 | if head_mask is not None: 181 | raise NotImplementedError 182 | else: 183 | head_mask = [None] * self.num_hidden_layers 184 | 185 | embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training) 186 | encoder_outputs = self.encoder( 187 | [embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states], 188 | training=training, 189 | ) 190 | sequence_output = encoder_outputs[0] 191 | pooled_output = self.pooler(sequence_output) 192 | 193 | outputs = (sequence_output, pooled_output,) + encoder_outputs[ 194 | 1:] # add hidden_states and attentions if they are here 195 | 196 | return outputs # sequence_output, pooled_output, (hidden_states), (attentions) 197 | 198 | 199 | class BERTModel(TFBertPreTrainedModel): 200 | def prune_heads(self, heads_to_prune): 201 | raise NotImplementedError 202 | 203 | def __init__(self, config): 204 | super(BERTModel, self).__init__(config) 205 | self.config = config 206 | self.bert = BERTCore(self.config, name="bert") 207 | 208 | def call(self, inputs, **kwargs): 209 | outputs = self.bert(inputs, **kwargs) 210 | return outputs 211 | -------------------------------------------------------------------------------- /src/create_tf_record_data.py: -------------------------------------------------------------------------------- 1 | """Script for pre-processing raw data""" 2 | 3 | import logging 4 | import argparse 5 | from data_preprocessing import TriLetterExtractor, process_datasets_to_file 6 | from tokenizer import TwinBertTokenizer 7 | from transformers import BertTokenizer 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--teacher_data_path', type=str, default="../../../data/QK_Graph/QK_ANN_Neighbor_Teacher.tsv", help="Path to teacher data") 13 | parser.add_argument('--teacher_data_save_path', type=str, default="../data/QK_ANN_Neighbor/Teacher/", help="Path to save teacher data") 14 | parser.add_argument('--teacher_eval_data_path', type=str, default="../../../data/QK_Graph/QK_ANN_Neighbor_Teacher.tsv", help="Path to teacher eval data") 15 | parser.add_argument('--teacher_eval_data_save_path', type=str, default="../data/QK_ANN_Neighbor/Teacher_Eval/", help="Path to save teacher eval data") 16 | parser.add_argument('--validation_data_path', type=str, default="../../../data/QK_Graph/QK_ANN_Neighbor_Validation.tsv", help="Path to validation data") 17 | parser.add_argument('--validation_data_save_path', type=str, default="../data/QK_ANN_Neighbor/Validation/", help="Path to save validation data") 18 | parser.add_argument('--vocab_path', type=str, default="../config/l3g.txt", help="Path to vocab file, only used by triletter tokenizer") 19 | parser.add_argument('--chunksize', type=int, default=1e6, help="Pandas loading chunksize") 20 | parser.add_argument('--skip_chunk', type=int, default=0, help="Number of chunks to skip") 21 | parser.add_argument('--n_chunk', type=int, default=0, help="Number of chunks to process") 22 | parser.add_argument('--tokenizer_type', type=str, default="bpe", help="Tokenizer type") 23 | parser.add_argument('--max_n_letters', type=int, default=20, help="Only used by triletter tokenizer") 24 | parser.add_argument('--max_seq_len', type=int, default=16, help="Max length of sequence") 25 | parser.add_argument('--tokenizer_name', type=str, default="bert-base-uncased", help="Pre-trained Bert tokenizer name") 26 | parser.add_argument('--a_fanouts', type=str, default="3", help="a fanouts") 27 | parser.add_argument('--b_fanouts', type=str, default="3", help="b fanouts") 28 | args = parser.parse_args() 29 | 30 | 31 | def main(): 32 | logging.basicConfig( 33 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 34 | datefmt="%m/%d/%Y %H:%M:%S", 35 | level=logging.INFO, 36 | ) 37 | 38 | if args.tokenizer_type == "triletter": 39 | extractor = TriLetterExtractor(args.vocab_path) 40 | else: 41 | extractor = BertTokenizer.from_pretrained(args.tokenizer_name) 42 | 43 | try: 44 | process_datasets_to_file(args.teacher_data_path, extractor, args.teacher_data_save_path, max_seq_len=args.max_seq_len, max_n_letters=args.max_n_letters, int_label=False, chunksize=args.chunksize, a_fanouts=list(map(int, args.a_fanouts.split(","))), b_fanouts=list(map(int, args.b_fanouts.split(","))), skip_chunk=args.skip_chunk, n_chunk=args.n_chunk) 45 | except Exception as e: 46 | logger.info(e) 47 | logger.info("Cannot load from raw teacher data") 48 | 49 | 50 | try: 51 | process_datasets_to_file(args.teacher_eval_data_path, extractor, args.teacher_eval_data_save_path, max_seq_len=args.max_seq_len, max_n_letters=args.max_n_letters, int_label=False, chunksize=args.chunksize, top=1000000, convert_to_int=True, a_fanouts=list(map(int, args.a_fanouts.split(","))), b_fanouts=list(map(int, args.b_fanouts.split(","))), n_chunk=args.n_chunk) 52 | except Exception as e: 53 | logger.info(e) 54 | logger.info("Cannot load from raw teacher eval data") 55 | 56 | 57 | try: 58 | print("Start processing validation data") 59 | process_datasets_to_file(args.validation_data_path, extractor, args.validation_data_save_path, max_seq_len=args.max_seq_len, max_n_letters=args.max_n_letters, int_label=True, chunksize=args.chunksize, a_fanouts=list(map(int, args.a_fanouts.split(","))), b_fanouts=list(map(int, args.b_fanouts.split(",")))) 60 | except Exception as e: 61 | logger.info(e) 62 | logger.info("Cannot load from raw validation data") 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /src/data_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Helper functions for data pre-processing""" 2 | 3 | from dataclasses import dataclass, asdict 4 | from typing import Optional, Union, List 5 | import re 6 | import collections 7 | import logging 8 | import tensorflow as tf 9 | import pandas as pd 10 | import json 11 | import os 12 | import glob 13 | import numpy as np 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | # Class for triletter vocab dictionary 19 | class L3G: 20 | dict = {} 21 | invdict = {} 22 | 23 | def __init__(self, l3g_path): 24 | with open(l3g_path, 'r', encoding='utf-8') as fp: 25 | i = 1 # note that the dictionary now increases all trigram index by 1!!! 26 | while True: 27 | s = fp.readline().strip('\n\r') 28 | if s == '': 29 | break 30 | self.dict[s] = i 31 | self.invdict[i] = s 32 | i += 1 33 | return 34 | 35 | 36 | # Triletter Encoder Class 37 | class TriLetterExtractor: 38 | def __init__(self, l3g_path, dim=49292): 39 | self.l3ginst = L3G(l3g_path) 40 | self.dimension = dim 41 | self.content = [] 42 | self.n_seq = 0 43 | self.invalid = re.compile('[^a-zA-Z0-9 ]') 44 | self.multispace = re.compile(' +') 45 | self.max_token_num = 12 46 | 47 | def extract_features(self, qstr, max_n_letters=20, max_seq_len=12): 48 | qseq, qmask = self.extract_from_sentence(qstr, max_n_letters, max_seq_len) # add word index if needed 49 | return qseq, qmask 50 | 51 | def extract_from_words(self, words, max_n_letters=20, max_seq_len=12): 52 | valid_words = [] 53 | for _, word in enumerate(words): 54 | word = self.invalid.sub('', word) 55 | word = word.strip() 56 | if word != '': 57 | valid_words.append(word) 58 | return self._from_words_to_id_sequence(valid_words, max_n_letters, max_seq_len) 59 | 60 | def extract_from_sentence(self, text, max_n_letters=20, max_seq_len=12): 61 | step1 = text.lower() 62 | step2 = self.invalid.sub('', step1) 63 | step3 = self.multispace.sub(' ', step2) 64 | step4 = step3.strip() 65 | words = step4.split(' ') 66 | return self._from_words_to_id_sequence(words, max_n_letters, max_seq_len) 67 | 68 | def _from_words_to_id_sequence(self, words, max_n_letters=20, max_seq_len=12): 69 | n_seq = min(len(words), max_seq_len) 70 | n_letter = max_n_letters 71 | feature_seq = [0] * (max_seq_len * max_n_letters) 72 | seq_mask = [0] * max_seq_len 73 | for i in range(n_seq): 74 | if words[i] == '': 75 | words[i] = '#' 76 | word = '#' + words[i] + '#' 77 | n_letter = min(len(word) - 2, max_n_letters) 78 | for j in range(n_letter): 79 | s = word[j:(j + 3)] 80 | if s in self.l3ginst.dict: 81 | feature_seq[i * max_n_letters + j] = self.l3ginst.dict[s] 82 | seq_mask[i] = 1 83 | return feature_seq, seq_mask 84 | 85 | 86 | @dataclass 87 | class InputExample: 88 | text_a: str 89 | text_b: Optional[str] = None 90 | label: Optional[Union[int, float]] = None 91 | text_a_neighbors: Optional[List[str]] = None 92 | text_b_neighbors: Optional[List[str]] = None 93 | text_a_neighbors_impression: Optional[List[int]] = None 94 | text_b_neighbors_impression: Optional[List[int]] = None 95 | text_a_neighbors_click: Optional[List[int]] = None 96 | text_b_neighbors_click: Optional[List[int]] = None 97 | 98 | def to_json_string(self): 99 | return json.dumps(asdict(self), indent=2) + "\n" 100 | 101 | 102 | @dataclass 103 | class InputFeatures: 104 | input_ids_a: List[int] 105 | input_ids_b: List[int] 106 | 107 | attention_mask_a: Optional[List[int]] = None 108 | attention_mask_b: Optional[List[int]] = None 109 | 110 | input_ids_a_neighbor: Optional[List[List[int]]] = None 111 | input_ids_b_neighbor: Optional[List[List[int]]] = None 112 | 113 | attention_mask_a_neighbor: Optional[List[List[int]]] = None 114 | attention_mask_b_neighbor: Optional[List[List[int]]] = None 115 | 116 | impression_a_neighbor: Optional[List[int]] = None 117 | impression_b_neighbor: Optional[List[int]] = None 118 | 119 | click_a_neighbor: Optional[List[int]] = None 120 | click_b_neighbor: Optional[List[int]] = None 121 | 122 | label: Optional[Union[int, float]] = None 123 | 124 | def to_json_string(self): 125 | return json.dumps(asdict(self)) + "\n" 126 | 127 | 128 | # Read a single pandas row into a InputExample 129 | def get_example_from_row(row, a_fanouts=[], b_fanouts=[], int_label=True, read_label=True, read_neigh_weights=True): 130 | text_a_neighbors = None 131 | text_b_neighbors = None 132 | text_a_neighbors_impression = None 133 | text_b_neighbors_impression = None 134 | text_a_neighbors_click = None 135 | text_b_neighbors_click = None 136 | 137 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 138 | layer_node = 1 139 | text_a_neighbors = [] 140 | if read_neigh_weights: 141 | text_a_neighbors_impression = [] 142 | text_a_neighbors_click = [] 143 | 144 | for layer in range(len(a_fanouts)): 145 | layer_node *= a_fanouts[layer] 146 | neighbors = row['Q_Neighbor_' + str(layer)].split('!!@@##$$') 147 | if read_neigh_weights: 148 | neighbors_impression = 0 if neighbors[0] == "" else list(map(int, row['Q_Neighbor_Impression_' + str(layer)].split('!!@@##$$'))) 149 | neighbors_click = 0 if neighbors[0] == "" else list(map(int, row['Q_Neighbor_Click_' + str(layer)].split('!!@@##$$'))) 150 | 151 | if len(neighbors) == 1 and neighbors[0] == "": 152 | text_a_neighbors += ["[PAD]"] * layer_node 153 | if read_neigh_weights: 154 | text_a_neighbors_impression += [1] * layer_node 155 | text_a_neighbors_click += [0] * layer_node 156 | else: 157 | neigh_length = min(len(neighbors), layer_node) 158 | text_a_neighbors += neighbors[:neigh_length] + ["[PAD]"] * (layer_node - neigh_length) 159 | if read_neigh_weights: 160 | text_a_neighbors_impression += neighbors_impression[:neigh_length] + [1] * (layer_node - neigh_length) 161 | text_a_neighbors_click += neighbors_click[:neigh_length] + [0] * (layer_node - neigh_length) 162 | 163 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 164 | layer_node = 1 165 | text_b_neighbors = [] 166 | if read_neigh_weights: 167 | text_b_neighbors_impression = [] 168 | text_b_neighbors_click = [] 169 | 170 | for layer in range(len(b_fanouts)): 171 | layer_node *= b_fanouts[layer] 172 | neighbors = row['K_Neighbor_' + str(layer)].split('!!@@##$$') 173 | if read_neigh_weights: 174 | neighbors_impression = 0 if neighbors[0] == "" else list(map(int, row['K_Neighbor_Impression_' + str(layer)].split('!!@@##$$'))) 175 | neighbors_click = 0 if neighbors[0] == "" else list(map(int, row['K_Neighbor_Click_' + str(layer)].split('!!@@##$$'))) 176 | 177 | if len(neighbors) == 1 and neighbors[0] == "": 178 | text_b_neighbors += ["[PAD]"] * layer_node 179 | if read_neigh_weights: 180 | text_b_neighbors_impression += [1] * layer_node 181 | text_b_neighbors_click += [0] * layer_node 182 | else: 183 | neigh_length = min(len(neighbors), layer_node) 184 | text_b_neighbors += neighbors[:neigh_length] + ["[PAD]"] * (layer_node - neigh_length) 185 | if read_neigh_weights: 186 | text_b_neighbors_impression += neighbors_impression[:neigh_length] + [1] * (layer_node - neigh_length) 187 | text_b_neighbors_click += neighbors_click[:neigh_length] + [0] * (layer_node - neigh_length) 188 | 189 | return InputExample( 190 | text_a=row['Query'], 191 | text_b=row['Keyword'], 192 | label=(row['QK_Rel'] if int_label else row['RoBERTaScore']) if read_label else None, 193 | text_a_neighbors=text_a_neighbors, 194 | text_b_neighbors=text_b_neighbors, 195 | text_a_neighbors_impression=text_a_neighbors_impression, 196 | text_b_neighbors_impression=text_b_neighbors_impression, 197 | text_a_neighbors_click=text_a_neighbors_click, 198 | text_b_neighbors_click=text_b_neighbors_click 199 | ) 200 | 201 | 202 | # Convert pandas DataFrame into a list of InputExamples 203 | def get_examples_from_pd(data, int_label=True, a_fanouts=[], b_fanouts=[]): 204 | return [get_example_from_row(data.iloc[i], a_fanouts=a_fanouts, b_fanouts=b_fanouts, int_label=int_label) for i in 205 | range(len(data))] 206 | 207 | 208 | # Convert a list of InputExamples into a list of InputFeatures 209 | def convert_examples_to_features(examples, extractor, max_seq_len=12, max_n_letters=20, a_fanouts=[], b_fanouts=[]): 210 | def label_from_example(ex: InputExample) -> Union[int, float, None]: 211 | label = ex.label 212 | return None if (label is None or label < 0) else label 213 | 214 | features = [] 215 | 216 | if type(extractor) == TriLetterExtractor: 217 | for example in examples: 218 | input_ids_a, attention_mask_a = extractor.extract_features(example.text_a, max_n_letters=max_n_letters, 219 | max_seq_len=max_seq_len) 220 | input_ids_b, attention_mask_b = extractor.extract_features(example.text_b, max_n_letters=max_n_letters, 221 | max_seq_len=max_seq_len) 222 | input_ids_a_neighbor = None 223 | input_ids_b_neighbor = None 224 | attention_mask_a_neighbor = None 225 | attention_mask_b_neighbor = None 226 | impression_a_neighbor = None 227 | impression_b_neighbor = None 228 | click_a_neighbor = None 229 | click_b_neighbor = None 230 | 231 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 232 | input_ids_a_neighbor = [] 233 | attention_mask_a_neighbor = [] 234 | impression_a_neighbor = example.text_a_neighbors_impression 235 | click_a_neighbor = example.text_a_neighbors_click 236 | 237 | for text in example.text_a_neighbors: 238 | input_ids, attention_mask = extractor.extract_features(text, max_n_letters=max_n_letters, 239 | max_seq_len=max_seq_len) 240 | input_ids_a_neighbor.append(input_ids) 241 | attention_mask_a_neighbor.append(attention_mask) 242 | 243 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 244 | input_ids_b_neighbor = [] 245 | attention_mask_b_neighbor = [] 246 | impression_b_neighbor = example.text_b_neighbors_impression 247 | click_b_neighbor = example.text_b_neighbors_click 248 | 249 | for text in example.text_b_neighbors: 250 | input_ids, attention_mask = extractor.extract_features(text, max_n_letters=max_n_letters, 251 | max_seq_len=max_seq_len) 252 | input_ids_b_neighbor.append(input_ids) 253 | attention_mask_b_neighbor.append(attention_mask) 254 | 255 | feature = { 256 | 'input_ids_a': input_ids_a, 257 | 'attention_mask_a': attention_mask_a, 258 | 'input_ids_b': input_ids_b, 259 | 'attention_mask_b': attention_mask_b, 260 | 'label': label_from_example(example), 261 | 'input_ids_a_neighbor': input_ids_a_neighbor, 262 | 'attention_mask_a_neighbor': attention_mask_a_neighbor, 263 | 'input_ids_b_neighbor': input_ids_b_neighbor, 264 | 'attention_mask_b_neighbor': attention_mask_b_neighbor, 265 | 'impression_a_neighbor': impression_a_neighbor, 266 | 'impression_b_neighbor': impression_b_neighbor, 267 | 'click_a_neighbor': click_a_neighbor, 268 | 'click_b_neighbor': click_b_neighbor 269 | } 270 | features.append(InputFeatures(**feature)) 271 | else: 272 | labels = [label_from_example(example) for example in examples] 273 | text_a_list = [example.text_a for example in examples] 274 | text_b_list = [example.text_b for example in examples] 275 | 276 | batch_encoding_a = extractor(text_a_list, max_length=max_seq_len, pad_to_max_length=True, 277 | return_token_type_ids=False, truncation=True) 278 | batch_encoding_b = extractor(text_b_list, max_length=max_seq_len, pad_to_max_length=True, 279 | return_token_type_ids=False, truncation=True) 280 | 281 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 282 | a_neighbors = [extractor(example.text_a_neighbors, max_length=max_seq_len, pad_to_max_length=True, 283 | return_token_type_ids=False, truncation=True) for example in examples] 284 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 285 | b_neighbors = [extractor(example.text_b_neighbors, max_length=max_seq_len, pad_to_max_length=True, 286 | return_token_type_ids=False, truncation=True) for example in examples] 287 | 288 | for i in range(len(text_a_list)): 289 | inputs = {k + "_a": batch_encoding_a[k][i] for k in batch_encoding_a} 290 | inputs.update({k + "_b": batch_encoding_b[k][i] for k in batch_encoding_b}) 291 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 292 | inputs.update({k + "_a_neighbor": a_neighbors[i][k] for k in a_neighbors[i]}) 293 | inputs.update({"impression_a_neighbor": examples[i].text_a_neighbors_impression, 294 | "click_a_neighbor": examples[i].text_a_neighbors_click}) 295 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 296 | inputs.update({k + "_b_neighbor": b_neighbors[i][k] for k in b_neighbors[i]}) 297 | inputs.update({"impression_b_neighbor": examples[i].text_b_neighbors_impression, 298 | "click_b_neighbor": examples[i].text_b_neighbors_click}) 299 | feature = InputFeatures(**inputs, label=labels[i]) 300 | features.append(feature) 301 | 302 | return features 303 | 304 | 305 | # Pandas parsing functions to prevent data error 306 | def convert_int(x): 307 | try: 308 | return int(x) 309 | except Exception as e: 310 | print(e) 311 | return -99 312 | 313 | 314 | def convert_float(x): 315 | try: 316 | return float(x) 317 | except Exception as e: 318 | print(e) 319 | return -99.0 320 | 321 | 322 | def convert_float_to_int(x): 323 | try: 324 | y = float(x) 325 | return 0 if y < 0.5 else 1 326 | except Exception as e: 327 | print(e) 328 | return -99 329 | 330 | 331 | # Helper function to create tf dataset features 332 | def create_int_feature(values): 333 | f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) 334 | return f 335 | 336 | 337 | def create_float_feature(values): 338 | f = tf.train.Feature(float_list=tf.train.FloatList(value=list(values))) 339 | return f 340 | 341 | 342 | def feature_to_dict(f, int_label=True, a_fanouts=[], b_fanouts=[]): 343 | f_s = {} 344 | f_s["input_ids_a_0"] = tf.convert_to_tensor(np.array(f.input_ids_a), dtype=tf.int32)[tf.newaxis,:] 345 | f_s["input_ids_b_0"] = tf.convert_to_tensor(np.array(f.input_ids_a), dtype=tf.int32)[tf.newaxis,:] 346 | f_s["attention_mask_a_0"] = tf.convert_to_tensor(np.array(f.input_ids_a), dtype=tf.int32)[tf.newaxis,:] 347 | f_s["attention_mask_b_0"] = tf.convert_to_tensor(np.array(f.input_ids_a), dtype=tf.int32)[tf.newaxis,:] 348 | 349 | 350 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 351 | layer_node = 1 352 | for layer in range(len(a_fanouts)): 353 | layer_node *= a_fanouts[layer] 354 | f_s['input_ids_a_' + str(layer+1)] = tf.convert_to_tensor(np.array(f.input_ids_a_neighbor), dtype=tf.int32)[tf.newaxis,:,:] 355 | f_s['attention_mask_a_' + str(layer+1)] = tf.convert_to_tensor(np.array(f.attention_mask_a_neighbor), dtype=tf.int32)[tf.newaxis,:,:] 356 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 357 | layer_node = 1 358 | for layer in range(len(b_fanouts)): 359 | layer_node *= b_fanouts[layer] 360 | f_s['input_ids_b_' + str(layer+1)] = tf.convert_to_tensor(np.array(f.input_ids_b_neighbor), dtype=tf.int32)[tf.newaxis,:,:] 361 | f_s['attention_mask_b_' + str(layer+1)] = tf.convert_to_tensor(np.array(f.attention_mask_b_neighbor), dtype=tf.int32)[tf.newaxis,:,:] 362 | return f_s 363 | 364 | # Read raw tsv file by chunks and save to tf record files 365 | def process_datasets_to_file(data_path: str, extractor, write_filename: str, max_seq_len=12, max_n_letters=20, 366 | int_label=True, chunksize=1e6, top=None, convert_to_int=False, a_fanouts=[], b_fanouts=[], skip_chunk=0, n_chunk=0): 367 | names = ['Query', 'Keyword', 'QK_Rel' if (int_label or convert_to_int) else 'RoBERTaScore'] 368 | converters = {"Query": str, "Keyword": str} 369 | 370 | if int_label: 371 | converters["QK_Rel"] = convert_int 372 | elif convert_to_int: 373 | converters["QK_Rel"] = convert_float_to_int 374 | else: 375 | converters["RoBERTaScore"] = convert_float 376 | 377 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 378 | layer_node = 1 379 | for layer in range(len(a_fanouts)): 380 | names.append('Q_Neighbor_' + str(layer)) 381 | converters['Q_Neighbor_' + str(layer)] = str 382 | names.append('Q_Neighbor_Impression_' + str(layer)) 383 | converters['Q_Neighbor_Impression_' + str(layer)] = str 384 | names.append('Q_Neighbor_Click_' + str(layer)) 385 | converters['Q_Neighbor_Click_' + str(layer)] = str 386 | 387 | names.append('Q_Dist') 388 | converters['Q_Dist'] = convert_int 389 | 390 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 391 | layer_node = 1 392 | for layer in range(len(b_fanouts)): 393 | names.append('K_Neighbor_' + str(layer)) 394 | converters['K_Neighbor_' + str(layer)] = str 395 | names.append('K_Neighbor_Impression_' + str(layer)) 396 | converters['K_Neighbor_Impression_' + str(layer)] = str 397 | names.append('K_Neighbor_Click_' + str(layer)) 398 | converters['K_Neighbor_Click_' + str(layer)] = str 399 | 400 | names.append('K_Dist') 401 | converters['K_Dist'] = convert_int 402 | 403 | names.append('rand') 404 | names.append('rank') 405 | converters['rand'] = convert_float 406 | converters['rank'] = convert_int 407 | 408 | logger.info(names) 409 | logger.info(converters) 410 | 411 | int_label = int_label or convert_to_int 412 | 413 | chunk_n = skip_chunk 414 | 415 | count = 0 416 | for chunk in pd.read_csv(data_path, sep='\t', header=None, names=names, error_bad_lines=False, 417 | converters=converters, chunksize=chunksize, skiprows=int(skip_chunk * chunksize)): 418 | writer = tf.io.TFRecordWriter(os.path.join(write_filename, 'data_%d.tf_record' % chunk_n)) 419 | if int_label: 420 | chunk = chunk.loc[chunk["QK_Rel"] >= 0] 421 | else: 422 | chunk = chunk.loc[chunk["RoBERTaScore"] >= 0.0] 423 | 424 | logger.info("Process chunk %d" % chunk_n) 425 | 426 | logger.info(len(chunk)) 427 | examples = get_examples_from_pd(chunk, int_label=int_label, a_fanouts=a_fanouts, b_fanouts=b_fanouts) 428 | logger.info("Finish loading examples") 429 | features = convert_examples_to_features(examples, extractor, max_n_letters=max_n_letters, 430 | max_seq_len=max_seq_len, a_fanouts=a_fanouts, b_fanouts=b_fanouts) 431 | logger.info("Finish converting features") 432 | 433 | for i, f in enumerate(features): 434 | logger.info(count) 435 | count += 1 436 | f_s = collections.OrderedDict() 437 | f_s["input_ids_a"] = create_int_feature(f.input_ids_a) 438 | f_s["input_ids_b"] = create_int_feature(f.input_ids_b) 439 | f_s["attention_mask_a"] = create_int_feature(f.attention_mask_a) 440 | f_s["attention_mask_b"] = create_int_feature(f.attention_mask_b) 441 | f_s["label"] = create_int_feature([f.label]) if int_label else create_float_feature([f.label]) 442 | 443 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 444 | layer_node = 1 445 | n = 0 446 | for layer in range(len(a_fanouts)): 447 | layer_node *= a_fanouts[layer] 448 | for k in range(layer_node): 449 | f_s['input_ids_a_' + str(layer) + '_' + str(k)] = create_int_feature(f.input_ids_a_neighbor[n]) 450 | f_s['attention_mask_a_' + str(layer) + '_' + str(k)] = create_int_feature(f.attention_mask_a_neighbor[n]) 451 | f_s['impression_a_neighbor_' + str(layer) + '_' + str(k)] = create_int_feature([f.impression_a_neighbor[n]]) 452 | f_s['click_a_neighbor_' + str(layer) + '_' + str(k)] = create_int_feature( 453 | [f.click_a_neighbor[n]]) 454 | n += 1 455 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 456 | layer_node = 1 457 | n = 0 458 | for layer in range(len(b_fanouts)): 459 | layer_node *= b_fanouts[layer] 460 | for k in range(layer_node): 461 | f_s['input_ids_b_' + str(layer) + '_' + str(k)] = create_int_feature(f.input_ids_b_neighbor[n]) 462 | f_s['attention_mask_b_' + str(layer) + '_' + str(k)] = create_int_feature(f.attention_mask_b_neighbor[n]) 463 | f_s['impression_b_neighbor_' + str(layer) + '_' + str(k)] = create_int_feature( 464 | [f.impression_b_neighbor[n]]) 465 | f_s['click_b_neighbor_' + str(layer) + '_' + str(k)] = create_int_feature( 466 | [f.click_b_neighbor[n]]) 467 | n += 1 468 | 469 | tf_example = tf.train.Example(features=tf.train.Features(feature=f_s)) 470 | writer.write(tf_example.SerializeToString()) 471 | chunk_n += 1 472 | if chunk_n - skip_chunk == n_chunk: 473 | break 474 | 475 | if top and count >= top: 476 | break 477 | 478 | writer.close() 479 | 480 | 481 | # Read and parse tf record dataset into desired input format 482 | def read_preprocessed_datasets(filepath: str, max_seq_len=16, max_n_letters=20, is_triletter=False, int_label=True, a_fanouts=[], b_fanouts=[], neigh_weights=False): 483 | filenames = glob.glob(os.path.join(filepath, '*.tf_record')) 484 | 485 | def atoi(text): 486 | return int(text) if text.isdigit() else text 487 | 488 | def natural_keys(text): 489 | return [atoi(c) for c in re.split('(\d+)', text)] 490 | 491 | filenames.sort(key=natural_keys) 492 | dataset = tf.data.TFRecordDataset(filenames) 493 | length = max_n_letters * max_seq_len if is_triletter else max_seq_len 494 | 495 | def _decode_record(record): 496 | name_to_features = {"input_ids_a": tf.io.FixedLenFeature([length], tf.int64), 497 | "input_ids_b": tf.io.FixedLenFeature([length], tf.int64), 498 | "attention_mask_a": tf.io.FixedLenFeature([max_seq_len], tf.int64), 499 | "attention_mask_b": tf.io.FixedLenFeature([max_seq_len], tf.int64), 500 | "label": tf.io.FixedLenFeature([], tf.int64) if int_label else tf.io.FixedLenFeature([], 501 | tf.float32) 502 | } 503 | 504 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 505 | layer_node = 1 506 | for layer in range(len(a_fanouts)): 507 | layer_node *= a_fanouts[layer] 508 | for k in range(layer_node): 509 | name_to_features['input_ids_a_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([length], tf.int64) 510 | name_to_features['attention_mask_a_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([max_seq_len], tf.int64) 511 | if neigh_weights: 512 | name_to_features['impression_a_neighbor_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([], tf.int64) 513 | name_to_features['click_a_neighbor_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([], tf.int64) 514 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 515 | layer_node = 1 516 | for layer in range(len(b_fanouts)): 517 | layer_node *= b_fanouts[layer] 518 | for k in range(layer_node): 519 | name_to_features['input_ids_b_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([length], tf.int64) 520 | name_to_features['attention_mask_b_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([max_seq_len], tf.int64) 521 | if neigh_weights: 522 | name_to_features['impression_b_neighbor_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([], tf.int64) 523 | name_to_features['click_b_neighbor_' + str(layer) + '_' + str(k)] = tf.io.FixedLenFeature([], tf.int64) 524 | example = tf.io.parse_single_example(record, name_to_features) 525 | 526 | for name in list(example.keys()): 527 | if name != "label" or int_label: 528 | example[name] = tf.cast(example[name], tf.int32) 529 | 530 | input_ids_a = example["input_ids_a"] 531 | attention_mask_a = example["attention_mask_a"] 532 | input_ids_b = example["input_ids_b"] 533 | attention_mask_b = example["attention_mask_b"] 534 | 535 | return_dict = { 536 | "input_ids_a_0": input_ids_a, 537 | "input_ids_b_0": input_ids_b, 538 | "attention_mask_a_0": attention_mask_a, 539 | "attention_mask_b_0": attention_mask_b, 540 | } 541 | 542 | if len(a_fanouts) > 0 and a_fanouts[0] > 0: 543 | layer_node = 1 544 | for layer in range(len(a_fanouts)): 545 | layer_node *= a_fanouts[layer] 546 | tmp_ids = [] 547 | tmp_mask = [] 548 | tmp_impression = [] 549 | tmp_click = [] 550 | 551 | for i in range(layer_node): 552 | tmp_ids.append(example['input_ids_a_' + str(layer) + '_' + str(i)]) 553 | tmp_mask.append(example['attention_mask_a_' + str(layer) + '_' + str(i)]) 554 | if neigh_weights: 555 | tmp_impression.append(example['impression_a_neighbor_' + str(layer) + '_' + str(i)]) 556 | tmp_click.append(example['click_a_neighbor_' + str(layer) + '_' + str(i)]) 557 | return_dict['input_ids_a_' + str(layer+1)] = tf.stack(tmp_ids) 558 | return_dict['attention_mask_a_' + str(layer+1)] = tf.stack(tmp_mask) 559 | if neigh_weights: 560 | return_dict['impression_a_' + str(layer+1)] = tf.stack(tmp_impression) 561 | return_dict['click_a_' + str(layer+1)] = tf.stack(tmp_click) 562 | 563 | if len(b_fanouts) > 0 and b_fanouts[0] > 0: 564 | layer_node = 1 565 | for layer in range(len(b_fanouts)): 566 | layer_node *= b_fanouts[layer] 567 | tmp_ids = [] 568 | tmp_mask = [] 569 | tmp_impression = [] 570 | tmp_click = [] 571 | 572 | for i in range(layer_node): 573 | tmp_ids.append(example['input_ids_b_' + str(layer) + '_' + str(i)]) 574 | tmp_mask.append(example['attention_mask_b_' + str(layer) + '_' + str(i)]) 575 | if neigh_weights: 576 | tmp_impression.append(example['impression_b_neighbor_' + str(layer) + '_' + str(i)]) 577 | tmp_click.append(example['click_b_neighbor_' + str(layer) + '_' + str(i)]) 578 | return_dict['input_ids_b_' + str(layer+1)] = tf.stack(tmp_ids) 579 | return_dict['attention_mask_b_' + str(layer+1)] = tf.stack(tmp_mask) 580 | if neigh_weights: 581 | return_dict['impression_b_' + str(layer+1)] = tf.stack(tmp_impression) 582 | return_dict['click_b_' + str(layer+1)] = tf.stack(tmp_click) 583 | return ( 584 | return_dict, 585 | example["label"], 586 | ) 587 | 588 | return dataset.map(_decode_record) -------------------------------------------------------------------------------- /src/encoders.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from transformers.modeling_tf_utils import shape_list, get_initializer 3 | try: 4 | from aggregators import * 5 | except ImportError: 6 | from .aggregators import * 7 | 8 | # simple concat 9 | class SimpleConcat(tf.keras.layers.Layer): 10 | def __init__(self, config, fanouts, **kwargs): 11 | super().__init__(**kwargs) 12 | self.fanouts = fanouts 13 | self.num_layers = len(self.fanouts) if (len(self.fanouts) > 0 and self.fanouts[0] > 0) else 0 14 | self.activations = config.gnn_acts 15 | 16 | for l in range(self.num_layers): 17 | self.dense_layers.append(tf.keras.layers.Dense( 18 | self.config.hidden_dims[l], activation=self.activations[l], kernel_initializer='glorot_uniform', 19 | name="dense_%d" % l 20 | )) 21 | 22 | def call(self, inputs, training=False): 23 | if self.num_layers == 0: 24 | return inputs[0] 25 | 26 | hidden = inputs 27 | for layer in range(self.num_layers): 28 | next_hidden = [] 29 | for hop in range(self.num_layers - l): 30 | neighbor = tf.reshape(hidden["bert_" + str(hop+1)], [-1, self.fanouts[hop] * shape_list(hidden["bert_" + str(hop+1)])[-1]]) 31 | seq = tf.concat([hidden["bert_" + str(hop)], neighbor], axis=-1) 32 | h = self.dense_layers[layer](seq, training=training) 33 | next_hidden.append(h) 34 | hidden = next_hidden 35 | return hidden[0] 36 | 37 | 38 | # graphsage 39 | class GraphSAGE(tf.keras.layers.Layer): 40 | def __init__(self, config, fanouts, **kwargs): 41 | super().__init__(**kwargs) 42 | self.config = config 43 | self.fanouts = fanouts 44 | self.num_layers = len(self.fanouts) if (len(self.fanouts) > 0 and self.fanouts[0] > 0) else 0 45 | self.activations = config.gnn_acts 46 | 47 | self.aggregator_class = aggregators.get(self.config.aggregator) 48 | self.aggs = [] 49 | for layer in range(self.num_layers): 50 | activation = self.activations[layer] 51 | if layer == self.num_layers - 1: 52 | self.aggs.append( 53 | self.aggregator_class(self.config, self.config.hidden_dims[layer], activation=activation, 54 | identity_act=True, name='agg_%d' % layer)) 55 | else: 56 | self.aggs.append( 57 | self.aggregator_class(self.config, self.config.hidden_dims[layer], activation=activation, 58 | identity_act=False, name='agg_%d' % layer)) 59 | 60 | def call(self, inputs, training=False): 61 | if self.num_layers == 0: 62 | return inputs[0] 63 | 64 | dim0 = shape_list(inputs["bert_1"])[-1] 65 | dims = [dim0] + self.config.hidden_dims[len(self.config.hidden_dims) - self.num_layers:] 66 | 67 | hidden = inputs 68 | 69 | for layer in range(self.num_layers): 70 | aggregator = self.aggs[layer] 71 | next_hidden = {} 72 | for hop in range(self.num_layers - layer): 73 | neigh_shape = [-1, self.fanouts[hop], dims[layer]] 74 | h = aggregator((hidden["bert_" + str(hop)], tf.reshape(hidden["bert_" + str(hop+1)], neigh_shape))) 75 | next_hidden["bert_" + str(hop)] = h 76 | hidden = next_hidden 77 | 78 | return hidden["bert_0"] 79 | 80 | 81 | # self attention head 82 | class AttentionHead(tf.keras.layers.Layer): 83 | def __init__(self, out_size, activation=tf.nn.leaky_relu, residual=False, **kwargs): 84 | super().__init__(**kwargs) 85 | self.out_size = out_size 86 | self.feature_conv = tf.keras.layers.Conv1D(self.out_size, 1, use_bias=False) 87 | self.f1_conv = tf.keras.layers.Conv1D(1, 1) 88 | self.f2_conv = tf.keras.layers.Conv1D(1, 1) 89 | if isinstance(activation, str): 90 | if activation == "leaky_relu": 91 | activation = tf.nn.leaky_relu 92 | elif activation == "relu": 93 | activation = tf.nn.relu 94 | self.activation = tf.keras.layers.Activation(activation) 95 | self.residual = residual 96 | 97 | def build(self, input_shape): 98 | with tf.name_scope("attn_head"): 99 | self.bias = self.add_weight( 100 | "bias", shape=[self.out_size], initializer="zero" 101 | ) 102 | if self.residual: 103 | if input_shape[-1] != self.out_size: 104 | self.res_conv = tf.keras.layers.Conv1D(self.out_size, 1) 105 | super().build(self) 106 | 107 | def call(self, seq, training=False): 108 | seq_fts = self.feature_conv(seq) 109 | f_1 = self.f1_conv(seq_fts) 110 | f_2 = self.f2_conv(seq_fts) 111 | logits = f_1 + tf.transpose(f_2, [0, 2, 1]) 112 | coefs = tf.nn.softmax(tf.nn.leaky_relu(logits)) 113 | vals = tf.matmul(coefs, seq_fts) 114 | ret = tf.nn.bias_add(vals, self.bias) 115 | 116 | # residual connection 117 | if self.residual: 118 | if shape_list(seq)[-1] != shape_list(ret)[-1]: 119 | ret = ret + self.res_conv(seq) 120 | else: 121 | ret = ret + seq 122 | return self.activation(ret) 123 | 124 | 125 | # gat 126 | class GAT(tf.keras.layers.Layer): 127 | def __init__(self, config, fanouts, **kwargs): 128 | super().__init__(**kwargs) 129 | self.config = config 130 | self.fanouts = fanouts 131 | self.num_layers = len(self.fanouts) if (len(self.fanouts) > 0 and self.fanouts[0] > 0) else 0 132 | self.activations = config.gnn_acts 133 | self.neighbor_num = self.fanouts[0] 134 | self.attention_heads = [] 135 | for layer, head_num in enumerate(self.config.head_nums): 136 | heads = [] 137 | for i in range(head_num): 138 | heads.append(AttentionHead(self.config.hidden_dims[layer], self.activations[layer], self.config.use_residual)) 139 | self.attention_heads.append(heads) 140 | 141 | def call(self, inputs, training=False): 142 | if self.num_layers == 0: 143 | return inputs["bert_0"] 144 | 145 | dim0 = shape_list(inputs["bert_1"])[-1] 146 | node_feats = tf.expand_dims(inputs["bert_0"], 1) 147 | neighbor_feats = tf.reshape(inputs["bert_1"], [-1, self.neighbor_num, dim0]) 148 | seq = tf.concat([node_feats, neighbor_feats], 1) 149 | 150 | for layer, head_num in enumerate(self.config.head_nums): 151 | hidden = [] 152 | for i in range(head_num): 153 | hidden_val = self.attention_heads[layer][i](seq) 154 | hidden.append(hidden_val) 155 | seq = tf.concat(hidden, -1) 156 | 157 | out = hidden 158 | out = tf.add_n(out) / self.config.head_nums[-1] 159 | out = tf.slice(out, [0, 0, 0], [-1, 1, self.config.hidden_dims[-1]]) 160 | return tf.reshape(out, [-1, self.config.hidden_dims[-1]]) 161 | 162 | 163 | encoders = { 164 | 'simple': SimpleConcat, 165 | 'graphsage': GraphSAGE, 166 | 'gat': GAT, 167 | } 168 | 169 | 170 | def get(encoder): 171 | return encoders.get(encoder) -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | """A bunch of helper functions to generate a dictionary for evaluation metrics.""" 2 | 3 | from scipy.stats import pearsonr, spearmanr 4 | from sklearn.metrics import f1_score, roc_auc_score, average_precision_score 5 | import numpy as np 6 | 7 | 8 | def simple_accuracy(preds, labels): 9 | return (preds == labels).mean() 10 | 11 | 12 | def acc_and_f1(preds, labels): 13 | acc = simple_accuracy(preds, labels) 14 | f1 = f1_score(y_true=labels, y_pred=preds) 15 | return { 16 | "acc": acc, 17 | "f1": f1, 18 | "acc_and_f1": (acc + f1) / 2 19 | } 20 | 21 | 22 | def pearson_and_spearman(preds, labels): 23 | pearson_corr = pearsonr(preds, labels)[0] 24 | spearman_corr = spearmanr(preds, labels)[0] 25 | 26 | return { 27 | "pearson": pearson_corr, 28 | "spearmanr": spearman_corr, 29 | "corr": (pearson_corr + spearman_corr) / 2 30 | } 31 | 32 | 33 | def auc(preds, labels): 34 | roc_auc = roc_auc_score(y_true=labels, y_score=preds) 35 | pr_auc = average_precision_score(y_true=labels, y_score=preds) 36 | return { 37 | "roc_auc": roc_auc, 38 | "pr_auc": pr_auc 39 | } 40 | 41 | 42 | def metrics(preds, labels): 43 | results = acc_and_f1(np.argmax(preds, axis=1), labels) 44 | results.update(auc(preds[:, 1], labels)) 45 | return results 46 | -------------------------------------------------------------------------------- /src/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from transformers import BertTokenizer 3 | 4 | 5 | class TwinBertTokenizer(BertTokenizer): 6 | def __init__(self, vocab_file, add_cls_tokens=False, **kwargs): 7 | super(TwinBertTokenizer, self).__init__(vocab_file=vocab_file, **kwargs) 8 | self.add_cls_tokens = add_cls_tokens 9 | 10 | def build_inputs_with_special_tokens( 11 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 12 | ) -> List[int]: 13 | if self.add_cls_tokens: 14 | return [self.cls_token_id] + token_ids_0 15 | return token_ids_0 16 | 17 | def get_special_tokens_mask( 18 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 19 | ) -> List[int]: 20 | if self.add_cls_tokens: 21 | return [1] + ([0] * len(token_ids_0)) 22 | return [0] * len(token_ids_0) 23 | 24 | def create_token_type_ids_from_sequences( 25 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 26 | ) -> List[int]: 27 | if self.add_cls_tokens: 28 | return [0] * (1 + len(token_ids_0)) 29 | return [0] * len(token_ids_0) 30 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """Main script for training""" 2 | 3 | import logging 4 | import os 5 | import math 6 | import numpy as np 7 | import tensorflow as tf 8 | from dataclasses import dataclass, field 9 | from data_preprocessing import read_preprocessed_datasets, read_preprocessed_datasets_twinbert, read_preprocessed_datasets_beijing 10 | from metric import metrics 11 | from twinbertgnn import TwinBERTGNN 12 | from transformers import TFTrainingArguments, HfArgumentParser, EvalPrediction 13 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 14 | from trainer import TFTrainer 15 | from typing import Dict, Optional 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | @dataclass 21 | class DataTrainingArguments: 22 | overwrite_cache: bool = field( 23 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 24 | ) 25 | 26 | train_data_path: Optional[str] = field( 27 | default='../data/QK_Neighbor/Teacher/', metadata={"help": "Path to training data"} 28 | ) 29 | 30 | train_data_size: Optional[int] = field( 31 | default=None, metadata={"help": "Training data size"} 32 | ) 33 | 34 | train_data_int_label: bool = field( 35 | default=False, metadata={"help": "Whether the training dataset use an int label or float roberta score"} 36 | ) 37 | 38 | do_eval_train: bool = field( 39 | default=False, metadata={"help": "Whether to run evaluation on a small training dataset."} 40 | ) 41 | 42 | eval_train_data_path: Optional[str] = field( 43 | default='../data/QK_Neighbor/Training/', metadata={"help": "Path to training eval data"} 44 | ) 45 | 46 | eval_data_path: Optional[str] = field( 47 | default='../data/QK_Neighbor/Validation/', metadata={"help": "Path to eval data"} 48 | ) 49 | 50 | test_data_path: Optional[str] = field( 51 | default='../data/QK_Neighbor/Validation/', metadata={"help": "Path to eval data"} 52 | ) 53 | 54 | is_twinbert_format: bool = field( 55 | default=False, metadata={"help": "Whether the dataset is from the previous twinbert model format"} 56 | ) 57 | 58 | finetune: bool = field( 59 | default=False, metadata={"help": "Whether the this training is finetune"} 60 | ) 61 | 62 | finetune_previous_epoch: Optional[int] = field( 63 | default=10, metadata={"help": "Number of pretrained epochs"} 64 | ) 65 | 66 | neigh_weights: bool = field( 67 | default=False, metadata={"help": "Whether to read the neighbor weights from TF datasets"} 68 | ) 69 | 70 | 71 | @dataclass 72 | class ModelArguments: 73 | model_checkpoint_path: str = field( 74 | default=None, metadata={"help": "Path to pre-trained model weights checkpoint"} 75 | ) 76 | 77 | is_tf_checkpoint: bool = field( 78 | default=True, metadata={"help": "Whether the checkpoint is from a torch model or twinbert tf model"} 79 | ) 80 | 81 | config_path: Optional[str] = field( 82 | default="../config/model.config", metadata={"help": "Path to model config"} 83 | ) 84 | 85 | checkpoint_dict: Optional[str] = field( 86 | default='../config/twinbert_checkpoint_dict.txt', 87 | metadata={"help": "Mapping between model weights to checkpoint weights"} 88 | ) 89 | 90 | pretrained_bert_name: Optional[str] = field( 91 | default=None, metadata={"help": "Pretrained Bert name, used by bpe tokenizer models"} 92 | ) 93 | 94 | 95 | def main(): 96 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TFTrainingArguments)) 97 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 98 | 99 | print("Load format, is twinbert?") 100 | print(data_args.is_twinbert_format) 101 | 102 | if ( 103 | os.path.exists(training_args.output_dir) 104 | and os.listdir(training_args.output_dir) 105 | and training_args.do_train 106 | and not training_args.overwrite_output_dir 107 | ): 108 | raise ValueError( 109 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 110 | ) 111 | 112 | logging.basicConfig( 113 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 114 | datefmt="%m/%d/%Y %H:%M:%S", 115 | level=logging.INFO, 116 | ) 117 | 118 | logger.info( 119 | "n_gpu: %s, distributed training: %s, 16-bits training: %s", 120 | training_args.n_gpu, 121 | bool(training_args.n_gpu > 1), 122 | training_args.fp16, 123 | ) 124 | logger.info("Training/evaluation parameters %s", training_args) 125 | 126 | with training_args.strategy.scope(): 127 | if model_args.model_checkpoint_path: 128 | model = TwinBERTGNN.load_from_checkpoint(model_args.config_path, model_args.model_checkpoint_path, 129 | model_args.checkpoint_dict, 130 | is_tf_checkpoint=model_args.is_tf_checkpoint) 131 | elif model_args.pretrained_bert_name: 132 | model = TwinBERTGNN.load_from_bert_pretrained(model_args.config_path, model_args.pretrained_bert_name) 133 | else: 134 | model = TwinBERTGNN(model_args.config_path) 135 | model(model.dummy_input, training=False) 136 | 137 | logger.info("Model finished loading.") 138 | if (training_args.do_train and not os.path.exists(data_args.train_data_path)) or (training_args.do_eval and (not os.path.exists(data_args.eval_train_data_path) or not os.path.exists(data_args.eval_data_path))) or (training_args.do_predict and not os.path.exists(data_args.test_data_path)): 139 | raise ValueError( 140 | f"Please run the create_tf_record_data.py script to first generate tf record files." 141 | ) 142 | 143 | is_triletter = model.config.embedding_type == 'triletter' 144 | 145 | train_dataset = ( 146 | read_preprocessed_datasets(data_args.train_data_path, 147 | max_n_letters=model.config.triletter_max_letters_in_word, 148 | max_seq_len=model.config.max_seq_len, 149 | is_triletter=is_triletter, int_label=data_args.train_data_int_label, 150 | a_fanouts=model.config.a_fanouts, 151 | b_fanouts=model.config.b_fanouts, neigh_weights=data_args.neigh_weights) 152 | if training_args.do_train 153 | else None 154 | ) 155 | 156 | eval_train_dataset = ( 157 | read_preprocessed_datasets(data_args.eval_train_data_path, 158 | max_n_letters=model.config.triletter_max_letters_in_word, 159 | max_seq_len=model.config.max_seq_len, 160 | is_triletter=is_triletter, int_label=True, 161 | a_fanouts=model.config.a_fanouts, 162 | b_fanouts=model.config.b_fanouts, neigh_weights=data_args.neigh_weights) 163 | if data_args.do_eval_train 164 | else None 165 | ) 166 | 167 | eval_dataset = ( 168 | read_preprocessed_datasets(data_args.eval_data_path, 169 | max_n_letters=model.config.triletter_max_letters_in_word, 170 | max_seq_len=model.config.max_seq_len, 171 | is_triletter=is_triletter, int_label=True, a_fanouts=model.config.a_fanouts, 172 | b_fanouts=model.config.b_fanouts, neigh_weights=data_args.neigh_weights) 173 | if training_args.do_eval 174 | else None 175 | ) 176 | 177 | test_dataset = ( 178 | read_preprocessed_datasets(data_args.test_data_path, 179 | max_n_letters=model.config.triletter_max_letters_in_word, 180 | max_seq_len=model.config.max_seq_len, 181 | is_triletter=is_triletter, int_label=True, a_fanouts=model.config.a_fanouts, 182 | b_fanouts=model.config.b_fanouts, neigh_weights=data_args.neigh_weights) 183 | if training_args.do_predict 184 | else None 185 | ) 186 | 187 | def compute_metrics(p: EvalPrediction) -> Dict: 188 | return metrics(p.predictions, p.label_ids) 189 | 190 | trainer = TFTrainer( 191 | model=model, 192 | args=training_args, 193 | train_dataset=train_dataset, 194 | eval_train_dataset=eval_train_dataset, 195 | eval_dataset=eval_dataset, 196 | compute_metrics=compute_metrics, 197 | train_size=data_args.train_data_size, 198 | ) 199 | 200 | if training_args.do_train: 201 | logger.info("*** Start Training ***") 202 | trainer.train(finetune=data_args.finetune, previous_epoch=data_args.finetune_previous_epoch, do_eval_train=data_args.do_eval_train) 203 | trainer.save_model() 204 | 205 | results = {} 206 | if training_args.do_eval: 207 | logger.info("*** Evaluate ***") 208 | 209 | trainer.gradient_accumulator.reset() 210 | 211 | with trainer.args.strategy.scope(): 212 | trainer.train_steps = math.ceil(trainer.num_train_examples / trainer.args.train_batch_size) 213 | optimizer, lr_scheduler = trainer.get_optimizers() 214 | iterations = optimizer.iterations 215 | folder = os.path.join(trainer.args.output_dir, PREFIX_CHECKPOINT_DIR) 216 | ckpt = tf.train.Checkpoint(optimizer=optimizer, model=trainer.model) 217 | trainer.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=trainer.args.save_total_limit) 218 | 219 | if trainer.model.ckpt_manager.latest_checkpoint: 220 | logger.info( 221 | "Checkpoint file %s found and restoring from checkpoint", trainer.model.ckpt_manager.latest_checkpoint 222 | ) 223 | 224 | ckpt.restore(trainer.model.ckpt_manager.latest_checkpoint).expect_partial() 225 | 226 | result = trainer.evaluate() 227 | output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt") 228 | 229 | with open(output_eval_file, "w") as writer: 230 | logger.info("***** Eval results *****") 231 | 232 | for key, value in result.items(): 233 | logger.info(" %s = %s", key, value) 234 | writer.write("%s = %s\n" % (key, value)) 235 | 236 | results.update(result) 237 | 238 | return results 239 | 240 | if training_args.do_predict: 241 | logger.info("*** Inference ***") 242 | trainer.gradient_accumulator.reset() 243 | 244 | with trainer.args.strategy.scope(): 245 | trainer.num_train_examples = 0 246 | trainer.train_steps = math.ceil(trainer.num_train_examples / trainer.args.train_batch_size) 247 | optimizer, lr_scheduler = trainer.get_optimizers() 248 | iterations = optimizer.iterations 249 | folder = os.path.join(trainer.args.output_dir, PREFIX_CHECKPOINT_DIR) 250 | ckpt = tf.train.Checkpoint(optimizer=optimizer, model=trainer.model) 251 | trainer.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, 252 | max_to_keep=trainer.args.save_total_limit) 253 | 254 | if trainer.model.ckpt_manager.latest_checkpoint: 255 | logger.info( 256 | "Checkpoint file %s found and restoring from checkpoint", 257 | trainer.model.ckpt_manager.latest_checkpoint 258 | ) 259 | 260 | ckpt.restore(trainer.model.ckpt_manager.latest_checkpoint).expect_partial() 261 | predictions = trainer.predict(test_dataset) 262 | path_prediction = os.path.join(training_args.output_dir, "predictions.npz") 263 | np.savez(path_prediction, predictions=predictions.predictions, labels=predictions.label_ids) 264 | 265 | for key, value in predictions.metrics.items(): 266 | logger.info(" %s = %s", key, value) 267 | 268 | 269 | if __name__ == "__main__": 270 | main() 271 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | """Tensorflow trainer class.""" 2 | 3 | import logging 4 | import math 5 | import os 6 | import time 7 | from typing import Callable, Dict, Optional, Tuple 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | from transformers.modeling_tf_utils import TFPreTrainedModel 13 | from transformers.optimization_tf import GradientAccumulator, create_optimizer 14 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, EvalPrediction, PredictionOutput, is_wandb_available 15 | from transformers.training_args_tf import TFTrainingArguments 16 | 17 | 18 | if is_wandb_available(): 19 | import wandb 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class TFTrainer: 26 | model: TFPreTrainedModel 27 | args: TFTrainingArguments 28 | train_dataset: Optional[tf.data.Dataset] 29 | eval_dataset: Optional[tf.data.Dataset] 30 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None 31 | prediction_loss_only: bool 32 | tb_writer: Optional[tf.summary.SummaryWriter] = None 33 | optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None 34 | global_step: Optional[int] = None 35 | epoch_logging: Optional[float] = None 36 | 37 | def __init__( 38 | self, 39 | model: TFPreTrainedModel, 40 | args: TFTrainingArguments, 41 | train_dataset: Optional[tf.data.Dataset] = None, 42 | eval_train_dataset: Optional[tf.data.Dataset] = None, 43 | eval_dataset: Optional[tf.data.Dataset] = None, 44 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, 45 | prediction_loss_only=False, 46 | tb_writer: Optional[tf.summary.SummaryWriter] = None, 47 | optimizers: Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule] = None, 48 | train_size: Optional[int] = None, 49 | ): 50 | self.model = model 51 | self.args = args 52 | self.train_dataset = train_dataset 53 | self.eval_train_dataset = eval_train_dataset 54 | self.eval_dataset = eval_dataset 55 | self.compute_metrics = compute_metrics 56 | self.prediction_loss_only = prediction_loss_only 57 | self.optimizers = optimizers 58 | self.gradient_accumulator = GradientAccumulator() 59 | self.global_step = 0 60 | self.epoch_logging = 0 61 | 62 | self.num_train_examples = train_size 63 | 64 | if tb_writer is not None: 65 | self.tb_writer = tb_writer 66 | else: 67 | self.tb_writer = tf.summary.create_file_writer(self.args.logging_dir) 68 | if is_wandb_available(): 69 | self._setup_wandb() 70 | else: 71 | logger.info( 72 | "You are instantiating a Trainer but W&B is not installed. To use wandb logging, " 73 | "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface." 74 | ) 75 | 76 | def get_train_tfdataset(self) -> tf.data.Dataset: 77 | if self.train_dataset is None: 78 | raise ValueError("Trainer: training requires a train_dataset.") 79 | 80 | if self.num_train_examples is None: 81 | self.num_train_examples = self.train_dataset.reduce(tf.constant(0), lambda x, _: x + 1).numpy() 82 | 83 | if self.args.max_steps > 0: 84 | self.train_steps = self.args.max_steps 85 | else: 86 | self.train_steps: int = math.ceil(self.num_train_examples / self.args.train_batch_size) 87 | 88 | ds = ( 89 | self.train_dataset.cache() 90 | .shuffle(1000000) 91 | .batch(self.args.train_batch_size, drop_remainder=self.args.dataloader_drop_last) 92 | .prefetch(tf.data.experimental.AUTOTUNE) 93 | ) 94 | 95 | if self.args.max_steps > 0: 96 | self.train_dataset = self.train_dataset.repeat(-1) 97 | 98 | return self.args.strategy.experimental_distribute_dataset(ds) 99 | 100 | def get_eval_train_tfdataset(self, eval_train_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset: 101 | if eval_train_dataset is None and self.eval_train_dataset is None: 102 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 103 | 104 | eval_train_dataset = eval_train_dataset if eval_train_dataset is not None else self.eval_train_dataset 105 | ds = ( 106 | eval_train_dataset.cache() 107 | .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last) 108 | .prefetch(tf.data.experimental.AUTOTUNE) 109 | ) 110 | 111 | return self.args.strategy.experimental_distribute_dataset(ds) 112 | 113 | def get_eval_tfdataset(self, eval_dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset: 114 | if eval_dataset is None and self.eval_dataset is None: 115 | raise ValueError("Trainer: evaluation requires an eval_dataset.") 116 | 117 | eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset 118 | ds = ( 119 | eval_dataset.cache() 120 | .batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last) 121 | .prefetch(tf.data.experimental.AUTOTUNE) 122 | ) 123 | 124 | return self.args.strategy.experimental_distribute_dataset(ds) 125 | 126 | def get_test_tfdataset(self, test_dataset: tf.data.Dataset) -> tf.data.Dataset: 127 | ds = test_dataset.batch(self.args.eval_batch_size, drop_remainder=self.args.dataloader_drop_last) 128 | 129 | return self.args.strategy.experimental_distribute_dataset(ds) 130 | 131 | def get_optimizers( 132 | self, 133 | ) -> Tuple[tf.keras.optimizers.Optimizer, tf.keras.optimizers.schedules.LearningRateSchedule]: 134 | """ 135 | Setup the optimizer and the learning rate scheduler. 136 | We provide a reasonable default that works well. 137 | If you want to use something else, you can pass a tuple in the Trainer's init, 138 | or override this method in a subclass. 139 | """ 140 | if self.optimizers is not None: 141 | return self.optimizers 142 | 143 | optimizer, scheduler = create_optimizer( 144 | self.args.learning_rate, 145 | self.train_steps * self.args.num_train_epochs, 146 | self.args.warmup_steps, 147 | adam_epsilon=self.args.adam_epsilon, 148 | weight_decay_rate=self.args.weight_decay, 149 | ) 150 | 151 | return optimizer, scheduler 152 | 153 | def _setup_wandb(self): 154 | """ 155 | Setup the optional Weights & Biases (`wandb`) integration. 156 | One can override this method to customize the setup if needed. Find more information at https://docs.wandb.com/huggingface 157 | You can also override the following environment variables: 158 | Environment: 159 | WANDB_PROJECT: 160 | (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project 161 | WANDB_DISABLED: 162 | (Optional): boolean - defaults to false, set to "true" to disable wandb entirely 163 | """ 164 | logger.info('Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"') 165 | wandb.init(project=os.getenv("WANDB_PROJECT", "TwinBertGNN"), config=vars(self.args)) 166 | 167 | @tf.function 168 | def _evaluate_steps(self, per_replica_features, per_replica_labels): 169 | """ 170 | One step evaluation across replica. 171 | Args: 172 | per_replica_features: the batched features. 173 | per_replica_labels: the batched labels. 174 | Returns: 175 | The loss corresponding to the given batch. 176 | """ 177 | per_replica_loss, per_replica_logits = self.args.strategy.experimental_run_v2( 178 | self._run_model, args=(per_replica_features, per_replica_labels, False) 179 | ) 180 | 181 | try: 182 | reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0) 183 | except ValueError: 184 | reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None) 185 | 186 | return reduced_loss, per_replica_logits 187 | 188 | def _prediction_train_loop( 189 | self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None 190 | ) -> PredictionOutput: 191 | """ 192 | Prediction/evaluation loop, shared by `evaluate()` and `predict()`. 193 | Works both with or without labels. 194 | """ 195 | 196 | prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only 197 | 198 | logger.info("***** Running %s *****", description) 199 | logger.info(" Batch size = %d", self.args.eval_batch_size) 200 | 201 | label_ids: np.ndarray = None 202 | preds: np.ndarray = None 203 | 204 | step: int = 1 205 | 206 | for features, labels in dataset: 207 | step = tf.convert_to_tensor(step, dtype=tf.int64) 208 | loss, logits = self._evaluate_steps(features, labels) 209 | loss = tf.reduce_mean(loss) 210 | 211 | if not prediction_loss_only: 212 | if isinstance(logits, tuple): 213 | logits = logits[0] 214 | 215 | if isinstance(labels, tuple): 216 | labels = labels[0] 217 | 218 | if self.args.n_gpu > 1: 219 | for val in logits.values: 220 | if preds is None: 221 | preds = val.numpy() 222 | else: 223 | preds = np.append(preds, val.numpy(), axis=0) 224 | 225 | for val in labels.values: 226 | if label_ids is None: 227 | label_ids = val.numpy() 228 | else: 229 | label_ids = np.append(label_ids, val.numpy(), axis=0) 230 | else: 231 | if preds is None: 232 | preds = logits.numpy() 233 | else: 234 | preds = np.append(preds, logits.numpy(), axis=0) 235 | 236 | if label_ids is None: 237 | label_ids = labels.numpy() 238 | else: 239 | label_ids = np.append(label_ids, labels.numpy(), axis=0) 240 | 241 | step += 1 242 | 243 | if self.compute_metrics is not None and preds is not None and label_ids is not None: 244 | metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) 245 | else: 246 | metrics = {} 247 | 248 | metrics["eval_train_loss"] = loss.numpy() 249 | 250 | for key in list(metrics.keys()): 251 | if not key.startswith("eval_train_"): 252 | metrics[f"eval_train_{key}"] = metrics.pop(key) 253 | 254 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) 255 | 256 | def _prediction_loop( 257 | self, dataset: tf.data.Dataset, description: str, prediction_loss_only: Optional[bool] = None 258 | ) -> PredictionOutput: 259 | """ 260 | Prediction/evaluation loop, shared by `evaluate()` and `predict()`. 261 | Works both with or without labels. 262 | """ 263 | 264 | prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only 265 | 266 | logger.info("***** Running %s *****", description) 267 | logger.info(" Batch size = %d", self.args.eval_batch_size) 268 | 269 | label_ids: np.ndarray = None 270 | preds: np.ndarray = None 271 | 272 | step: int = 1 273 | 274 | for features, labels in dataset: 275 | step = tf.convert_to_tensor(step, dtype=tf.int64) 276 | loss, logits = self._evaluate_steps(features, labels) 277 | loss = tf.reduce_mean(loss) 278 | 279 | if not prediction_loss_only: 280 | if isinstance(logits, tuple): 281 | logits = logits[0] 282 | 283 | if isinstance(labels, tuple): 284 | labels = labels[0] 285 | 286 | if self.args.n_gpu > 1: 287 | for val in logits.values: 288 | if preds is None: 289 | preds = val.numpy() 290 | else: 291 | preds = np.append(preds, val.numpy(), axis=0) 292 | 293 | for val in labels.values: 294 | if label_ids is None: 295 | label_ids = val.numpy() 296 | else: 297 | label_ids = np.append(label_ids, val.numpy(), axis=0) 298 | else: 299 | if preds is None: 300 | preds = logits.numpy() 301 | else: 302 | preds = np.append(preds, logits.numpy(), axis=0) 303 | 304 | if label_ids is None: 305 | label_ids = labels.numpy() 306 | else: 307 | label_ids = np.append(label_ids, labels.numpy(), axis=0) 308 | 309 | step += 1 310 | 311 | if self.compute_metrics is not None and preds is not None and label_ids is not None: 312 | metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) 313 | else: 314 | metrics = {} 315 | 316 | metrics["eval_loss"] = loss.numpy() 317 | 318 | for key in list(metrics.keys()): 319 | if not key.startswith("eval_"): 320 | metrics[f"eval_{key}"] = metrics.pop(key) 321 | 322 | return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics) 323 | 324 | def _log(self, logs: Dict[str, float]) -> None: 325 | if self.tb_writer: 326 | with self.tb_writer.as_default(): 327 | for k, v in logs.items(): 328 | tf.summary.scalar(k, v, step=self.global_step) 329 | self.tb_writer.flush() 330 | if is_wandb_available(): 331 | wandb.log(logs, step=self.global_step) 332 | output = {**logs, **{"step": self.global_step}} 333 | logger.info(output) 334 | 335 | def evaluate_train( 336 | self, eval_train_dataset: Optional[tf.data.Dataset] = None, prediction_loss_only: Optional[bool] = None 337 | ) -> Dict[str, float]: 338 | """ 339 | Prediction/evaluation loop, shared by `evaluate()` and `predict()`. 340 | """ 341 | eval_train_ds = self.get_eval_train_tfdataset(eval_train_dataset) 342 | 343 | output = self._prediction_train_loop(eval_train_ds, description="Evaluation Train") 344 | 345 | logs = {**output.metrics} 346 | logs["epoch"] = self.epoch_logging 347 | self._log(logs) 348 | 349 | return output.metrics 350 | 351 | def evaluate( 352 | self, eval_dataset: Optional[tf.data.Dataset] = None, prediction_loss_only: Optional[bool] = None 353 | ) -> Dict[str, float]: 354 | """ 355 | Prediction/evaluation loop, shared by `evaluate()` and `predict()`. 356 | """ 357 | eval_ds = self.get_eval_tfdataset(eval_dataset) 358 | 359 | output = self._prediction_loop(eval_ds, description="Evaluation") 360 | 361 | logs = {**output.metrics} 362 | logs["epoch"] = self.epoch_logging 363 | self._log(logs) 364 | 365 | return output.metrics 366 | 367 | def train(self, finetune=False, previous_epoch=10, do_eval_train=False) -> None: 368 | """ 369 | Train method to train the model. 370 | """ 371 | train_ds = self.get_train_tfdataset() 372 | 373 | if self.args.debug: 374 | tf.summary.trace_on(graph=True, profiler=True) 375 | 376 | self.gradient_accumulator.reset() 377 | 378 | with self.args.strategy.scope(): 379 | optimizer, lr_scheduler = self.get_optimizers() 380 | iterations = optimizer.iterations 381 | folder = os.path.join(self.args.output_dir, PREFIX_CHECKPOINT_DIR) 382 | ckpt = tf.train.Checkpoint(optimizer=optimizer, model=self.model) 383 | self.model.ckpt_manager = tf.train.CheckpointManager(ckpt, folder, max_to_keep=self.args.save_total_limit) 384 | 385 | if self.model.ckpt_manager.latest_checkpoint: 386 | logger.info( 387 | "Checkpoint file %s found and restoring from checkpoint", self.model.ckpt_manager.latest_checkpoint 388 | ) 389 | 390 | ckpt.restore(self.model.ckpt_manager.latest_checkpoint).expect_partial() 391 | if finetune: 392 | logger.info("Reset Optimizer in finetune mode") 393 | optimizer, lr_scheduler = create_optimizer( 394 | self.args.learning_rate, 395 | self.train_steps * self.args.num_train_epochs, 396 | self.args.warmup_steps, 397 | adam_epsilon=self.args.adam_epsilon, 398 | weight_decay_rate=self.args.weight_decay, 399 | ) 400 | iterations = optimizer.iterations 401 | 402 | if iterations.numpy() > 0: 403 | logger.info("Start the training from the last checkpoint") 404 | start_epoch = (iterations.numpy() // self.train_steps) + 1 405 | else: 406 | if finetune: 407 | logger.info("Start the finetune from the last checkpoint") 408 | start_epoch = 1 + previous_epoch 409 | else: 410 | start_epoch = 1 411 | 412 | tf.summary.experimental.set_step(iterations) 413 | 414 | epochs = 1 if self.args.max_steps > 0 else self.args.num_train_epochs 415 | if finetune and self.args.max_steps <= 0: 416 | epochs += previous_epoch 417 | 418 | if self.args.fp16: 419 | policy = tf.keras.mixed_precision.experimental.Policy("mixed_float16") 420 | tf.keras.mixed_precision.experimental.set_policy(policy) 421 | 422 | with self.tb_writer.as_default(): 423 | tf.summary.text("args", self.args.to_json_string()) 424 | 425 | self.tb_writer.flush() 426 | 427 | logger.info("***** Running training *****") 428 | logger.info(" Num examples = %d", self.num_train_examples) 429 | logger.info(" Num Epochs = %d", epochs) 430 | logger.info(" Total optimization steps = %d", self.train_steps) 431 | 432 | for epoch_iter in range(start_epoch, int(epochs + 1)): 433 | time_start = time.time() 434 | for step, training_loss in enumerate(self._training_steps(train_ds, optimizer)): 435 | self.global_step = iterations.numpy() 436 | self.epoch_logging = epoch_iter - 1 + (self.global_step % self.train_steps + 1) / self.train_steps 437 | 438 | if self.args.debug: 439 | logs = {} 440 | logs["loss"] = training_loss.numpy() 441 | logs["epoch"] = self.epoch_logging 442 | self._log(logs) 443 | 444 | if self.global_step == 1 and self.args.debug: 445 | with self.tb_writer.as_default(): 446 | tf.summary.trace_export( 447 | name="training", step=self.global_step, profiler_outdir=self.args.logging_dir 448 | ) 449 | 450 | if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0: 451 | if do_eval_train: 452 | self.evaluate_train() 453 | self.evaluate() 454 | 455 | if self.global_step % self.args.logging_steps == 0: 456 | logs = {} 457 | logs["loss"] = training_loss.numpy() 458 | logs["learning_rate"] = lr_scheduler(self.global_step).numpy() 459 | logs["epoch"] = self.epoch_logging 460 | self._log(logs) 461 | 462 | if self.global_step % self.args.save_steps == 0: 463 | ckpt_save_path = self.model.ckpt_manager.save() 464 | logger.info("Saving checkpoint for step {} at {}".format(self.global_step, ckpt_save_path)) 465 | 466 | if self.global_step % 1000 == 0: 467 | time_elapse = time.time() - time_start 468 | time_est = time_elapse / (step + 1) * self.train_steps 469 | logger.info('Epoch: %d; Step: %d; Current Epoch Elapse/Estimate: %0.2fs/%0.2fs' % (epoch_iter, step + 1, time_elapse, time_est)) 470 | 471 | if self.global_step % self.train_steps == 0: 472 | break 473 | 474 | def _training_steps(self, ds, optimizer): 475 | """ 476 | Returns a generator over training steps (i.e. parameters update). 477 | """ 478 | for i, loss in enumerate(self._accumulate_next_gradients(ds)): 479 | if i % self.args.gradient_accumulation_steps == 0: 480 | self._apply_gradients(optimizer) 481 | yield loss 482 | 483 | @tf.function 484 | def _apply_gradients(self, optimizer): 485 | """Applies the gradients (cross-replica).""" 486 | self.args.strategy.experimental_run_v2(self._step, args=(optimizer,)) 487 | 488 | def _step(self, optimizer): 489 | """Applies gradients and resets accumulation.""" 490 | gradient_scale = self.gradient_accumulator.step * self.args.strategy.num_replicas_in_sync 491 | gradients = [ 492 | gradient / tf.cast(gradient_scale, gradient.dtype) for gradient in self.gradient_accumulator.gradients 493 | ] 494 | gradients = [(tf.clip_by_value(grad, -self.args.max_grad_norm, self.args.max_grad_norm)) for grad in gradients] 495 | 496 | optimizer.apply_gradients(list(zip(gradients, self.model.trainable_variables))) 497 | self.gradient_accumulator.reset() 498 | 499 | def _accumulate_next_gradients(self, ds): 500 | """Accumulates the gradients from the next element in dataset.""" 501 | iterator = iter(ds) 502 | 503 | @tf.function 504 | def _accumulate_next(): 505 | per_replica_features, per_replica_labels = next(iterator) 506 | 507 | return self._accumulate_gradients(per_replica_features, per_replica_labels) 508 | 509 | while True: 510 | try: 511 | yield _accumulate_next() 512 | except tf.errors.OutOfRangeError: 513 | break 514 | 515 | def _accumulate_gradients(self, per_replica_features, per_replica_labels): 516 | """Accumulates the gradients across all the replica.""" 517 | per_replica_loss = self.args.strategy.experimental_run_v2( 518 | self._forward, args=(per_replica_features, per_replica_labels) 519 | ) 520 | 521 | try: 522 | reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=0) 523 | except ValueError: 524 | reduced_loss = self.args.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, None) 525 | 526 | return reduced_loss 527 | 528 | def _forward(self, features, labels): 529 | """Forwards a training example and accumulates the gradients.""" 530 | per_example_loss, _ = self._run_model(features, labels, True) 531 | gradients = tf.gradients(per_example_loss, self.model.trainable_variables) 532 | gradients = [ 533 | g if g is not None else tf.zeros_like(v) for g, v in zip(gradients, self.model.trainable_variables) 534 | ] 535 | 536 | self.gradient_accumulator(gradients) 537 | 538 | return per_example_loss 539 | 540 | def _run_model(self, features, labels, training): 541 | """ 542 | Computes the loss of the given features and labels pair. 543 | Args: 544 | features: the batched features. 545 | labels: the batched labels. 546 | training: run the model in training mode or not 547 | """ 548 | if isinstance(labels, (dict)): 549 | loss, logits = self.model(features, training=training, **labels)[:2] 550 | else: 551 | loss, logits = self.model(features, labels=labels, training=training)[:2] 552 | loss += sum(self.model.losses) * (1.0 / self.args.n_gpu) 553 | 554 | return loss, logits 555 | 556 | def predict(self, test_dataset: tf.data.Dataset) -> PredictionOutput: 557 | """ 558 | Run prediction and return predictions and potential metrics. 559 | Depending on the dataset and your use case, your test dataset may contain labels. 560 | In that case, this method will also return metrics, like in evaluate(). 561 | Args: 562 | test_dataset: something similar to a PT Dataset. This is just 563 | temporary before to have a framework-agnostic approach for datasets. 564 | """ 565 | test_ds = self.get_test_tfdataset(test_dataset) 566 | 567 | return self._prediction_loop(test_ds, description="Prediction") 568 | 569 | def save_model(self, output_dir: Optional[str] = None): 570 | """ 571 | Save the pretrained model. 572 | """ 573 | output_dir = output_dir if output_dir is not None else self.args.output_dir 574 | 575 | logger.info("Saving model in {}".format(output_dir)) 576 | 577 | if not isinstance(self.model, TFPreTrainedModel): 578 | raise ValueError("Trainer.model appears to not be a PreTrainedModel") 579 | 580 | self.model.save_pretrained(self.args.output_dir) -------------------------------------------------------------------------------- /src/twinbertgnn.py: -------------------------------------------------------------------------------- 1 | """TwinBert Implementation""" 2 | 3 | # import torch 4 | import logging 5 | from tensorflow.python.keras import backend as K 6 | from transformers.modeling_tf_bert import TFSequenceClassificationLoss, TFBertPreTrainedModel 7 | from transformers import BertConfig 8 | import tensorflow as tf 9 | from tensorflow.python.keras.saving.hdf5_format import load_attributes_from_hdf5_group 10 | from transformers.modeling_tf_utils import hf_bucket_url 11 | from transformers.file_utils import TF2_WEIGHTS_NAME, cached_path 12 | import h5py 13 | import numpy as np 14 | from transformers.modeling_tf_utils import shape_list 15 | import os 16 | 17 | try: 18 | from bert_core import * 19 | from encoders import * 20 | except ImportError: 21 | from .bert_core import * 22 | from .encoders import * 23 | 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | # TwinBert pooler layer, if 'clspooler', simply use the vector corresponding to the pooled bert output. Otherwise use an attention weighting to weight the token vectors 29 | class PoolerLayer(tf.keras.layers.Layer): 30 | def __init__(self, config, **kwargs): 31 | super(PoolerLayer, self).__init__(**kwargs) 32 | self.config = config 33 | if self.config.pooler_type == 'weightpooler': 34 | self.weighting = tf.keras.layers.Dense(1, name='weighted_pooler') 35 | 36 | def call(self, cls_tensor, term_tensor, mask, training=False): 37 | if self.config.pooler_type == 'clspooler': 38 | return cls_tensor 39 | elif self.config.pooler_type == 'weightpooler': 40 | weights = self.weighting(term_tensor) 41 | weights = weights + tf.expand_dims((tf.cast(mask, weights.dtype) - 1.0), axis=2) / 1e-8 42 | weights = tf.nn.softmax(weights, axis=1) 43 | return tf.reduce_sum(tf.multiply(term_tensor, weights), axis=1) 44 | elif self.config.pooler_type == 'average': 45 | inds = tf.cast(mask, tf.float32) 46 | output = tf.transpose(term_tensor, [2, 0, 1]) * inds 47 | token_tensor = tf.reduce_sum(tf.transpose(output, [1, 2, 0]), axis=1) 48 | token_tensor = tf.transpose(token_tensor, [1, 0]) / tf.reduce_sum(inds, axis=-1) 49 | return tf.transpose(token_tensor, [1, 0]) 50 | 51 | 52 | # TwinBert postprocessing layer (downscale, tanh pooling, and quantization) 53 | class PostprocessingLayer(tf.keras.layers.Layer): 54 | def __init__(self, config, **kwargs): 55 | super(PostprocessingLayer, self).__init__(**kwargs) 56 | self.config = config 57 | 58 | def call(self, downscale, vec, is_left=True, training=False): 59 | def quantization(v): 60 | v = tf.round((v + 1) / (2 / 255)) # 2/256, it is 2/255 in production 61 | v = v * (2.0 / 255) - 1 62 | return v 63 | 64 | if self.config.downscale > 0: 65 | vec = downscale(vec) 66 | if self.config.tanh_pooler: 67 | vec = tf.tanh(vec) 68 | if self.config.quantization_side == 'both' or ( 69 | (not is_left) and self.config.quantization_side == 'right') or ( 70 | is_left and self.config.quantization_side == 'left'): 71 | vec = quantization(vec) 72 | return vec 73 | 74 | 75 | # TwinBert Crossing layer to combine Q and K vectors into a score 76 | class CrossingLayer(tf.keras.layers.Layer): 77 | def __init__(self, config, **kwargs): 78 | super(CrossingLayer, self).__init__(**kwargs) 79 | self.config = config 80 | 81 | if self.config.sim_type == 'cosine': 82 | self.logistic = tf.keras.layers.Dense(1, name='logistic') 83 | elif self.config.sim_type == 'feedforward': 84 | self.ff_dense0 = tf.keras.layers.Dense(self.config.res_size, name='ff_dense0') 85 | 86 | dim_scale_gnn = 2 if (self.config.agg_concat and self.config.gnn_model == "graphsage") else 1 87 | gnn_concat_ret_scale = 1 if self.config.gnn_concat_residual else 0 88 | 89 | input_a_dim = self.config.hidden_size * gnn_concat_ret_scale + self.config.hidden_dims[-1] * dim_scale_gnn 90 | input_b_dim = self.config.hidden_size * gnn_concat_ret_scale + self.config.hidden_dims[-1] * dim_scale_gnn 91 | if self.config.a_fanouts[0] == 0: 92 | input_a_dim = self.config.hidden_size 93 | if self.config.b_fanouts[0] == 0: 94 | input_b_dim = self.config.hidden_size 95 | 96 | self.ff_dense1 = tf.keras.layers.Dense((input_a_dim if self.config.comb_type == 'max' else (input_a_dim + input_b_dim)), 97 | name='ff_dense1') 98 | 99 | if self.config.res_bn: 100 | self.res_bn_1 = tf.keras.layers.BatchNormalization(axis=-1, name='batch_norm_1') 101 | self.res_bn_2 = tf.keras.layers.BatchNormalization(axis=-1, name='batch_norm_2') 102 | 103 | self.relu = tf.keras.layers.ReLU() 104 | self.logistic = tf.keras.layers.Dense(2, name='logistic') 105 | 106 | def call(self, vec_a, vec_b, training=False): 107 | if self.config.sim_type == 'cosine': 108 | sim_score = tf.reduce_sum( 109 | tf.multiply(tf.math.l2_normalize(vec_a, axis=1), tf.math.l2_normalize(vec_b, axis=1)), axis=1, 110 | keepdims=True) 111 | probabilities = tf.math.sigmoid(self.logistic(sim_score)) 112 | probabilities = tf.stack([1 - probabilities, probabilities], axis=-1) 113 | 114 | elif self.config.sim_type == 'feedforward': 115 | if self.config.comb_type == 'max': 116 | cross_input = tf.math.maximum(vec_a, vec_b) 117 | elif self.config.comb_type == 'concat': 118 | cross_input = tf.concat([vec_a, vec_b], axis=1) 119 | 120 | output = self.ff_dense0(cross_input) 121 | if self.config.res_bn: 122 | output = self.res_bn_1(output, training=training) 123 | 124 | output = self.ff_dense1(self.relu(output)) 125 | 126 | if self.config.res_bn: 127 | output = self.res_bn_2(output, training=training) 128 | 129 | if self.config.crossing_res: 130 | output = self.relu(output + cross_input) 131 | logits = self.logistic(output) 132 | probabilities = tf.nn.softmax(logits, axis=-1) 133 | 134 | return probabilities 135 | 136 | 137 | class TwinBERTGNNCore(tf.keras.layers.Layer): 138 | def __init__(self, config, **kwargs): 139 | super(TwinBERTGNNCore, self).__init__(**kwargs) 140 | self.config = config 141 | 142 | self.bert_encoder_a = BERTCore(self.config, trainable=self.config.bert_trainable) 143 | self.pooler_a = PoolerLayer(self.config, name='pooler_a', trainable=self.config.bert_trainable) 144 | if self.config.post_processing: 145 | self.postprocessing = PostprocessingLayer(self.config, name='postprocessing') 146 | self.crossing = CrossingLayer(self.config, name='crossing') 147 | if self.config.use_two_crossings: 148 | self.tb_crossing = CrossingLayer(self.config, name='twinbert_crossing') 149 | else: 150 | self.tb_crossing = self.crossing 151 | 152 | if self.config.downscale > 0: 153 | self.downscale_a = tf.keras.layers.Dense(self.config.downscale, name='downscale_a', 154 | trainable=self.config.bert_trainable) 155 | else: 156 | self.downscale_a = None 157 | 158 | encoder_class = encoders.get(self.config.gnn_model.lower()) 159 | self.encoder_a = encoder_class(self.config, self.config.a_fanouts) 160 | if self.config.use_two_gnn: 161 | self.encoder_b = encoder_class(self.config, self.config.b_fanouts) 162 | else: 163 | self.encoder_b = self.encoder_a 164 | 165 | if self.config.use_two_bert: 166 | self.bert_encoder_b = BERTCore(self.config, trainable=self.config.bert_trainable) 167 | self.pooler_b = PoolerLayer(self.config, name='pooler_b', trainable=self.config.bert_trainable) 168 | 169 | if self.config.downscale > 0: 170 | self.downscale_b = tf.keras.layers.Dense(self.config.downscale, name='downscale_b', 171 | trainable=self.config.bert_trainable) 172 | else: 173 | self.downscale_b = None 174 | else: 175 | self.bert_encoder_b = self.bert_encoder_a 176 | self.pooler_b = self.pooler_a 177 | if self.config.downscale > 0: 178 | self.downscale_b = self.downscale_a 179 | else: 180 | self.downscale_b = None 181 | 182 | 183 | def call(self, inputs, training=False): 184 | input_ids_as = {"input_ids_a_0": tf.reshape(inputs.get("input_ids_a_0", None), [-1, self.config.max_seq_len])} 185 | attention_mask_as = {"attention_mask_a_0": tf.reshape(inputs.get("attention_mask_a_0", None), [-1, self.config.max_seq_len])} 186 | 187 | if len(self.config.a_fanouts) > 0 and self.config.a_fanouts[0] > 0: 188 | for i in range(1, len(self.config.a_fanouts) + 1): 189 | input_ids_as["input_ids_a_" + str(i)] = tf.reshape(inputs.get("input_ids_a_" + str(i), None), 190 | [-1, self.config.max_seq_len]) 191 | attention_mask_as["attention_mask_a_" + str(i)] = tf.reshape( 192 | inputs.get("attention_mask_a_" + str(i), None), [-1, self.config.max_seq_len]) 193 | 194 | input_ids_bs = {"input_ids_b_0": tf.reshape(inputs.get("input_ids_b_0", None), [-1, self.config.max_seq_len])} 195 | attention_mask_bs = {"attention_mask_b_0": tf.reshape(inputs.get("attention_mask_b_0", None), [-1, self.config.max_seq_len])} 196 | 197 | if len(self.config.b_fanouts) > 0 and self.config.b_fanouts[0] > 0: 198 | for i in range(1, len(self.config.b_fanouts) + 1): 199 | input_ids_bs["input_ids_b_" + str(i)] = tf.reshape(inputs.get("input_ids_b_" + str(i), None), 200 | [-1, self.config.max_seq_len]) 201 | attention_mask_bs["attention_mask_b_" + str(i)] = tf.reshape( 202 | inputs.get("attention_mask_b_" + str(i), None), [-1, self.config.max_seq_len]) 203 | 204 | berts_a = {} 205 | berts_b = {} 206 | 207 | # q d q d... 208 | for i in range(len(input_ids_as)): 209 | if i % 2 == 0: 210 | term_tensor, cls_tensor = self.bert_encoder_a(input_ids_as["input_ids_a_" + str(i)], 211 | attention_mask=attention_mask_as[ 212 | "attention_mask_a_" + str(i)], 213 | output_attentions=False, output_hidden_states=False, 214 | training=training) 215 | vec = self.pooler_a(cls_tensor, term_tensor, attention_mask_as["attention_mask_a_" + str(i)], 216 | training=training) 217 | if self.config.post_processing: 218 | vec = self.postprocessing(self.downscale_a, vec, is_left=True, training=training) 219 | else: 220 | term_tensor, cls_tensor = self.bert_encoder_a(input_ids_as["input_ids_a_" + str(i)], 221 | attention_mask=attention_mask_as[ 222 | "attention_mask_a_" + str(i)], 223 | output_attentions=False, output_hidden_states=False, 224 | training=training) 225 | vec = self.pooler_b(cls_tensor, term_tensor, attention_mask_as["attention_mask_a_" + str(i)], 226 | training=training) 227 | if self.config.post_processing: 228 | vec = self.postprocessing(self.downscale_b, vec, is_left=False, training=training) 229 | berts_a["bert_" + str(i)] = vec 230 | 231 | # d q d q... 232 | for i in range(len(input_ids_bs)): 233 | if i % 2 == 0: 234 | term_tensor, cls_tensor = self.bert_encoder_b(input_ids_bs["input_ids_b_" + str(i)], 235 | attention_mask=attention_mask_bs[ 236 | "attention_mask_b_" + str(i)], 237 | output_attentions=False, output_hidden_states=False, 238 | training=training) 239 | vec = self.pooler_b(cls_tensor, term_tensor, attention_mask_bs["attention_mask_b_" + str(i)], 240 | training=training) 241 | if self.config.post_processing: 242 | vec = self.postprocessing(self.downscale_b, vec, is_left=False, training=training) 243 | else: 244 | term_tensor, cls_tensor = self.bert_encoder_a(input_ids_bs["input_ids_b_" + str(i)], 245 | attention_mask=attention_mask_bs[ 246 | "attention_mask_b_" + str(i)], 247 | output_attentions=False, output_hidden_states=False, 248 | training=training) 249 | vec = self.pooler_a(cls_tensor, term_tensor, attention_mask_bs["attention_mask_b_" + str(i)], 250 | training=training) 251 | if self.config.post_processing: 252 | vec = self.postprocessing(self.downscale_a, vec, is_left=True, training=training) 253 | berts_b["bert_" + str(i)] = vec 254 | 255 | if self.config.gnn_model == 'weighted': 256 | for i in range(self.config.head_nums[0]): 257 | if self.config.weighted_gnn_type[i] == 'ctr': 258 | berts_a["weights_" + str(i)] = tf.cast(inputs.get('click_a_1', None), dtype=tf.float32) / tf.cast(inputs.get('impression_a_1', None), dtype=tf.float32) 259 | berts_b["weights_" + str(i)] = tf.cast(inputs.get('click_b_1', None), dtype=tf.float32) / tf.cast(inputs.get('impression_b_1', None), dtype=tf.float32) 260 | else: 261 | berts_a["weights_" + str(i)] = tf.math.maximum(tf.math.log(tf.cast(inputs.get(self.config.weighted_gnn_type[i] + '_a_1', None), dtype=tf.float32) + 1.0 + 1e-7), 0.0) 262 | berts_b["weights_" + str(i)] = tf.math.maximum(tf.math.log(tf.cast(inputs.get(self.config.weighted_gnn_type[i] + '_b_1', None), dtype=tf.float32) + 1.0 + 1e-7), 0.0) 263 | 264 | # berts_a_input = tf.identity(berts_a["bert_0"]) 265 | # berts_b_input = tf.identity(berts_b["bert_0"]) 266 | # # print(berts_a) 267 | # print("bert_0_a_input") 268 | # print(berts_a_input) 269 | vec_a = self.encoder_a(berts_a) 270 | # print("vec_a") 271 | # print(vec_a) 272 | vec_b = self.encoder_b(berts_b) 273 | if self.config.gnn_concat_residual: 274 | vec_a = tf.concat([vec_a, berts_a["bert_0"]], axis=1) 275 | vec_b = tf.concat([vec_b, berts_b["bert_0"]], axis=1) 276 | elif self.config.gnn_add_residual: 277 | vec_a = vec_a + berts_a["bert_0"] 278 | vec_b = vec_b + berts_b["bert_0"] 279 | 280 | # print("bert_0_a") 281 | # print(berts_a["bert_0"]) 282 | # print("vec_a_res") 283 | # print(vec_a) 284 | # print(berts_a_input - berts_a["bert_0"]) 285 | output = self.crossing(vec_a, vec_b, training=training) 286 | tb_output = None 287 | if self.config.tb_loss: 288 | tb_output = self.tb_crossing(berts_a["bert_0"], berts_b["bert_0"], training=training) 289 | return output, tb_output, vec_a, vec_b 290 | 291 | 292 | class TwinBERTGNN(TFBertPreTrainedModel): 293 | def __init__(self, config_file, **kwargs): 294 | self.config = self.init_config_from_file(config_file) 295 | super(TwinBERTGNN, self).__init__(self.config, **kwargs) 296 | self.twinbertgnncore = TwinBERTGNNCore(self.config, name="twin_bert") 297 | 298 | @property 299 | def dummy_inputs(self): 300 | length = self.config.max_n_letters * self.config.max_seq_len if self.config.embedding_type == "triletter" else self.config.max_seq_len 301 | 302 | input_dict = { 303 | "input_ids_a_0": tf.ones([1, length], dtype=tf.int32), 304 | "attention_mask_a_0": tf.ones([1, self.config.max_seq_len], dtype=tf.int32), 305 | "inputs_embeds_a_0": None, 306 | "input_ids_b_0": tf.ones([1, length], dtype=tf.int32), 307 | "attention_mask_b_0": tf.ones([1, self.config.max_seq_len], dtype=tf.int32), 308 | "inputs_embeds_b_0": None, 309 | } 310 | 311 | if len(self.config.a_fanouts) > 0 and self.config.a_fanouts[0] > 0: 312 | layer_node = 1 313 | for layer in range(len(self.config.a_fanouts)): 314 | layer_node *= self.config.a_fanouts[layer] 315 | input_dict.update({ 316 | 'input_ids_a_' + str(layer + 1): tf.ones([1, layer_node, length], dtype=tf.int32), 317 | "attention_mask_a_" + str(layer + 1): tf.ones([1, layer_node, self.config.max_seq_len], 318 | dtype=tf.int32), 319 | 'impression_a_' + str(layer + 1): tf.ones([1, layer_node], dtype=tf.int32), 320 | 'click_a_' + str(layer + 1): tf.ones([1, layer_node], dtype=tf.int32) 321 | }) 322 | 323 | if len(self.config.b_fanouts) > 0 and self.config.b_fanouts[0] > 0: 324 | layer_node = 1 325 | for layer in range(len(self.config.b_fanouts)): 326 | layer_node *= self.config.b_fanouts[layer] 327 | input_dict.update({ 328 | 'input_ids_b_' + str(layer + 1): tf.ones([1, layer_node, length], dtype=tf.int32), 329 | "attention_mask_b_" + str(layer + 1): tf.ones([1, layer_node, self.config.max_seq_len], 330 | dtype=tf.int32), 331 | 'impression_b_' + str(layer + 1): tf.ones([1, layer_node], dtype=tf.int32), 332 | 'click_b_' + str(layer + 1): tf.ones([1, layer_node], dtype=tf.int32) 333 | }) 334 | 335 | return input_dict 336 | 337 | def prune_heads(self, heads_to_prune): 338 | """ Prunes heads of the model. 339 | heads_to_prune: dict of {layer_num: list of heads to prune in this layer} 340 | See base class PreTrainedModel 341 | """ 342 | raise NotImplementedError 343 | 344 | def compute_loss(self, labels, outputs, loss_type): 345 | probabilities = outputs 346 | 347 | if loss_type == 'mse': 348 | loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE) 349 | pred = probabilities[:, 1][:, tf.newaxis] 350 | labels = tf.cast(labels, tf.float32) 351 | elif loss_type == "ssce": 352 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( 353 | reduction=tf.keras.losses.Reduction.NONE 354 | ) 355 | float_labels = tf.cast(labels, tf.float32) 356 | model_label = tf.stack([1 - float_labels, float_labels], axis=-1) 357 | labels = tf.nn.softmax(model_label, -1) 358 | pred = probabilities 359 | else: 360 | if loss_type != "ce": 361 | logger.info('unknown loss type {}; fallback to ce'.format(loss_type)) 362 | loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( 363 | reduction=tf.keras.losses.Reduction.NONE 364 | ) 365 | pred = probabilities 366 | 367 | return loss_fn(labels, pred) 368 | 369 | def call(self, inputs, labels=None, training=False, **kwargs): 370 | outputs = self.twinbertgnncore(inputs, training=training, **kwargs) 371 | 372 | if labels is not None: 373 | loss = self.compute_loss(labels, outputs[0], self.config.loss_type) 374 | if self.config.tb_loss: 375 | loss += self.compute_loss(labels, outputs[1], self.config.loss_type) 376 | outputs = (loss,) + outputs 377 | 378 | return outputs 379 | 380 | @classmethod 381 | def init_config_from_file(cls, filename): 382 | ret = {} 383 | with open(filename, 'r', encoding='utf-8') as fp: 384 | while True: 385 | line = fp.readline().strip('\n\r') 386 | if line == '': 387 | break 388 | tokens = line.split('\t') 389 | name = tokens[0].split(':')[0] 390 | type = tokens[0].split(':')[1] 391 | val = tokens[1] 392 | 393 | if type == 'str': 394 | ret[name] = val 395 | elif type == 'int': 396 | ret[name] = int(val) 397 | elif type == 'float': 398 | ret[name] = float(val) 399 | elif type == 'bool': 400 | ret[name] = (val == 'True') 401 | else: 402 | print('unrecognized config: ' + line) 403 | ret = BertConfig.from_dict(ret) 404 | ret.a_fanouts = list(map(int, ret.a_fanouts.split(","))) if ret.a_fanouts else [] 405 | ret.b_fanouts = list(map(int, ret.b_fanouts.split(","))) if ret.b_fanouts else [] 406 | ret.hidden_dims = list(map(int, ret.hidden_dims.split(","))) if ret.hidden_dims else [] 407 | ret.gnn_acts = ret.gnn_acts.split(",") if ret.gnn_acts else [] 408 | ret.head_nums = list(map(int, ret.head_nums.split(","))) if ret.head_nums else [] 409 | ret.weighted_gnn_type = ret.weighted_gnn_type.split(",") if ret.weighted_gnn_type else [] 410 | return ret 411 | 412 | @classmethod 413 | def load_from_checkpoint(cls, config_file, checkpoint_file, checkpoint_dict_file, is_tf_checkpoint=True, **kwargs): 414 | def _read_checkpoint_dict(filename): 415 | ret = {} 416 | with open(filename, 'r', encoding='utf-8') as fp: 417 | while True: 418 | line = fp.readline().strip('\n\r') 419 | if line == '': 420 | break 421 | tokens = line.split('\t') 422 | model_weights_name = tokens[0] 423 | ckpt_weights_name = tokens[1] 424 | 425 | ret[model_weights_name] = ckpt_weights_name 426 | return ret 427 | 428 | model = cls(config_file, **kwargs) 429 | model(model.dummy_inputs, training=False) 430 | checkpoint_dict = _read_checkpoint_dict(checkpoint_dict_file) 431 | # print(checkpoint_dict) 432 | 433 | weight_value_tuples = [] 434 | w_names = [] 435 | if is_tf_checkpoint: 436 | tf_checkpoint_reader = tf.train.load_checkpoint(checkpoint_file) 437 | for w in model.layers[0].weights: 438 | w_name = '/'.join(w.name.split('/')[3:]) 439 | 440 | if w_name in checkpoint_dict: 441 | weight_value_tuples.append((w, tf_checkpoint_reader.get_tensor(checkpoint_dict[w_name]))) 442 | w_names.append(w_name) 443 | else: 444 | print(w_name) 445 | else: 446 | torch_checkpoint = torch.load(checkpoint_file) 447 | 448 | for w in model.layers[0].weights: 449 | if w.name not in checkpoint_dict: 450 | continue 451 | if w.name.split('/')[-1] == "kernel:0": 452 | weight_value_tuples.append((w, torch_checkpoint[checkpoint_dict[w.name]].transpose(0, 1).numpy())) 453 | else: 454 | weight_value_tuples.append((w, torch_checkpoint[checkpoint_dict[w.name]].numpy())) 455 | w_names.append(w.name) 456 | 457 | K.batch_set_value(weight_value_tuples) 458 | 459 | print("Loaded %d weights" % (len(w_names))) 460 | print("Loaded weights names are: %s" % (", ".join(w_names))) 461 | 462 | model(model.dummy_inputs, training=False) 463 | return model 464 | 465 | @classmethod 466 | def load_from_bert_pretrained(cls, config_file, pretrained_model_name='bert-base-uncased', **kwargs): 467 | model = cls(config_file, **kwargs) 468 | model(model.dummy_inputs, training=False) 469 | 470 | ckpt_layer_mapping = {} 471 | for vind, ckpt_ind in enumerate(model.config.ckpt_layer_mapping.split(',')): 472 | ckpt_layer_mapping['layer_._{}'.format(vind)] = 'layer_._{}'.format(ckpt_ind) 473 | 474 | archive_file = hf_bucket_url(pretrained_model_name, filename=TF2_WEIGHTS_NAME, use_cdn=True) 475 | resolved_archive_file = cached_path(archive_file, cache_dir=None, force_download=False, resume_download=False, 476 | proxies=None) 477 | f = h5py.File(resolved_archive_file, mode='r') 478 | 479 | layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 480 | g = f[layer_names[0]] 481 | weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 482 | weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 483 | weights_map = {'/'.join(name.split('/')[2:]): i for i, name in enumerate(weight_names)} 484 | weight_value_tuples = [] 485 | w_names = [] 486 | for w in model.layers[0].weights: 487 | w_name = '/'.join(w.name.split('/')[3:]) 488 | for k in ckpt_layer_mapping: 489 | if w_name.find(k): 490 | w_name = w_name.replace(k, ckpt_layer_mapping[k]) 491 | break 492 | 493 | if w_name in weights_map and w.shape == weight_values[weights_map[w_name]].shape: 494 | w_names.append(w_name) 495 | weight_value_tuples.append((w, weight_values[weights_map[w_name]])) 496 | 497 | logger.info("Loaded %d weights" % (len(w_names))) 498 | logger.info("Loaded weights names are: %s" % (", ".join(w_names))) 499 | 500 | K.batch_set_value(weight_value_tuples) 501 | 502 | print("Loaded %d weights" % (len(w_names))) 503 | print("Loaded weights names are: %s" % (", ".join(w_names))) 504 | 505 | model(model.dummy_inputs, training=False) 506 | return model 507 | 508 | @classmethod 509 | def from_pretrained(cls, pretrained_model_path, config_path, **kwargs): 510 | model = cls(config_path, **kwargs) 511 | model(model.dummy_inputs, training=False) # build the network with dummy inputs 512 | 513 | assert os.path.isfile(pretrained_model_path), "Error retrieving file {}".format(pretrained_model_path) 514 | # 'by_name' allow us to do transfer learning by skipping/adding layers 515 | # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 516 | try: 517 | model.load_weights(pretrained_model_path, by_name=True) 518 | except OSError: 519 | raise OSError( 520 | "Unable to load weights from h5 file. " 521 | "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " 522 | ) 523 | 524 | model(model.dummy_inputs, training=False) # Make sure restore ops are run 525 | return model --------------------------------------------------------------------------------